prediction: split out maintenance (and thereby the logging interface) of the UI conversation to a separate task, so log::* can work in realtime.
This commit is contained in:
+6
-6
@@ -10,7 +10,7 @@ use futures::StreamExt;
|
|||||||
|
|
||||||
use ratatui::prelude::*;
|
use ratatui::prelude::*;
|
||||||
|
|
||||||
use crate::{audio::start_audio_input, scene::{Scenery, StageDirection}, tts::start_tts, ui::Ui};
|
use crate::{audio::start_audio_input, scene::{Scenery, StageDirection, conversation::{ConversationEntry, start_conversation}}, tts::start_tts, ui::Ui};
|
||||||
|
|
||||||
mod scene;
|
mod scene;
|
||||||
mod events;
|
mod events;
|
||||||
@@ -66,7 +66,7 @@ impl SaveData {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<String>>, Mutex<T>);
|
struct SysMessageLogger<T>(Arc<tokio::sync::mpsc::UnboundedSender<ConversationEntry>>, Mutex<T>);
|
||||||
|
|
||||||
impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
|
impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
|
||||||
fn enabled(&self, _metadata: &log::Metadata) -> bool {
|
fn enabled(&self, _metadata: &log::Metadata) -> bool {
|
||||||
@@ -79,7 +79,7 @@ impl<T: std::io::Write + Send + Sync> log::Log for SysMessageLogger<T> {
|
|||||||
let msg = format!("{}", record.args());
|
let msg = format!("{}", record.args());
|
||||||
write!(self.1.lock().unwrap(), "{}\n", msg).unwrap();
|
write!(self.1.lock().unwrap(), "{}\n", msg).unwrap();
|
||||||
if record.level() <= LevelFilter::Info {
|
if record.level() <= LevelFilter::Info {
|
||||||
self.0.send(msg).unwrap();
|
self.0.send(ConversationEntry::SystemMessage(msg)).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -97,10 +97,10 @@ async fn main() {
|
|||||||
println!("Panic: {}", msg);
|
println!("Panic: {}", msg);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::unbounded_channel();
|
let (conversation_src, conversation_sink) = start_conversation().await;
|
||||||
|
|
||||||
static LOGGER: StaticCell<SysMessageLogger<std::fs::File>> = StaticCell::new();
|
static LOGGER: StaticCell<SysMessageLogger<std::fs::File>> = StaticCell::new();
|
||||||
let logger = LOGGER.init(SysMessageLogger(Arc::new(sys_message_sink), Mutex::new(std::fs::File::create("out.log").unwrap())));
|
let logger = LOGGER.init(SysMessageLogger(Arc::new(conversation_sink.clone()), Mutex::new(std::fs::File::create("out.log").unwrap())));
|
||||||
log::set_logger(logger).unwrap();
|
log::set_logger(logger).unwrap();
|
||||||
log::set_max_level(log::LevelFilter::Debug);
|
log::set_max_level(log::LevelFilter::Debug);
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ async fn main() {
|
|||||||
SaveData::default()
|
SaveData::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await;
|
let prediction_ctrl = prediction::start_prediction(saved_session, conversation_src, conversation_sink).await;
|
||||||
let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await;
|
let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await;
|
||||||
let tts_ctrl = start_tts(tts_output).await;
|
let tts_ctrl = start_tts(tts_output).await;
|
||||||
let transcription_ctrl = transcription::start_transcription(mic_stream).await;
|
let transcription_ctrl = transcription::start_transcription(mic_stream).await;
|
||||||
|
|||||||
+31
-31
@@ -1,13 +1,13 @@
|
|||||||
use std::{fmt::Debug, sync::Arc};
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}};
|
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use schemars::{JsonSchema, schema_for};
|
use schemars::{JsonSchema, schema_for};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{Serializer, ser::CompactFormatter};
|
use serde_json::{Serializer, ser::CompactFormatter};
|
||||||
use tokio::sync::{RwLock, mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
|
|
||||||
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}};
|
use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::{ConversationEntry, ConversationSink, ConversationSrc}}};
|
||||||
|
|
||||||
|
|
||||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||||
@@ -36,7 +36,7 @@ pub struct GeneratedResponses {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Session {
|
struct Session {
|
||||||
client: Client<OpenAIConfig>,
|
client: Client<OpenAIConfig>,
|
||||||
conversation: Vec<ConversationEntry>,
|
conversation: ConversationSink,
|
||||||
header_message: ChatCompletionRequestMessage,
|
header_message: ChatCompletionRequestMessage,
|
||||||
messages: Vec<ChatCompletionRequestMessage>,
|
messages: Vec<ChatCompletionRequestMessage>,
|
||||||
reply_options: GeneratedResponses,
|
reply_options: GeneratedResponses,
|
||||||
@@ -66,11 +66,11 @@ struct ToolResults {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
fn new(scene_sink: watch::Sender<Scene>, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
fn new(scene_sink: watch::Sender<Scene>, conversation: ConversationSink, messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
||||||
let mut conversation = vec![];
|
|
||||||
for msg in &messages {
|
for msg in &messages {
|
||||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||||
conversation.push(conversation_msg);
|
// FIXME
|
||||||
|
conversation.send(conversation_msg).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,6 +277,10 @@ impl Session {
|
|||||||
"query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await,
|
"query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await,
|
||||||
_ => unreachable!()
|
_ => unreachable!()
|
||||||
};
|
};
|
||||||
|
// Push tool output messages directly into the conversation as fast as we can
|
||||||
|
for message in &tool_result.messages {
|
||||||
|
self.conversation.send(message.clone()).unwrap();
|
||||||
|
}
|
||||||
results.push((&call.id, tool_result));
|
results.push((&call.id, tool_result));
|
||||||
},
|
},
|
||||||
_ => panic!("Unknown tool was called")
|
_ => panic!("Unknown tool was called")
|
||||||
@@ -293,8 +297,11 @@ impl Session {
|
|||||||
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
self.messages.push(ChatCompletionRequestMessage::Tool(msg.build().unwrap()));
|
||||||
tool_messages.append(&mut result.messages);
|
tool_messages.append(&mut result.messages);
|
||||||
}
|
}
|
||||||
for msg in tool_messages {
|
// OpenAI requires we put all the tool call results before any other message, so we append them manually down here
|
||||||
self.insert_conversation(msg);
|
for message in tool_messages {
|
||||||
|
if let Ok(next_msg) = message.clone().try_into() {
|
||||||
|
self.messages.push(next_msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(content) = message.message.content.as_ref() {
|
if let Some(content) = message.message.content.as_ref() {
|
||||||
@@ -317,17 +324,15 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn as_scene(&self) -> Scene {
|
fn as_scene(&self) -> Scene {
|
||||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
Scene::new(self.reply_options.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||||
self.conversation.push(entry.clone());
|
self.conversation.send(entry.clone()).unwrap();
|
||||||
|
|
||||||
if let Ok(next_msg) = entry.try_into() {
|
if let Ok(next_msg) = entry.try_into() {
|
||||||
self.messages.push(next_msg);
|
self.messages.push(next_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.refresh();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn refresh(&self) {
|
fn refresh(&self) {
|
||||||
@@ -339,13 +344,15 @@ impl Session {
|
|||||||
pub struct SessionControl {
|
pub struct SessionControl {
|
||||||
event_sink: mpsc::Sender<PredictionAction>,
|
event_sink: mpsc::Sender<PredictionAction>,
|
||||||
scene_watch: watch::Receiver<Scene>,
|
scene_watch: watch::Receiver<Scene>,
|
||||||
activity_watch: watch::Receiver<bool>
|
activity_watch: watch::Receiver<bool>,
|
||||||
|
conversation_watch: ConversationSrc
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum SessionUpdate {
|
pub enum SessionUpdate {
|
||||||
Scene(Scene),
|
Scene(Scene),
|
||||||
Thinking(bool)
|
Thinking(bool),
|
||||||
|
Conversation(Vec<ConversationEntry>)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionControl {
|
impl SessionControl {
|
||||||
@@ -364,40 +371,32 @@ impl SessionControl {
|
|||||||
},
|
},
|
||||||
_ = self.scene_watch.changed() => {
|
_ = self.scene_watch.changed() => {
|
||||||
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
||||||
|
},
|
||||||
|
_ = self.conversation_watch.changed() => {
|
||||||
|
SessionUpdate::Conversation(self.conversation_watch.borrow_and_update().clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver<String>) -> SessionControl {
|
pub async fn start_prediction(saved_session: SaveData, conversation_src: ConversationSrc, conversations: ConversationSink) -> SessionControl {
|
||||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||||
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
||||||
|
|
||||||
let (action_sink, mut action_src) = mpsc::channel(5);
|
let (action_sink, mut action_src) = mpsc::channel(5);
|
||||||
|
|
||||||
let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
let mut session = Session::new(prediction_in, conversations, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
||||||
|
|
||||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||||
session.refresh();
|
session.refresh();
|
||||||
|
|
||||||
let shared_session = Arc::new(RwLock::new(session));
|
|
||||||
|
|
||||||
let log_session = Arc::clone(&shared_session);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
loop {
|
|
||||||
if let Some(msg) = messages.recv().await {
|
|
||||||
log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
if let Some(evt) = action_src.recv().await {
|
if let Some(evt) = action_src.recv().await {
|
||||||
shared_session.write().await.on_event(evt).await;
|
session.on_event(evt).await;
|
||||||
// Commit in a separate unlock operation, so the logging task has time to write messages into the conversation
|
// Commit in a separate unlock operation, so the logging task has time to write messages into the conversation
|
||||||
// FIXME: The conversation we see in the UI really needs to go to another task.
|
// FIXME: The conversation we see in the UI really needs to go to another task.
|
||||||
shared_session.write().await.commit().await;
|
session.commit().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -405,6 +404,7 @@ pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync
|
|||||||
SessionControl {
|
SessionControl {
|
||||||
event_sink: action_sink,
|
event_sink: action_sink,
|
||||||
scene_watch: prediction_out,
|
scene_watch: prediction_out,
|
||||||
activity_watch: activity_notify_src
|
activity_watch: activity_notify_src,
|
||||||
|
conversation_watch: conversation_src
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,4 +85,39 @@ impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConversationRunner {
|
||||||
|
sink: tokio::sync::watch::Sender<Vec<ConversationEntry>>
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ConversationSrc = tokio::sync::watch::Receiver<Vec<ConversationEntry>>;
|
||||||
|
pub type ConversationSink = tokio::sync::mpsc::UnboundedSender<ConversationEntry>;
|
||||||
|
|
||||||
|
impl ConversationRunner {
|
||||||
|
pub fn new() -> (Self, ConversationSrc) {
|
||||||
|
let (sink, src) = tokio::sync::watch::channel(vec![]);
|
||||||
|
(Self {
|
||||||
|
sink
|
||||||
|
}, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, entry: ConversationEntry) {
|
||||||
|
self.sink.send_modify(|contents| {
|
||||||
|
contents.push(entry)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_conversation() -> (ConversationSrc, ConversationSink) {
|
||||||
|
let (raw_sink, mut raw_src) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let (mut runner, src) = ConversationRunner::new();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(evt) = raw_src.recv().await {
|
||||||
|
runner.insert(evt);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(src, raw_sink)
|
||||||
}
|
}
|
||||||
+2
-8
@@ -1,7 +1,7 @@
|
|||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry};
|
use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}};
|
||||||
|
|
||||||
pub mod conversation;
|
pub mod conversation;
|
||||||
|
|
||||||
@@ -43,17 +43,15 @@ pub struct Scenery {
|
|||||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||||
pub struct Scene {
|
pub struct Scene {
|
||||||
reply_options: GeneratedResponses,
|
reply_options: GeneratedResponses,
|
||||||
conversation: Vec<ConversationEntry>,
|
|
||||||
direction: StageDirection,
|
direction: StageDirection,
|
||||||
pub tokens_consumed: usize,
|
pub tokens_consumed: usize,
|
||||||
scenery: Scenery
|
scenery: Scenery
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Scene {
|
impl Scene {
|
||||||
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
|
pub fn new(reply_options: GeneratedResponses, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
|
||||||
Self {
|
Self {
|
||||||
reply_options,
|
reply_options,
|
||||||
conversation,
|
|
||||||
scenery,
|
scenery,
|
||||||
tokens_consumed,
|
tokens_consumed,
|
||||||
direction
|
direction
|
||||||
@@ -67,10 +65,6 @@ impl Scene {
|
|||||||
pub fn scenery(&self) -> &Scenery {
|
pub fn scenery(&self) -> &Scenery {
|
||||||
&self.scenery
|
&self.scenery
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn conversation(&self) -> &Vec<ConversationEntry> {
|
|
||||||
&self.conversation
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reply_options(&self) -> &Vec<PossibleResponse> {
|
pub fn reply_options(&self) -> &Vec<PossibleResponse> {
|
||||||
&self.reply_options.responses
|
&self.reply_options.responses
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ pub struct Ui {
|
|||||||
transcription: TranscriptionControl,
|
transcription: TranscriptionControl,
|
||||||
audio: AudioInputControl,
|
audio: AudioInputControl,
|
||||||
tts: TtsControl,
|
tts: TtsControl,
|
||||||
predictions: SessionControl
|
predictions: SessionControl,
|
||||||
|
conversation: Vec<ConversationEntry>
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -51,7 +52,8 @@ impl Ui {
|
|||||||
focus_state: FocusState::UserInput,
|
focus_state: FocusState::UserInput,
|
||||||
tts,
|
tts,
|
||||||
predictions,
|
predictions,
|
||||||
last_tick: Instant::now()
|
last_tick: Instant::now(),
|
||||||
|
conversation: vec![]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +116,7 @@ impl Ui {
|
|||||||
.constraints([Constraint::Fill(4), Constraint::Fill(1)])
|
.constraints([Constraint::Fill(4), Constraint::Fill(1)])
|
||||||
.split(layout[0]);
|
.split(layout[0]);
|
||||||
|
|
||||||
frame.render_stateful_widget(Conversation(self.scene.conversation()), scene_layout[0], &mut self.conversation_state);
|
frame.render_stateful_widget(Conversation(&self.conversation), scene_layout[0], &mut self.conversation_state);
|
||||||
self.draw_narration(frame, scene_layout[1]);
|
self.draw_narration(frame, scene_layout[1]);
|
||||||
self.draw_options(frame, layout[1]);
|
self.draw_options(frame, layout[1]);
|
||||||
|
|
||||||
@@ -198,7 +200,7 @@ impl Ui {
|
|||||||
KeyCode::Up => self.conversation_state.select_next(),
|
KeyCode::Up => self.conversation_state.select_next(),
|
||||||
KeyCode::Enter => {
|
KeyCode::Enter => {
|
||||||
let row_num = self.conversation_state.selected().unwrap();
|
let row_num = self.conversation_state.selected().unwrap();
|
||||||
if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] {
|
if let ConversationEntry::Eva(text) = &self.conversation[self.conversation.len() - 1 - row_num] {
|
||||||
self.tts.speak(text.clone()).await;
|
self.tts.speak(text.clone()).await;
|
||||||
self.focus_state = FocusState::UserInput;
|
self.focus_state = FocusState::UserInput;
|
||||||
self.conversation_state.select(None);
|
self.conversation_state.select(None);
|
||||||
@@ -264,7 +266,10 @@ impl Ui {
|
|||||||
SessionUpdate::Scene(scene) => {
|
SessionUpdate::Scene(scene) => {
|
||||||
self.scene = scene;
|
self.scene = scene;
|
||||||
self.reply_state.select_first();
|
self.reply_state.select_first();
|
||||||
}
|
},
|
||||||
|
SessionUpdate::Conversation(conversation) => {
|
||||||
|
self.conversation = conversation;
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
next_volume = self.audio.next() => {
|
next_volume = self.audio.next() => {
|
||||||
|
|||||||
Reference in New Issue
Block a user