Files
eva-pwm-cohost/src/prediction.rs
T

343 lines
17 KiB
Rust

use std::process::{Command, Stdio};
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
use bandcamp::SearchResultItem;
use chrono::{DateTime, Utc};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
pub struct PossibleResponse {
pub text: String,
pub stage_direction: Option<String>
}
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
pub struct GeneratedResponses {
pub responses: Vec<PossibleResponse>,
}
#[derive(Debug)]
struct Session {
client: Client<OpenAIConfig>,
conversation: Vec<ConversationEntry>,
header_message: ChatCompletionRequestMessage,
messages: Vec<ChatCompletionRequestMessage>,
reply_options: GeneratedResponses
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
struct StageEventArgs {
text: String,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
struct BeatsQueryArgs {
artist: Option<String>,
album: Option<String>,
genre: Option<String>,
title: Option<String>,
year: Option<u32>
}
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
struct BandcampQueryArgs {
query: String
}
#[derive(Debug, Serialize, Deserialize, Clone)]
enum BandcampResult {
Artist { name: String, bio: Option<String>, location: Option<String> },
Album { title: String, about: Option<String>, credits: Option<String>, release_date: DateTime<Utc>, artist: String }
}
impl Into<BandcampResult> for bandcamp::Artist {
fn into(self) -> BandcampResult {
BandcampResult::Artist { name: self.name, bio: self.bio, location: self.location }
}
}
impl Into<BandcampResult> for bandcamp::Album {
fn into(self) -> BandcampResult {
BandcampResult::Album {
about: self.about,
title: self.title,
artist: self.band.name,
credits: self.credits,
release_date: self.release_date
}
}
}
impl Session {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
let mut conversation = vec![];
for msg in &messages {
if let Ok(conversation_msg) = msg.clone().try_into() {
conversation.push(conversation_msg);
}
}
Self {
client: Default::default(),
conversation,
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
messages,
reply_options: Default::default()
}
}
fn insert_actions(&mut self, actions: &StageActions) {
for addition in &actions.additions {
self.insert_conversation(addition.clone());
}
}
async fn regenerate_options(&mut self, direction: &StageDirection) {
loop {
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
let mut full_conversation = vec![
self.header_message.clone(),
direction_message
];
full_conversation.append(&mut self.messages.clone());
let tools = vec![
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("log_stage_event")
.description("Inserts an event into the current scene script")
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("log_ship_computer_message")
.description("Inserts a message from the ship computer into the scene script")
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("archive_query")
.description("Queries the ship's musical artifact archives for tracks matching the given search parameters")
.parameters(schema_for!(BeatsQueryArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("bandcamp_artifact_scan")
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters")
.parameters(schema_for!(BandcampQueryArgs))
.build().unwrap()
})
];
let request = CreateChatCompletionRequestArgs::default()
.messages(full_conversation.clone())
.model("gpt-5.4")
.tools(tools)
.max_completion_tokens(1024u32)
.response_format(ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: None,
name: "responses".into(),
schema: schema_for!(GeneratedResponses).into(),
strict: None
}
})
.build().unwrap();
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
panic!("{} {:?}", err, full_conversation);
});
if let Some(message) = response.choices.first() {
match message.finish_reason {
Some(FinishReason::ContentFilter) => {
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
return;
},
Some(FinishReason::Length) => {
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
return;
},
_ => ()
}
if let Some(calls) = &message.message.tool_calls {
let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(calls.clone())
.build().unwrap().into();
self.messages.push(assistant_messages);
let mut results = vec![];
let mut messages = vec![];
for call in calls {
match call {
ChatCompletionMessageToolCalls::Function(call) => {
match call.function.name.as_str() {
"log_stage_event" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.build().unwrap()
));
messages.push(ConversationEntry::StageDirection(args.text));
},
"log_ship_computer_message" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(args.text));
},
"bandcamp_artifact_scan" => {
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
let mut json_results = vec![];
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
for result in results {
match result {
SearchResultItem::Artist(data) => {
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
json_results.push(result);
},
SearchResultItem::Album(data) => {
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
json_results.push(result);
}
_ => ()
}
}
}
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content(serde_json::to_string(&json_results).unwrap())
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
},
"archive_query" => {
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
let mut beets_cmd = Command::new("beet");
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
if let Some(artist) = args.artist {
beets_cmd.arg(format!("artist:{}", artist));
}
if let Some(genre) = args.genre {
beets_cmd.arg(format!("genre:{}", genre));
}
if let Some(album) = args.album {
beets_cmd.arg(format!("album:{}", album));
}
if let Some(title) = args.title {
beets_cmd.arg(format!("title:{}", title));
}
if let Some(year) = args.year {
beets_cmd.arg(format!("year:{}", year));
}
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content(minified)
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
} else {
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content("")
.build().unwrap()
));
}
}
_ => panic!("Unknown function was called")
}
},
_ => panic!("Unknown tool was called")
}
}
self.messages.append(&mut results);
for msg in messages {
self.insert_conversation(msg);
}
}
if let Some(content) = message.message.content.as_ref() {
if let Ok(options) = serde_json::from_str(content.as_str()) {
self.reply_options = options;
return;
} else {
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
}
}
} else {
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
}
}
}
fn as_scene(&self) -> Scene {
Scene::new(self.reply_options.clone(), self.conversation.clone())
}
fn insert_conversation(&mut self, entry: ConversationEntry) {
self.conversation.push(entry.clone());
if let Ok(next_msg) = entry.try_into() {
self.messages.push(next_msg);
}
}
}
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
let mut session = Session::from_initial_messages(initial_messages);
// Send the initial scene to the UI, after we have loaded the session from the first messages.
prediction_in.send(session.as_scene()).unwrap();
tokio::spawn(async move {
loop {
tokio::select! {
maybe_message = sys_message_src.recv() => {
if let Some(message) = maybe_message {
session.insert_conversation(ConversationEntry::SystemMessage(message));
prediction_in.send(session.as_scene()).unwrap();
}
},
maybe_request = prediction_request_out.changed() => {
if maybe_request.is_ok() {
let next_cxt = prediction_request_out.borrow().clone();
session.insert_actions(&next_cxt);
let mut save_data = SaveData {
direction: next_cxt.direction,
messages: session.messages.clone()
};
save_data.save();
session.regenerate_options(&save_data.direction).await;
save_data.messages = session.messages.clone();
save_data.save();
prediction_in.send(session.as_scene()).unwrap();
}
}
};
}
});
(prediction_request_in, prediction_out)
}