prediction: split up the tools into separate functions
This commit is contained in:
+102
-85
@@ -74,6 +74,12 @@ impl Into<BandcampResult> for bandcamp::Album {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
struct ToolResults {
|
||||||
|
result: Option<String>,
|
||||||
|
messages: Vec<ConversationEntry>
|
||||||
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
||||||
let mut conversation = vec![];
|
let mut conversation = vec![];
|
||||||
@@ -98,6 +104,80 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
|
||||||
|
ToolResults {
|
||||||
|
messages: vec![ConversationEntry::StageDirection(args.text)],
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tool_computer_event(&mut self, args: StageEventArgs) -> ToolResults {
|
||||||
|
ToolResults {
|
||||||
|
messages: vec![ConversationEntry::ShipComputer(args.text)],
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults {
|
||||||
|
let mut messages = vec![];
|
||||||
|
messages.push(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
||||||
|
let mut json_results = vec![];
|
||||||
|
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
||||||
|
for result in results {
|
||||||
|
match result {
|
||||||
|
SearchResultItem::Artist(data) => {
|
||||||
|
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
||||||
|
json_results.push(result);
|
||||||
|
},
|
||||||
|
SearchResultItem::Album(data) => {
|
||||||
|
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
||||||
|
json_results.push(result);
|
||||||
|
}
|
||||||
|
_ => ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
||||||
|
|
||||||
|
ToolResults {
|
||||||
|
result: Some(serde_json::to_string(&json_results).unwrap()),
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults {
|
||||||
|
let mut messages = vec![];
|
||||||
|
let mut beets_cmd = Command::new("beet");
|
||||||
|
beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist"]);
|
||||||
|
if let Some(artist) = args.artist {
|
||||||
|
beets_cmd.arg(format!("artist:{}", artist));
|
||||||
|
}
|
||||||
|
if let Some(genre) = args.genre {
|
||||||
|
beets_cmd.arg(format!("genre:{}", genre));
|
||||||
|
}
|
||||||
|
if let Some(album) = args.album {
|
||||||
|
beets_cmd.arg(format!("album:{}", album));
|
||||||
|
}
|
||||||
|
if let Some(title) = args.title {
|
||||||
|
beets_cmd.arg(format!("title:{}", title));
|
||||||
|
}
|
||||||
|
if let Some(year) = args.year {
|
||||||
|
beets_cmd.arg(format!("year:{}", year));
|
||||||
|
}
|
||||||
|
let result = if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
||||||
|
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
||||||
|
Some(minify::json::minify(str::from_utf8(&output.stdout).unwrap()))
|
||||||
|
} else {
|
||||||
|
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
ToolResults {
|
||||||
|
result,
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
||||||
loop {
|
loop {
|
||||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||||
@@ -176,100 +256,37 @@ impl Session {
|
|||||||
.build().unwrap().into();
|
.build().unwrap().into();
|
||||||
self.messages.push(assistant_messages);
|
self.messages.push(assistant_messages);
|
||||||
let mut results = vec![];
|
let mut results = vec![];
|
||||||
let mut messages = vec![];
|
|
||||||
for call in calls {
|
for call in calls {
|
||||||
match call {
|
match call {
|
||||||
ChatCompletionMessageToolCalls::Function(call) => {
|
ChatCompletionMessageToolCalls::Function(call) => {
|
||||||
match call.function.name.as_str() {
|
let func_name = call.function.name.as_str();
|
||||||
"log_stage_event" => {
|
let args = call.function.arguments.as_str();
|
||||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
let tool_result = match func_name {
|
||||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
|
||||||
.tool_call_id(call.id.clone())
|
"log_ship_computer_message" => self.tool_computer_event(serde_json::from_str(args).unwrap()).await,
|
||||||
.build().unwrap()
|
"bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await,
|
||||||
));
|
"archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await,
|
||||||
messages.push(ConversationEntry::StageDirection(args.text));
|
_ => unreachable!()
|
||||||
},
|
};
|
||||||
"log_ship_computer_message" => {
|
results.push((&call.id, tool_result));
|
||||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
||||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
||||||
.tool_call_id(call.id.clone())
|
|
||||||
.build().unwrap()
|
|
||||||
));
|
|
||||||
messages.push(ConversationEntry::ShipComputer(args.text));
|
|
||||||
},
|
|
||||||
"bandcamp_artifact_scan" => {
|
|
||||||
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
||||||
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
|
||||||
let mut json_results = vec![];
|
|
||||||
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
|
||||||
for result in results {
|
|
||||||
match result {
|
|
||||||
SearchResultItem::Artist(data) => {
|
|
||||||
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
|
||||||
json_results.push(result);
|
|
||||||
},
|
|
||||||
SearchResultItem::Album(data) => {
|
|
||||||
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
|
||||||
json_results.push(result);
|
|
||||||
}
|
|
||||||
_ => ()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
||||||
.tool_call_id(call.id.clone())
|
|
||||||
.content(serde_json::to_string(&json_results).unwrap())
|
|
||||||
.build().unwrap()
|
|
||||||
));
|
|
||||||
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
|
||||||
},
|
|
||||||
"archive_query" => {
|
|
||||||
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
||||||
let mut beets_cmd = Command::new("beet");
|
|
||||||
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
|
|
||||||
if let Some(artist) = args.artist {
|
|
||||||
beets_cmd.arg(format!("artist:{}", artist));
|
|
||||||
}
|
|
||||||
if let Some(genre) = args.genre {
|
|
||||||
beets_cmd.arg(format!("genre:{}", genre));
|
|
||||||
}
|
|
||||||
if let Some(album) = args.album {
|
|
||||||
beets_cmd.arg(format!("album:{}", album));
|
|
||||||
}
|
|
||||||
if let Some(title) = args.title {
|
|
||||||
beets_cmd.arg(format!("title:{}", title));
|
|
||||||
}
|
|
||||||
if let Some(year) = args.year {
|
|
||||||
beets_cmd.arg(format!("year:{}", year));
|
|
||||||
}
|
|
||||||
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
|
||||||
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
|
|
||||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
||||||
.tool_call_id(call.id.clone())
|
|
||||||
.content(minified)
|
|
||||||
.build().unwrap()
|
|
||||||
));
|
|
||||||
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
|
||||||
} else {
|
|
||||||
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
|
||||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
||||||
.tool_call_id(call.id.clone())
|
|
||||||
.content("")
|
|
||||||
.build().unwrap()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => panic!("Unknown function was called")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
_ => panic!("Unknown tool was called")
|
_ => panic!("Unknown tool was called")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.messages.append(&mut results);
|
|
||||||
for msg in messages {
|
let mut tool_messages = vec![];
|
||||||
|
for (id, mut result) in results {
|
||||||
|
let mut msg = ChatCompletionRequestToolMessageArgs::default();
|
||||||
|
msg.tool_call_id(id);
|
||||||
|
if let Some(output) = result.result {
|
||||||
|
msg.content(output);
|
||||||
|
}
|
||||||
|
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
||||||
|
tool_messages.append(&mut result.messages);
|
||||||
|
}
|
||||||
|
for msg in tool_messages {
|
||||||
self.insert_conversation(msg);
|
self.insert_conversation(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if let Some(content) = message.message.content.as_ref() {
|
if let Some(content) = message.message.content.as_ref() {
|
||||||
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
||||||
|
|||||||
Reference in New Issue
Block a user