From 89125d2defc8be3c7a25db28d39db920316e002f Mon Sep 17 00:00:00 2001 From: Victoria Fischer Date: Wed, 17 Jun 2026 22:16:19 +0200 Subject: [PATCH] prediction: split out maintenance (and thereby the logging interface) of the UI conversation to a separate task, so log::* can work in realtime. --- src/main.rs | 12 ++++---- src/prediction.rs | 62 +++++++++++++++++++-------------------- src/scene/conversation.rs | 35 ++++++++++++++++++++++ src/scene/mod.rs | 10 ++----- src/ui.rs | 15 ++++++---- 5 files changed, 84 insertions(+), 50 deletions(-) diff --git a/src/main.rs b/src/main.rs index 84d3585..6eae8d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use futures::StreamExt; use ratatui::prelude::*; -use crate::{audio::start_audio_input, scene::{Scenery, StageDirection}, tts::start_tts, ui::Ui}; +use crate::{audio::start_audio_input, scene::{Scenery, StageDirection, conversation::{ConversationEntry, start_conversation}}, tts::start_tts, ui::Ui}; mod scene; mod events; @@ -66,7 +66,7 @@ impl SaveData { } } -struct SysMessageLogger(Arc>, Mutex); +struct SysMessageLogger(Arc>, Mutex); impl log::Log for SysMessageLogger { fn enabled(&self, _metadata: &log::Metadata) -> bool { @@ -79,7 +79,7 @@ impl log::Log for SysMessageLogger { let msg = format!("{}", record.args()); write!(self.1.lock().unwrap(), "{}\n", msg).unwrap(); if record.level() <= LevelFilter::Info { - self.0.send(msg).unwrap(); + self.0.send(ConversationEntry::SystemMessage(msg)).unwrap(); } } } @@ -97,10 +97,10 @@ async fn main() { println!("Panic: {}", msg); })); - let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::unbounded_channel(); + let (conversation_src, conversation_sink) = start_conversation().await; static LOGGER: StaticCell> = StaticCell::new(); - let logger = LOGGER.init(SysMessageLogger(Arc::new(sys_message_sink), Mutex::new(std::fs::File::create("out.log").unwrap()))); + let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink.clone()), Mutex::new(std::fs::File::create("out.log").unwrap()))); log::set_logger(logger).unwrap(); log::set_max_level(log::LevelFilter::Debug); @@ -129,7 +129,7 @@ async fn main() { SaveData::default() }; - let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await; + let prediction_ctrl = prediction::start_prediction(saved_session, conversation_src, conversation_sink).await; let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await; let tts_ctrl = start_tts(tts_output).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await; diff --git a/src/prediction.rs b/src/prediction.rs index 605cbfb..0246999 100644 --- a/src/prediction.rs +++ b/src/prediction.rs @@ -1,13 +1,13 @@ -use std::{fmt::Debug, sync::Arc}; +use std::fmt::Debug; use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}}; use chrono::{DateTime, Utc}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::{Serializer, ser::CompactFormatter}; -use tokio::sync::{RwLock, mpsc, watch}; +use tokio::sync::{mpsc, watch}; -use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}}; +use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::{ConversationEntry, ConversationSink, ConversationSrc}}}; const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); @@ -36,7 +36,7 @@ pub struct GeneratedResponses { #[derive(Debug)] struct Session { client: Client, - conversation: Vec, + conversation: ConversationSink, header_message: ChatCompletionRequestMessage, messages: Vec, reply_options: GeneratedResponses, @@ -66,11 +66,11 @@ struct ToolResults { } impl Session { - fn new(scene_sink: watch::Sender, messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { - let mut conversation = vec![]; + fn new(scene_sink: watch::Sender, conversation: ConversationSink, messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { for msg in &messages { if let Ok(conversation_msg) = msg.clone().try_into() { - conversation.push(conversation_msg); + // FIXME + conversation.send(conversation_msg).unwrap(); } } @@ -277,6 +277,10 @@ impl Session { "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, _ => unreachable!() }; + // Push tool output messages directly into the conversation as fast as we can + for message in &tool_result.messages { + self.conversation.send(message.clone()).unwrap(); + } results.push((&call.id, tool_result)); }, _ => panic!("Unknown tool was called") @@ -293,8 +297,11 @@ impl Session { self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); tool_messages.append(&mut result.messages); } - for msg in tool_messages { - self.insert_conversation(msg); + // OpenAI requires we put all the tool call results before any other message, so we append them manually down here + for message in tool_messages { + if let Ok(next_msg) = message.clone().try_into() { + self.messages.push(next_msg); + } } } if let Some(content) = message.message.content.as_ref() { @@ -317,17 +324,15 @@ impl Session { } fn as_scene(&self) -> Scene { - Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) + Scene::new(self.reply_options.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) } fn insert_conversation(&mut self, entry: ConversationEntry) { - self.conversation.push(entry.clone()); + self.conversation.send(entry.clone()).unwrap(); if let Ok(next_msg) = entry.try_into() { self.messages.push(next_msg); } - - self.refresh(); } fn refresh(&self) { @@ -339,13 +344,15 @@ impl Session { pub struct SessionControl { event_sink: mpsc::Sender, scene_watch: watch::Receiver, - activity_watch: watch::Receiver + activity_watch: watch::Receiver, + conversation_watch: ConversationSrc } #[derive(Debug)] pub enum SessionUpdate { Scene(Scene), - Thinking(bool) + Thinking(bool), + Conversation(Vec) } impl SessionControl { @@ -364,40 +371,32 @@ impl SessionControl { }, _ = self.scene_watch.changed() => { SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone()) + }, + _ = self.conversation_watch.changed() => { + SessionUpdate::Conversation(self.conversation_watch.borrow_and_update().clone()) } } } } -pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver) -> SessionControl { +pub async fn start_prediction(saved_session: SaveData, conversation_src: ConversationSrc, conversations: ConversationSink) -> SessionControl { let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false); let (action_sink, mut action_src) = mpsc::channel(5); - let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); + let mut session = Session::new(prediction_in, conversations, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); // Send the initial scene to the UI, after we have loaded the session from the first messages. session.refresh(); - let shared_session = Arc::new(RwLock::new(session)); - - let log_session = Arc::clone(&shared_session); - tokio::spawn(async move { - loop { - if let Some(msg) = messages.recv().await { - log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg)); - } - } - }); - tokio::spawn(async move { loop { if let Some(evt) = action_src.recv().await { - shared_session.write().await.on_event(evt).await; + session.on_event(evt).await; // Commit in a separate unlock operation, so the logging task has time to write messages into the conversation // FIXME: The conversation we see in the UI really needs to go to another task. - shared_session.write().await.commit().await; + session.commit().await; } } }); @@ -405,6 +404,7 @@ pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync SessionControl { event_sink: action_sink, scene_watch: prediction_out, - activity_watch: activity_notify_src + activity_watch: activity_notify_src, + conversation_watch: conversation_src } } \ No newline at end of file diff --git a/src/scene/conversation.rs b/src/scene/conversation.rs index 9bf1a0b..962cbcd 100644 --- a/src/scene/conversation.rs +++ b/src/scene/conversation.rs @@ -85,4 +85,39 @@ impl TryInto for ChatCompletionRequestMessage { } type Error = (); +} + +pub struct ConversationRunner { + sink: tokio::sync::watch::Sender> +} + +pub type ConversationSrc = tokio::sync::watch::Receiver>; +pub type ConversationSink = tokio::sync::mpsc::UnboundedSender; + +impl ConversationRunner { + pub fn new() -> (Self, ConversationSrc) { + let (sink, src) = tokio::sync::watch::channel(vec![]); + (Self { + sink + }, src) + } + + pub fn insert(&mut self, entry: ConversationEntry) { + self.sink.send_modify(|contents| { + contents.push(entry) + }); + } +} + +pub async fn start_conversation() -> (ConversationSrc, ConversationSink) { + let (raw_sink, mut raw_src) = tokio::sync::mpsc::unbounded_channel(); + let (mut runner, src) = ConversationRunner::new(); + + tokio::spawn(async move { + while let Some(evt) = raw_src.recv().await { + runner.insert(evt); + } + }); + + (src, raw_sink) } \ No newline at end of file diff --git a/src/scene/mod.rs b/src/scene/mod.rs index 87659e2..a467c8a 100644 --- a/src/scene/mod.rs +++ b/src/scene/mod.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Duration, Utc}; use serde::{Deserialize, Serialize}; -use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry}; +use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}}; pub mod conversation; @@ -43,17 +43,15 @@ pub struct Scenery { #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct Scene { reply_options: GeneratedResponses, - conversation: Vec, direction: StageDirection, pub tokens_consumed: usize, scenery: Scenery } impl Scene { - pub fn new(reply_options: GeneratedResponses, conversation: Vec, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self { + pub fn new(reply_options: GeneratedResponses, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self { Self { reply_options, - conversation, scenery, tokens_consumed, direction @@ -67,10 +65,6 @@ impl Scene { pub fn scenery(&self) -> &Scenery { &self.scenery } - - pub fn conversation(&self) -> &Vec { - &self.conversation - } pub fn reply_options(&self) -> &Vec { &self.reply_options.responses diff --git a/src/ui.rs b/src/ui.rs index ef1b3cb..042525f 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -26,7 +26,8 @@ pub struct Ui { transcription: TranscriptionControl, audio: AudioInputControl, tts: TtsControl, - predictions: SessionControl + predictions: SessionControl, + conversation: Vec } #[derive(Debug)] @@ -51,7 +52,8 @@ impl Ui { focus_state: FocusState::UserInput, tts, predictions, - last_tick: Instant::now() + last_tick: Instant::now(), + conversation: vec![] } } @@ -114,7 +116,7 @@ impl Ui { .constraints([Constraint::Fill(4), Constraint::Fill(1)]) .split(layout[0]); - frame.render_stateful_widget(Conversation(self.scene.conversation()), scene_layout[0], &mut self.conversation_state); + frame.render_stateful_widget(Conversation(&self.conversation), scene_layout[0], &mut self.conversation_state); self.draw_narration(frame, scene_layout[1]); self.draw_options(frame, layout[1]); @@ -198,7 +200,7 @@ impl Ui { KeyCode::Up => self.conversation_state.select_next(), KeyCode::Enter => { let row_num = self.conversation_state.selected().unwrap(); - if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] { + if let ConversationEntry::Eva(text) = &self.conversation[self.conversation.len() - 1 - row_num] { self.tts.speak(text.clone()).await; self.focus_state = FocusState::UserInput; self.conversation_state.select(None); @@ -264,7 +266,10 @@ impl Ui { SessionUpdate::Scene(scene) => { self.scene = scene; self.reply_state.select_first(); - } + }, + SessionUpdate::Conversation(conversation) => { + self.conversation = conversation; + }, } }, next_volume = self.audio.next() => {