use std::sync::Arc; use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; use bandcamp::SearchResultItem; use chrono::{DateTime, Utc}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::{Serializer, ser::CompactFormatter}; use tokio::sync::{RwLock, mpsc, watch}; use crate::{SaveData, artifacts::{Album, Artifact, Artist, Track, beets::BeatsQueryArgs, bandcamp::BandcampQueryArgs, mixxx::MixxxDB}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}}; const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); #[derive(Debug, Clone)] pub enum PredictionAction { ConversationAppend(ConversationEntry), SetPlaylist(String), GeneratePredictions, SetNarrative(String), SetShowEndTime(DateTime) } #[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)] pub struct PossibleResponse { pub text: String, pub stage_direction: Option } #[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)] pub struct GeneratedResponses { pub responses: Vec, } #[derive(Debug)] struct Session { client: Client, conversation: Vec, header_message: ChatCompletionRequestMessage, messages: Vec, reply_options: GeneratedResponses, direction: StageDirection, scenery: Scenery, tokens_consumed: usize, activity_notify: watch::Sender, scene_sink: watch::Sender } #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] enum StageEvent { ShipComputer(String), StageDirection(String) } #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] struct StageEventArgs { event: StageEvent } #[derive(Default, Debug)] struct ToolResults { result: Option, messages: Vec } impl Session { fn new(scene_sink: watch::Sender, messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { let mut conversation = vec![]; for msg in &messages { if let Ok(conversation_msg) = msg.clone().try_into() { conversation.push(conversation_msg); } } Self { client: Default::default(), conversation, header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(), messages, reply_options: Default::default(), scenery, direction, tokens_consumed: 0, activity_notify, scene_sink } } async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults { let msg = match args.event { StageEvent::ShipComputer(text) => ConversationEntry::ShipComputer(text), StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text) }; ToolResults { messages: vec![msg], ..Default::default() } } async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults { let mut messages = vec![]; log::info!("Fetching artifacts from Bandcamp with {:?}", args); let mut json_results = vec![]; if let Ok(results) = bandcamp::search(args.query.as_str()).await { for result in results { log::debug!("Result: {:?}", result); match result { SearchResultItem::Artist(data) => { let result = Artifact::Artist(Artist { name: data.name, location: data.location, ..Default::default() }); json_results.push(result); }, SearchResultItem::Album(data) => { let result = Artifact::Album(Album { title: data.name, artist: data.band_name, ..Default::default() }); json_results.push(result); }, SearchResultItem::Track(data) => { let result = Artifact::Track(Track { title: data.name, artist: Some(data.band_name), album: data.album_name, ..Default::default() }); json_results.push(result); } _ => () } } } let artifact_count = json_results.len(); messages.push(ConversationEntry::ShipComputer(format!("Relay scan for '{}' complete. {} artifacts added to the archive.", args.query, artifact_count).into())); for track in &json_results { if let Some(merge_target) = self.scenery.artifacts.iter_mut().find(|a| { *a == track }) { merge_target.merge(track.clone()); } else { self.scenery.artifacts.push(track.clone()); } } ToolResults { result: Some(format!("{} artifacts were added to the archive.", artifact_count)), messages } } async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults { let mut messages = vec![]; messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", args))); log::info!("Executing beets query {:?}", args); if let Ok(output) = args.execute() { for track in &output { if let Some(merge_target) = self.scenery.artifacts.iter_mut().find(|a| { *a == track }) { merge_target.merge(track.clone()); } else { self.scenery.artifacts.push(track.clone()); } } } else { messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); }; ToolResults { result: None, messages } } fn generate_conversation(&self, direction: &StageDirection) -> Vec { let mut json_buf = vec![]; let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter); direction.serialize(&mut ser).unwrap(); let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default() .content(String::from_utf8(json_buf).unwrap()) .build().unwrap().into(); let mut json_buf = vec![]; let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter); self.scenery.serialize(&mut ser).unwrap(); let scenery_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default() .content(String::from_utf8(json_buf).unwrap()) .build().unwrap().into(); let mut full_conversation = vec![ self.header_message.clone(), direction_message, scenery_message, ]; full_conversation.append(&mut self.messages.clone()); full_conversation } async fn regenerate_options(&mut self) { self.reply_options.responses.clear(); self.refresh(); self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }}); loop { let full_conversation = self.generate_conversation(&self.direction); let tools = vec![ ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("log_stage_event") .description("Inserts an event into the current scene script") .parameters(schema_for!(StageEventArgs)) .build().unwrap() }), // TODO: There should only be two queries, one against the ship's onboard archive, and another against the relay network, or whatever we call it. Both should be structured with the same arguments schema // TODO: A relay search should try to grab first from beets, then musicbrainz, then from bandcamp. // TODO: A query should specify what parts of metadata are sufficient for the result, so we don't always have to hit all the layers of data. beets can of course, ignore this. // TODO: A query should be hierarchical somehow? eg, "I already know about artist X, but I want to know everything about track Y from album Z" or "I don't know anything about artist X/album Y, please give me an overview" ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("archive_query") .description("Queries the ship's musical artifact archives for tracks matching the given search parameters") .parameters(schema_for!(BeatsQueryArgs)) .build().unwrap() }), ChatCompletionTools::Function(ChatCompletionTool { function: FunctionObjectArgs::default() .name("bandcamp_artifact_scan") .description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters. To find an artist, provide only the artist name. To find an album, provide the artist and the album.") .parameters(schema_for!(BandcampQueryArgs)) .build().unwrap() }) ]; log::info!("Sending request.."); let request = CreateChatCompletionRequestArgs::default() .messages(full_conversation) .model("gpt-5.4-mini") .tools(tools) .max_completion_tokens(1024u32) .response_format(ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: None, name: "responses".into(), schema: schema_for!(GeneratedResponses).into(), strict: None } }) .build().unwrap(); let response = self.client.chat().create(request).await.unwrap_or_else(|err| { panic!("OpenAI Panic: {}", err); }); if let Some(usage) = response.usage { self.tokens_consumed += usage.total_tokens as usize; log::info!("{} tokens cast into the void", usage.total_tokens); } if let Some(message) = response.choices.first() { match message.finish_reason { Some(FinishReason::ContentFilter) => { log::error!("Content filter triggered."); break; }, Some(FinishReason::Length) => { log::error!("Maximum token count exceeded!"); break; }, _ => () } if let Some(calls) = &message.message.tool_calls { let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() .tool_calls(calls.clone()) .build().unwrap().into(); self.messages.push(assistant_messages); let mut results = vec![]; for call in calls { match call { ChatCompletionMessageToolCalls::Function(call) => { let func_name = call.function.name.as_str(); let args = call.function.arguments.as_str(); let tool_result = match func_name { "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, "bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await, "archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await, _ => unreachable!() }; results.push((&call.id, tool_result)); }, _ => panic!("Unknown tool was called") } } let mut tool_messages = vec![]; for (id, mut result) in results { let mut msg = ChatCompletionRequestToolMessageArgs::default(); msg.tool_call_id(id); if let Some(output) = result.result { msg.content(output); } self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); tool_messages.append(&mut result.messages); } for msg in tool_messages { self.insert_conversation(msg); } } if let Some(content) = message.message.content.as_ref() { if let Ok(options) = serde_json::from_str(content.as_str()) { self.reply_options = options; break; } else { log::info!("Received invalid JSON! Trying again."); } } } else { log::info!("No messages were received! Trying again."); } self.refresh(); } self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }}); self.refresh(); } fn as_scene(&self) -> Scene { Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) } fn insert_conversation(&mut self, entry: ConversationEntry) { self.conversation.push(entry.clone()); if let Ok(next_msg) = entry.try_into() { self.messages.push(next_msg); } self.refresh(); } fn refresh(&self) { self.scene_sink.send(self.as_scene()).unwrap(); } } #[derive(Debug)] pub struct SessionControl { event_sink: mpsc::Sender, scene_watch: watch::Receiver, activity_watch: watch::Receiver } #[derive(Debug)] pub enum SessionUpdate { Scene(Scene), Thinking(bool) } impl SessionControl { pub async fn insert(&self, action: PredictionAction) { self.event_sink.send(action).await.unwrap(); } pub async fn regenerate_options(&self) { self.insert(PredictionAction::GeneratePredictions).await; } pub async fn changed(&mut self) -> SessionUpdate { tokio::select! { _ = self.activity_watch.changed() => { SessionUpdate::Thinking(*self.activity_watch.borrow_and_update()) }, _ = self.scene_watch.changed() => { SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone()) } } } } pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver) -> SessionControl { let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false); let (action_sink, mut action_src) = mpsc::channel(5); let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); // Send the initial scene to the UI, after we have loaded the session from the first messages. session.refresh(); let shared_session = Arc::new(RwLock::new(session)); let log_session = Arc::clone(&shared_session); tokio::spawn(async move { loop { if let Some(msg) = messages.recv().await { log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg)); } } }); tokio::spawn(async move { loop { if let Some(evt) = action_src.recv().await { let mut session = shared_session.write().await; let do_regen = match evt { PredictionAction::ConversationAppend(msg) => { let do_regen = match msg { ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true, _ => false }; session.insert_conversation(msg); do_regen }, PredictionAction::SetPlaylist(playlist_name) => { match MixxxDB::load(&playlist_name) { Err(err) => log::info!("Failed to load mixxx playlist: {:?}.", err), Ok(playlist) => { session.scenery.artifacts.merge(playlist.clone()); session.scenery.current_playlist = playlist; session.direction.playlist = playlist_name; log::info!("Mixxx playlist reloaded."); } } false }, PredictionAction::GeneratePredictions => { true }, PredictionAction::SetNarrative(narrative) => { session.direction.narrative = narrative; log::info!("Updated stage direction narrative"); true }, PredictionAction::SetShowEndTime(end_time) => { session.direction.end_time = end_time; false } }; let save_data = SaveData { direction: session.direction.clone(), messages: session.messages.clone(), scenery: session.scenery.clone() }; save_data.save(); if do_regen { drop(session); shared_session.write().await.regenerate_options().await; } } } }); SessionControl { event_sink: action_sink, scene_watch: prediction_out, activity_watch: activity_notify_src } }