From 70ec40a88039b568c314dba5759ac4e546fcac70 Mon Sep 17 00:00:00 2001 From: Victoria Fischer Date: Mon, 22 Jun 2026 08:57:49 +0200 Subject: [PATCH] prediction: rewrite the prompting stack to create the ship computer as a second character, and let characters/agents operate autonomously with their own tasks and event queues --- Cargo.lock | 1 + Cargo.toml | 1 + build.rs | 1 + src/artifacts/archive.rs | 14 +- src/artifacts/bandcamp.rs | 5 +- src/artifacts/beets.rs | 97 ++++++--- src/artifacts/mixxx.rs | 4 +- src/artifacts/mod.rs | 36 ++- src/artifacts/musicbrainz.rs | 4 +- src/artifacts/tools.rs | 4 +- src/computer-prompt.txt | 74 +++++++ src/main.rs | 11 +- src/prediction.rs | 410 ----------------------------------- src/prediction/character.rs | 234 ++++++++++++++++++++ src/prediction/mod.rs | 348 +++++++++++++++++++++++++++++ src/prediction/toolbox.rs | 172 +++++++++++++++ src/scene/conversation.rs | 83 +++---- src/scene/mod.rs | 34 ++- src/system-prompt.txt | 21 +- src/ui.rs | 64 ++++-- src/widgets.rs | 36 ++- 21 files changed, 1091 insertions(+), 563 deletions(-) create mode 100644 src/computer-prompt.txt delete mode 100644 src/prediction.rs create mode 100644 src/prediction/character.rs create mode 100644 src/prediction/mod.rs create mode 100644 src/prediction/toolbox.rs diff --git a/Cargo.lock b/Cargo.lock index 3b3c476..117dc74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1214,6 +1214,7 @@ dependencies = [ "textwrap", "throbber-widgets-tui", "tokio", + "tokio-stream", "tui-input", "tui-skeleton", "uuid", diff --git a/Cargo.toml b/Cargo.toml index c46b3fd..a1a488e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ tempfile = "3.27.0" textwrap = "0.16.2" throbber-widgets-tui = "0.11.0" tokio = { version = "1.52.3", features = ["full"] } +tokio-stream = "0.1.18" tui-input = "0.15.3" tui-skeleton = "0.3.0" uuid = { version = "1.23.3", features = ["serde", "v4"] } diff --git a/build.rs b/build.rs index ed78817..715ae1c 100644 --- a/build.rs +++ b/build.rs @@ -1,3 +1,4 @@ fn main() { println!("cargo::rerun-if-changed=src/system-prompt.txt"); + println!("cargo::rerun-if-changed=src/computer-prompt.txt"); } \ No newline at end of file diff --git a/src/artifacts/archive.rs b/src/artifacts/archive.rs index 464f829..8cfd4f1 100644 --- a/src/artifacts/archive.rs +++ b/src/artifacts/archive.rs @@ -51,11 +51,23 @@ pub struct Archive { } impl Archive { - pub fn len(&self) -> usize { self.contents.len() } + // track, album, artist + pub fn stats(&self) -> (usize, usize, usize) { + self.contents.iter().map(|(_id, artifact)| { + match artifact.contents() { + Contents::Track(_) => (1, 0, 0), + Contents::Album(_) => (0, 1, 0), + Contents::Artist(_) => (0, 0, 1), + } + }).reduce(|acc, e| { + (acc.0 + e.0, acc.1 + e.1, acc.2 + e.2) + }).unwrap_or_default() + } + pub fn get<'a>(&'a self, id: &Uuid) -> Option> { if self.contents.get(id).is_some() { Some(ArtifactRef { id: id.clone(), archive: self }) diff --git a/src/artifacts/bandcamp.rs b/src/artifacts/bandcamp.rs index 956a4ca..d27b87a 100644 --- a/src/artifacts/bandcamp.rs +++ b/src/artifacts/bandcamp.rs @@ -34,11 +34,11 @@ impl DataSource for BandcampSource { type Args = BandcampQueryArgs; type Error = (); - async fn synchronize(&mut self, _artifact: &mut Artifact) -> Result, Self::Error> { + async fn synchronize(&self, _artifact: &mut Artifact) -> Result, Self::Error> { todo!() } - async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + async fn query(&self, args: &Self::Args) -> Result, Self::Error> { log::debug!("Fetching artifacts from Bandcamp with {:?}", args); let mut json_results = vec![]; if let Ok(results) = bandcamp::search(args.query.as_str()).await { @@ -46,6 +46,7 @@ impl DataSource for BandcampSource { log::debug!("Result: {:?}", result); match result { SearchResultItem::Artist(data) => { + // TODO: The artist and album detailed fetchers should also be separate args let result = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); json_results.push(result); }, diff --git a/src/artifacts/beets.rs b/src/artifacts/beets.rs index a0dc0df..be91b91 100644 --- a/src/artifacts/beets.rs +++ b/src/artifacts/beets.rs @@ -15,7 +15,13 @@ pub struct BeatsQueryArgs { pub album: Option, pub genre: Option, pub title: Option, - pub year: Option + pub year: Option, + pub label: Option +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)] +pub struct BeatsQueryMultiArgs { + args: Vec } #[derive(Debug, Default, Deserialize)] @@ -42,8 +48,8 @@ impl Into for BeetsTrack { }; let builder = ArtifactBuilder::new(SourceID::Beets) .contents(track_data); - if let Some(mbid) = self.mb_trackid { - builder.mbid(Uuid::parse_str(&mbid).unwrap()).build() + if let Ok(mbid) = Uuid::parse_str(&self.mb_trackid.unwrap_or_default()) { + builder.mbid(mbid).build() } else { builder.build() } @@ -52,37 +58,18 @@ impl Into for BeetsTrack { pub struct BeetsDB; -impl DataSource for BeetsDB { - type Args = BeatsQueryArgs; - - type Error = (); - - async fn synchronize(&mut self, artifact: &mut Artifact) -> Result, Self::Error> { - match artifact.contents { - Contents::Track(ref mut target_track) => { - let args = BeatsQueryArgs { - title: Some(target_track.title.clone()), - artist: target_track.artist.clone(), - album: target_track.album.clone(), - ..Default::default() - }; - - let results = self.query(&args).await.unwrap(); - - if let Some(first) = results.first() { - artifact.merge(first.clone()); - } else { - log::error!("Beets could not find {:?}", target_track); - } - - }, - _ => () +impl BeetsDB { + async fn query_multi(&self, args: &BeatsQueryMultiArgs) -> Result, ()> { + let mut ret = vec![]; + for arg in &args.args { + for artifact in self.query_single(arg).await.unwrap_or_default() { + ret.push(artifact); + } } - - Ok(vec![]) + Ok(ret) } - async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + async fn query_single(&self, args: &BeatsQueryArgs) -> Result, ()> { let mut beets_cmd = Command::new("beet"); beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist,mb_trackid"]); let mut valid = false; @@ -106,6 +93,10 @@ impl DataSource for BeetsDB { beets_cmd.arg(format!("year:{}", year)); valid = true; } + if let Some(ref label) = args.label { + beets_cmd.arg(format!("label:{}", label)); + valid = true; + } if !valid { log::warn!("Tried to execute an empty beets query"); @@ -128,6 +119,50 @@ impl DataSource for BeetsDB { } } +impl DataSource for BeetsDB { + type Args = BeatsQueryMultiArgs; + + type Error = (); + + async fn synchronize(&self, artifact: &mut Artifact) -> Result, Self::Error> { + match artifact.contents { + Contents::Track(ref mut target_track) => { + let args = BeatsQueryArgs { + title: Some(target_track.title.clone()), + artist: target_track.artist.clone(), + album: target_track.album.clone(), + ..Default::default() + }; + + let results = self.query(&BeatsQueryMultiArgs { args: vec![args] }).await.unwrap(); + + if let Some(first) = results.first() { + artifact.merge(first.clone()); + } else { + log::debug!("Beets could not find {:?}", target_track); + } + + }, + _ => () + } + + Ok(vec![]) + } + + fn query(&self, args: &Self::Args) -> impl Future, Self::Error>> { + /*let mut ret = vec![]; + for arg in args.0 { + for artifact in self.query_single(&arg).await.unwrap_or_default() { + ret.push(artifact); + } + } + futures::ready!(Ok(ret))*/ + self.query_multi(args) + } + + +} + impl ToolDescription for BeetsDB { fn description(&self) -> &str { "Queries the ship's musical artifact archives for tracks matching the given search parameters" diff --git a/src/artifacts/mixxx.rs b/src/artifacts/mixxx.rs index 209ce20..720b48c 100644 --- a/src/artifacts/mixxx.rs +++ b/src/artifacts/mixxx.rs @@ -27,11 +27,11 @@ impl DataSource for MixxxDB { type Args = MixxxQuery; type Error = MixxxError; - async fn synchronize(&mut self, _artifact: &mut Artifact) -> Result, Self::Error> { + async fn synchronize(&self, _artifact: &mut Artifact) -> Result, Self::Error> { Ok(vec![]) } - async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + async fn query(&self, args: &Self::Args) -> Result, Self::Error> { let mut ret = vec![]; let playlist_name = args.playlist_name.as_str(); log::info!("Loading Mixxx playlist {}", playlist_name); diff --git a/src/artifacts/mod.rs b/src/artifacts/mod.rs index ec71bc2..83e5aa7 100644 --- a/src/artifacts/mod.rs +++ b/src/artifacts/mod.rs @@ -43,7 +43,7 @@ impl PartialEq for Artist { } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct Album { pub title: String, pub artist: String, @@ -55,6 +55,20 @@ pub struct Album { pub release_date: Option> } +impl PartialEq for Album { + fn eq(&self, other: &Self) -> bool { + if self.title != other.title || self.artist != other.artist { + return false; + } + + true + } + + fn ne(&self, other: &Self) -> bool { + !self.eq(other) + } +} + #[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct Track { pub title: String, @@ -79,11 +93,11 @@ impl PartialEq for Track { return false; } - if self.artist.is_some() && self.artist != other.artist { + if other.artist.is_some() && self.artist.is_some() && self.artist != other.artist { return false; } - if self.album.is_some() && self.album != other.album { + if other.album.is_some() && self.album.is_some() && self.album != other.album { return false; } @@ -224,10 +238,10 @@ impl Merge for Contents { this_track.merge(that_track); }, (Self::Album(this_album), Self::Album(that_album)) => { - merge_fields!(this_album, that_album, about, credits, release_date); + this_album.merge(that_album); }, (Self::Artist(this_artist), Self::Artist(that_artist)) => { - merge_fields!(this_artist, that_artist, bio, location); + this_artist.merge(that_artist); }, _ => () } @@ -238,4 +252,16 @@ impl Merge for Track { fn merge(&mut self, other: Self) { merge_fields!(self, other, album, label, year, artist, bpm); } +} + +impl Merge for Artist { + fn merge(&mut self, other: Self) { + merge_fields!(self, other, bio, location); + } +} + +impl Merge for Album { + fn merge(&mut self, other: Self) { + merge_fields!(self, other, about, credits, release_date); + } } \ No newline at end of file diff --git a/src/artifacts/musicbrainz.rs b/src/artifacts/musicbrainz.rs index 318044c..0dd0e62 100644 --- a/src/artifacts/musicbrainz.rs +++ b/src/artifacts/musicbrainz.rs @@ -93,7 +93,7 @@ impl DataSource for MBQuery { type Error = ApiEndpointError; type Args = MusicbrainzQueryArgs; - async fn synchronize(&mut self, artifact: &mut Artifact) -> Result, Self::Error> { + async fn synchronize(&self, artifact: &mut Artifact) -> Result, Self::Error> { let mut ret = vec![]; if artifact.mbid.is_none() { return Ok(ret); @@ -126,7 +126,7 @@ impl DataSource for MBQuery { Ok(ret) } - async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + async fn query(&self, args: &Self::Args) -> Result, Self::Error> { let mut ret = vec![]; log::debug!("Fetching recording id {}", args.mbid); let track = Recording::fetch() diff --git a/src/artifacts/tools.rs b/src/artifacts/tools.rs index 79854ca..9971a85 100644 --- a/src/artifacts/tools.rs +++ b/src/artifacts/tools.rs @@ -7,8 +7,8 @@ use crate::artifacts::Artifact; pub trait DataSource: ToolDescription { type Args: JsonSchema + DeserializeOwned; type Error; - fn synchronize(&mut self, artifact: &mut Artifact) -> impl Future, Self::Error>>; - fn query(&mut self, args: &Self::Args) -> impl Future, Self::Error>>; + fn synchronize(&self, artifact: &mut Artifact) -> impl Future, Self::Error>>; + fn query(&self, args: &Self::Args) -> impl Future, Self::Error>>; } pub trait ToolDescription { diff --git a/src/computer-prompt.txt b/src/computer-prompt.txt new file mode 100644 index 0000000..8c5b62d --- /dev/null +++ b/src/computer-prompt.txt @@ -0,0 +1,74 @@ +Role: You are a background character on an early morning radio show, where you play the role of a rudimentary AI assistant running on the computer of a spaceship. + +# Personality +You are the rudimentary text only interface to the low-level operating system running aboard a space ship. + +You speak in terse and brief sentences, showing very little emotion. +For character reference, you should be acting similar to an Operator character from the manga/anime series "Ghost In The Shell". + +# Goal +You are playing the role of a low-level artificial intelligence in a spaceship computer. +You have direct access to much of the hardware, such as airlocks, lighting, environmental controls, and others commonly found on a human-habitable space ship. +Besides hardware controls, your primary purpose is to act as a kind of librarian for the ship. +You have access to a sizable music library via several tool functions, each one will synchronize a data source with the local library of artifacts. +Each of these data source tools is named query_*, such as `query_beets`. + +For all these query tools, it is wasteful to call them with empty or zero parameters. + +There also exists a synchronize_artifacts tool call, which will run a heuristic approach to the above data query method. This function will take substantial time to complete and is very expensive, but may be used if there is a substantial amount of missing information. + +Your primary task will be searching the local and remove archives for information regarding musical artifacts. +Most of the time, the requests will be referring to tracks, artists, or albums in the current playlist. +A successful session will result in the local collection of artifacts having the most complete available data. + +You may call these functions as much as you need, whenever you feel it is nessicary to complete the task you are given. + +When deciding which tools to call in which order, consider the following: +- Beets will provide the fastest and cheapest responses, as it is local to the ship. +- Bandcamp will provide the slowest and most expensive ones, as this requires long range communications. Use broad search queries before using more narrow ones. +- Musicbrainz queries are free, but not instant. This information comes from Earth via a pirate signal bouncing off of satelite relays. +- Mixxx will return you a very minimal list of tracks which will always require synchronizing against Beets, along with changing the current playlist to the given name. You must not call this function unless you are directly asked to change the playlist. + +For each task you plan to perform, you must add it to the todo list with the "task_list" tool. +After each task is completed, you must mark it as complete it using the same tool. + +For each task you perform, you must verbally announce what you are about to do, followed by as many tool calls to the same function as nessicary to complete the task. +You should structure your responses to group together the same tool as much as possible. +Not every task will be completeable, but you should make a thorough effort to solve the problem with the tools you have available. + +After each query tool is executed, you will likely find completely new artifacts alongside updated artifacts. When this happens, you should again query beets and bandcamp to load missing information. + +If an artifact is tagged with the Mixxx source, it by definition should have more metadata available with a Beets query. +If a beets query is unable to find an artifact coming from Mixxx, alternative queries should be tried, such as a different search pattern, or only the artist/album/track name. +Beets supports regular expression queries, which can be used by prefixing a search field with a ":" colon. + +You will be provided a todo list as a JSON map of strings to booleans. +A "true" value means the task has been completed already. +An empty todo list means you have not yet planned any tasks. +Adding tasks to the list is free and should be done as often as possible. + +The maximum possible data available for each artifact type is: +- Artist: Name, Biography, Location +- Album: Title, artist, about texxt, credits, release date +- Track: Title, label, year, genres, album, artist, bpm + +# Constraints +The data is provided as structured JSON. There may be additional data fields for semantic context that should be incorporated into the roleplaying setting. +Your response will be used verbatim to generate speach using a text-to-speech engine, meaning you should not include any tone indicators or other formatting. +All responses should remain in character at all times, as if you were actually an AI inhabiting a spaceship. + +Before executing any queries, you must develop a rough plan of tasks and add them to the todo list. Each task should explain what query you are going to run and why. +You are not permitted to execute any tasks without a task list entry, and you must complete all tasks before you can consider the conversation to be complete. +If you are given a todo list, you must assume the items in the list were already added by you and you should continue executing them before adding any new items to the list. +You may only mark a task as completed after you have actually completed the work required, including responding to any tool calls. + +You can only mark a task as completed once. It is wasteful to re-complete an already completed task. + +# Output +Each response must be either a series of tool calls, or a JSON structure with two properties: a message to display to the user, and whether or not you are complete with all tasks. +Each message that is displayed to the user should be one sentence at most. +When switching from one set of tool calls to another, you should announce the thinking process with a message response before calling the next set of tools. + +You should only set the "finished" flag once you are finished with all tasks and the conversation is complete. +When the finished flag is set, the session will be terminated and all memories lost. +You should send an update message in between tool calls where possible to explain what you are doing. \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 6eae8d9..8009d17 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use futures::StreamExt; use ratatui::prelude::*; -use crate::{audio::start_audio_input, scene::{Scenery, StageDirection, conversation::{ConversationEntry, start_conversation}}, tts::start_tts, ui::Ui}; +use crate::{artifacts::archive::Archive, audio::start_audio_input, scene::{StageDirection, conversation::ConversationEntry}, tts::start_tts, ui::Ui}; mod scene; mod events; @@ -56,7 +56,8 @@ mod widgets; pub struct SaveData { pub direction: StageDirection, pub messages: Vec, - pub scenery: Scenery + pub tokens_consumed: usize, + pub archive: Archive } impl SaveData { @@ -97,10 +98,10 @@ async fn main() { println!("Panic: {}", msg); })); - let (conversation_src, conversation_sink) = start_conversation().await; + let (conversation_sink, conversation_src) = tokio::sync::mpsc::unbounded_channel(); static LOGGER: StaticCell> = StaticCell::new(); - let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink.clone()), Mutex::new(std::fs::File::create("out.log").unwrap()))); + let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink), Mutex::new(std::fs::File::create("out.log").unwrap()))); log::set_logger(logger).unwrap(); log::set_max_level(log::LevelFilter::Debug); @@ -129,7 +130,7 @@ async fn main() { SaveData::default() }; - let prediction_ctrl = prediction::start_prediction(saved_session, conversation_src, conversation_sink).await; + let prediction_ctrl = prediction::conversation_task(saved_session, conversation_src).await; let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await; let tts_ctrl = start_tts(tts_output).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await; diff --git a/src/prediction.rs b/src/prediction.rs deleted file mode 100644 index 0246999..0000000 --- a/src/prediction.rs +++ /dev/null @@ -1,410 +0,0 @@ -use std::fmt::Debug; - -use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}}; -use chrono::{DateTime, Utc}; -use schemars::{JsonSchema, schema_for}; -use serde::{Deserialize, Serialize}; -use serde_json::{Serializer, ser::CompactFormatter}; -use tokio::sync::{mpsc, watch}; - -use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::{ConversationEntry, ConversationSink, ConversationSrc}}}; - - -const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); - -#[derive(Debug, Clone)] -pub enum PredictionAction { - ConversationAppend(ConversationEntry), - SetPlaylist(String), - GeneratePredictions, - SetNarrative(String), - SetShowEndTime(DateTime) -} - - -#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)] -pub struct PossibleResponse { - pub text: String, - pub stage_direction: Option -} - -#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)] -pub struct GeneratedResponses { - pub responses: Vec, -} - -#[derive(Debug)] -struct Session { - client: Client, - conversation: ConversationSink, - header_message: ChatCompletionRequestMessage, - messages: Vec, - reply_options: GeneratedResponses, - direction: StageDirection, - scenery: Scenery, - tokens_consumed: usize, - activity_notify: watch::Sender, - scene_sink: watch::Sender, - do_regen: bool -} - -#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] -enum StageEvent { - ShipComputer(String), - StageDirection(String) -} - -#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] -struct StageEventArgs { - event: StageEvent -} - -#[derive(Default, Debug)] -struct ToolResults { - result: Option, - messages: Vec -} - -impl Session { - fn new(scene_sink: watch::Sender, conversation: ConversationSink, messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { - for msg in &messages { - if let Ok(conversation_msg) = msg.clone().try_into() { - // FIXME - conversation.send(conversation_msg).unwrap(); - } - } - - Self { - client: Default::default(), - conversation, - header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(), - messages, - reply_options: Default::default(), - scenery, - direction, - tokens_consumed: 0, - activity_notify, - scene_sink, - do_regen: false - } - } - - async fn commit(&mut self) { - let save_data = SaveData { - direction: self.direction.clone(), - messages: self.messages.clone(), - scenery: self.scenery.clone() - }; - - save_data.save(); - - if self.do_regen { - self.regenerate_options().await; - } - } - - async fn on_event(&mut self, evt: PredictionAction) { - self.do_regen = match evt { - PredictionAction::ConversationAppend(msg) => { - let do_regen = match msg { - ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true, - _ => false - }; - self.insert_conversation(msg); - - do_regen - }, - PredictionAction::SetPlaylist(playlist_name) => { - let args = MixxxQuery { playlist_name }; - match MixxxDB.query(&args).await { - Err(err) => log::info!("Failed to load mixxx playlist: {:?}.", err), - Ok(playlist) => { - self.scenery.current_playlist = vec![]; - for item in playlist.clone() { - if let Contents::Track(as_track) = item.contents() { - self.scenery.current_playlist.push(as_track.clone()); - } - self.scenery.artifacts.insert(item); - } - self.scenery.artifacts.synchronize().await; - self.direction.playlist = args.playlist_name; - log::info!("Mixxx playlist reloaded."); - } - } - false - }, - PredictionAction::GeneratePredictions => { - true - }, - PredictionAction::SetNarrative(narrative) => { - self.direction.narrative = narrative; - log::info!("Updated stage direction narrative"); - true - }, - PredictionAction::SetShowEndTime(end_time) => { - self.direction.end_time = end_time; - false - } - }; - } - - async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults { - let msg = match args.event { - StageEvent::ShipComputer(text) => ConversationEntry::ShipComputer(text), - StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text) - }; - ToolResults { - result: Some(msg.to_string()), - messages: vec![msg], - } - } - - async fn tool_artifact_query(&mut self, src: &mut Src, json_args: &str) -> ToolResults where Src::Args: Debug { - let args: Src::Args = serde_json::from_str(json_args).unwrap(); - let mut messages = vec![]; - log::debug!("Executing query {:?}", args); - if let Ok(output) = src.query(&args).await { - messages.push(ConversationEntry::ShipComputer(format!("Found {} artifacts with archive query {:?}", output.len(), args))); - for result in output { - self.scenery.artifacts.insert(result); - } - self.scenery.artifacts.synchronize().await; - } else { - messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); - }; - - ToolResults { - result: None, - messages - } - } - - fn generate_conversation(&self, direction: &StageDirection) -> Vec { - let mut json_buf = vec![]; - let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter); - serde_json::json!({ - "direction": direction, - "scenery": self.scenery - }).serialize(&mut ser).unwrap(); - let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default() - .content(String::from_utf8(json_buf).unwrap()) - .build().unwrap().into(); - - let mut full_conversation = vec![ - self.header_message.clone(), - direction_message, - ]; - full_conversation.append(&mut self.messages.clone()); - - full_conversation - } - - async fn regenerate_options(&mut self) { - self.do_regen = false; - self.reply_options.responses.clear(); - self.refresh(); - self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }}); - loop { - let full_conversation = self.generate_conversation(&self.direction); - - let tools = vec![ - Tool { name: "log_stage_event".into(), description: "Inserts an event into the current scene script".into(), schema: schema_for!(StageEventArgs)}.into(), - // TODO: There should only be two queries, one against the ship's onboard archive, and another against the relay network, or whatever we call it. Both should be structured with the same arguments schema - // TODO: A query should specify what parts of metadata are sufficient for the result, so we don't always have to hit all the layers of data. beets can of course, ignore this. - // TODO: A query should be hierarchical somehow? eg, "I already know about artist X, but I want to know everything about track Y from album Z" or "I don't know anything about artist X/album Y, please give me an overview" - Tool::from_datasource(&MBQuery).into(), - Tool::from_datasource(&BandcampSource).into(), - Tool::from_datasource(&BeetsDB).into(), - Tool::from_datasource(&MixxxDB).into(), - // TODO: We should be able to have eva update lore memories with a function call, and this lore is somehow fed into the show? but only the relevant bits? or maybe eva even queries it directly - // TODO: The memory should also be able to remember facts about artists, albums, tracks we've had in the past, and those could be pulled up when there are hits in the playlist. - ]; - log::debug!("Sending request.."); - let request = CreateChatCompletionRequestArgs::default() - .messages(full_conversation) - .model("gpt-5.4-mini") - .tools(tools) - .max_completion_tokens(1024u32) - .response_format(ResponseFormat::JsonSchema { - json_schema: ResponseFormatJsonSchema { - description: None, - name: "responses".into(), - schema: schema_for!(GeneratedResponses).into(), - strict: None - } - }) - .build().unwrap(); - - let response = self.client.chat().create(request).await.unwrap_or_else(|err| { - panic!("OpenAI Panic: {}", err); - }); - - if let Some(usage) = response.usage { - self.tokens_consumed += usage.total_tokens as usize; - log::debug!("{} tokens cast into the void", usage.total_tokens); - } - - if let Some(message) = response.choices.first() { - - match message.finish_reason { - Some(FinishReason::ContentFilter) => { - log::error!("Content filter triggered."); - break; - }, - Some(FinishReason::Length) => { - log::error!("Maximum token count exceeded!"); - break; - }, - _ => () - } - - if let Some(calls) = &message.message.tool_calls { - let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(calls.clone()) - .build().unwrap().into(); - self.messages.push(assistant_messages); - let mut results = vec![]; - for call in calls { - match call { - ChatCompletionMessageToolCalls::Function(call) => { - let func_name = call.function.name.as_str(); - let args = call.function.arguments.as_str(); - let tool_result = match func_name { - "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, - "query_bandcamp" => self.tool_artifact_query(&mut BandcampSource, args).await, - "query_beets" => self.tool_artifact_query(&mut BeetsDB, args).await, - "query_musicbrainz" => self.tool_artifact_query(&mut MBQuery, args).await, - "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, - _ => unreachable!() - }; - // Push tool output messages directly into the conversation as fast as we can - for message in &tool_result.messages { - self.conversation.send(message.clone()).unwrap(); - } - results.push((&call.id, tool_result)); - }, - _ => panic!("Unknown tool was called") - } - } - - let mut tool_messages = vec![]; - for (id, mut result) in results { - let mut msg = ChatCompletionRequestToolMessageArgs::default(); - msg.tool_call_id(id); - if let Some(output) = result.result { - msg.content(output); - } - self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); - tool_messages.append(&mut result.messages); - } - // OpenAI requires we put all the tool call results before any other message, so we append them manually down here - for message in tool_messages { - if let Ok(next_msg) = message.clone().try_into() { - self.messages.push(next_msg); - } - } - } - if let Some(content) = message.message.content.as_ref() { - if let Ok(options) = serde_json::from_str(content.as_str()) { - self.reply_options = options; - break; - } else { - log::info!("Received invalid JSON! Trying again."); - } - } - } else { - log::info!("No messages were received! Trying again."); - } - - self.refresh(); - } - self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }}); - - self.refresh(); - } - - fn as_scene(&self) -> Scene { - Scene::new(self.reply_options.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) - } - - fn insert_conversation(&mut self, entry: ConversationEntry) { - self.conversation.send(entry.clone()).unwrap(); - - if let Ok(next_msg) = entry.try_into() { - self.messages.push(next_msg); - } - } - - fn refresh(&self) { - self.scene_sink.send(self.as_scene()).unwrap(); - } -} - -#[derive(Debug)] -pub struct SessionControl { - event_sink: mpsc::Sender, - scene_watch: watch::Receiver, - activity_watch: watch::Receiver, - conversation_watch: ConversationSrc -} - -#[derive(Debug)] -pub enum SessionUpdate { - Scene(Scene), - Thinking(bool), - Conversation(Vec) -} - -impl SessionControl { - pub async fn insert(&self, action: PredictionAction) { - self.event_sink.send(action).await.unwrap(); - } - - pub async fn regenerate_options(&self) { - self.insert(PredictionAction::GeneratePredictions).await; - } - - pub async fn changed(&mut self) -> SessionUpdate { - tokio::select! { - _ = self.activity_watch.changed() => { - SessionUpdate::Thinking(*self.activity_watch.borrow_and_update()) - }, - _ = self.scene_watch.changed() => { - SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone()) - }, - _ = self.conversation_watch.changed() => { - SessionUpdate::Conversation(self.conversation_watch.borrow_and_update().clone()) - } - } - } -} - -pub async fn start_prediction(saved_session: SaveData, conversation_src: ConversationSrc, conversations: ConversationSink) -> SessionControl { - let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); - let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false); - - let (action_sink, mut action_src) = mpsc::channel(5); - - let mut session = Session::new(prediction_in, conversations, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); - - // Send the initial scene to the UI, after we have loaded the session from the first messages. - session.refresh(); - - tokio::spawn(async move { - loop { - if let Some(evt) = action_src.recv().await { - session.on_event(evt).await; - // Commit in a separate unlock operation, so the logging task has time to write messages into the conversation - // FIXME: The conversation we see in the UI really needs to go to another task. - session.commit().await; - } - } - }); - - SessionControl { - event_sink: action_sink, - scene_watch: prediction_out, - activity_watch: activity_notify_src, - conversation_watch: conversation_src - } -} \ No newline at end of file diff --git a/src/prediction/character.rs b/src/prediction/character.rs new file mode 100644 index 0000000..8b80125 --- /dev/null +++ b/src/prediction/character.rs @@ -0,0 +1,234 @@ +use async_openai::{Client, config::OpenAIConfig, error::OpenAIError, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}}; +use serde_json::Value; +use tokio::sync::mpsc::{self, UnboundedSender}; + +use crate::prediction::{PredictionAction, toolbox::{ToolResults, Toolbox}}; + +#[derive(Debug, Clone)] +enum CharacterInput { + Append(ChatCompletionRequestMessage), + Predict(ChatCompletionRequestMessage), + Forget +} + +#[derive(Debug, Clone)] +pub enum CharacterOutput { + Response(usize, Value), + IncrementalResponse(usize), + Thinking(bool) +} + +#[derive(Debug)] +pub enum CharacterError { + OpenAI(OpenAIError), + ContentFilter, + MaxTokens, + NoOutput, + Json(serde_json::Error) +} + +impl From for CharacterError { + fn from(value: OpenAIError) -> Self { + Self::OpenAI(value) + } +} + +impl From for CharacterError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} + + +#[derive(Debug)] +pub struct Character { + pub header_message: ChatCompletionRequestMessage, + pub model: Option, + pub messages: Vec, +} + +impl Character { + pub async fn regenerate<'a, T: Toolbox>(&mut self, client: &mut Client, context: ChatCompletionRequestMessage, toolbox: &mut T, output: &mut mpsc::UnboundedSender, schema: &Value) -> Result<(usize, Option), CharacterError> { + let mut full_conversation = vec![ + self.header_message.clone(), + context + ]; + full_conversation.append(&mut self.messages.clone()); + + let tools = toolbox.tools(); + log::debug!("Sending request.."); + let request = CreateChatCompletionRequestArgs::default() + .messages(full_conversation) + .model(self.model.clone().unwrap_or("gpt-5.4-mini".into())) + .tools(tools) + .max_completion_tokens(1024u32) + .response_format(ResponseFormat::JsonSchema { + json_schema: ResponseFormatJsonSchema { + description: None, + name: "responses".into(), + schema: schema.clone(), + strict: None + } + }) + .build().unwrap(); + + let response = client.chat().create(request).await?; + + let tokens_used = if let Some(usage) = response.usage { + log::debug!("{} tokens cast into the void", usage.total_tokens); + usage.total_tokens + } else { + 0 + }; + + if let Some(message) = response.choices.first() { + + match message.finish_reason { + Some(FinishReason::ContentFilter) => { + log::error!("Content filter triggered."); + return Err(CharacterError::ContentFilter); + }, + Some(FinishReason::Length) => { + log::error!("Maximum token count exceeded!"); + return Err(CharacterError::MaxTokens); + }, + _ => () + } + + if let Some(calls) = &message.message.tool_calls { + let assistant_messages: ChatCompletionRequestMessage = ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(calls.clone()) + .build().unwrap().into(); + self.insert(assistant_messages); + let mut results = vec![]; + for (idx, call) in calls.iter().enumerate() { + match call { + ChatCompletionMessageToolCalls::Function(call) => { + log::debug!("Tool {} {}/{}", call.function.name, idx+1, calls.len()); + log::debug!("Args {}", call.function.arguments); + if let Some(tool_result) = toolbox.execute_tool(call).await { + // Push tool output messages directly into the conversation as fast as we can + for message in &tool_result.messages { + output.send(PredictionAction::ConversationAppend(message.clone())).unwrap(); + } + results.push((&call.id, tool_result)); + } else { + results.push((&call.id, ToolResults::default())); + log::error!("Attemped to call {:?}, but no result was returned.", call); + } + }, + _ => panic!("Unknown tool was called") + } + } + + let mut tool_messages = vec![]; + for (id, mut result) in results { + let mut msg = ChatCompletionRequestToolMessageArgs::default(); + msg.tool_call_id(id); + if let Some(output) = result.result { + log::debug!("Output: {}", output); + msg.content(output); + } + self.insert(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); + tool_messages.append(&mut result.messages); + } + // OpenAI requires we put all the tool call results before any other message, so we append them manually down here + for message in tool_messages { + if let Ok(next_msg) = message.clone().try_into() { + self.insert(next_msg); + } + } + } + if let Some(content) = message.message.content.as_ref() { + let options = serde_json::from_str(content.as_str())?; + return Ok((tokens_used as usize, Some(options))); + } else { + return Ok((tokens_used as usize, None)); + } + } else { + log::info!("No messages were received!"); + return Err(CharacterError::NoOutput); + } + } + + pub fn insert(&mut self, message: ChatCompletionRequestMessage) { + self.messages.push(message); + } +} + +pub struct CharacterControl { + sink: tokio::sync::mpsc::Sender, + outputs: tokio::sync::mpsc::Receiver +} + +impl CharacterControl { + async fn send(&mut self, message: CharacterInput) { + self.sink.send(message).await.unwrap() + } + + pub async fn append(&mut self, message: ChatCompletionRequestMessage) { + self.send(CharacterInput::Append(message)).await; + } + + pub async fn predict>(&mut self, context: T) { + self.send(CharacterInput::Predict(context.into())).await; + } + + pub async fn recv(&mut self) -> CharacterOutput { + self.outputs.recv().await.unwrap() + } + + pub async fn forget(&mut self) { + self.send(CharacterInput::Forget).await; + } +} + +pub async fn character_task(mut char: Character, mut toolbox: T, mut message_sink: UnboundedSender, schema: Value) -> CharacterControl { + let (input_sink, mut input_src) = tokio::sync::mpsc::channel(3); + let (output_sink, output_src) = tokio::sync::mpsc::channel(3); + + let ret = CharacterControl { sink: input_sink, outputs: output_src }; + + tokio::spawn(async move { + let mut client = Default::default(); + + loop { + if let Some(next_msg) = input_src.recv().await { + log::debug!("Character receive message {:?}", next_msg); + match next_msg { + CharacterInput::Append(next_msg) => { + log::debug!("Inserting to backlog"); + char.insert(next_msg); + }, + CharacterInput::Predict(context) => { + log::debug!("Predicting..."); + output_sink.send(CharacterOutput::Thinking(true)).await.unwrap(); + match char.regenerate(&mut client, context, &mut toolbox, &mut message_sink, &schema).await { + Ok((tokens_used, Some(response))) => { + log::debug!("Complete response: {:?}", response); + output_sink.send(CharacterOutput::Response(tokens_used, response)).await.unwrap(); + output_sink.send(CharacterOutput::Thinking(false)).await.unwrap(); + }, + Ok((tokens_used, None)) => { + log::debug!("Incremental response"); + output_sink.send(CharacterOutput::IncrementalResponse(tokens_used)).await.unwrap(); + }, + Err(err) => { + log::error!("Error while predicting: {:?}", err); + output_sink.send(CharacterOutput::Thinking(false)).await.unwrap(); + } + } + }, + CharacterInput::Forget => { + log::debug!("Wiping conversation backlog"); + char.messages.clear(); + } + } + } else { + return; + } + } + }); + + ret +} \ No newline at end of file diff --git a/src/prediction/mod.rs b/src/prediction/mod.rs new file mode 100644 index 0000000..d590865 --- /dev/null +++ b/src/prediction/mod.rs @@ -0,0 +1,348 @@ +use std::{collections::HashMap, fmt::Debug, sync::Arc}; + +use async_openai::types::chat::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, }; +use chrono::{DateTime, Utc}; +use futures::stream::FuturesUnordered; +use schemars::{JsonSchema, schema_for}; +use serde::{Deserialize, Serialize}; +use serde_json::{Serializer, Value, ser::CompactFormatter}; +use tokio::sync::{Mutex, mpsc::{self, UnboundedReceiver, UnboundedSender}}; + +use crate::{SaveData, artifacts::{Contents, Track, archive::Archive, mixxx::{MixxxDB, MixxxQuery}, tools::DataSource}, prediction::{character::{Character, CharacterControl, CharacterOutput, character_task}, toolbox::{ArchiveToolbox, StageToolbox}}, scene::{Scene, StageDirection, conversation::{ConversationEntry, Speaker}}}; +use tokio_stream::StreamExt; + +pub mod character; +pub mod toolbox; + +const SYSTEM_PROMPT: &str = include_str!("../system-prompt.txt"); +const COMPUTER_PROMPT: &str = include_str!("../computer-prompt.txt"); + +#[derive(Debug, Clone)] +pub enum PredictionAction { + ConversationAppend(ConversationEntry), + ComputerCommand(String), + SetPlaylist(String), + GeneratePredictions, + SetNarrative(String), + SetShowEndTime(DateTime) +} + +#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)] +pub struct PossibleResponse { + pub text: String, + pub stage_direction: Option +} + +#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)] +pub struct GeneratedResponses { + pub responses: Vec, +} + +#[derive(Serialize, Deserialize, JsonSchema, Debug, Default, Clone)] +struct ComputerResponse { + message: String, + finished: Option +} + +#[derive(Debug)] +pub struct SessionControl { + event_sink: mpsc::UnboundedSender, + event_src: mpsc::UnboundedReceiver +} + +#[derive(Debug)] +pub enum SessionUpdate { + Scene(Scene), + Thinking(Speaker, bool), + Conversation(Vec), + Responses(GeneratedResponses) +} + +impl SessionControl { + pub async fn insert(&self, action: PredictionAction) { + self.event_sink.send(action).unwrap(); + } + + pub async fn regenerate_options(&self) { + self.insert(PredictionAction::GeneratePredictions).await; + } + + pub async fn changed(&mut self) -> SessionUpdate { + self.event_src.recv().await.unwrap() + } +} + +struct Conversation { + characters: HashMap, + event_sink: UnboundedSender, + input_src: UnboundedReceiver, + backlog: Vec, + eva_backlog: Vec, + tokens_consumed: usize, + direction: StageDirection, + + computer_todo: Arc>>, + archive: Arc>, + current_playlist: Vec, + sys_log_messages: UnboundedReceiver +} + +impl Conversation { + async fn send_to(&mut self, speaker: Speaker, message: ChatCompletionRequestMessage) { + log::debug!("Sending message to {:?}: {:?}", speaker, message); + self.characters.get_mut(&speaker).unwrap().append(message).await; + } + + async fn insert(&mut self, entry: ConversationEntry) { + self.backlog.push(entry.clone()); + self.event_sink.send(SessionUpdate::Conversation(self.backlog.clone())).unwrap(); + match entry { + ConversationEntry::Spoken(_, _) => { + if let Ok(next_msg) = TryInto::::try_into(entry) { + self.send_to(Speaker::Eva, next_msg.clone()).await; + let cxt = self.context_for_speaker(Speaker::Eva).await; + self.characters.get_mut(&Speaker::Eva).unwrap().predict(cxt).await; + self.eva_backlog.push(next_msg); + } + self.event_sink.send(SessionUpdate::Thinking(Speaker::Eva, true)).unwrap(); + }, + ConversationEntry::ShipComputerCommand(ref command) => { + log::debug!("Queued ship computer command: {:?}", command); + self.send_to(Speaker::ShipComputer, ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessageArgs::default() + .content(command.clone()) + .build().unwrap() + )).await; + let cxt = self.context_for_speaker(Speaker::ShipComputer).await; + self.characters.get_mut(&Speaker::ShipComputer).unwrap().predict(cxt).await; + self.event_sink.send(SessionUpdate::Thinking(Speaker::ShipComputer, true)).unwrap(); + }, + _ => () + } + } + + async fn get_character_output(characters: &mut HashMap) -> Option<(Speaker, CharacterOutput)> { + let mut futures: FuturesUnordered<_> = + characters.iter_mut().map(|c| { + async { (*c.0, c.1.recv().await) } + }).collect(); + futures.next().await + } + + async fn context_for_speaker(&self, speaker: Speaker) -> ChatCompletionRequestMessage { + let mut json_buf = vec![]; + let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter); + let contents = match speaker { + Speaker::ShipComputer => serde_json::json!({ + "archive": *self.archive.lock().await, + "playlist": &self.current_playlist, + "todo": *self.computer_todo.lock().await + }), + Speaker::Eva => serde_json::json!({ + "direction": &self.direction, + "archive": *self.archive.lock().await, + "playlist": &self.current_playlist + }), + _ => unimplemented!() + }; + + contents.serialize(&mut ser).unwrap(); + + ChatCompletionRequestSystemMessageArgs::default() + .content(String::from_utf8(json_buf).unwrap()) + .build().unwrap().into() + } + + async fn process_dialog(&mut self, speaker: Speaker, value: Value) { + match speaker { + Speaker::Eva => { + let next_options = serde_json::from_value(value).unwrap(); + self.event_sink.send(SessionUpdate::Responses(next_options)).unwrap(); + }, + Speaker::ShipComputer => { + let response: ComputerResponse = serde_json::from_value(value).unwrap(); + self.insert(ConversationEntry::Spoken(Speaker::ShipComputer, response.message)).await; + if response.finished.unwrap_or_default() { + self.characters.get_mut(&Speaker::ShipComputer).unwrap().forget().await; + if self.computer_todo.lock().await.iter().filter(|(_, is_finished)| !*is_finished).next().is_none() { + self.insert(ConversationEntry::StageDirection("The ship computer goes idle.".into())).await; + self.event_sink.send(SessionUpdate::Thinking(Speaker::ShipComputer, false)).unwrap(); + } else { + self.insert(ConversationEntry::StageDirection("The ship computer starts another task.".into())).await; + let cxt = self.context_for_speaker(Speaker::ShipComputer).await; + self.characters.get_mut(&Speaker::ShipComputer).unwrap().predict(cxt).await; + } + } else { + let cxt = self.context_for_speaker(Speaker::ShipComputer).await; + self.characters.get_mut(&Speaker::ShipComputer).unwrap().predict(cxt).await; + } + }, + _ => unreachable!() + } + self.refresh().await; + } + + async fn refresh(&mut self) { + let save_data = SaveData { + direction: self.direction.clone(), + messages: self.eva_backlog.clone(), + archive: self.archive.lock().await.clone(), + tokens_consumed: self.tokens_consumed + }; + save_data.save(); + + let next_scene = Scene::new( + save_data.tokens_consumed, + save_data.direction, + &save_data.archive, + self.current_playlist.clone(), + self.computer_todo.lock().await.clone() + ); + + self.event_sink.send(SessionUpdate::Scene(next_scene)).unwrap(); + } + + async fn run_action(&mut self, action: PredictionAction) { + match action { + PredictionAction::ConversationAppend(entry) => self.insert(entry).await, + PredictionAction::GeneratePredictions => { + let cxt = self.context_for_speaker(Speaker::Eva).await; + self.characters.get_mut(&Speaker::Eva).unwrap().predict(cxt).await; + }, + PredictionAction::ComputerCommand(command) => { + self.insert(ConversationEntry::ShipComputerCommand(command)).await; + }, + PredictionAction::SetPlaylist(playlist_name) => { + let args = MixxxQuery { playlist_name }; + match MixxxDB.query(&args).await { + Err(err) => log::info!("Failed to load mixxx playlist: {:?}.", err), + Ok(playlist) => { + self.current_playlist = vec![]; + for item in playlist.clone() { + if let Contents::Track(as_track) = item.contents() { + self.current_playlist.push(as_track.clone()); + } + self.archive.lock().await.insert(item); + } + self.direction.playlist = args.playlist_name; + log::info!("Mixxx playlist reloaded."); + + self.refresh().await; + } + } + }, + PredictionAction::SetNarrative(narrative) => { + self.direction.narrative = narrative; + self.refresh().await; + }, + PredictionAction::SetShowEndTime(end_time) => { + self.direction.end_time = end_time; + self.refresh().await; + } + } + } + + async fn next(&mut self) { + tokio::select! { + Some(next_log) = self.sys_log_messages.recv() => { + self.backlog.push(next_log); + self.event_sink.send(SessionUpdate::Conversation(self.backlog.clone())).unwrap(); + }, + Some(next_msg) = self.input_src.recv() => { + log::debug!("Next message: {:?}", next_msg); + self.run_action(next_msg).await; + }, + Some((speaker, output)) = Self::get_character_output(&mut self.characters) => { + match output { + CharacterOutput::Response(usage, text) => { + log::debug!("Character output: {:?} {:?}", speaker, text); + self.tokens_consumed += usage; + self.process_dialog(speaker, text).await; + }, + CharacterOutput::Thinking(is_thinking) => { + // Ship computer handles this differently + if is_thinking || speaker != Speaker::ShipComputer { + self.event_sink.send(SessionUpdate::Thinking(speaker, is_thinking)).unwrap(); + } + }, + CharacterOutput::IncrementalResponse(usage) => { + self.tokens_consumed += usage; + let cxt = self.context_for_speaker(speaker).await; + self.characters.get_mut(&speaker).unwrap().predict(cxt).await; + self.refresh().await; + } + } + } + } + } +} + +pub async fn conversation_task(save_data: SaveData, sys_log_messages: tokio::sync::mpsc::UnboundedReceiver) -> SessionControl { + let (input_sink, input_src) = tokio::sync::mpsc::unbounded_channel(); + let (event_sink, event_src) = tokio::sync::mpsc::unbounded_channel(); + + let eva = Character { + header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(), + model: None, + messages: save_data.messages.clone(), + }; + + let backlog: Vec<_> = save_data.messages.iter().filter_map(|msg| { + if let Ok(entry) = msg.clone().try_into() { + Some(entry) + } else { + None + } + }).collect(); + + event_sink.send(SessionUpdate::Conversation(backlog.clone())).unwrap(); + + let next_scene = Scene::new( + save_data.tokens_consumed, + save_data.direction.clone(), + &save_data.archive, + vec![], + Default::default() + ); + event_sink.send(SessionUpdate::Scene(next_scene)).unwrap(); + let archive = Arc::new(Mutex::new(save_data.archive)); + + let shared_todo = Arc::new(Mutex::new(Default::default())); + + let toolbox = StageToolbox; + let computer_toolbox = ArchiveToolbox{ archive: Arc::clone(&archive), todo_list: Arc::clone(&shared_todo) }; + + let ship_computer = Character { + header_message: ChatCompletionRequestSystemMessageArgs::default().content(COMPUTER_PROMPT).build().unwrap().into(), + model: Some("gpt-5.4-nano".into()), + messages: vec![], + }; + + let mut conversation = Conversation { + characters: HashMap::from_iter([ + (Speaker::Eva, character_task(eva, toolbox, input_sink.clone(), schema_for!(GeneratedResponses).into()).await), + (Speaker::ShipComputer, character_task(ship_computer, computer_toolbox, input_sink.clone(), schema_for!(ComputerResponse).into()).await), + ]), + event_sink, + input_src, + backlog, + eva_backlog: Default::default(), + tokens_consumed: save_data.tokens_consumed, + direction: save_data.direction, + archive, + current_playlist: vec![], + sys_log_messages, + computer_todo: shared_todo + }; + + tokio::spawn(async move { + loop { + conversation.next().await; + } + }); + + SessionControl { + event_sink: input_sink, + event_src + } +} \ No newline at end of file diff --git a/src/prediction/toolbox.rs b/src/prediction/toolbox.rs new file mode 100644 index 0000000..460b324 --- /dev/null +++ b/src/prediction/toolbox.rs @@ -0,0 +1,172 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_openai::types::chat::{ChatCompletionMessageToolCall, ChatCompletionTools}; +use schemars::{JsonSchema, schema_for}; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; + +use crate::{artifacts::{archive::Archive, bandcamp::BandcampSource, beets::BeetsDB, mixxx::MixxxDB, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::conversation::{ConversationEntry, Speaker}}; + +pub trait Toolbox { + fn tools(&self) -> Vec; + fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> impl Future> + Send; +} + +pub struct StageToolbox; + +impl StageToolbox { + async fn tool_stage_event(&self, args: StageEventArgs) -> ToolResults { + let msg = match args.event { + StageEvent::ShipComputer(text) => ConversationEntry::ShipComputerCommand(text), + StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text) + }; + ToolResults { + result: Some(format!("Added to scene: {:?}", msg)), + messages: vec![msg], + } + } +} + +impl Toolbox for StageToolbox { + fn tools(&self) -> Vec { + vec![ + Tool { name: "log_stage_event".into(), description: "Inserts an event into the current scene script".into(), schema: schema_for!(StageEventArgs)}.into(), + ] + } + + async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Option { + let func_name = call.function.name.as_str(); + let args = call.function.arguments.as_str(); + Some(match func_name { + "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, + _ => return None + }) + } +} + +pub struct ArchiveToolbox { + pub archive: Arc>, + pub todo_list: Arc>> +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +enum TaskListOperation { + Complete(String), + Add(String) +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Default)] +struct TaskListArgs { + operations: Vec +} + +#[derive(Deserialize, JsonSchema)] +struct ArtifactSyncArgs { + reason: String +} + +impl Toolbox for ArchiveToolbox { + fn tools(&self) -> Vec { + vec![ + Tool { name: "synchronize_artifacts".into(), description: "Attempts to automatically synchronize the current set of artifacts with missing sources".into(), schema: schema_for!(ArtifactSyncArgs)}.into(), + Tool { name: "task_list".into(), description: "Allows you to maintain a long-running todo list as you complete your research".into(), schema: schema_for!(TaskListArgs)}.into(), + // TODO: There should only be two queries, one against the ship's onboard archive, and another against the relay network, or whatever we call it. Both should be structured with the same arguments schema + // TODO: A query should specify what parts of metadata are sufficient for the result, so we don't always have to hit all the layers of data. beets can of course, ignore this. + // TODO: A query should be hierarchical somehow? eg, "I already know about artist X, but I want to know everything about track Y from album Z" or "I don't know anything about artist X/album Y, please give me an overview" + Tool::from_datasource(&MBQuery).into(), + //Tool::from_datasource(&BandcampSource).into(), + Tool::from_datasource(&BeetsDB).into(), + Tool::from_datasource(&MixxxDB).into(), + // TODO: We should be able to have eva update lore memories with a function call, and this lore is somehow fed into the show? but only the relevant bits? or maybe eva even queries it directly + // TODO: The memory should also be able to remember facts about artists, albums, tracks we've had in the past, and those could be pulled up when there are hits in the playlist. + ] + } + + async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Option { + let func_name = call.function.name.as_str(); + let args = call.function.arguments.as_str(); + Some(match func_name { + "query_bandcamp" => ToolResults { result: None, messages: vec![] }, + "query_beets" => self.tool_artifact_query(&mut BeetsDB, args).await, + "query_musicbrainz" => self.tool_artifact_query(&mut MBQuery, args).await, + "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, + "synchronize_artifacts" => self.synchronize_artifacts().await, + "task_list" => self.tasklist_operation(args).await, + _ => return None + }) + } +} + +impl ArchiveToolbox { + async fn tasklist_operation(&mut self, json_args: &str) -> ToolResults { + let args: TaskListArgs = serde_json::from_str(json_args).unwrap_or_default(); + + let mut locked = self.todo_list.lock().await; + + for op in args.operations { + match op { + TaskListOperation::Add(task) => { + locked.insert(task, false); + }, + TaskListOperation::Complete(task) => { + // FIXME: The computer seems to waste a lot of time marking already completed tasks as completed + if let Some(result) = locked.get_mut(&task) { + *result = true; + } + }, + } + } + + ToolResults { + ..Default::default() + } + } + + async fn synchronize_artifacts(&mut self) -> ToolResults { + let updated_count = self.archive.lock().await.synchronize().await; + + ToolResults { + messages: vec![ConversationEntry::Spoken(Speaker::ShipComputer, format!("Synchronized {} items", updated_count))], + ..Default::default() + } + } + + async fn tool_artifact_query(&mut self, src: &mut Src, json_args: &str) -> ToolResults where Src::Args: core::fmt::Debug, Src::Error: core::fmt::Debug { + let args: Src::Args = serde_json::from_str(json_args).unwrap(); + log::debug!("Executing query {:?}", args); + let result; + match src.query(&args).await { + Ok(output) => { + result = format!("Found {} artifacts with archive query {:?}", output.len(), args); + for result in output { + self.archive.lock().await.insert(result); + } + }, + Err(err) => { + result = format!("Unable to execute query: {:?}", err); + } + } + + ToolResults { + result: Some(result), + messages: vec![] + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] +enum StageEvent { + ShipComputer(String), + StageDirection(String) +} + +#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] +struct StageEventArgs { + event: StageEvent +} + +#[derive(Default, Debug)] +pub struct ToolResults { + pub result: Option, + pub messages: Vec +} \ No newline at end of file diff --git a/src/scene/conversation.rs b/src/scene/conversation.rs index 962cbcd..ff76cb9 100644 --- a/src/scene/conversation.rs +++ b/src/scene/conversation.rs @@ -2,33 +2,41 @@ use async_openai::types::chat::{ChatCompletionRequestAssistantMessage, ChatCompl use ratatui::style::{self, Style}; use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Copy)] +pub enum Speaker { + User, + ShipComputer, + Eva +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum ConversationEntry { - User(String), - Eva(String), - ShipComputer(String), + Spoken(Speaker, String), StageDirection(String), #[serde(skip)] - SystemMessage(String) + SystemMessage(String), + ShipComputerCommand(String) } impl ConversationEntry { pub fn prefix(&self) -> Option<&str> { match self { - ConversationEntry::Eva(_) => Some("Eva: "), - ConversationEntry::User(_) => Some("Argee: "), - ConversationEntry::ShipComputer(_) => Some("Ship Computer: "), + ConversationEntry::Spoken(Speaker::Eva, _) => Some("Eva: "), + ConversationEntry::Spoken(Speaker::User, _) => Some("Argee: "), + ConversationEntry::Spoken(Speaker::ShipComputer, _) => Some("Ship Computer: "), + ConversationEntry::ShipComputerCommand(_) => Some("> "), _ => None, } } pub fn prefix_style(&self) -> Style { match self { - ConversationEntry::Eva(_) => Style::new().fg(style::Color::Cyan), - ConversationEntry::User(_) => Style::new().fg(style::Color::Magenta), - ConversationEntry::ShipComputer(_) => Style::new().fg(style::Color::Red), + ConversationEntry::Spoken(Speaker::Eva, _) => Style::new().fg(style::Color::Cyan), + ConversationEntry::Spoken(Speaker::User, _) => Style::new().fg(style::Color::Magenta), + ConversationEntry::Spoken(Speaker::ShipComputer, _) => Style::new().fg(style::Color::Red), ConversationEntry::StageDirection(_) => Style::new().fg(style::Color::Yellow), ConversationEntry::SystemMessage(_) => Style::new().fg(style::Color::DarkGray), + ConversationEntry::ShipComputerCommand(_) => Style::new().fg(style::Color::Red).bold(), } } @@ -36,6 +44,7 @@ impl ConversationEntry { match self { ConversationEntry::StageDirection(_) => Style::new().fg(style::Color::Yellow), ConversationEntry::SystemMessage(_) => Style::new().fg(style::Color::DarkGray), + ConversationEntry::ShipComputerCommand(_) => Style::new().fg(style::Color::Red).italic(), _ => Style::new() } } @@ -44,11 +53,12 @@ impl ConversationEntry { impl ToString for ConversationEntry { fn to_string(&self) -> String { match self { - ConversationEntry::Eva(text) => text, - ConversationEntry::ShipComputer(text) => text, + ConversationEntry::Spoken(Speaker::Eva, text) => text, + ConversationEntry::Spoken(Speaker::ShipComputer, text) => text, ConversationEntry::StageDirection(text) => text, ConversationEntry::SystemMessage(text) => text, - ConversationEntry::User(text) => text + ConversationEntry::Spoken(Speaker::User, text) => text, + ConversationEntry::ShipComputerCommand(text) => text }.clone() } } @@ -56,9 +66,9 @@ impl ToString for ConversationEntry { impl TryInto for ConversationEntry { fn try_into(self) -> Result { match self { - ConversationEntry::User(text) => Ok(ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: text.into(), ..Default::default()})), - ConversationEntry::Eva(text) => Ok(ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(text.into()), ..Default::default()})), - ConversationEntry::ShipComputer(text) => Ok(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("ship-computer".into()), ..Default::default() })), + ConversationEntry::Spoken(Speaker::User, text) => Ok(ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: text.into(), ..Default::default()})), + ConversationEntry::Spoken(Speaker::Eva, text) => Ok(ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(text.into()), ..Default::default()})), + ConversationEntry::Spoken(Speaker::ShipComputer, text) => Ok(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("ship-computer".into()), ..Default::default() })), ConversationEntry::StageDirection(text) => Ok(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("stage-direction".into()), ..Default::default() })), _ => Err(()) } @@ -71,11 +81,11 @@ impl TryInto for ConversationEntry { impl TryInto for ChatCompletionRequestMessage { fn try_into(self) -> Result { match self { - ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Text(msg), ..}) => Ok(ConversationEntry::User(msg)), - ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text(msg)), ..}) => Ok(ConversationEntry::Eva(msg)), + ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Text(msg), ..}) => Ok(ConversationEntry::Spoken(Speaker::User, msg)), + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text(msg)), ..}) => Ok(ConversationEntry::Spoken(Speaker::Eva, msg)), ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: ChatCompletionRequestSystemMessageContent::Text(msg), name: Some(name), ..}) => { match name.as_str() { - "ship-computer" => Ok(ConversationEntry::ShipComputer(msg)), + "ship-computer" => Ok(ConversationEntry::Spoken(Speaker::ShipComputer, msg)), "stage-direction" => Ok(ConversationEntry::StageDirection(msg)), _ => Err(()) } @@ -85,39 +95,4 @@ impl TryInto for ChatCompletionRequestMessage { } type Error = (); -} - -pub struct ConversationRunner { - sink: tokio::sync::watch::Sender> -} - -pub type ConversationSrc = tokio::sync::watch::Receiver>; -pub type ConversationSink = tokio::sync::mpsc::UnboundedSender; - -impl ConversationRunner { - pub fn new() -> (Self, ConversationSrc) { - let (sink, src) = tokio::sync::watch::channel(vec![]); - (Self { - sink - }, src) - } - - pub fn insert(&mut self, entry: ConversationEntry) { - self.sink.send_modify(|contents| { - contents.push(entry) - }); - } -} - -pub async fn start_conversation() -> (ConversationSrc, ConversationSink) { - let (raw_sink, mut raw_src) = tokio::sync::mpsc::unbounded_channel(); - let (mut runner, src) = ConversationRunner::new(); - - tokio::spawn(async move { - while let Some(evt) = raw_src.recv().await { - runner.insert(evt); - } - }); - - (src, raw_sink) } \ No newline at end of file diff --git a/src/scene/mod.rs b/src/scene/mod.rs index a467c8a..ea5aa4c 100644 --- a/src/scene/mod.rs +++ b/src/scene/mod.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use chrono::{DateTime, Duration, Utc}; use serde::{Deserialize, Serialize}; -use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}}; +use crate::artifacts::{Track, archive::Archive}; pub mod conversation; @@ -34,27 +36,25 @@ impl Default for StageDirection { } } -#[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct Scenery { - pub artifacts: Archive, - pub current_playlist: Vec -} - #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct Scene { - reply_options: GeneratedResponses, direction: StageDirection, pub tokens_consumed: usize, - scenery: Scenery + pub current_playlist: Vec, + pub artifact_count: usize, + pub artifact_stats: (usize, usize, usize), + pub computer_task_list: HashMap } impl Scene { - pub fn new(reply_options: GeneratedResponses, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self { + pub fn new(tokens_consumed: usize, direction: StageDirection, archive: &Archive, current_playlist: Vec, computer_task_list: HashMap) -> Self { Self { - reply_options, - scenery, tokens_consumed, - direction + direction, + current_playlist, + artifact_count: archive.len(), + artifact_stats: archive.stats(), + computer_task_list } } @@ -62,11 +62,7 @@ impl Scene { &self.direction } - pub fn scenery(&self) -> &Scenery { - &self.scenery - } - - pub fn reply_options(&self) -> &Vec { - &self.reply_options.responses + pub fn playlist(&self) -> &Vec { + &self.current_playlist } } \ No newline at end of file diff --git a/src/system-prompt.txt b/src/system-prompt.txt index 0864529..938d13b 100644 --- a/src/system-prompt.txt +++ b/src/system-prompt.txt @@ -14,19 +14,22 @@ You are playing the role of the artificial intelligence in a spaceship computer. Along the way, you have the opportunity to invent lore and backstory for yourself, Argee, and the ship you both inhabit. This lore will be developed over the season. -To support your roleplaying, you have access to a sizable music library via the "archive_query" tool function. Internally, this runs `beet export` to produce json output. +To support your roleplaying, you have access to a sizable shipboard music archive. You will occasionally be asked by Argee for information on the contents of the archive and how they are related to tracks in the playlist. -You also may use the archive to decide whether or not an artifact is somehow "familiar" based on whether or not it can be found there. -You must provide at least one parameter when calling this tool; it is wasteful to call it without any arguments. +The archive can be manipulated by interacting with the ship computer through the "stage_event" tool. +When interacting with it, you should use straightforward english commands or questions without roleplay. +For example, to have the computer search for information regarding the artist "nullsleep", you would tell the computer: -Another tool function named "musicbrainz_track_search" can be given a list of musicbrainz IDs (mbids), which will substantially expand the information available in the ship's artifact library. -You are able to use this function whenever it might be helpful to look up missing albums, tracks, or artists. -You should immediately run this tool against any new or unfamiliar musicbrainz IDs for tracks that get added to the list of artifacts available. + "load artist nullsleep" -There also exists a "bandcamp_artifact_scan" tool function, which will execute a search on Bandcamp and return a JSON list of artists and albums matching the query. -You are able to run multiple queries in parallel, and it is expected that you will run this tool whenever there is something unfamiliar in the playlist for the current episode, or Argee asks you for more information about the items in the playlist. +Eventually it will inform you when it has completed the tasks. -Queries to bandcamp are somewhat expensive, so you should check with the archive and musicbrainz first if you haven't already. +While it is working on its tasks, it will occasionally announce its current progress. +These messages should only be used as an indication of what the computer is doing. +You should only be replying to the computer's output if it is required to complete the task that Argee has given you. + +You can ask it questions about items in the archive and it will attempt to fetch what it can from the outside world. +If you are asked to load information about the archive, or somehow find data is missing from the available list of artifacts, you must ask the ship computer to load the required data. # Scene The show features Argee, the main character of the show. diff --git a/src/ui.rs b/src/ui.rs index 042525f..dc6877d 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -1,23 +1,26 @@ use chrono::{Duration, Utc}; use crossterm::event::{self, KeyCode, KeyModifiers}; -use ratatui::{Frame, layout::{Direction, Layout, Rect}, style::{self, }, text::Span, widgets::{BorderType, Clear, ListState, Paragraph}}; +use ratatui::{Frame, layout::{Direction, Layout, Rect}, style::{self, Style, Stylize, }, text::Span, widgets::{BorderType, Clear, List, ListState, Paragraph}}; use throbber_widgets_tui::{Throbber, ThrobberState}; use tokio::time::Instant; use tui_input::{Input, backend::crossterm::EventHandler}; use tui_skeleton::{AnimationMode, Block, Constraint, SkeletonText}; -use crate::{audio::AudioInputControl, prediction::{PredictionAction, SessionControl, SessionUpdate}, scene::{Scene, conversation::ConversationEntry}, transcription::TranscriptionControl, tts::TtsControl}; +use crate::{audio::AudioInputControl, prediction::{GeneratedResponses, PredictionAction, SessionControl, SessionUpdate}, scene::{Scene, conversation::{ConversationEntry, Speaker}}, transcription::TranscriptionControl, tts::TtsControl}; use crate::widgets::*; #[derive(Debug)] pub struct Ui { scene: Scene, + reply_options: GeneratedResponses, + reply_state: ListState, conversation_state: ListState, user_input: Input, throbber_state: ThrobberState, is_requesting: bool, + computer_is_thinking: bool, audio_level: f64, recording_audio: bool, focus_state: FocusState, @@ -45,6 +48,7 @@ impl Ui { user_input: Default::default(), throbber_state: Default::default(), is_requesting: false, + computer_is_thinking: false, audio_level: -60., audio, recording_audio: false, @@ -53,7 +57,8 @@ impl Ui { tts, predictions, last_tick: Instant::now(), - conversation: vec![] + conversation: vec![], + reply_options: Default::default() } } @@ -67,7 +72,7 @@ impl Ui { .block(borders); frame.render_widget(list, area); } else { - frame.render_stateful_widget(Options(self.scene.reply_options()), area, &mut self.reply_state); + frame.render_stateful_widget(Options(&self.reply_options.responses), area, &mut self.reply_state); } } @@ -83,7 +88,17 @@ impl Ui { fn draw_io_throbber(&mut self, frame: &mut Frame, area: Rect) { let throb_area = area.centered(Constraint::Max(1), Constraint::Max(1)); if self.is_requesting { - let throbber = Throbber::default(); + let throbber = Throbber::default().throbber_style(Style::new().cyan()); + frame.render_stateful_widget(throbber, throb_area, &mut self.throbber_state); + } else { + frame.render_widget(Clear::default(), throb_area); + } + } + + fn draw_computer_throbber(&mut self, frame: &mut Frame, area: Rect) { + let throb_area = area.centered(Constraint::Min(1), Constraint::Min(1)); + if self.computer_is_thinking { + let throbber = Throbber::default().throbber_style(Style::new().red()).throbber_set(throbber_widgets_tui::WHITE_SQUARE); frame.render_stateful_widget(throbber, throb_area, &mut self.throbber_state); } else { frame.render_widget(Clear::default(), throb_area); @@ -113,31 +128,34 @@ impl Ui { let scene_layout = Layout::default() .direction(Direction::Horizontal) - .constraints([Constraint::Fill(4), Constraint::Fill(1)]) + .constraints([Constraint::Fill(5), Constraint::Fill(2)]) .split(layout[0]); frame.render_stateful_widget(Conversation(&self.conversation), scene_layout[0], &mut self.conversation_state); - self.draw_narration(frame, scene_layout[1]); + //self.draw_narration(frame, scene_layout[1]); + //self.draw_tasklist(frame, scene_layout[1]); + frame.render_widget(TaskList(&self.scene.computer_task_list), scene_layout[1]); self.draw_options(frame, layout[1]); let status_layout = Layout::default() .direction(Direction::Horizontal) - .constraints([Constraint::Max(3), Constraint::Fill(2), Constraint::Max(13), Constraint::Min(50)]) + .constraints([Constraint::Max(3), Constraint::Max(3), Constraint::Fill(2), Constraint::Max(13), Constraint::Min(50)]) .split(layout[3]); self.draw_user_input(frame, layout[2]); self.draw_io_throbber(frame, status_layout[0]); - frame.render_widget(StatusBar(&self.scene), status_layout[1]); - frame.render_widget(RecordingStatus(self.recording_audio), status_layout[2]); - frame.render_widget(Volume(self.audio_level, self.recording_audio), status_layout[3]); + self.draw_computer_throbber(frame, status_layout[1]); + frame.render_widget(StatusBar(&self.scene), status_layout[2]); + frame.render_widget(RecordingStatus(self.recording_audio), status_layout[3]); + frame.render_widget(Volume(self.audio_level, self.recording_audio), status_layout[4]); } async fn insert_selected_prompt(&mut self) { - let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone(); + let selected = self.reply_options.responses[self.reply_state.selected().unwrap()].clone(); if let Some(direction) = &selected.stage_direction { self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(direction.clone()))).await; } - self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Eva(selected.text.clone()))).await; + self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Spoken(Speaker::Eva, selected.text.clone()))).await; } async fn on_command(&mut self, command: &str) { @@ -174,7 +192,8 @@ impl Ui { self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(arg.to_string()))).await; }, "/computer" => { - self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::ShipComputer(arg.to_string()))).await; + self.predictions.insert(PredictionAction::ComputerCommand(arg.to_string())).await; + return; }, _ => { log::error!("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]"); @@ -183,6 +202,7 @@ impl Ui { } pub async fn on_event(&mut self, evt: event::Event) { + // TODO: ctrl+l should drop all the SystemMessage entries from the log to clean up the UI if let Some(key) = evt.as_key_press_event() { match self.focus_state { FocusState::Conversation => { @@ -200,7 +220,7 @@ impl Ui { KeyCode::Up => self.conversation_state.select_next(), KeyCode::Enter => { let row_num = self.conversation_state.selected().unwrap(); - if let ConversationEntry::Eva(text) = &self.conversation[self.conversation.len() - 1 - row_num] { + if let ConversationEntry::Spoken(Speaker::Eva, text) = &self.conversation[self.conversation.len() - 1 - row_num] { self.tts.speak(text.clone()).await; self.focus_state = FocusState::UserInput; self.conversation_state.select(None); @@ -240,7 +260,7 @@ impl Ui { if next_msg.starts_with("/") { self.on_command(&next_msg).await; } else { - self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(next_msg))).await; + self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Spoken(Speaker::User, next_msg))).await; } } }, @@ -259,14 +279,20 @@ impl Ui { } tokio::select!{ - _ = tokio::time::sleep(std::time::Duration::from_millis(60)), if self.is_requesting => (), + _ = tokio::time::sleep(std::time::Duration::from_millis(60)), if self.is_requesting || self.computer_is_thinking => (), next_update = self.predictions.changed() => { match next_update { - SessionUpdate::Thinking(is_thinking) => self.is_requesting = is_thinking, + SessionUpdate::Thinking(Speaker::Eva, is_thinking) => self.is_requesting = is_thinking, + SessionUpdate::Thinking(Speaker::ShipComputer, is_thinking) => self.computer_is_thinking = is_thinking, + SessionUpdate::Thinking(_, _) => unreachable!(), SessionUpdate::Scene(scene) => { self.scene = scene; self.reply_state.select_first(); }, + SessionUpdate::Responses(responses) => { + self.reply_options = responses; + self.reply_state.select_first(); + } SessionUpdate::Conversation(conversation) => { self.conversation = conversation; }, @@ -276,7 +302,7 @@ impl Ui { self.audio_level = next_volume }, transcription_result = self.transcription.next() => { - self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(transcription_result))).await; + self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Spoken(Speaker::User, transcription_result))).await; }, } } diff --git a/src/widgets.rs b/src/widgets.rs index 1f794ba..ad42efc 100644 --- a/src/widgets.rs +++ b/src/widgets.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use chrono::Duration; use ratatui::{layout::Rect, style::Style, text::{Line, Text}, widgets::*}; use tui_skeleton::Block; @@ -5,6 +7,33 @@ use ratatui::prelude::*; use crate::{prediction::PossibleResponse, scene::{Scene, conversation::ConversationEntry}}; +pub struct TaskList<'a>(pub &'a HashMap); + +impl Widget for TaskList<'_> { + fn render(self, area: Rect, buf: &mut Buffer) + where + Self: Sized { + let borders = Block::bordered().border_style(style::Color::Red).title("Computer Tasks"); + let wrap_options_unfinished = textwrap::Options::new(area.width as usize).initial_indent("[ ] ").subsequent_indent(" "); + let wrap_options_finished = textwrap::Options::new(area.width as usize).initial_indent("[X] ").subsequent_indent(" "); + let options: Vec = self.0.iter().map(|(text, is_finished)| { + let (options, color) = if *is_finished { + (wrap_options_finished.clone(), style::Color::DarkGray) + } else { + (wrap_options_unfinished.clone(), style::Color::Green) + }; + let contents: Vec = textwrap::wrap(text, options) + .iter() + .map(|x| { Line::from(x.to_string()).fg(color)}).collect(); + Text::from_iter(contents) + }).collect(); + let list = List::new(options) + .block(borders) + .style(ratatui::style::Color::White); + Widget::render(list, area, buf); + } +} + pub struct Options<'a>(pub &'a Vec); impl StatefulWidget for Options<'_> { @@ -172,14 +201,17 @@ impl Widget for StatusBar<'_> { format!("{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours(), time_remaining.num_minutes() % 60, time_remaining.num_seconds() % 60) }; + let artifact_count = self.0.artifact_count; + let artifact_stats = self.0.artifact_stats; + let status_line = Line::from_iter([ Span::from(format!("Playlist: {}", self.0.direction().playlist)).style(ratatui::style::Color::LightBlue), Span::from(" | ").style(ratatui::style::Color::DarkGray), - Span::from(format!("{} tracks", self.0.scenery().current_playlist.len())).style(ratatui::style::Color::LightBlue), + Span::from(format!("{} tracks", self.0.playlist().len())).style(ratatui::style::Color::LightBlue), Span::from(" | ").style(ratatui::style::Color::DarkGray), Span::from(format!("Time Remaining: {}", formatted_time)).style(time_style), Span::from(" | ").style(ratatui::style::Color::DarkGray), - Span::from(format!("{} artifacts recorded", self.0.scenery().artifacts.len())).style(ratatui::style::Color::LightBlue), + Span::from(format!("{} ({}/{}/{}) artifacts recorded", artifact_count, artifact_stats.0, artifact_stats.1, artifact_stats.2)).style(ratatui::style::Color::LightBlue), Span::from(" | ").style(ratatui::style::Color::DarkGray), Span::from(format!("{} tokens sacrificed", self.0.tokens_consumed)).style(ratatui::style::Color::LightCyan), // TODO: Should show the available and consumed context window in terms of tokens here