prediction: rewrite the messaging to use a loop for self-executing chains, add bandcamp and beets tools
This commit is contained in:
+222
-67
@@ -1,4 +1,8 @@
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use std::process::{Command, Stdio};
|
||||
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use bandcamp::SearchResultItem;
|
||||
use chrono::{DateTime, Utc};
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -32,6 +36,44 @@ struct StageEventArgs {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
struct BeatsQueryArgs {
|
||||
artist: Option<String>,
|
||||
album: Option<String>,
|
||||
genre: Option<String>,
|
||||
title: Option<String>,
|
||||
year: Option<u32>
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
struct BandcampQueryArgs {
|
||||
query: String
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
enum BandcampResult {
|
||||
Artist { name: String, bio: Option<String>, location: Option<String> },
|
||||
Album { title: String, about: Option<String>, credits: Option<String>, release_date: DateTime<Utc>, artist: String }
|
||||
}
|
||||
|
||||
impl Into<BandcampResult> for bandcamp::Artist {
|
||||
fn into(self) -> BandcampResult {
|
||||
BandcampResult::Artist { name: self.name, bio: self.bio, location: self.location }
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<BandcampResult> for bandcamp::Album {
|
||||
fn into(self) -> BandcampResult {
|
||||
BandcampResult::Album {
|
||||
about: self.about,
|
||||
title: self.title,
|
||||
artist: self.band.name,
|
||||
credits: self.credits,
|
||||
release_date: self.release_date
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
@@ -56,77 +98,190 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) -> Option<Scene> {
|
||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||
let mut full_conversation = vec![
|
||||
self.header_message.clone(),
|
||||
direction_message
|
||||
];
|
||||
full_conversation.append(&mut self.messages.clone());
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
||||
loop {
|
||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||
let mut full_conversation = vec![
|
||||
self.header_message.clone(),
|
||||
direction_message
|
||||
];
|
||||
full_conversation.append(&mut self.messages.clone());
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_stage_event")
|
||||
.description("Inserts an event into the current scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_ship_computer_message")
|
||||
.description("Inserts a message from the ship computer into the scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("archive_query")
|
||||
.description("Queries the ship's musical artifact archives for tracks matching the given search parameters")
|
||||
.parameters(schema_for!(BeatsQueryArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("bandcamp_artifact_scan")
|
||||
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters")
|
||||
.parameters(schema_for!(BandcampQueryArgs))
|
||||
.build().unwrap()
|
||||
})
|
||||
];
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.messages(full_conversation.clone())
|
||||
.model("gpt-5.4")
|
||||
.tools(tools)
|
||||
.max_completion_tokens(1024u32)
|
||||
.response_format(ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: None,
|
||||
name: "responses".into(),
|
||||
schema: schema_for!(GeneratedResponses).into(),
|
||||
strict: None
|
||||
}
|
||||
})
|
||||
.build().unwrap();
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_stage_event")
|
||||
.description("Inserts an event into the current scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_ship_computer_message")
|
||||
.description("Inserts a message from the ship computer into the scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
})
|
||||
];
|
||||
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
|
||||
panic!("{} {:?}", err, full_conversation);
|
||||
});
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.messages(full_conversation)
|
||||
.model("gpt-5.4")
|
||||
.tools(tools)
|
||||
.max_completion_tokens(350u32)
|
||||
.response_format(ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: None,
|
||||
name: "responses".into(),
|
||||
schema: schema_for!(GeneratedResponses).into(),
|
||||
strict: None
|
||||
if let Some(message) = response.choices.first() {
|
||||
|
||||
match message.finish_reason {
|
||||
Some(FinishReason::ContentFilter) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
|
||||
return;
|
||||
},
|
||||
Some(FinishReason::Length) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
|
||||
return;
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
})
|
||||
.build().unwrap();
|
||||
|
||||
let response = self.client.chat().create(request).await.unwrap();
|
||||
if let Some(calls) = &message.message.tool_calls {
|
||||
let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(calls.clone())
|
||||
.build().unwrap().into();
|
||||
self.messages.push(assistant_messages);
|
||||
let mut results = vec![];
|
||||
let mut messages = vec![];
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
match call.function.name.as_str() {
|
||||
"log_stage_event" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::StageDirection(args.text));
|
||||
},
|
||||
"log_ship_computer_message" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(args.text));
|
||||
},
|
||||
"bandcamp_artifact_scan" => {
|
||||
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
||||
let mut json_results = vec![];
|
||||
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
||||
for result in results {
|
||||
match result {
|
||||
SearchResultItem::Artist(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
},
|
||||
SearchResultItem::Album(data) => {
|
||||
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
||||
json_results.push(result);
|
||||
}
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content(serde_json::to_string(&json_results).unwrap())
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
||||
},
|
||||
"archive_query" => {
|
||||
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
let mut beets_cmd = Command::new("beet");
|
||||
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
|
||||
if let Some(artist) = args.artist {
|
||||
beets_cmd.arg(format!("artist:{}", artist));
|
||||
}
|
||||
if let Some(genre) = args.genre {
|
||||
beets_cmd.arg(format!("genre:{}", genre));
|
||||
}
|
||||
if let Some(album) = args.album {
|
||||
beets_cmd.arg(format!("album:{}", album));
|
||||
}
|
||||
if let Some(title) = args.title {
|
||||
beets_cmd.arg(format!("title:{}", title));
|
||||
}
|
||||
if let Some(year) = args.year {
|
||||
beets_cmd.arg(format!("year:{}", year));
|
||||
}
|
||||
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
||||
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content(minified)
|
||||
.build().unwrap()
|
||||
));
|
||||
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
||||
} else {
|
||||
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
||||
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
||||
.tool_call_id(call.id.clone())
|
||||
.content("")
|
||||
.build().unwrap()
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => panic!("Unknown function was called")
|
||||
}
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
}
|
||||
}
|
||||
self.messages.append(&mut results);
|
||||
for msg in messages {
|
||||
self.insert_conversation(msg);
|
||||
}
|
||||
|
||||
if let Some(message) = response.choices.first() {
|
||||
if let Some(calls) = &message.message.tool_calls {
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
match call.function.name.as_str() {
|
||||
"log_stage_event" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::StageDirection(args.text));
|
||||
},
|
||||
"log_ship_computer_message" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::ShipComputer(args.text));
|
||||
},
|
||||
_ => panic!("Unknown function was called")
|
||||
}
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
}
|
||||
if let Some(content) = message.message.content.as_ref() {
|
||||
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
||||
self.reply_options = options;
|
||||
return;
|
||||
} else {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
|
||||
}
|
||||
}
|
||||
Some(self.as_scene())
|
||||
} else {
|
||||
self.reply_options = serde_json::from_str(message.message.content.as_ref().unwrap().as_str()).unwrap();
|
||||
Some(self.as_scene())
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
|
||||
}
|
||||
} else {
|
||||
//FIXME: Handle tool calls
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,11 +328,11 @@ pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<S
|
||||
|
||||
save_data.save();
|
||||
|
||||
if let Some(next_scene) = session.regenerate_options(&save_data.direction).await {
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(next_scene).unwrap();
|
||||
}
|
||||
session.regenerate_options(&save_data.direction).await;
|
||||
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user