prediction: split out maintenance (and thereby the logging interface) of the UI conversation to a separate task, so log::* can work in realtime.
This commit is contained in:
+6
-6
@@ -10,7 +10,7 @@ use futures::StreamExt;
|
||||
|
||||
use ratatui::prelude::*;
|
||||
|
||||
use crate::{audio::start_audio_input, scene::{Scenery, StageDirection}, tts::start_tts, ui::Ui};
|
||||
use crate::{audio::start_audio_input, scene::{Scenery, StageDirection, conversation::{ConversationEntry, start_conversation}}, tts::start_tts, ui::Ui};
|
||||
|
||||
mod scene;
|
||||
mod events;
|
||||
@@ -66,7 +66,7 @@ impl SaveData {
|
||||
}
|
||||
}
|
||||
|
||||
struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<String>>, Mutex<T>);
|
||||
struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<ConversationEntry>>, Mutex<T>);
|
||||
|
||||
impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
|
||||
fn enabled(&self, _metadata: &log::Metadata) -> bool {
|
||||
@@ -79,7 +79,7 @@ impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
|
||||
let msg = format!("{}", record.args());
|
||||
write!(self.1.lock().unwrap(), "{}\n", msg).unwrap();
|
||||
if record.level() <= LevelFilter::Info {
|
||||
self.0.send(msg).unwrap();
|
||||
self.0.send(ConversationEntry::SystemMessage(msg)).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,10 +97,10 @@ async fn main() {
|
||||
println!("Panic: {}", msg);
|
||||
}));
|
||||
|
||||
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (conversation_src, conversation_sink) = start_conversation().await;
|
||||
|
||||
static LOGGER: StaticCell<SysMessageLogger<std::fs::File>> = StaticCell::new();
|
||||
let logger = LOGGER.init(SysMessageLogger(Arc::new(sys_message_sink), Mutex::new(std::fs::File::create("out.log").unwrap())));
|
||||
let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink.clone()), Mutex::new(std::fs::File::create("out.log").unwrap())));
|
||||
log::set_logger(logger).unwrap();
|
||||
log::set_max_level(log::LevelFilter::Debug);
|
||||
|
||||
@@ -129,7 +129,7 @@ async fn main() {
|
||||
SaveData::default()
|
||||
};
|
||||
|
||||
let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await;
|
||||
let prediction_ctrl = prediction::start_prediction(saved_session, conversation_src, conversation_sink).await;
|
||||
let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await;
|
||||
let tts_ctrl = start_tts(tts_output).await;
|
||||
let transcription_ctrl = transcription::start_transcription(mic_stream).await;
|
||||
|
||||
+31
-31
@@ -1,13 +1,13 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use chrono::{DateTime, Utc};
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Serializer, ser::CompactFormatter};
|
||||
use tokio::sync::{RwLock, mpsc, watch};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}};
|
||||
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::{ConversationEntry, ConversationSink, ConversationSrc}}};
|
||||
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
@@ -36,7 +36,7 @@ pub struct GeneratedResponses {
|
||||
#[derive(Debug)]
|
||||
struct Session {
|
||||
client: Client<OpenAIConfig>,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
conversation: ConversationSink,
|
||||
header_message: ChatCompletionRequestMessage,
|
||||
messages: Vec<ChatCompletionRequestMessage>,
|
||||
reply_options: GeneratedResponses,
|
||||
@@ -66,11 +66,11 @@ struct ToolResults {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn new(scene_sink: watch::Sender<Scene>, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
fn new(scene_sink: watch::Sender<Scene>, conversation: ConversationSink, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
||||
for msg in &messages {
|
||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||
conversation.push(conversation_msg);
|
||||
// FIXME
|
||||
conversation.send(conversation_msg).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,6 +277,10 @@ impl Session {
|
||||
"query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await,
|
||||
_ => unreachable!()
|
||||
};
|
||||
// Push tool output messages directly into the conversation as fast as we can
|
||||
for message in &tool_result.messages {
|
||||
self.conversation.send(message.clone()).unwrap();
|
||||
}
|
||||
results.push((&call.id, tool_result));
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
@@ -293,8 +297,11 @@ impl Session {
|
||||
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
||||
tool_messages.append(&mut result.messages);
|
||||
}
|
||||
for msg in tool_messages {
|
||||
self.insert_conversation(msg);
|
||||
// OpenAI requires we put all the tool call results before any other message, so we append them manually down here
|
||||
for message in tool_messages {
|
||||
if let Ok(next_msg) = message.clone().try_into() {
|
||||
self.messages.push(next_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(content) = message.message.content.as_ref() {
|
||||
@@ -317,17 +324,15 @@ impl Session {
|
||||
}
|
||||
|
||||
fn as_scene(&self) -> Scene {
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
||||
Scene::new(self.reply_options.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
||||
}
|
||||
|
||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
self.conversation.push(entry.clone());
|
||||
self.conversation.send(entry.clone()).unwrap();
|
||||
|
||||
if let Ok(next_msg) = entry.try_into() {
|
||||
self.messages.push(next_msg);
|
||||
}
|
||||
|
||||
self.refresh();
|
||||
}
|
||||
|
||||
fn refresh(&self) {
|
||||
@@ -339,13 +344,15 @@ impl Session {
|
||||
pub struct SessionControl {
|
||||
event_sink: mpsc::Sender<PredictionAction>,
|
||||
scene_watch: watch::Receiver<Scene>,
|
||||
activity_watch: watch::Receiver<bool>
|
||||
activity_watch: watch::Receiver<bool>,
|
||||
conversation_watch: ConversationSrc
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SessionUpdate {
|
||||
Scene(Scene),
|
||||
Thinking(bool)
|
||||
Thinking(bool),
|
||||
Conversation(Vec<ConversationEntry>)
|
||||
}
|
||||
|
||||
impl SessionControl {
|
||||
@@ -364,40 +371,32 @@ impl SessionControl {
|
||||
},
|
||||
_ = self.scene_watch.changed() => {
|
||||
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
||||
},
|
||||
_ = self.conversation_watch.changed() => {
|
||||
SessionUpdate::Conversation(self.conversation_watch.borrow_and_update().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver<String>) -> SessionControl {
|
||||
pub async fn start_prediction(saved_session: SaveData, conversation_src: ConversationSrc, conversations: ConversationSink) -> SessionControl {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
||||
|
||||
let (action_sink, mut action_src) = mpsc::channel(5);
|
||||
|
||||
let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
||||
let mut session = Session::new(prediction_in, conversations, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
||||
|
||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||
session.refresh();
|
||||
|
||||
let shared_session = Arc::new(RwLock::new(session));
|
||||
|
||||
let log_session = Arc::clone(&shared_session);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if let Some(msg) = messages.recv().await {
|
||||
log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if let Some(evt) = action_src.recv().await {
|
||||
shared_session.write().await.on_event(evt).await;
|
||||
session.on_event(evt).await;
|
||||
// Commit in a separate unlock operation, so the logging task has time to write messages into the conversation
|
||||
// FIXME: The conversation we see in the UI really needs to go to another task.
|
||||
shared_session.write().await.commit().await;
|
||||
session.commit().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -405,6 +404,7 @@ pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync
|
||||
SessionControl {
|
||||
event_sink: action_sink,
|
||||
scene_watch: prediction_out,
|
||||
activity_watch: activity_notify_src
|
||||
activity_watch: activity_notify_src,
|
||||
conversation_watch: conversation_src
|
||||
}
|
||||
}
|
||||
@@ -85,4 +85,39 @@ impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
|
||||
}
|
||||
|
||||
type Error = ();
|
||||
}
|
||||
|
||||
pub struct ConversationRunner {
|
||||
sink: tokio::sync::watch::Sender<Vec<ConversationEntry>>
|
||||
}
|
||||
|
||||
pub type ConversationSrc = tokio::sync::watch::Receiver<Vec<ConversationEntry>>;
|
||||
pub type ConversationSink = tokio::sync::mpsc::UnboundedSender<ConversationEntry>;
|
||||
|
||||
impl ConversationRunner {
|
||||
pub fn new() -> (Self, ConversationSrc) {
|
||||
let (sink, src) = tokio::sync::watch::channel(vec![]);
|
||||
(Self {
|
||||
sink
|
||||
}, src)
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, entry: ConversationEntry) {
|
||||
self.sink.send_modify(|contents| {
|
||||
contents.push(entry)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_conversation() -> (ConversationSrc, ConversationSink) {
|
||||
let (raw_sink, mut raw_src) = tokio::sync::mpsc::unbounded_channel();
|
||||
let (mut runner, src) = ConversationRunner::new();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(evt) = raw_src.recv().await {
|
||||
runner.insert(evt);
|
||||
}
|
||||
});
|
||||
|
||||
(src, raw_sink)
|
||||
}
|
||||
+2
-8
@@ -1,7 +1,7 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry};
|
||||
use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}};
|
||||
|
||||
pub mod conversation;
|
||||
|
||||
@@ -43,17 +43,15 @@ pub struct Scenery {
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct Scene {
|
||||
reply_options: GeneratedResponses,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
direction: StageDirection,
|
||||
pub tokens_consumed: usize,
|
||||
scenery: Scenery
|
||||
}
|
||||
|
||||
impl Scene {
|
||||
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
|
||||
pub fn new(reply_options: GeneratedResponses, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
|
||||
Self {
|
||||
reply_options,
|
||||
conversation,
|
||||
scenery,
|
||||
tokens_consumed,
|
||||
direction
|
||||
@@ -67,10 +65,6 @@ impl Scene {
|
||||
pub fn scenery(&self) -> &Scenery {
|
||||
&self.scenery
|
||||
}
|
||||
|
||||
pub fn conversation(&self) -> &Vec<ConversationEntry> {
|
||||
&self.conversation
|
||||
}
|
||||
|
||||
pub fn reply_options(&self) -> &Vec<PossibleResponse> {
|
||||
&self.reply_options.responses
|
||||
|
||||
@@ -26,7 +26,8 @@ pub struct Ui {
|
||||
transcription: TranscriptionControl,
|
||||
audio: AudioInputControl,
|
||||
tts: TtsControl,
|
||||
predictions: SessionControl
|
||||
predictions: SessionControl,
|
||||
conversation: Vec<ConversationEntry>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -51,7 +52,8 @@ impl Ui {
|
||||
focus_state: FocusState::UserInput,
|
||||
tts,
|
||||
predictions,
|
||||
last_tick: Instant::now()
|
||||
last_tick: Instant::now(),
|
||||
conversation: vec![]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +116,7 @@ impl Ui {
|
||||
.constraints([Constraint::Fill(4), Constraint::Fill(1)])
|
||||
.split(layout[0]);
|
||||
|
||||
frame.render_stateful_widget(Conversation(self.scene.conversation()), scene_layout[0], &mut self.conversation_state);
|
||||
frame.render_stateful_widget(Conversation(&self.conversation), scene_layout[0], &mut self.conversation_state);
|
||||
self.draw_narration(frame, scene_layout[1]);
|
||||
self.draw_options(frame, layout[1]);
|
||||
|
||||
@@ -198,7 +200,7 @@ impl Ui {
|
||||
KeyCode::Up => self.conversation_state.select_next(),
|
||||
KeyCode::Enter => {
|
||||
let row_num = self.conversation_state.selected().unwrap();
|
||||
if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] {
|
||||
if let ConversationEntry::Eva(text) = &self.conversation[self.conversation.len() - 1 - row_num] {
|
||||
self.tts.speak(text.clone()).await;
|
||||
self.focus_state = FocusState::UserInput;
|
||||
self.conversation_state.select(None);
|
||||
@@ -264,7 +266,10 @@ impl Ui {
|
||||
SessionUpdate::Scene(scene) => {
|
||||
self.scene = scene;
|
||||
self.reply_state.select_first();
|
||||
}
|
||||
},
|
||||
SessionUpdate::Conversation(conversation) => {
|
||||
self.conversation = conversation;
|
||||
},
|
||||
}
|
||||
},
|
||||
next_volume = self.audio.next() => {
|
||||
|
||||
Reference in New Issue
Block a user