343 lines
17 KiB
Rust
343 lines
17 KiB
Rust
use std::process::{Command, Stdio};
|
|
|
|
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
|
use bandcamp::SearchResultItem;
|
|
use chrono::{DateTime, Utc};
|
|
use schemars::{JsonSchema, schema_for};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
|
|
|
|
|
|
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
|
|
|
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
|
pub struct PossibleResponse {
|
|
pub text: String,
|
|
pub stage_direction: Option<String>
|
|
}
|
|
|
|
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
|
|
pub struct GeneratedResponses {
|
|
pub responses: Vec<PossibleResponse>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Session {
|
|
client: Client<OpenAIConfig>,
|
|
conversation: Vec<ConversationEntry>,
|
|
header_message: ChatCompletionRequestMessage,
|
|
messages: Vec<ChatCompletionRequestMessage>,
|
|
reply_options: GeneratedResponses
|
|
}
|
|
|
|
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct StageEventArgs {
|
|
text: String,
|
|
}
|
|
|
|
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct BeatsQueryArgs {
|
|
artist: Option<String>,
|
|
album: Option<String>,
|
|
genre: Option<String>,
|
|
title: Option<String>,
|
|
year: Option<u32>
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct BandcampQueryArgs {
|
|
query: String
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
enum BandcampResult {
|
|
Artist { name: String, bio: Option<String>, location: Option<String> },
|
|
Album { title: String, about: Option<String>, credits: Option<String>, release_date: DateTime<Utc>, artist: String }
|
|
}
|
|
|
|
impl Into<BandcampResult> for bandcamp::Artist {
|
|
fn into(self) -> BandcampResult {
|
|
BandcampResult::Artist { name: self.name, bio: self.bio, location: self.location }
|
|
}
|
|
}
|
|
|
|
impl Into<BandcampResult> for bandcamp::Album {
|
|
fn into(self) -> BandcampResult {
|
|
BandcampResult::Album {
|
|
about: self.about,
|
|
title: self.title,
|
|
artist: self.band.name,
|
|
credits: self.credits,
|
|
release_date: self.release_date
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Session {
|
|
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
|
let mut conversation = vec![];
|
|
for msg in &messages {
|
|
if let Ok(conversation_msg) = msg.clone().try_into() {
|
|
conversation.push(conversation_msg);
|
|
}
|
|
}
|
|
|
|
Self {
|
|
client: Default::default(),
|
|
conversation,
|
|
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
|
|
messages,
|
|
reply_options: Default::default()
|
|
}
|
|
}
|
|
|
|
fn insert_actions(&mut self, actions: &StageActions) {
|
|
for addition in &actions.additions {
|
|
self.insert_conversation(addition.clone());
|
|
}
|
|
}
|
|
|
|
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
|
loop {
|
|
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
|
let mut full_conversation = vec![
|
|
self.header_message.clone(),
|
|
direction_message
|
|
];
|
|
full_conversation.append(&mut self.messages.clone());
|
|
|
|
let tools = vec![
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("log_stage_event")
|
|
.description("Inserts an event into the current scene script")
|
|
.parameters(schema_for!(StageEventArgs))
|
|
.build().unwrap()
|
|
}),
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("log_ship_computer_message")
|
|
.description("Inserts a message from the ship computer into the scene script")
|
|
.parameters(schema_for!(StageEventArgs))
|
|
.build().unwrap()
|
|
}),
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("archive_query")
|
|
.description("Queries the ship's musical artifact archives for tracks matching the given search parameters")
|
|
.parameters(schema_for!(BeatsQueryArgs))
|
|
.build().unwrap()
|
|
}),
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("bandcamp_artifact_scan")
|
|
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters")
|
|
.parameters(schema_for!(BandcampQueryArgs))
|
|
.build().unwrap()
|
|
})
|
|
];
|
|
let request = CreateChatCompletionRequestArgs::default()
|
|
.messages(full_conversation.clone())
|
|
.model("gpt-5.4")
|
|
.tools(tools)
|
|
.max_completion_tokens(1024u32)
|
|
.response_format(ResponseFormat::JsonSchema {
|
|
json_schema: ResponseFormatJsonSchema {
|
|
description: None,
|
|
name: "responses".into(),
|
|
schema: schema_for!(GeneratedResponses).into(),
|
|
strict: None
|
|
}
|
|
})
|
|
.build().unwrap();
|
|
|
|
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
|
|
panic!("{} {:?}", err, full_conversation);
|
|
});
|
|
|
|
if let Some(message) = response.choices.first() {
|
|
|
|
match message.finish_reason {
|
|
Some(FinishReason::ContentFilter) => {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
|
|
return;
|
|
},
|
|
Some(FinishReason::Length) => {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
|
|
return;
|
|
},
|
|
_ => ()
|
|
}
|
|
|
|
if let Some(calls) = &message.message.tool_calls {
|
|
let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default()
|
|
.tool_calls(calls.clone())
|
|
.build().unwrap().into();
|
|
self.messages.push(assistant_messages);
|
|
let mut results = vec![];
|
|
let mut messages = vec![];
|
|
for call in calls {
|
|
match call {
|
|
ChatCompletionMessageToolCalls::Function(call) => {
|
|
match call.function.name.as_str() {
|
|
"log_stage_event" => {
|
|
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
.tool_call_id(call.id.clone())
|
|
.build().unwrap()
|
|
));
|
|
messages.push(ConversationEntry::StageDirection(args.text));
|
|
},
|
|
"log_ship_computer_message" => {
|
|
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
.tool_call_id(call.id.clone())
|
|
.build().unwrap()
|
|
));
|
|
messages.push(ConversationEntry::ShipComputer(args.text));
|
|
},
|
|
"bandcamp_artifact_scan" => {
|
|
let args: BandcampQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
self.insert_conversation(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
|
let mut json_results = vec![];
|
|
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
|
for result in results {
|
|
match result {
|
|
SearchResultItem::Artist(data) => {
|
|
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
|
json_results.push(result);
|
|
},
|
|
SearchResultItem::Album(data) => {
|
|
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
|
json_results.push(result);
|
|
}
|
|
_ => ()
|
|
}
|
|
}
|
|
}
|
|
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
.tool_call_id(call.id.clone())
|
|
.content(serde_json::to_string(&json_results).unwrap())
|
|
.build().unwrap()
|
|
));
|
|
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
|
|
},
|
|
"archive_query" => {
|
|
let args: BeatsQueryArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
|
let mut beets_cmd = Command::new("beet");
|
|
beets_cmd.arg("export").arg("-f").arg("json").arg("-i").arg("title,label,year,genres,album,artist");
|
|
if let Some(artist) = args.artist {
|
|
beets_cmd.arg(format!("artist:{}", artist));
|
|
}
|
|
if let Some(genre) = args.genre {
|
|
beets_cmd.arg(format!("genre:{}", genre));
|
|
}
|
|
if let Some(album) = args.album {
|
|
beets_cmd.arg(format!("album:{}", album));
|
|
}
|
|
if let Some(title) = args.title {
|
|
beets_cmd.arg(format!("title:{}", title));
|
|
}
|
|
if let Some(year) = args.year {
|
|
beets_cmd.arg(format!("year:{}", year));
|
|
}
|
|
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
|
let minified = minify::json::minify(str::from_utf8(&output.stdout).unwrap());
|
|
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
.tool_call_id(call.id.clone())
|
|
.content(minified)
|
|
.build().unwrap()
|
|
));
|
|
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
|
} else {
|
|
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
|
results.push(ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessageArgs::default()
|
|
.tool_call_id(call.id.clone())
|
|
.content("")
|
|
.build().unwrap()
|
|
));
|
|
}
|
|
}
|
|
_ => panic!("Unknown function was called")
|
|
}
|
|
},
|
|
_ => panic!("Unknown tool was called")
|
|
}
|
|
}
|
|
self.messages.append(&mut results);
|
|
for msg in messages {
|
|
self.insert_conversation(msg);
|
|
}
|
|
|
|
}
|
|
if let Some(content) = message.message.content.as_ref() {
|
|
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
|
self.reply_options = options;
|
|
return;
|
|
} else {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
|
|
}
|
|
}
|
|
} else {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
|
|
}
|
|
}
|
|
}
|
|
|
|
fn as_scene(&self) -> Scene {
|
|
Scene::new(self.reply_options.clone(), self.conversation.clone())
|
|
}
|
|
|
|
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
|
self.conversation.push(entry.clone());
|
|
|
|
if let Ok(next_msg) = entry.try_into() {
|
|
self.messages.push(next_msg);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
|
|
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
|
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
|
|
|
|
let mut session = Session::from_initial_messages(initial_messages);
|
|
|
|
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
maybe_message = sys_message_src.recv() => {
|
|
if let Some(message) = maybe_message {
|
|
session.insert_conversation(ConversationEntry::SystemMessage(message));
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
}
|
|
},
|
|
maybe_request = prediction_request_out.changed() => {
|
|
if maybe_request.is_ok() {
|
|
let next_cxt = prediction_request_out.borrow().clone();
|
|
session.insert_actions(&next_cxt);
|
|
|
|
let mut save_data = SaveData {
|
|
direction: next_cxt.direction,
|
|
messages: session.messages.clone()
|
|
};
|
|
|
|
save_data.save();
|
|
|
|
session.regenerate_options(&save_data.direction).await;
|
|
|
|
save_data.messages = session.messages.clone();
|
|
save_data.save();
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
}
|
|
}
|
|
};
|
|
}
|
|
});
|
|
|
|
(prediction_request_in, prediction_out)
|
|
} |