use std::process::{Command, Stdio}; use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; use bandcamp::SearchResultItem; use chrono::{DateTime, Utc}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}}; const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); #[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)] pub struct PossibleResponse { pub text: String, pub stage_direction: Option } #[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)] pub struct GeneratedResponses { pub responses: Vec, } #[derive(Debug)] struct Session { client: Client, conversation: Vec, header_message: ChatCompletionRequestMessage, messages: Vec, reply_options: GeneratedResponses } #[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)] struct StageEventArgs { text: String, } #[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)] struct BeatsQueryArgs { artist: Option, album: Option, genre: Option, title: Option, year: Option } #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] struct BandcampQueryArgs { query: String } #[derive(Debug, Serialize, Deserialize, Clone)] enum BandcampResult { Artist { name: String, bio: Option, location: Option }, Album { title: String, about: Option, credits: Option, release_date: DateTime, artist: String } } impl Into for bandcamp::Artist { fn into(self) -> BandcampResult { BandcampResult::Artist { name: self.name, bio: self.bio, location: self.location } } } impl Into for bandcamp::Album { fn into(self) -> BandcampResult { BandcampResult::Album { about: self.about, title: self.title, artist: self.band.name, credits: self.credits, release_date: self.release_date } } } impl Session { fn from_initial_messages(messages: Vec) -> Self { let mut conversation = vec![]; for msg in &messages { if let Ok(conversation_msg) = msg.clone().try_into() { conversation.push(conversation_msg); } } Self { client: Default::default(), conversation, header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(), messages, reply_options: Default::default() } } fn insert_actions(&mut self, actions: &StageActions) { for addition in &actions.additions { self.insert_conversation(addition.clone()); } } async fn regenerate_options(&mut self, direction: &StageDirection) { loop { let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into(); let mut full_conversation = vec![ self.header_message.clone(), direction_message ]; full_conversation.append(&mut self.messages.clone()); let tools = vec![ ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("log_stage_event") .description("Inserts an event into the current scene script") .parameters(schema_for!(StageEventArgs)) .build().unwrap() }), ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("log_ship_computer_message") .description("Inserts a message from the ship computer into the scene script") .parameters(schema_for!(StageEventArgs)) .build().unwrap() }), ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("archive_query") .description("Queries the ship's musical artifact archives for tracks matching the given search parameters") .parameters(schema_for!(BeatsQueryArgs)) .build().unwrap() }), ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("bandcamp_artifact_scan") .description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters") .parameters(schema_for!(BandcampQueryArgs)) .build().unwrap() }) ]; let request = CreateChatCompletionRequestArgs::default() .messages(full_conversation.clone()) .model("gpt-5.4") .tools(tools) .max_completion_tokens(1024u32) .response_format(ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: None, name: "responses".into(), schema: schema_for!(GeneratedResponses).into(), strict: None } }) .build().unwrap(); let response = self.client.chat().create(request).await.unwrap_or_else(|err| { panic!("{} {:?}", err, full_conversation); }); if let Some(message) = response.choices.first() { match message.finish_reason { Some(FinishReason::ContentFilter) => { self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into())); return; }, Some(FinishReason::Length) => { self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into())); return; }, _ => () } if let Some(calls) = &message.message.tool_calls { let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() .tool_calls(calls.clone()) .build().unwrap().into(); self.messages.push(assistant_messages); let mut results = vec![]; let mut messages = vec![]; for call in calls { match call { ChatCompletionMessageToolCalls::Function(call) => { match call.function.name.as_str() { "log_stage_event" => { let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() .tool_call_id(call.id.clone()) .build().unwrap() )); messages.push(ConversationEntry::StageDirection(args.text)); }, "log_ship_computer_message" => { let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() .tool_call_id(call.id.clone()) .build().unwrap() )); messages.push(ConversationEntry::ShipComputer(args.text)); }, "bandcamp_artifact_scan" => { let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into())); let mut json_results = vec![]; if let Ok(results) = bandcamp::search(args.query.as_str()).await { for result in results { match result { SearchResultItem::Artist(data) => { let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); json_results.push(result); }, SearchResultItem::Album(data) => { let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); json_results.push(result); } _ => () } } } results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() .tool_call_id(call.id.clone()) .content(serde_json::to_string(&json_results).unwrap()) .build().unwrap() )); messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into())); }, "archive_query" => { let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); let mut beets_cmd = Command::new("beet"); beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist"); if let Some(artist) = args.artist { beets_cmd.arg(format!("artist:{}", artist)); } if let Some(genre) = args.genre { beets_cmd.arg(format!("genre:{}", genre)); } if let Some(album) = args.album { beets_cmd.arg(format!("album:{}", album)); } if let Some(title) = args.title { beets_cmd.arg(format!("title:{}", title)); } if let Some(year) = args.year { beets_cmd.arg(format!("year:{}", year)); } if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() { let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap()); results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() .tool_call_id(call.id.clone()) .content(minified) .build().unwrap() )); messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd))); } else { messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() .tool_call_id(call.id.clone()) .content("") .build().unwrap() )); } } _ => panic!("Unknown function was called") } }, _ => panic!("Unknown tool was called") } } self.messages.append(&mut results); for msg in messages { self.insert_conversation(msg); } } if let Some(content) = message.message.content.as_ref() { if let Ok(options) = serde_json::from_str(content.as_str()) { self.reply_options = options; return; } else { self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into())); } } } else { self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into())); } } } fn as_scene(&self) -> Scene { Scene::new(self.reply_options.clone(), self.conversation.clone()) } fn insert_conversation(&mut self, entry: ConversationEntry) { self.conversation.push(entry.clone()); if let Ok(next_msg) = entry.try_into() { self.messages.push(next_msg); } } } pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver, initial_messages: Vec) -> (tokio::sync::watch::Sender, tokio::sync::watch::Receiver) { let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default()); let mut session = Session::from_initial_messages(initial_messages); // Send the initial scene to the UI, after we have loaded the session from the first messages. prediction_in.send(session.as_scene()).unwrap(); tokio::spawn(async move { loop { tokio::select! { maybe_message = sys_message_src.recv() => { if let Some(message) = maybe_message { session.insert_conversation(ConversationEntry::SystemMessage(message)); prediction_in.send(session.as_scene()).unwrap(); } }, maybe_request = prediction_request_out.changed() => { if maybe_request.is_ok() { let next_cxt = prediction_request_out.borrow().clone(); session.insert_actions(&next_cxt); let mut save_data = SaveData { direction: next_cxt.direction, messages: session.messages.clone() }; save_data.save(); session.regenerate_options(&save_data.direction).await; save_data.messages = session.messages.clone(); save_data.save(); prediction_in.send(session.as_scene()).unwrap(); } } }; } }); (prediction_request_in, prediction_out) }