From 51db51b63613aee041fbbda5593754fe51da111e Mon Sep 17 00:00:00 2001 From: Victoria Fischer Date: Fri, 5 Jun 2026 10:04:19 +0200 Subject: [PATCH] prediction: rewrite the messaging to use a loop for self-executing chains, add bandcamp and beets tools --- Cargo.lock | 60 +++++++++ Cargo.toml | 2 + src/prediction.rs | 289 ++++++++++++++++++++++++++++++++---------- src/system-prompt.txt | 7 + 4 files changed, 291 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f15fe0..b7e05bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "bandcamp" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc8f990cfc739590270d0413b821f09b28505d442204fd0f3fd99196df394d52" +dependencies = [ + "chrono", + "lazy_static", + "regex", + "reqwest 0.12.28", + "serde", + "serde_json", + "snafu", + "url", +] + [[package]] name = "base64" version = "0.22.1" @@ -938,6 +954,7 @@ name = "eva_cohost" version = "0.1.0" dependencies = [ "async-openai", + "bandcamp", "chrono", "color-eyre", "crossterm", @@ -947,6 +964,7 @@ dependencies = [ "iref 4.0.0", "jack", "json-ld", + "minify", "oximedia-metering", "ratatui", "rc-writer", @@ -1433,6 +1451,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", + "webpki-roots", ] [[package]] @@ -2313,6 +2332,12 @@ dependencies = [ "unicase", ] +[[package]] +name = "minify" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93bacfc6ce0cf3e41da4d9415904090e1f7ca8d105c1396907f78d8fee42635" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3308,6 +3333,8 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", "serde", "serde_json", @@ -3315,6 +3342,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls", "tower", "tower-http", "tower-service", @@ -3322,6 +3350,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", ] [[package]] @@ -3476,6 +3505,7 @@ checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -3881,6 +3911,27 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "snafu" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "socket2" version = "0.6.3" @@ -4889,6 +4940,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "wezterm-bidi" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index b8170e7..8b0ebe3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] async-openai = { version = "0.40.2", features = ["completions", "full"] } +bandcamp = "0.3.4" chrono = { version = "0.4.44", features = ["serde"] } color-eyre = "0.6.5" crossterm = { version = "0.29.0", features = ["event-stream"] } @@ -14,6 +15,7 @@ hound = "3.5.1" iref = { version = "4.0.0", features = ["url", "serde"] } jack = "0.13.5" json-ld = { version = "0.21.4", features = ["reqwest", "serde"] } +minify = "1.3.0" oximedia-metering = "0.1.7" ratatui = "0.30.0" rc-writer = "1.1.10" diff --git a/src/prediction.rs b/src/prediction.rs index b84950f..9f3a2b3 100644 --- a/src/prediction.rs +++ b/src/prediction.rs @@ -1,4 +1,8 @@ -use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; +use std::process::{Command, Stdio}; + +use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; +use bandcamp::SearchResultItem; +use chrono::{DateTime, Utc}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; @@ -32,6 +36,44 @@ struct StageEventArgs { text: String, } +#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)] +struct BeatsQueryArgs { + artist: Option, + album: Option, + genre: Option, + title: Option, + year: Option +} + +#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] +struct BandcampQueryArgs { + query: String +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +enum BandcampResult { + Artist { name: String, bio: Option, location: Option }, + Album { title: String, about: Option, credits: Option, release_date: DateTime, artist: String } +} + +impl Into for bandcamp::Artist { + fn into(self) -> BandcampResult { + BandcampResult::Artist { name: self.name, bio: self.bio, location: self.location } + } +} + +impl Into for bandcamp::Album { + fn into(self) -> BandcampResult { + BandcampResult::Album { + about: self.about, + title: self.title, + artist: self.band.name, + credits: self.credits, + release_date: self.release_date + } + } +} + impl Session { fn from_initial_messages(messages: Vec) -> Self { let mut conversation = vec![]; @@ -56,77 +98,190 @@ impl Session { } } - async fn regenerate_options(&mut self, direction: &StageDirection) -> Option { - let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into(); - let mut full_conversation = vec![ - self.header_message.clone(), - direction_message - ]; - full_conversation.append(&mut self.messages.clone()); + async fn regenerate_options(&mut self, direction: &StageDirection) { + loop { + let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into(); + let mut full_conversation = vec![ + self.header_message.clone(), + direction_message + ]; + full_conversation.append(&mut self.messages.clone()); + let tools = vec![ + ChatCompletionTools::Function(ChatCompletionTool { + function: FunctionObjectArgs::default() + .name("log_stage_event") + .description("Inserts an event into the current scene script") + .parameters(schema_for!(StageEventArgs)) + .build().unwrap() + }), + ChatCompletionTools::Function(ChatCompletionTool { + function: FunctionObjectArgs::default() + .name("log_ship_computer_message") + .description("Inserts a message from the ship computer into the scene script") + .parameters(schema_for!(StageEventArgs)) + .build().unwrap() + }), + ChatCompletionTools::Function(ChatCompletionTool { + function: FunctionObjectArgs::default() + .name("archive_query") + .description("Queries the ship's musical artifact archives for tracks matching the given search parameters") + .parameters(schema_for!(BeatsQueryArgs)) + .build().unwrap() + }), + ChatCompletionTools::Function(ChatCompletionTool { + function: FunctionObjectArgs::default() + .name("bandcamp_artifact_scan") + .description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters") + .parameters(schema_for!(BandcampQueryArgs)) + .build().unwrap() + }) + ]; + let request = CreateChatCompletionRequestArgs::default() + .messages(full_conversation.clone()) + .model("gpt-5.4") + .tools(tools) + .max_completion_tokens(1024u32) + .response_format(ResponseFormat::JsonSchema { + json_schema: ResponseFormatJsonSchema { + description: None, + name: "responses".into(), + schema: schema_for!(GeneratedResponses).into(), + strict: None + } + }) + .build().unwrap(); - let tools = vec![ - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("log_stage_event") - .description("Inserts an event into the current scene script") - .parameters(schema_for!(StageEventArgs)) - .build().unwrap() - }), - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("log_ship_computer_message") - .description("Inserts a message from the ship computer into the scene script") - .parameters(schema_for!(StageEventArgs)) - .build().unwrap() - }) - ]; + let response = self.client.chat().create(request).await.unwrap_or_else(|err| { + panic!("{} {:?}", err, full_conversation); + }); - let request = CreateChatCompletionRequestArgs::default() - .messages(full_conversation) - .model("gpt-5.4") - .tools(tools) - .max_completion_tokens(350u32) - .response_format(ResponseFormat::JsonSchema { - json_schema: ResponseFormatJsonSchema { - description: None, - name: "responses".into(), - schema: schema_for!(GeneratedResponses).into(), - strict: None + if let Some(message) = response.choices.first() { + + match message.finish_reason { + Some(FinishReason::ContentFilter) => { + self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into())); + return; + }, + Some(FinishReason::Length) => { + self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into())); + return; + }, + _ => () } - }) - .build().unwrap(); - let response = self.client.chat().create(request).await.unwrap(); + if let Some(calls) = &message.message.tool_calls { + let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(calls.clone()) + .build().unwrap().into(); + self.messages.push(assistant_messages); + let mut results = vec![]; + let mut messages = vec![]; + for call in calls { + match call { + ChatCompletionMessageToolCalls::Function(call) => { + match call.function.name.as_str() { + "log_stage_event" => { + let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); + results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(call.id.clone()) + .build().unwrap() + )); + messages.push(ConversationEntry::StageDirection(args.text)); + }, + "log_ship_computer_message" => { + let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); + results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(call.id.clone()) + .build().unwrap() + )); + messages.push(ConversationEntry::ShipComputer(args.text)); + }, + "bandcamp_artifact_scan" => { + let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); + self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into())); + let mut json_results = vec![]; + if let Ok(results) = bandcamp::search(args.query.as_str()).await { + for result in results { + match result { + SearchResultItem::Artist(data) => { + let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); + json_results.push(result); + }, + SearchResultItem::Album(data) => { + let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); + json_results.push(result); + } + _ => () + } + } + } + results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(call.id.clone()) + .content(serde_json::to_string(&json_results).unwrap()) + .build().unwrap() + )); + messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into())); + }, + "archive_query" => { + let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); + let mut beets_cmd = Command::new("beet"); + beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist"); + if let Some(artist) = args.artist { + beets_cmd.arg(format!("artist:{}", artist)); + } + if let Some(genre) = args.genre { + beets_cmd.arg(format!("genre:{}", genre)); + } + if let Some(album) = args.album { + beets_cmd.arg(format!("album:{}", album)); + } + if let Some(title) = args.title { + beets_cmd.arg(format!("title:{}", title)); + } + if let Some(year) = args.year { + beets_cmd.arg(format!("year:{}", year)); + } + if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() { + let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap()); + results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(call.id.clone()) + .content(minified) + .build().unwrap() + )); + messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd))); + } else { + messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); + results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(call.id.clone()) + .content("") + .build().unwrap() + )); + } + } + _ => panic!("Unknown function was called") + } + }, + _ => panic!("Unknown tool was called") + } + } + self.messages.append(&mut results); + for msg in messages { + self.insert_conversation(msg); + } - if let Some(message) = response.choices.first() { - if let Some(calls) = &message.message.tool_calls { - for call in calls { - match call { - ChatCompletionMessageToolCalls::Function(call) => { - match call.function.name.as_str() { - "log_stage_event" => { - let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - self.insert_conversation(ConversationEntry::StageDirection(args.text)); - }, - "log_ship_computer_message" => { - let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - self.insert_conversation(ConversationEntry::ShipComputer(args.text)); - }, - _ => panic!("Unknown function was called") - } - }, - _ => panic!("Unknown tool was called") + } + if let Some(content) = message.message.content.as_ref() { + if let Ok(options) = serde_json::from_str(content.as_str()) { + self.reply_options = options; + return; + } else { + self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into())); } } - Some(self.as_scene()) } else { - self.reply_options = serde_json::from_str(message.message.content.as_ref().unwrap().as_str()).unwrap(); - Some(self.as_scene()) + self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into())); } - } else { - //FIXME: Handle tool calls - panic!(); } } @@ -173,11 +328,11 @@ pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver