400 lines
16 KiB
Rust
400 lines
16 KiB
Rust
use std::{fmt::Debug, sync::Arc};
|
|
|
|
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}};
|
|
use chrono::{DateTime, Utc};
|
|
use schemars::{JsonSchema, schema_for};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{Serializer, ser::CompactFormatter};
|
|
use tokio::sync::{RwLock, mpsc, watch};
|
|
|
|
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}};
|
|
|
|
|
|
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum PredictionAction {
|
|
ConversationAppend(ConversationEntry),
|
|
SetPlaylist(String),
|
|
GeneratePredictions,
|
|
SetNarrative(String),
|
|
SetShowEndTime(DateTime<Utc>)
|
|
}
|
|
|
|
|
|
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
|
pub struct PossibleResponse {
|
|
pub text: String,
|
|
pub stage_direction: Option<String>
|
|
}
|
|
|
|
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
|
|
pub struct GeneratedResponses {
|
|
pub responses: Vec<PossibleResponse>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Session {
|
|
client: Client<OpenAIConfig>,
|
|
conversation: Vec<ConversationEntry>,
|
|
header_message: ChatCompletionRequestMessage,
|
|
messages: Vec<ChatCompletionRequestMessage>,
|
|
reply_options: GeneratedResponses,
|
|
direction: StageDirection,
|
|
scenery: Scenery,
|
|
tokens_consumed: usize,
|
|
activity_notify: watch::Sender<bool>,
|
|
scene_sink: watch::Sender<Scene>
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
enum StageEvent {
|
|
ShipComputer(String),
|
|
StageDirection(String)
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
|
struct StageEventArgs {
|
|
event: StageEvent
|
|
}
|
|
|
|
#[derive(Default, Debug)]
|
|
struct ToolResults {
|
|
result: Option<String>,
|
|
messages: Vec<ConversationEntry>
|
|
}
|
|
|
|
impl Session {
|
|
fn new(scene_sink: watch::Sender<Scene>, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
|
let mut conversation = vec![];
|
|
for msg in &messages {
|
|
if let Ok(conversation_msg) = msg.clone().try_into() {
|
|
conversation.push(conversation_msg);
|
|
}
|
|
}
|
|
|
|
Self {
|
|
client: Default::default(),
|
|
conversation,
|
|
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
|
|
messages,
|
|
reply_options: Default::default(),
|
|
scenery,
|
|
direction,
|
|
tokens_consumed: 0,
|
|
activity_notify,
|
|
scene_sink
|
|
}
|
|
}
|
|
|
|
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
|
|
let msg = match args.event {
|
|
StageEvent::ShipComputer(text) => ConversationEntry::ShipComputer(text),
|
|
StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text)
|
|
};
|
|
ToolResults {
|
|
result: Some(msg.to_string()),
|
|
messages: vec![msg],
|
|
}
|
|
}
|
|
|
|
async fn tool_artifact_query<Src: DataSource>(&mut self, src: &mut Src, json_args: &str) -> ToolResults where Src::Args: Debug {
|
|
let args: Src::Args = serde_json::from_str(json_args).unwrap();
|
|
let mut messages = vec![];
|
|
log::debug!("Executing query {:?}", args);
|
|
if let Ok(output) = src.query(&args).await {
|
|
messages.push(ConversationEntry::ShipComputer(format!("Found {} artifacts with archive query {:?}", output.len(), args)));
|
|
for result in output {
|
|
self.scenery.artifacts.insert(result);
|
|
}
|
|
self.scenery.artifacts.synchronize().await;
|
|
} else {
|
|
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
|
|
};
|
|
|
|
ToolResults {
|
|
result: None,
|
|
messages
|
|
}
|
|
}
|
|
|
|
fn generate_conversation(&self, direction: &StageDirection) -> Vec<ChatCompletionRequestMessage> {
|
|
let mut json_buf = vec![];
|
|
let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter);
|
|
serde_json::json!({
|
|
"direction": direction,
|
|
"scenery": self.scenery
|
|
}).serialize(&mut ser).unwrap();
|
|
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default()
|
|
.content(String::from_utf8(json_buf).unwrap())
|
|
.build().unwrap().into();
|
|
|
|
let mut full_conversation = vec![
|
|
self.header_message.clone(),
|
|
direction_message,
|
|
];
|
|
full_conversation.append(&mut self.messages.clone());
|
|
|
|
full_conversation
|
|
}
|
|
|
|
async fn regenerate_options(&mut self) {
|
|
self.reply_options.responses.clear();
|
|
self.refresh();
|
|
self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }});
|
|
loop {
|
|
let full_conversation = self.generate_conversation(&self.direction);
|
|
|
|
let tools = vec![
|
|
Tool { name: "log_stage_event".into(), description: "Inserts an event into the current scene script".into(), schema: schema_for!(StageEventArgs)}.into(),
|
|
// TODO: There should only be two queries, one against the ship's onboard archive, and another against the relay network, or whatever we call it. Both should be structured with the same arguments schema
|
|
// TODO: A query should specify what parts of metadata are sufficient for the result, so we don't always have to hit all the layers of data. beets can of course, ignore this.
|
|
// TODO: A query should be hierarchical somehow? eg, "I already know about artist X, but I want to know everything about track Y from album Z" or "I don't know anything about artist X/album Y, please give me an overview"
|
|
Tool::from_datasource(&MBQuery).into(),
|
|
Tool::from_datasource(&BandcampSource).into(),
|
|
Tool::from_datasource(&BeetsDB).into(),
|
|
Tool::from_datasource(&MixxxDB).into(),
|
|
// TODO: We should be able to have eva update lore memories with a function call, and this lore is somehow fed into the show? but only the relevant bits? or maybe eva even queries it directly
|
|
// TODO: The memory should also be able to remember facts about artists, albums, tracks we've had in the past, and those could be pulled up when there are hits in the playlist.
|
|
];
|
|
log::debug!("Sending request..");
|
|
let request = CreateChatCompletionRequestArgs::default()
|
|
.messages(full_conversation)
|
|
.model("gpt-5.4-mini")
|
|
.tools(tools)
|
|
.max_completion_tokens(1024u32)
|
|
.response_format(ResponseFormat::JsonSchema {
|
|
json_schema: ResponseFormatJsonSchema {
|
|
description: None,
|
|
name: "responses".into(),
|
|
schema: schema_for!(GeneratedResponses).into(),
|
|
strict: None
|
|
}
|
|
})
|
|
.build().unwrap();
|
|
|
|
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
|
|
panic!("OpenAI Panic: {}", err);
|
|
});
|
|
|
|
if let Some(usage) = response.usage {
|
|
self.tokens_consumed += usage.total_tokens as usize;
|
|
log::debug!("{} tokens cast into the void", usage.total_tokens);
|
|
}
|
|
|
|
if let Some(message) = response.choices.first() {
|
|
|
|
match message.finish_reason {
|
|
Some(FinishReason::ContentFilter) => {
|
|
log::error!("Content filter triggered.");
|
|
break;
|
|
},
|
|
Some(FinishReason::Length) => {
|
|
log::error!("Maximum token count exceeded!");
|
|
break;
|
|
},
|
|
_ => ()
|
|
}
|
|
|
|
if let Some(calls) = &message.message.tool_calls {
|
|
let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default()
|
|
.tool_calls(calls.clone())
|
|
.build().unwrap().into();
|
|
self.messages.push(assistant_messages);
|
|
let mut results = vec![];
|
|
for call in calls {
|
|
match call {
|
|
ChatCompletionMessageToolCalls::Function(call) => {
|
|
let func_name = call.function.name.as_str();
|
|
let args = call.function.arguments.as_str();
|
|
let tool_result = match func_name {
|
|
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
|
|
"query_bandcamp" => self.tool_artifact_query(&mut BandcampSource, args).await,
|
|
"query_beets" => self.tool_artifact_query(&mut BeetsDB, args).await,
|
|
"query_musicbrainz" => self.tool_artifact_query(&mut MBQuery, args).await,
|
|
"query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await,
|
|
_ => unreachable!()
|
|
};
|
|
results.push((&call.id, tool_result));
|
|
},
|
|
_ => panic!("Unknown tool was called")
|
|
}
|
|
}
|
|
|
|
let mut tool_messages = vec![];
|
|
for (id, mut result) in results {
|
|
let mut msg = ChatCompletionRequestToolMessageArgs::default();
|
|
msg.tool_call_id(id);
|
|
if let Some(output) = result.result {
|
|
msg.content(output);
|
|
}
|
|
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
|
tool_messages.append(&mut result.messages);
|
|
}
|
|
for msg in tool_messages {
|
|
self.insert_conversation(msg);
|
|
}
|
|
}
|
|
if let Some(content) = message.message.content.as_ref() {
|
|
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
|
self.reply_options = options;
|
|
break;
|
|
} else {
|
|
log::info!("Received invalid JSON! Trying again.");
|
|
}
|
|
}
|
|
} else {
|
|
log::info!("No messages were received! Trying again.");
|
|
}
|
|
|
|
self.refresh();
|
|
}
|
|
self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }});
|
|
|
|
self.refresh();
|
|
}
|
|
|
|
fn as_scene(&self) -> Scene {
|
|
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
|
}
|
|
|
|
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
|
self.conversation.push(entry.clone());
|
|
|
|
if let Ok(next_msg) = entry.try_into() {
|
|
self.messages.push(next_msg);
|
|
}
|
|
|
|
self.refresh();
|
|
}
|
|
|
|
fn refresh(&self) {
|
|
self.scene_sink.send(self.as_scene()).unwrap();
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct SessionControl {
|
|
event_sink: mpsc::Sender<PredictionAction>,
|
|
scene_watch: watch::Receiver<Scene>,
|
|
activity_watch: watch::Receiver<bool>
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum SessionUpdate {
|
|
Scene(Scene),
|
|
Thinking(bool)
|
|
}
|
|
|
|
impl SessionControl {
|
|
pub async fn insert(&self, action: PredictionAction) {
|
|
self.event_sink.send(action).await.unwrap();
|
|
}
|
|
|
|
pub async fn regenerate_options(&self) {
|
|
self.insert(PredictionAction::GeneratePredictions).await;
|
|
}
|
|
|
|
pub async fn changed(&mut self) -> SessionUpdate {
|
|
tokio::select! {
|
|
_ = self.activity_watch.changed() => {
|
|
SessionUpdate::Thinking(*self.activity_watch.borrow_and_update())
|
|
},
|
|
_ = self.scene_watch.changed() => {
|
|
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver<String>) -> SessionControl {
|
|
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
|
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
|
|
|
let (action_sink, mut action_src) = mpsc::channel(5);
|
|
|
|
let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
|
|
|
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
|
session.refresh();
|
|
|
|
let shared_session = Arc::new(RwLock::new(session));
|
|
|
|
let log_session = Arc::clone(&shared_session);
|
|
tokio::spawn(async move {
|
|
loop {
|
|
if let Some(msg) = messages.recv().await {
|
|
log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg));
|
|
}
|
|
}
|
|
});
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
if let Some(evt) = action_src.recv().await {
|
|
let mut session = shared_session.write().await;
|
|
let do_regen = match evt {
|
|
PredictionAction::ConversationAppend(msg) => {
|
|
let do_regen = match msg {
|
|
ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true,
|
|
_ => false
|
|
};
|
|
session.insert_conversation(msg);
|
|
|
|
do_regen
|
|
},
|
|
PredictionAction::SetPlaylist(playlist_name) => {
|
|
let args = MixxxQuery { playlist_name };
|
|
match MixxxDB.query(&args).await {
|
|
Err(err) => log::info!("Failed to load mixxx playlist: {:?}.", err),
|
|
Ok(playlist) => {
|
|
session.scenery.current_playlist = vec![];
|
|
for item in playlist.clone() {
|
|
if let Contents::Track(as_track) = item.contents() {
|
|
session.scenery.current_playlist.push(as_track.clone());
|
|
}
|
|
session.scenery.artifacts.insert(item);
|
|
}
|
|
session.scenery.artifacts.synchronize().await;
|
|
session.direction.playlist = args.playlist_name;
|
|
log::info!("Mixxx playlist reloaded.");
|
|
}
|
|
}
|
|
false
|
|
},
|
|
PredictionAction::GeneratePredictions => {
|
|
true
|
|
},
|
|
PredictionAction::SetNarrative(narrative) => {
|
|
session.direction.narrative = narrative;
|
|
log::info!("Updated stage direction narrative");
|
|
true
|
|
},
|
|
PredictionAction::SetShowEndTime(end_time) => {
|
|
session.direction.end_time = end_time;
|
|
false
|
|
}
|
|
};
|
|
|
|
let save_data = SaveData {
|
|
direction: session.direction.clone(),
|
|
messages: session.messages.clone(),
|
|
scenery: session.scenery.clone()
|
|
};
|
|
|
|
save_data.save();
|
|
|
|
if do_regen {
|
|
drop(session);
|
|
shared_session.write().await.regenerate_options().await;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
SessionControl {
|
|
event_sink: action_sink,
|
|
scene_watch: prediction_out,
|
|
activity_watch: activity_notify_src
|
|
}
|
|
} |