431 lines
18 KiB
Rust
431 lines
18 KiB
Rust
use std::process::{Command, Stdio};
|
|
|
|
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
|
use bandcamp::SearchResultItem;
|
|
use schemars::{JsonSchema, schema_for};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{Serializer, ser::CompactFormatter};
|
|
use tokio::sync::{mpsc, watch};
|
|
|
|
use crate::{SaveData, artifacts::{Artifact, BandcampResult}, scene::{ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}};
|
|
|
|
|
|
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
|
|
|
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
|
pub struct PossibleResponse {
|
|
pub text: String,
|
|
pub stage_direction: Option<String>
|
|
}
|
|
|
|
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
|
|
pub struct GeneratedResponses {
|
|
pub responses: Vec<PossibleResponse>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Session {
|
|
client: Client<OpenAIConfig>,
|
|
conversation: Vec<ConversationEntry>,
|
|
header_message: ChatCompletionRequestMessage,
|
|
messages: Vec<ChatCompletionRequestMessage>,
|
|
reply_options: GeneratedResponses,
|
|
direction: StageDirection,
|
|
scenery: Scenery,
|
|
tokens_consumed: usize,
|
|
activity_notify: watch::Sender<bool>
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
enum StageEvent {
|
|
ShipComputer(String),
|
|
StageDirection(String)
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct StageEventArgs {
|
|
event: StageEvent
|
|
}
|
|
|
|
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct BeatsQueryArgs {
|
|
artist: Option<String>,
|
|
album: Option<String>,
|
|
genre: Option<String>,
|
|
title: Option<String>,
|
|
year: Option<u32>
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct BandcampQueryArgs {
|
|
query: String
|
|
}
|
|
|
|
#[derive(Default, Debug)]
|
|
struct ToolResults {
|
|
result: Option<String>,
|
|
messages: Vec<ConversationEntry>
|
|
}
|
|
|
|
impl Session {
|
|
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
|
let mut conversation = vec![];
|
|
for msg in &messages {
|
|
if let Ok(conversation_msg) = msg.clone().try_into() {
|
|
conversation.push(conversation_msg);
|
|
}
|
|
}
|
|
|
|
Self {
|
|
client: Default::default(),
|
|
conversation,
|
|
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
|
|
messages,
|
|
reply_options: Default::default(),
|
|
scenery,
|
|
direction,
|
|
tokens_consumed: 0,
|
|
activity_notify,
|
|
}
|
|
}
|
|
|
|
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
|
|
let msg = match args.event {
|
|
StageEvent::ShipComputer(text) => ConversationEntry::ShipComputer(text),
|
|
StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text)
|
|
};
|
|
ToolResults {
|
|
messages: vec![msg],
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults {
|
|
let mut messages = vec![];
|
|
messages.push(ConversationEntry::SystemMessage(format!("Fetching artifacts from Bandcamp with {:?}", args).into()));
|
|
let mut json_results = vec![];
|
|
if let Ok(results) = bandcamp::search(args.query.as_str()).await {
|
|
for result in results {
|
|
match result {
|
|
SearchResultItem::Artist(data) => {
|
|
let result = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
|
|
json_results.push(result);
|
|
},
|
|
SearchResultItem::Album(data) => {
|
|
let result = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
|
|
json_results.push(result);
|
|
}
|
|
_ => ()
|
|
}
|
|
}
|
|
}
|
|
let artifact_count = json_results.len();
|
|
messages.push(ConversationEntry::ShipComputer(format!("Relay scan for '{}' complete. {} artifacts added to the archive.", args.query, artifact_count).into()));
|
|
|
|
self.scenery.artifacts.append(&mut json_results);
|
|
|
|
ToolResults {
|
|
result: Some(format!("{} artifacts were added to the archive.", artifact_count)),
|
|
messages
|
|
}
|
|
}
|
|
|
|
async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults {
|
|
let mut messages = vec![];
|
|
let mut beets_cmd = Command::new("beet");
|
|
beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist"]);
|
|
if let Some(artist) = args.artist {
|
|
beets_cmd.arg(format!("artist:{}", artist));
|
|
}
|
|
if let Some(genre) = args.genre {
|
|
beets_cmd.arg(format!("genre:{}", genre));
|
|
}
|
|
if let Some(album) = args.album {
|
|
beets_cmd.arg(format!("album:{}", album));
|
|
}
|
|
if let Some(title) = args.title {
|
|
beets_cmd.arg(format!("title:{}", title));
|
|
}
|
|
if let Some(year) = args.year {
|
|
beets_cmd.arg(format!("year:{}", year));
|
|
}
|
|
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
|
|
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
|
|
self.scenery.artifacts.push(Artifact::BeetsTrack(serde_json::from_str(str::from_utf8(&output.stdout).unwrap()).unwrap()));
|
|
} else {
|
|
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
|
};
|
|
|
|
ToolResults {
|
|
result: None,
|
|
messages
|
|
}
|
|
}
|
|
|
|
fn generate_conversation(&self, direction: &StageDirection) -> Vec<ChatCompletionRequestMessage> {
|
|
let mut json_buf = vec![];
|
|
let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter);
|
|
direction.serialize(&mut ser).unwrap();
|
|
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default()
|
|
.content(String::from_utf8(json_buf).unwrap())
|
|
.build().unwrap().into();
|
|
|
|
let mut json_buf = vec![];
|
|
let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter);
|
|
self.scenery.serialize(&mut ser).unwrap();
|
|
let scenery_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default()
|
|
.content(String::from_utf8(json_buf).unwrap())
|
|
.build().unwrap().into();
|
|
let mut full_conversation = vec![
|
|
self.header_message.clone(),
|
|
direction_message,
|
|
scenery_message,
|
|
];
|
|
full_conversation.append(&mut self.messages.clone());
|
|
|
|
full_conversation
|
|
}
|
|
|
|
async fn regenerate_options(&mut self) {
|
|
self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }});
|
|
loop {
|
|
let full_conversation = self.generate_conversation(&self.direction);
|
|
|
|
let tools = vec![
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("log_stage_event")
|
|
.description("Inserts an event into the current scene script")
|
|
.parameters(schema_for!(StageEventArgs))
|
|
.build().unwrap()
|
|
}),
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("archive_query")
|
|
.description("Queries the ship's musical artifact archives for tracks matching the given search parameters")
|
|
.parameters(schema_for!(BeatsQueryArgs))
|
|
.build().unwrap()
|
|
}),
|
|
ChatCompletionTools::Function(ChatCompletionTool {
|
|
function: FunctionObjectArgs::default()
|
|
.name("bandcamp_artifact_scan")
|
|
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters. To find an artist, provide only the artist name. To find an album, provide the artist and the album.")
|
|
.parameters(schema_for!(BandcampQueryArgs))
|
|
.build().unwrap()
|
|
})
|
|
];
|
|
let request = CreateChatCompletionRequestArgs::default()
|
|
.messages(full_conversation)
|
|
.model("gpt-5.4-mini")
|
|
.tools(tools)
|
|
.max_completion_tokens(1024u32)
|
|
.response_format(ResponseFormat::JsonSchema {
|
|
json_schema: ResponseFormatJsonSchema {
|
|
description: None,
|
|
name: "responses".into(),
|
|
schema: schema_for!(GeneratedResponses).into(),
|
|
strict: None
|
|
}
|
|
})
|
|
.build().unwrap();
|
|
|
|
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
|
|
panic!("OpenAI Panic: {}", err);
|
|
});
|
|
|
|
if let Some(usage) = response.usage {
|
|
self.tokens_consumed += usage.total_tokens as usize;
|
|
}
|
|
|
|
if let Some(message) = response.choices.first() {
|
|
|
|
match message.finish_reason {
|
|
Some(FinishReason::ContentFilter) => {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
|
|
break;
|
|
},
|
|
Some(FinishReason::Length) => {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
|
|
break;
|
|
},
|
|
_ => ()
|
|
}
|
|
|
|
if let Some(calls) = &message.message.tool_calls {
|
|
let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default()
|
|
.tool_calls(calls.clone())
|
|
.build().unwrap().into();
|
|
self.messages.push(assistant_messages);
|
|
let mut results = vec![];
|
|
for call in calls {
|
|
match call {
|
|
ChatCompletionMessageToolCalls::Function(call) => {
|
|
let func_name = call.function.name.as_str();
|
|
let args = call.function.arguments.as_str();
|
|
let tool_result = match func_name {
|
|
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
|
|
"bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await,
|
|
"archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await,
|
|
_ => unreachable!()
|
|
};
|
|
results.push((&call.id, tool_result));
|
|
},
|
|
_ => panic!("Unknown tool was called")
|
|
}
|
|
}
|
|
|
|
let mut tool_messages = vec![];
|
|
for (id, mut result) in results {
|
|
let mut msg = ChatCompletionRequestToolMessageArgs::default();
|
|
msg.tool_call_id(id);
|
|
if let Some(output) = result.result {
|
|
msg.content(output);
|
|
}
|
|
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
|
tool_messages.append(&mut result.messages);
|
|
}
|
|
for msg in tool_messages {
|
|
self.insert_conversation(msg);
|
|
}
|
|
}
|
|
if let Some(content) = message.message.content.as_ref() {
|
|
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
|
self.reply_options = options;
|
|
break;
|
|
} else {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
|
|
}
|
|
}
|
|
} else {
|
|
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
|
|
}
|
|
}
|
|
self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }});
|
|
}
|
|
|
|
fn as_scene(&self) -> Scene {
|
|
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
|
}
|
|
|
|
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
|
self.conversation.push(entry.clone());
|
|
|
|
if let Ok(next_msg) = entry.try_into() {
|
|
self.messages.push(next_msg);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct SessionControl {
|
|
event_sink: mpsc::Sender<PredictionAction>,
|
|
scene_watch: watch::Receiver<Scene>,
|
|
activity_watch: watch::Receiver<bool>
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum SessionUpdate {
|
|
Scene(Scene),
|
|
Thinking(bool)
|
|
}
|
|
|
|
impl SessionControl {
|
|
pub async fn insert(&self, action: PredictionAction) {
|
|
self.event_sink.send(action).await.unwrap();
|
|
}
|
|
|
|
pub async fn log(&self, message: String) {
|
|
self.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(message))).await;
|
|
}
|
|
|
|
pub async fn regenerate_options(&self) {
|
|
self.insert(PredictionAction::GeneratePredictions).await;
|
|
}
|
|
|
|
pub async fn changed(&mut self) -> SessionUpdate {
|
|
tokio::select! {
|
|
_ = self.activity_watch.changed() => {
|
|
SessionUpdate::Thinking(*self.activity_watch.borrow_and_update())
|
|
},
|
|
_ = self.scene_watch.changed() => {
|
|
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn start_prediction(saved_session: SaveData) -> SessionControl {
|
|
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
|
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
|
|
|
let (action_sink, mut action_src) = mpsc::channel(5);
|
|
|
|
let mut session = Session::from_initial_messages(saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
|
|
|
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
if let Some(evt) = action_src.recv().await {
|
|
let do_regen = match evt {
|
|
PredictionAction::ConversationAppend(msg) => {
|
|
let do_regen = match msg {
|
|
ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true,
|
|
_ => false
|
|
};
|
|
session.insert_conversation(msg);
|
|
|
|
do_regen
|
|
},
|
|
PredictionAction::SetEpisodeNumber(num) => {
|
|
session.direction.episode_number = num;
|
|
if let Err(err) = session.direction.reload_mixxx_playlist() {
|
|
session.insert_conversation(ConversationEntry::SystemMessage(format!("Failed to load mixxx playlist: {:?}.", err).into()));
|
|
} else {
|
|
session.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
|
|
}
|
|
false
|
|
},
|
|
PredictionAction::GeneratePredictions => {
|
|
true
|
|
},
|
|
PredictionAction::SetNarrative(narrative) => {
|
|
session.direction.narrative = narrative;
|
|
session.insert_conversation(ConversationEntry::SystemMessage("Updated stage direction narrative".into()));
|
|
true
|
|
},
|
|
PredictionAction::SetShowEndTime(end_time) => {
|
|
session.direction.end_time = end_time;
|
|
false
|
|
}
|
|
};
|
|
|
|
let save_data = SaveData {
|
|
direction: session.direction.clone(),
|
|
messages: session.messages.clone(),
|
|
scenery: session.scenery.clone()
|
|
};
|
|
|
|
save_data.save();
|
|
|
|
if do_regen {
|
|
session.reply_options.responses.clear();
|
|
}
|
|
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
|
|
if do_regen {
|
|
session.regenerate_options().await;
|
|
prediction_in.send(session.as_scene()).unwrap();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
SessionControl {
|
|
event_sink: action_sink,
|
|
scene_watch: prediction_out,
|
|
activity_watch: activity_notify_src
|
|
}
|
|
} |