diff --git a/src/prediction.rs b/src/prediction.rs index dfb3e1b..3985725 100644 --- a/src/prediction.rs +++ b/src/prediction.rs @@ -74,6 +74,12 @@ impl Into for bandcamp::Album { } } +#[derive(Default, Debug)] +struct ToolResults { + result: Option, + messages: Vec +} + impl Session { fn from_initial_messages(messages: Vec) -> Self { let mut conversation = vec![]; @@ -98,6 +104,80 @@ impl Session { } } + async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults { + ToolResults { + messages: vec![ConversationEntry::StageDirection(args.text)], + ..Default::default() + } + } + + async fn tool_computer_event(&mut self, args: StageEventArgs) -> ToolResults { + ToolResults { + messages: vec![ConversationEntry::ShipComputer(args.text)], + ..Default::default() + } + } + + async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults { + let mut messages = vec![]; + messages.push(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into())); + let mut json_results = vec![]; + if let Ok(results) = bandcamp::search(args.query.as_str()).await { + for result in results { + match result { + SearchResultItem::Artist(data) => { + let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); + json_results.push(result); + }, + SearchResultItem::Album(data) => { + let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); + json_results.push(result); + } + _ => () + } + } + } + messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into())); + + ToolResults { + result: Some(serde_json::to_string(&json_results).unwrap()), + messages + } + } + + async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults { + let mut messages = vec![]; + let mut beets_cmd = Command::new("beet"); + beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist"]); + if let Some(artist) = args.artist { + beets_cmd.arg(format!("artist:{}", artist)); + } + if let Some(genre) = args.genre { + beets_cmd.arg(format!("genre:{}", genre)); + } + if let Some(album) = args.album { + beets_cmd.arg(format!("album:{}", album)); + } + if let Some(title) = args.title { + beets_cmd.arg(format!("title:{}", title)); + } + if let Some(year) = args.year { + beets_cmd.arg(format!("year:{}", year)); + } + let result = if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() { + messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd))); + Some(minify::json::minify(str::from_utf8(&output.stdout).unwrap())) + } else { + messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); + None + }; + + ToolResults { + result, + messages + } + } + async fn regenerate_options(&mut self, direction: &StageDirection) { loop { let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into(); @@ -176,100 +256,37 @@ impl Session { .build().unwrap().into(); self.messages.push(assistant_messages); let mut results = vec![]; - let mut messages = vec![]; for call in calls { match call { ChatCompletionMessageToolCalls::Function(call) => { - match call.function.name.as_str() { - "log_stage_event" => { - let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() - .tool_call_id(call.id.clone()) - .build().unwrap() - )); - messages.push(ConversationEntry::StageDirection(args.text)); - }, - "log_ship_computer_message" => { - let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() - .tool_call_id(call.id.clone()) - .build().unwrap() - )); - messages.push(ConversationEntry::ShipComputer(args.text)); - }, - "bandcamp_artifact_scan" => { - let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into())); - let mut json_results = vec![]; - if let Ok(results) = bandcamp::search(args.query.as_str()).await { - for result in results { - match result { - SearchResultItem::Artist(data) => { - let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); - json_results.push(result); - }, - SearchResultItem::Album(data) => { - let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); - json_results.push(result); - } - _ => () - } - } - } - results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() - .tool_call_id(call.id.clone()) - .content(serde_json::to_string(&json_results).unwrap()) - .build().unwrap() - )); - messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into())); - }, - "archive_query" => { - let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap(); - let mut beets_cmd = Command::new("beet"); - beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist"); - if let Some(artist) = args.artist { - beets_cmd.arg(format!("artist:{}", artist)); - } - if let Some(genre) = args.genre { - beets_cmd.arg(format!("genre:{}", genre)); - } - if let Some(album) = args.album { - beets_cmd.arg(format!("album:{}", album)); - } - if let Some(title) = args.title { - beets_cmd.arg(format!("title:{}", title)); - } - if let Some(year) = args.year { - beets_cmd.arg(format!("year:{}", year)); - } - if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() { - let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap()); - results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() - .tool_call_id(call.id.clone()) - .content(minified) - .build().unwrap() - )); - messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd))); - } else { - messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); - results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default() - .tool_call_id(call.id.clone()) - .content("") - .build().unwrap() - )); - } - } - _ => panic!("Unknown function was called") - } + let func_name = call.function.name.as_str(); + let args = call.function.arguments.as_str(); + let tool_result = match func_name { + "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, + "log_ship_computer_message" => self.tool_computer_event(serde_json::from_str(args).unwrap()).await, + "bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await, + "archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await, + _ => unreachable!() + }; + results.push((&call.id, tool_result)); }, _ => panic!("Unknown tool was called") } } - self.messages.append(&mut results); - for msg in messages { + + let mut tool_messages = vec![]; + for (id, mut result) in results { + let mut msg = ChatCompletionRequestToolMessageArgs::default(); + msg.tool_call_id(id); + if let Some(output) = result.result { + msg.content(output); + } + self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); + tool_messages.append(&mut result.messages); + } + for msg in tool_messages { self.insert_conversation(msg); } - } if let Some(content) = message.message.content.as_ref() { if let Ok(options) = serde_json::from_str(content.as_str()) {