prediction: split up the tools into separate functions

This commit is contained in:
2026-06-08 12:00:14 +02:00
parent 0b7fc7736a
commit aa84381d97
+102 -85
View File
@@ -74,6 +74,12 @@ impl Into<BandcampResult> for bandcamp::Album {
}
}
#[derive(Default, Debug)]
struct ToolResults {
result: Option<String>,
messages: Vec<ConversationEntry>
}
impl Session {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
let mut conversation = vec![];
@@ -98,6 +104,80 @@ impl Session {
}
}
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
ToolResults {
messages: vec![ConversationEntry::StageDirection(args.text)],
..Default::default()
}
}
async fn tool_computer_event(&mut self, args: StageEventArgs) -> ToolResults {
ToolResults {
messages: vec![ConversationEntry::ShipComputer(args.text)],
..Default::default()
}
}
async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults {
let mut messages = vec![];
messages.push(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
let mut json_results = vec![];
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
for result in results {
match result {
SearchResultItem::Artist(data) => {
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
json_results.push(result);
},
SearchResultItem::Album(data) => {
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
json_results.push(result);
}
_ => ()
}
}
}
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
ToolResults {
result: Some(serde_json::to_string(&json_results).unwrap()),
messages
}
}
async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults {
let mut messages = vec![];
let mut beets_cmd = Command::new("beet");
beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist"]);
if let Some(artist) = args.artist {
beets_cmd.arg(format!("artist:{}", artist));
}
if let Some(genre) = args.genre {
beets_cmd.arg(format!("genre:{}", genre));
}
if let Some(album) = args.album {
beets_cmd.arg(format!("album:{}", album));
}
if let Some(title) = args.title {
beets_cmd.arg(format!("title:{}", title));
}
if let Some(year) = args.year {
beets_cmd.arg(format!("year:{}", year));
}
let result = if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
Some(minify::json::minify(str::from_utf8(&output.stdout).unwrap()))
} else {
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
None
};
ToolResults {
result,
messages
}
}
async fn regenerate_options(&mut self, direction: &StageDirection) {
loop {
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
@@ -176,100 +256,37 @@ impl Session {
.build().unwrap().into();
self.messages.push(assistant_messages);
let mut results = vec![];
let mut messages = vec![];
for call in calls {
match call {
ChatCompletionMessageToolCalls::Function(call) => {
match call.function.name.as_str() {
"log_stage_event" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.build().unwrap()
));
messages.push(ConversationEntry::StageDirection(args.text));
},
"log_ship_computer_message" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(args.text));
},
"bandcamp_artifact_scan" => {
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
let mut json_results = vec![];
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
for result in results {
match result {
SearchResultItem::Artist(data) => {
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
json_results.push(result);
},
SearchResultItem::Album(data) => {
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
json_results.push(result);
}
_ => ()
}
}
}
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content(serde_json::to_string(&json_results).unwrap())
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
},
"archive_query" => {
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
let mut beets_cmd = Command::new("beet");
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
if let Some(artist) = args.artist {
beets_cmd.arg(format!("artist:{}", artist));
}
if let Some(genre) = args.genre {
beets_cmd.arg(format!("genre:{}", genre));
}
if let Some(album) = args.album {
beets_cmd.arg(format!("album:{}", album));
}
if let Some(title) = args.title {
beets_cmd.arg(format!("title:{}", title));
}
if let Some(year) = args.year {
beets_cmd.arg(format!("year:{}", year));
}
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content(minified)
.build().unwrap()
));
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
} else {
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(call.id.clone())
.content("")
.build().unwrap()
));
}
}
_ => panic!("Unknown function was called")
}
let func_name = call.function.name.as_str();
let args = call.function.arguments.as_str();
let tool_result = match func_name {
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
"log_ship_computer_message" => self.tool_computer_event(serde_json::from_str(args).unwrap()).await,
"bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await,
"archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await,
_ => unreachable!()
};
results.push((&call.id, tool_result));
},
_ => panic!("Unknown tool was called")
}
}
self.messages.append(&mut results);
for msg in messages {
let mut tool_messages = vec![];
for (id, mut result) in results {
let mut msg = ChatCompletionRequestToolMessageArgs::default();
msg.tool_call_id(id);
if let Some(output) = result.result {
msg.content(output);
}
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
tool_messages.append(&mut result.messages);
}
for msg in tool_messages {
self.insert_conversation(msg);
}
}
if let Some(content) = message.message.content.as_ref() {
if let Ok(options) = serde_json::from_str(content.as_str()) {