prediction: split out maintenance (and thereby the logging interface) of the UI conversation to a separate task, so log::* can work in realtime.

This commit is contained in:
2026-06-17 22:16:19 +02:00
parent a8a44dae63
commit 89125d2def
5 changed files with 84 additions and 50 deletions
+6 -6
View File
@@ -10,7 +10,7 @@ use futures::StreamExt;
use ratatui::prelude::*; use ratatui::prelude::*;
use crate::{audio::start_audio_input, scene::{Scenery, StageDirection}, tts::start_tts, ui::Ui}; use crate::{audio::start_audio_input, scene::{Scenery, StageDirection, conversation::{ConversationEntry, start_conversation}}, tts::start_tts, ui::Ui};
mod scene; mod scene;
mod events; mod events;
@@ -66,7 +66,7 @@ impl SaveData {
} }
} }
struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<String>>, Mutex<T>); struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<ConversationEntry>>, Mutex<T>);
impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> { impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
fn enabled(&self, _metadata: &log::Metadata) -> bool { fn enabled(&self, _metadata: &log::Metadata) -> bool {
@@ -79,7 +79,7 @@ impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
let msg = format!("{}", record.args()); let msg = format!("{}", record.args());
write!(self.1.lock().unwrap(), "{}\n", msg).unwrap(); write!(self.1.lock().unwrap(), "{}\n", msg).unwrap();
if record.level() <= LevelFilter::Info { if record.level() <= LevelFilter::Info {
self.0.send(msg).unwrap(); self.0.send(ConversationEntry::SystemMessage(msg)).unwrap();
} }
} }
} }
@@ -97,10 +97,10 @@ async fn main() {
println!("Panic: {}", msg); println!("Panic: {}", msg);
})); }));
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::unbounded_channel(); let (conversation_src, conversation_sink) = start_conversation().await;
static LOGGER: StaticCell<SysMessageLogger<std::fs::File>> = StaticCell::new(); static LOGGER: StaticCell<SysMessageLogger<std::fs::File>> = StaticCell::new();
let logger = LOGGER.init(SysMessageLogger(Arc::new(sys_message_sink), Mutex::new(std::fs::File::create("out.log").unwrap()))); let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink.clone()), Mutex::new(std::fs::File::create("out.log").unwrap())));
log::set_logger(logger).unwrap(); log::set_logger(logger).unwrap();
log::set_max_level(log::LevelFilter::Debug); log::set_max_level(log::LevelFilter::Debug);
@@ -129,7 +129,7 @@ async fn main() {
SaveData::default() SaveData::default()
}; };
let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await; let prediction_ctrl = prediction::start_prediction(saved_session, conversation_src, conversation_sink).await;
let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await; let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await;
let tts_ctrl = start_tts(tts_output).await; let tts_ctrl = start_tts(tts_output).await;
let transcription_ctrl = transcription::start_transcription(mic_stream).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await;
+31 -31
View File
@@ -1,13 +1,13 @@
use std::{fmt::Debug, sync::Arc}; use std::fmt::Debug;
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}}; use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use schemars::{JsonSchema, schema_for}; use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Serializer, ser::CompactFormatter}; use serde_json::{Serializer, ser::CompactFormatter};
use tokio::sync::{RwLock, mpsc, watch}; use tokio::sync::{mpsc, watch};
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}}; use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::{ConversationEntry, ConversationSink, ConversationSrc}}};
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
@@ -36,7 +36,7 @@ pub struct GeneratedResponses {
#[derive(Debug)] #[derive(Debug)]
struct Session { struct Session {
client: Client<OpenAIConfig>, client: Client<OpenAIConfig>,
conversation: Vec<ConversationEntry>, conversation: ConversationSink,
header_message: ChatCompletionRequestMessage, header_message: ChatCompletionRequestMessage,
messages: Vec<ChatCompletionRequestMessage>, messages: Vec<ChatCompletionRequestMessage>,
reply_options: GeneratedResponses, reply_options: GeneratedResponses,
@@ -66,11 +66,11 @@ struct ToolResults {
} }
impl Session { impl Session {
fn new(scene_sink: watch::Sender<Scene>, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self { fn new(scene_sink: watch::Sender<Scene>, conversation: ConversationSink, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
let mut conversation = vec![];
for msg in &messages { for msg in &messages {
if let Ok(conversation_msg) = msg.clone().try_into() { if let Ok(conversation_msg) = msg.clone().try_into() {
conversation.push(conversation_msg); // FIXME
conversation.send(conversation_msg).unwrap();
} }
} }
@@ -277,6 +277,10 @@ impl Session {
"query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await,
_ => unreachable!() _ => unreachable!()
}; };
// Push tool output messages directly into the conversation as fast as we can
for message in &tool_result.messages {
self.conversation.send(message.clone()).unwrap();
}
results.push((&call.id, tool_result)); results.push((&call.id, tool_result));
}, },
_ => panic!("Unknown tool was called") _ => panic!("Unknown tool was called")
@@ -293,8 +297,11 @@ impl Session {
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap())); self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
tool_messages.append(&mut result.messages); tool_messages.append(&mut result.messages);
} }
for msg in tool_messages { // OpenAI requires we put all the tool call results before any other message, so we append them manually down here
self.insert_conversation(msg); for message in tool_messages {
if let Ok(next_msg) = message.clone().try_into() {
self.messages.push(next_msg);
}
} }
} }
if let Some(content) = message.message.content.as_ref() { if let Some(content) = message.message.content.as_ref() {
@@ -317,17 +324,15 @@ impl Session {
} }
fn as_scene(&self) -> Scene { fn as_scene(&self) -> Scene {
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) Scene::new(self.reply_options.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
} }
fn insert_conversation(&mut self, entry: ConversationEntry) { fn insert_conversation(&mut self, entry: ConversationEntry) {
self.conversation.push(entry.clone()); self.conversation.send(entry.clone()).unwrap();
if let Ok(next_msg) = entry.try_into() { if let Ok(next_msg) = entry.try_into() {
self.messages.push(next_msg); self.messages.push(next_msg);
} }
self.refresh();
} }
fn refresh(&self) { fn refresh(&self) {
@@ -339,13 +344,15 @@ impl Session {
pub struct SessionControl { pub struct SessionControl {
event_sink: mpsc::Sender<PredictionAction>, event_sink: mpsc::Sender<PredictionAction>,
scene_watch: watch::Receiver<Scene>, scene_watch: watch::Receiver<Scene>,
activity_watch: watch::Receiver<bool> activity_watch: watch::Receiver<bool>,
conversation_watch: ConversationSrc
} }
#[derive(Debug)] #[derive(Debug)]
pub enum SessionUpdate { pub enum SessionUpdate {
Scene(Scene), Scene(Scene),
Thinking(bool) Thinking(bool),
Conversation(Vec<ConversationEntry>)
} }
impl SessionControl { impl SessionControl {
@@ -364,40 +371,32 @@ impl SessionControl {
}, },
_ = self.scene_watch.changed() => { _ = self.scene_watch.changed() => {
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone()) SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
},
_ = self.conversation_watch.changed() => {
SessionUpdate::Conversation(self.conversation_watch.borrow_and_update().clone())
} }
} }
} }
} }
pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver<String>) -> SessionControl { pub async fn start_prediction(saved_session: SaveData, conversation_src: ConversationSrc, conversations: ConversationSink) -> SessionControl {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false); let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
let (action_sink, mut action_src) = mpsc::channel(5); let (action_sink, mut action_src) = mpsc::channel(5);
let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); let mut session = Session::new(prediction_in, conversations, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
// Send the initial scene to the UI, after we have loaded the session from the first messages. // Send the initial scene to the UI, after we have loaded the session from the first messages.
session.refresh(); session.refresh();
let shared_session = Arc::new(RwLock::new(session));
let log_session = Arc::clone(&shared_session);
tokio::spawn(async move {
loop {
if let Some(msg) = messages.recv().await {
log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg));
}
}
});
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
if let Some(evt) = action_src.recv().await { if let Some(evt) = action_src.recv().await {
shared_session.write().await.on_event(evt).await; session.on_event(evt).await;
// Commit in a separate unlock operation, so the logging task has time to write messages into the conversation // Commit in a separate unlock operation, so the logging task has time to write messages into the conversation
// FIXME: The conversation we see in the UI really needs to go to another task. // FIXME: The conversation we see in the UI really needs to go to another task.
shared_session.write().await.commit().await; session.commit().await;
} }
} }
}); });
@@ -405,6 +404,7 @@ pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync
SessionControl { SessionControl {
event_sink: action_sink, event_sink: action_sink,
scene_watch: prediction_out, scene_watch: prediction_out,
activity_watch: activity_notify_src activity_watch: activity_notify_src,
conversation_watch: conversation_src
} }
} }
+35
View File
@@ -86,3 +86,38 @@ impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
type Error = (); type Error = ();
} }
pub struct ConversationRunner {
sink: tokio::sync::watch::Sender<Vec<ConversationEntry>>
}
pub type ConversationSrc = tokio::sync::watch::Receiver<Vec<ConversationEntry>>;
pub type ConversationSink = tokio::sync::mpsc::UnboundedSender<ConversationEntry>;
impl ConversationRunner {
pub fn new() -> (Self, ConversationSrc) {
let (sink, src) = tokio::sync::watch::channel(vec![]);
(Self {
sink
}, src)
}
pub fn insert(&mut self, entry: ConversationEntry) {
self.sink.send_modify(|contents| {
contents.push(entry)
});
}
}
pub async fn start_conversation() -> (ConversationSrc, ConversationSink) {
let (raw_sink, mut raw_src) = tokio::sync::mpsc::unbounded_channel();
let (mut runner, src) = ConversationRunner::new();
tokio::spawn(async move {
while let Some(evt) = raw_src.recv().await {
runner.insert(evt);
}
});
(src, raw_sink)
}
+2 -8
View File
@@ -1,7 +1,7 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry}; use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}};
pub mod conversation; pub mod conversation;
@@ -43,17 +43,15 @@ pub struct Scenery {
#[derive(Debug, Default, Clone, Serialize, Deserialize)] #[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct Scene { pub struct Scene {
reply_options: GeneratedResponses, reply_options: GeneratedResponses,
conversation: Vec<ConversationEntry>,
direction: StageDirection, direction: StageDirection,
pub tokens_consumed: usize, pub tokens_consumed: usize,
scenery: Scenery scenery: Scenery
} }
impl Scene { impl Scene {
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self { pub fn new(reply_options: GeneratedResponses, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
Self { Self {
reply_options, reply_options,
conversation,
scenery, scenery,
tokens_consumed, tokens_consumed,
direction direction
@@ -68,10 +66,6 @@ impl Scene {
&self.scenery &self.scenery
} }
pub fn conversation(&self) -> &Vec<ConversationEntry> {
&self.conversation
}
pub fn reply_options(&self) -> &Vec<PossibleResponse> { pub fn reply_options(&self) -> &Vec<PossibleResponse> {
&self.reply_options.responses &self.reply_options.responses
} }
+10 -5
View File
@@ -26,7 +26,8 @@ pub struct Ui {
transcription: TranscriptionControl, transcription: TranscriptionControl,
audio: AudioInputControl, audio: AudioInputControl,
tts: TtsControl, tts: TtsControl,
predictions: SessionControl predictions: SessionControl,
conversation: Vec<ConversationEntry>
} }
#[derive(Debug)] #[derive(Debug)]
@@ -51,7 +52,8 @@ impl Ui {
focus_state: FocusState::UserInput, focus_state: FocusState::UserInput,
tts, tts,
predictions, predictions,
last_tick: Instant::now() last_tick: Instant::now(),
conversation: vec![]
} }
} }
@@ -114,7 +116,7 @@ impl Ui {
.constraints([Constraint::Fill(4), Constraint::Fill(1)]) .constraints([Constraint::Fill(4), Constraint::Fill(1)])
.split(layout[0]); .split(layout[0]);
frame.render_stateful_widget(Conversation(self.scene.conversation()), scene_layout[0], &mut self.conversation_state); frame.render_stateful_widget(Conversation(&self.conversation), scene_layout[0], &mut self.conversation_state);
self.draw_narration(frame, scene_layout[1]); self.draw_narration(frame, scene_layout[1]);
self.draw_options(frame, layout[1]); self.draw_options(frame, layout[1]);
@@ -198,7 +200,7 @@ impl Ui {
KeyCode::Up => self.conversation_state.select_next(), KeyCode::Up => self.conversation_state.select_next(),
KeyCode::Enter => { KeyCode::Enter => {
let row_num = self.conversation_state.selected().unwrap(); let row_num = self.conversation_state.selected().unwrap();
if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] { if let ConversationEntry::Eva(text) = &self.conversation[self.conversation.len() - 1 - row_num] {
self.tts.speak(text.clone()).await; self.tts.speak(text.clone()).await;
self.focus_state = FocusState::UserInput; self.focus_state = FocusState::UserInput;
self.conversation_state.select(None); self.conversation_state.select(None);
@@ -264,7 +266,10 @@ impl Ui {
SessionUpdate::Scene(scene) => { SessionUpdate::Scene(scene) => {
self.scene = scene; self.scene = scene;
self.reply_state.select_first(); self.reply_state.select_first();
} },
SessionUpdate::Conversation(conversation) => {
self.conversation = conversation;
},
} }
}, },
next_volume = self.audio.next() => { next_volume = self.audio.next() => {