diff --git a/Cargo.lock b/Cargo.lock index 68875cd..09dbf5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1000,6 +1000,7 @@ dependencies = [ "iref 4.0.0", "jack", "json-ld", + "log", "minify", "oximedia-metering", "ratatui", @@ -1013,6 +1014,7 @@ dependencies = [ "serde_json", "sqlite", "static-iref", + "static_cell", "tempfile", "textwrap", "throbber-widgets-tui", @@ -2293,9 +2295,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.30" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "lru" @@ -4085,6 +4087,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "static_cell" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0530892bb4fa575ee0da4b86f86c667132a94b74bb72160f58ee5a4afec74c23" +dependencies = [ + "portable-atomic", +] + [[package]] name = "str-newtype" version = "2.0.0" diff --git a/Cargo.toml b/Cargo.toml index a6291cf..1f88870 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ hound = "3.5.1" iref = { version = "4.0.0", features = ["url", "serde"] } jack = "0.13.5" json-ld = { version = "0.21.4", features = ["reqwest", "serde"] } +log = "0.4.32" minify = "1.3.0" oximedia-metering = "0.1.7" ratatui = "0.30.0" @@ -30,6 +31,7 @@ serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.150" sqlite = "0.37.0" static-iref = "3.0.0" +static_cell = "2.1.1" tempfile = "3.27.0" textwrap = "0.16.2" throbber-widgets-tui = "0.11.0" diff --git a/src/audio.rs b/src/audio.rs index 996af5b..0546fa9 100644 --- a/src/audio.rs +++ b/src/audio.rs @@ -60,7 +60,6 @@ struct Notify { config: AudioConfig, mic_port: jack::Port, tts_port: jack::Port, - log: mpsc::Sender } impl NotificationHandler for Notify { @@ -84,10 +83,10 @@ impl NotificationHandler for Notify { if let Ok(port_name) = other_port.name() { if are_connected { - self.log.blocking_send(format!("{} connected to {}", stream_name, port_name)).unwrap(); + log::info!("{} connected to {}", stream_name, port_name); target_cfg.push(port_name); } else { - self.log.blocking_send(format!("{} disconnected from {}", stream_name, port_name)).unwrap(); + log::info!("{} disconnected from {}", stream_name, port_name); target_cfg.retain(|x| { x != &port_name} ); } @@ -97,7 +96,7 @@ impl NotificationHandler for Notify { } } -pub async fn start_audio_input(messages: &mpsc::Sender) -> (AudioInputControl, MicStream, TtsOutStream) { +pub async fn start_audio_input() -> (AudioInputControl, MicStream, TtsOutStream) { let (exit_tx, exit_rx) = oneshot::channel(); @@ -117,17 +116,17 @@ pub async fn start_audio_input(messages: &mpsc::Sender) -> (AudioInputCo for port in &config.mic_in_connections { if let Ok(_) = client.connect_ports_by_name(&port, &mic_name) { - messages.send(format!("Connected mic to {}", port)).await.unwrap(); + log::info!("Connected mic to {}", port); } else { - messages.send(format!("Failed to reconnect mic to {}.", port)).await.unwrap(); + log::warn!("Failed to reconnect mic to {}.", port); } } for port in &config.tts_out_connections { if let Ok(_) = client.connect_ports_by_name(&tts_name, &port) { - messages.send(format!("Connected TTS output to {}", port)).await.unwrap(); + log::info!("Connected TTS output to {}", port); } else { - messages.send(format!("Failed to reconnect TTS output to {}.", port)).await.unwrap(); + log::warn!("Failed to reconnect TTS output to {}.", port); } } @@ -135,7 +134,6 @@ pub async fn start_audio_input(messages: &mpsc::Sender) -> (AudioInputCo config, mic_port: mic_port.clone_unowned(), tts_port: tts_port.clone_unowned(), - log: messages.clone() }; let mut meter = VuMeter::new(rate.into(), 1, None); diff --git a/src/main.rs b/src/main.rs index 8d64328..3046b56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,15 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + use async_openai::types::chat::ChatCompletionRequestMessage; use chrono::{Duration, Utc}; use futures_timer::Delay; use serde::{Deserialize, Serialize}; use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListState, Paragraph, Wrap}}; +use static_cell::StaticCell; use throbber_widgets_tui::{Throbber, ThrobberState}; use crossterm::{event::{self, EventStream, KeyCode, KeyModifiers}}; -use tokio::time::Instant; +use tokio::{sync::RwLock, time::Instant}; use tui_input::{Input, backend::crossterm::EventHandler}; use futures::{StreamExt, future::FutureExt}; @@ -492,6 +495,20 @@ impl SaveData { } } +struct SysMessageLogger(Arc>); + +impl log::Log for SysMessageLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + true + } + + fn flush(&self) {} + + fn log(&self, record: &log::Record) { + self.0.send(format!("{}", record.args())).unwrap(); + } +} + #[tokio::main] async fn main() { let (panic_hook, eyre_hook) = color_eyre::config::HookBuilder::default() @@ -505,6 +522,13 @@ async fn main() { println!("Panic: {}", msg); })); + let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::unbounded_channel(); + + static LOGGER: StaticCell = StaticCell::new(); + let logger = LOGGER.init(SysMessageLogger(Arc::new(sys_message_sink))); + log::set_logger(logger).unwrap(); + log::set_max_level(log::LevelFilter::Info); + dotenv::dotenv().ok(); if std::env::var("OPENAI_API_KEY").is_err() { @@ -514,23 +538,21 @@ async fn main() { let mut terminal: Terminal> = ratatui::init(); - let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32); - let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") { if let Ok(ret) = serde_json::from_str(&save_data) { - sys_message_sink.send("Loaded session from save.json".into()).await.unwrap(); + log::info!("Loaded session from save.json"); ret } else { - sys_message_sink.send("Could not load saved session!".into()).await.unwrap(); + log::warn!("Could not load saved session!"); SaveData::default() } } else { - sys_message_sink.send("Creating new session in save.json".into()).await.unwrap(); + log::info!("Creating new session in save.json"); SaveData::default() }; - let prediction_ctrl = prediction::start_prediction(saved_session).await; - let (audio_ctrl, mic_stream, tts_output) = start_audio_input(&sys_message_sink).await; + let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await; + let (audio_ctrl, mic_stream, tts_output) = start_audio_input().await; let tts_ctrl = start_tts(tts_output).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await; @@ -551,9 +573,6 @@ async fn main() { tokio::select! { _ = delay => (), - Some(next_log) = sys_message_src.recv() => { - app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(next_log))).await; - }, next_update = app.predictions.changed() => { match next_update { SessionUpdate::Thinking(is_thinking) => app.is_requesting = is_thinking, diff --git a/src/prediction.rs b/src/prediction.rs index c43a7ab..bbc7fb5 100644 --- a/src/prediction.rs +++ b/src/prediction.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; use bandcamp::SearchResultItem; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::{Serializer, ser::CompactFormatter}; -use tokio::sync::{mpsc, watch}; +use tokio::sync::{RwLock, mpsc, watch}; use crate::{SaveData, archive::BeatsQueryArgs, artifacts::BandcampQueryArgs, scene::{PredictionAction, Scene, Scenery, StageDirection, conversation::ConversationEntry}}; @@ -31,7 +33,8 @@ struct Session { direction: StageDirection, scenery: Scenery, tokens_consumed: usize, - activity_notify: watch::Sender + activity_notify: watch::Sender, + scene_sink: watch::Sender } #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] @@ -52,7 +55,7 @@ struct ToolResults { } impl Session { - fn from_initial_messages(messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { + fn new(scene_sink: watch::Sender, messages: Vec, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender) -> Self { let mut conversation = vec![]; for msg in &messages { if let Ok(conversation_msg) = msg.clone().try_into() { @@ -70,6 +73,7 @@ impl Session { direction, tokens_consumed: 0, activity_notify, + scene_sink } } @@ -154,6 +158,8 @@ impl Session { } async fn regenerate_options(&mut self) { + self.reply_options.responses.clear(); + self.refresh(); self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }}); loop { let full_conversation = self.generate_conversation(&self.direction); @@ -181,6 +187,7 @@ impl Session { .build().unwrap() }) ]; + self.log("Sending request.."); let request = CreateChatCompletionRequestArgs::default() .messages(full_conversation) .model("gpt-5.4-mini") @@ -202,6 +209,7 @@ impl Session { if let Some(usage) = response.usage { self.tokens_consumed += usage.total_tokens as usize; + self.log(format!("{} tokens cast into the void", usage.total_tokens)); } if let Some(message) = response.choices.first() { @@ -260,26 +268,38 @@ impl Session { self.reply_options = options; break; } else { - self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into())); + self.log("Received invalid JSON! Trying again."); } } } else { - self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into())); + self.log("No messages were received! Trying again."); } } self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }}); + + self.refresh(); } fn as_scene(&self) -> Scene { Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone()) } + fn log>(&mut self, msg: T) { + self.insert_conversation(ConversationEntry::SystemMessage(msg.into())); + } + fn insert_conversation(&mut self, entry: ConversationEntry) { self.conversation.push(entry.clone()); if let Ok(next_msg) = entry.try_into() { self.messages.push(next_msg); } + + self.refresh(); + } + + fn refresh(&self) { + self.scene_sink.send(self.as_scene()).unwrap(); } } @@ -321,20 +341,32 @@ impl SessionControl { } } -pub async fn start_prediction(saved_session: SaveData) -> SessionControl { +pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync::mpsc::UnboundedReceiver) -> SessionControl { let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false); let (action_sink, mut action_src) = mpsc::channel(5); - let mut session = Session::from_initial_messages(saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); + let session = Session::new(prediction_in, saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink); // Send the initial scene to the UI, after we have loaded the session from the first messages. - prediction_in.send(session.as_scene()).unwrap(); + session.refresh(); + + let shared_session = Arc::new(RwLock::new(session)); + + let log_session = Arc::clone(&shared_session); + tokio::spawn(async move { + loop { + if let Some(msg) = messages.recv().await { + log_session.write().await.insert_conversation(ConversationEntry::SystemMessage(msg)); + } + } + }); tokio::spawn(async move { loop { if let Some(evt) = action_src.recv().await { + let mut session = shared_session.write().await; let do_regen = match evt { PredictionAction::ConversationAppend(msg) => { let do_regen = match msg { @@ -348,9 +380,9 @@ pub async fn start_prediction(saved_session: SaveData) -> SessionControl { PredictionAction::SetEpisodeNumber(num) => { session.direction.episode_number = num; if let Err(err) = session.direction.reload_mixxx_playlist() { - session.insert_conversation(ConversationEntry::SystemMessage(format!("Failed to load mixxx playlist: {:?}.", err).into())); + session.log(format!("Failed to load mixxx playlist: {:?}.", err)); } else { - session.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into())); + session.log("Mixxx playlist reloaded."); } false }, @@ -359,7 +391,7 @@ pub async fn start_prediction(saved_session: SaveData) -> SessionControl { }, PredictionAction::SetNarrative(narrative) => { session.direction.narrative = narrative; - session.insert_conversation(ConversationEntry::SystemMessage("Updated stage direction narrative".into())); + session.log("Updated stage direction narrative"); true }, PredictionAction::SetShowEndTime(end_time) => { @@ -377,14 +409,8 @@ pub async fn start_prediction(saved_session: SaveData) -> SessionControl { save_data.save(); if do_regen { - session.reply_options.responses.clear(); - } - - prediction_in.send(session.as_scene()).unwrap(); - - if do_regen { - session.regenerate_options().await; - prediction_in.send(session.as_scene()).unwrap(); + drop(session); + shared_session.write().await.regenerate_options().await; } } }