prediction: split up the tools into separate functions
This commit is contained in:
+102
-85
@@ -74,6 +74,12 @@ impl Into<BandcampResult> for bandcamp::Album {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ToolResults {
|
||||
result: Option<String>,
|
||||
messages: Vec<ConversationEntry>
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
@@ -98,6 +104,80 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
|
||||
ToolResults {
|
||||
messages: vec![ConversationEntry::StageDirection(args.text)],
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_computer_event(&mut self, args: StageEventArgs) -> ToolResults {
|
||||
ToolResults {
|
||||
messages: vec![ConversationEntry::ShipComputer(args.text)],
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults {
|
||||
let mut messages = vec![];
|
||||
messages.push(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
||||
let mut json_results = vec![];
|
||||
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
||||
for result in results {
|
||||
match result {
|
||||
SearchResultItem::Artist(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
},
|
||||
SearchResultItem::Album(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
}
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
||||
|
||||
ToolResults {
|
||||
result: Some(serde_json::to_string(&json_results).unwrap()),
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults {
|
||||
let mut messages = vec![];
|
||||
let mut beets_cmd = Command::new("beet");
|
||||
beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist"]);
|
||||
if let Some(artist) = args.artist {
|
||||
beets_cmd.arg(format!("artist:{}", artist));
|
||||
}
|
||||
if let Some(genre) = args.genre {
|
||||
beets_cmd.arg(format!("genre:{}", genre));
|
||||
}
|
||||
if let Some(album) = args.album {
|
||||
beets_cmd.arg(format!("album:{}", album));
|
||||
}
|
||||
if let Some(title) = args.title {
|
||||
beets_cmd.arg(format!("title:{}", title));
|
||||
}
|
||||
if let Some(year) = args.year {
|
||||
beets_cmd.arg(format!("year:{}", year));
|
||||
}
|
||||
let result = if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
||||
Some(minify::json::minify(str::from_utf8(&output.stdout).unwrap()))
|
||||
} else {
|
||||
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
||||
None
|
||||
};
|
||||
|
||||
ToolResults {
|
||||
result,
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
||||
loop {
|
||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||
@@ -176,100 +256,37 @@ impl Session {
|
||||
.build().unwrap().into();
|
||||
self.messages.push(assistant_messages);
|
||||
let mut results = vec![];
|
||||
let mut messages = vec![];
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
match call.function.name.as_str() {
|
||||
"log_stage_event" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::StageDirection(args.text));
|
||||
},
|
||||
"log_ship_computer_message" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(args.text));
|
||||
},
|
||||
"bandcamp_artifact_scan" => {
|
||||
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
||||
let mut json_results = vec![];
|
||||
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
||||
for result in results {
|
||||
match result {
|
||||
SearchResultItem::Artist(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
},
|
||||
SearchResultItem::Album(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
}
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content(serde_json::to_string(&json_results).unwrap())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
||||
},
|
||||
"archive_query" => {
|
||||
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
let mut beets_cmd = Command::new("beet");
|
||||
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
|
||||
if let Some(artist) = args.artist {
|
||||
beets_cmd.arg(format!("artist:{}", artist));
|
||||
}
|
||||
if let Some(genre) = args.genre {
|
||||
beets_cmd.arg(format!("genre:{}", genre));
|
||||
}
|
||||
if let Some(album) = args.album {
|
||||
beets_cmd.arg(format!("album:{}", album));
|
||||
}
|
||||
if let Some(title) = args.title {
|
||||
beets_cmd.arg(format!("title:{}", title));
|
||||
}
|
||||
if let Some(year) = args.year {
|
||||
beets_cmd.arg(format!("year:{}", year));
|
||||
}
|
||||
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
||||
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content(minified)
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
||||
} else {
|
||||
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content("")
|
||||
.build().unwrap()
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => panic!("Unknown function was called")
|
||||
}
|
||||
let func_name = call.function.name.as_str();
|
||||
let args = call.function.arguments.as_str();
|
||||
let tool_result = match func_name {
|
||||
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
|
||||
"log_ship_computer_message" => self.tool_computer_event(serde_json::from_str(args).unwrap()).await,
|
||||
"bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await,
|
||||
"archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await,
|
||||
_ => unreachable!()
|
||||
};
|
||||
results.push((&call.id, tool_result));
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
}
|
||||
}
|
||||
self.messages.append(&mut results);
|
||||
for msg in messages {
|
||||
|
||||
let mut tool_messages = vec![];
|
||||
for (id, mut result) in results {
|
||||
let mut msg = ChatCompletionRequestToolMessageArgs::default();
|
||||
msg.tool_call_id(id);
|
||||
if let Some(output) = result.result {
|
||||
msg.content(output);
|
||||
}
|
||||
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
||||
tool_messages.append(&mut result.messages);
|
||||
}
|
||||
for msg in tool_messages {
|
||||
self.insert_conversation(msg);
|
||||
}
|
||||
|
||||
}
|
||||
if let Some(content) = message.message.content.as_ref() {
|
||||
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
||||
|
||||
Reference in New Issue
Block a user