main: refactor prediction engine to use an event stream

This commit is contained in:
2026-06-09 19:08:08 +02:00
parent 8394a076d4
commit ad90df7767
3 changed files with 196 additions and 156 deletions
+47 -92
View File
@@ -1,19 +1,19 @@
use async_openai::types::chat::ChatCompletionRequestMessage; use async_openai::types::chat::ChatCompletionRequestMessage;
use chrono::{DateTime, Duration, Utc}; use chrono::{Duration, Utc};
use futures_timer::Delay; use futures_timer::Delay;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListItem, ListState, Paragraph, Wrap}}; use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListState, Paragraph, Wrap}};
use throbber_widgets_tui::{Throbber, ThrobberState}; use throbber_widgets_tui::{Throbber, ThrobberState};
use crossterm::{event::{self, EventStream, KeyCode, KeyModifiers}}; use crossterm::{event::{self, EventStream, KeyCode, KeyModifiers}};
use tokio::{sync::{mpsc, watch}, time::Instant}; use tokio::time::Instant;
use tui_input::{Input, backend::crossterm::EventHandler}; use tui_input::{Input, backend::crossterm::EventHandler};
use futures::{StreamExt, future::FutureExt}; use futures::{StreamExt, future::FutureExt};
use ratatui::prelude::*; use ratatui::prelude::*;
use tui_skeleton::{AnimationMode, SkeletonText}; use tui_skeleton::{AnimationMode, SkeletonText};
use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{BandcampResult, PossibleResponse}, scene::{ConversationEntry, Scene, Scenery, StageActions, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}}; use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{SessionControl, SessionUpdate}, scene::{ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}};
mod scene; mod scene;
mod events; mod events;
@@ -51,26 +51,9 @@ mod audio;
- Right panel: Shortcuts for triggering scenarios, soundboard events, etc - Right panel: Shortcuts for triggering scenarios, soundboard events, etc
*/ */
impl<'a> Into<ListItem<'a>> for PossibleResponse {
fn into(self) -> ListItem<'a> {
if let Some(direction) = self.stage_direction {
Line::from_iter([
//Span::from(format!("({})", direction)).style(ratatui::style::Color::Yellow),
//Span::from(" "),
Span::from(self.text)
]).into()
} else {
Line::from(self.text).into()
}
}
}
#[derive(Debug)] #[derive(Debug)]
struct App { struct App {
scene: Scene, scene: Scene,
direction: StageDirection,
next_actions: Vec<ConversationEntry>,
end_time: DateTime<Utc>,
reply_state: ListState, reply_state: ListState,
conversation_state: ListState, conversation_state: ListState,
@@ -82,10 +65,9 @@ struct App {
focus_state: FocusState, focus_state: FocusState,
transcription: TranscriptionControl, transcription: TranscriptionControl,
prediction_request_sink: watch::Sender<StageActions>,
audio: AudioInputControl, audio: AudioInputControl,
tts: TtsControl, tts: TtsControl,
sys_message_sink: mpsc::Sender<String> predictions: SessionControl
} }
#[derive(Debug)] #[derive(Debug)]
@@ -95,17 +77,13 @@ enum FocusState {
} }
impl App { impl App {
fn new(prediction_request_sink: watch::Sender<StageActions>, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl, sys_message_sink: mpsc::Sender<String>, initial_direction: StageDirection) -> Self { fn new(predictions: SessionControl, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl) -> Self {
Self { Self {
scene: Default::default(), scene: Default::default(),
direction: initial_direction,
next_actions: Default::default(),
reply_state: Default::default(), reply_state: Default::default(),
conversation_state: Default::default(), conversation_state: Default::default(),
user_input: Default::default(), user_input: Default::default(),
end_time: Utc::now() + Duration::hours(2),
throbber_state: Default::default(), throbber_state: Default::default(),
prediction_request_sink,
is_requesting: false, is_requesting: false,
audio_level: -60., audio_level: -60.,
audio, audio,
@@ -113,7 +91,7 @@ impl App {
transcription, transcription,
focus_state: FocusState::UserInput, focus_state: FocusState::UserInput,
tts, tts,
sys_message_sink predictions
} }
} }
@@ -183,7 +161,7 @@ impl App {
fn draw_options(&mut self, frame: &mut Frame, area: Rect) { fn draw_options(&mut self, frame: &mut Frame, area: Rect) {
let borders = Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)"); let borders = Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)");
if self.scene.reply_options().len() == 0 && self.is_requesting { if self.is_requesting {
let list = SkeletonText::new(std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64) let list = SkeletonText::new(std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64)
.braille(true) .braille(true)
.line_widths(&[0.25, 0.5, 0.4, 0.6]) .line_widths(&[0.25, 0.5, 0.4, 0.6])
@@ -242,10 +220,10 @@ impl App {
} }
fn draw_status(&self, frame: &mut Frame, area: Rect) { fn draw_status(&self, frame: &mut Frame, area: Rect) {
let minutes_remaining = self.direction.time_remaining.num_seconds() / 60; let time_remaining: Duration = self.scene.direction.time_remaining();
let minutes_remaining = time_remaining.num_seconds() / 60;
let negative = self.direction.time_remaining.abs() != self.direction.time_remaining;
let negative = time_remaining.abs() != time_remaining;
let time_style = if minutes_remaining <= 0 || negative { let time_style = if minutes_remaining <= 0 || negative {
Style::new().fg(ratatui::style::Color::LightRed).bold() Style::new().fg(ratatui::style::Color::LightRed).bold()
@@ -262,15 +240,15 @@ impl App {
}; };
let formatted_time = if negative { let formatted_time = if negative {
format!("-{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours().abs(), self.direction.time_remaining.num_minutes().abs()% 60, self.direction.time_remaining.num_seconds().abs() % 60) format!("-{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours().abs(), time_remaining.num_minutes().abs()% 60, time_remaining.num_seconds().abs() % 60)
} else { } else {
format!("{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours(), self.direction.time_remaining.num_minutes() % 60, self.direction.time_remaining.num_seconds() % 60) format!("{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours(), time_remaining.num_minutes() % 60, time_remaining.num_seconds() % 60)
}; };
let status_line = Line::from_iter([ let status_line = Line::from_iter([
Span::from(format!("Episode {}", self.direction.episode_number)).style(ratatui::style::Color::LightBlue), Span::from(format!("Episode {}", self.scene.direction.episode_number)).style(ratatui::style::Color::LightBlue),
Span::from(" | ").style(ratatui::style::Color::DarkGray), Span::from(" | ").style(ratatui::style::Color::DarkGray),
Span::from(format!("{} tracks", self.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue), Span::from(format!("{} tracks", self.scene.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue),
Span::from(" | ").style(ratatui::style::Color::DarkGray), Span::from(" | ").style(ratatui::style::Color::DarkGray),
Span::from(format!("Time Remaining: {}", formatted_time)).style(time_style), Span::from(format!("Time Remaining: {}", formatted_time)).style(time_style),
Span::from(" | ").style(ratatui::style::Color::DarkGray), Span::from(" | ").style(ratatui::style::Color::DarkGray),
@@ -282,10 +260,10 @@ impl App {
} }
fn draw_narration(&self, frame: &mut Frame, area: Rect) { fn draw_narration(&self, frame: &mut Frame, area: Rect) {
let narrative_desc = if self.direction.narrative.is_empty() { let narrative_desc = if self.scene.direction.narrative.is_empty() {
Span::from("No narrative available.").style(ratatui::style::Color::DarkGray) Span::from("No narrative available.").style(ratatui::style::Color::DarkGray)
} else { } else {
Span::from(self.direction.narrative.clone()) Span::from(self.scene.direction.narrative.clone())
}; };
let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false }); let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false });
frame.render_widget(setting, area); frame.render_widget(setting, area);
@@ -384,11 +362,9 @@ impl App {
async fn insert_selected_prompt(&mut self) { async fn insert_selected_prompt(&mut self) {
let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone(); let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone();
if let Some(direction) = &selected.stage_direction { if let Some(direction) = &selected.stage_direction {
self.next_actions.push(ConversationEntry::StageDirection(direction.clone())); self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(direction.clone()))).await;
} }
self.next_actions.push(ConversationEntry::Eva(selected.text.clone())); self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Eva(selected.text.clone()))).await;
self.tts.speak(selected.text.clone()).await;
self.regenerate_responses();
} }
async fn on_command(&mut self, command: &str) { async fn on_command(&mut self, command: &str) {
@@ -398,38 +374,33 @@ impl App {
match command { match command {
// FIXME: Need some new kind of /bandcamp command to force loading of specific urls // FIXME: Need some new kind of /bandcamp command to force loading of specific urls
"/episode" => { "/episode" => {
if let Ok(episode_number) = arg.trim().parse::<u32>() { if let Ok(episode_number) = arg.trim().parse() {
self.direction.episode_number = episode_number; self.predictions.insert(scene::PredictionAction::SetEpisodeNumber(episode_number)).await;
self.sys_message_sink.send(format!("Updated episode number: {}", self.direction.episode_number)).await.unwrap();
self.reload_mixxx_playlist();
} else { } else {
self.sys_message_sink.send("Invalid episode number format. Use /episode [number]".into()).await.unwrap(); self.predictions.log("Invalid episode number format. Use /episode [number]".into()).await;
return; return;
} }
}, },
"/timer" => { "/timer" => {
if let Ok(minutes) = arg.trim().parse::<i64>() { if let Ok(minutes) = arg.trim().parse::<i64>() {
self.end_time = Utc::now() + Duration::minutes(minutes); let end_time = Utc::now() + Duration::minutes(minutes);
self.sys_message_sink.send(format!("Set timer for {} minutes.", minutes)).await.unwrap(); self.predictions.insert(PredictionAction::SetShowEndTime(end_time)).await;
self.predictions.log(format!("Set timer for {} minutes.", minutes)).await;
} else { } else {
self.sys_message_sink.send("Invalid timer format. Use /timer [minutes]".into()).await.unwrap(); self.predictions.log("Invalid timer format. Use /timer [minutes]".into()).await;
} }
}, },
"/narrative" => { "/narrative" => {
self.direction.narrative = arg.to_string(); self.predictions.insert(PredictionAction::SetNarrative(arg.to_string())).await;
self.sys_message_sink.send(format!("Updated stage direction: {}", self.direction.narrative)).await.unwrap();
self.regenerate_responses();
}, },
"/event" => { "/event" => {
self.next_actions.push(ConversationEntry::StageDirection(arg.to_string())); self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(arg.to_string()))).await;
self.regenerate_responses();
}, },
"/computer" => { "/computer" => {
self.next_actions.push(ConversationEntry::ShipComputer(arg.to_string())); self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::ShipComputer(arg.to_string()))).await;
self.regenerate_responses();
}, },
_ => { _ => {
self.sys_message_sink.send("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await.unwrap(); self.predictions.log("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await;
} }
} }
} }
@@ -469,12 +440,11 @@ impl App {
self.conversation_state.select_first(); self.conversation_state.select_first();
self.reply_state.select(None); self.reply_state.select(None);
}, },
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.regenerate_responses(), KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.predictions.regenerate_options().await,
KeyCode::Char('x') if key.modifiers.contains(KeyModifiers::CONTROL) => { KeyCode::Char('x') if key.modifiers.contains(KeyModifiers::CONTROL) => {
if self.recording_audio { if self.recording_audio {
self.recording_audio = false; self.recording_audio = false;
self.transcription.stop(); self.transcription.stop();
self.is_requesting = true;
} else { } else {
self.recording_audio = true; self.recording_audio = true;
self.transcription.start(); self.transcription.start();
@@ -493,8 +463,7 @@ impl App {
if next_msg.starts_with("/") { if next_msg.starts_with("/") {
self.on_command(&next_msg).await; self.on_command(&next_msg).await;
} else { } else {
self.next_actions.push(ConversationEntry::User(next_msg)); self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(next_msg))).await;
self.regenerate_responses();
} }
} }
}, },
@@ -504,25 +473,6 @@ impl App {
} }
} }
} }
fn regenerate_responses(&mut self) {
let actions = StageActions {
direction: self.direction.clone(),
additions: std::mem::take(&mut self.next_actions)
};
self.scene.reply_options_mut().clear();
self.scene.conversation_mut().append(&mut actions.additions.clone());
self.prediction_request_sink.send(actions).unwrap();
self.is_requesting = true;
}
fn reload_mixxx_playlist(&mut self) {
if let Err(err) = self.direction.reload_mixxx_playlist() {
self.next_actions.push(ConversationEntry::SystemMessage(format!("Error while loading mixxx playlist: {:?}", err)));
} else {
self.next_actions.push(ConversationEntry::SystemMessage(format!("Mixxx playlist reloaded. {} tracks found.", self.direction.current_playlist.len()).into()));
}
}
} }
@@ -562,7 +512,7 @@ async fn main() {
let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init(); let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init();
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::channel(32); let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32);
let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") { let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") {
if let Ok(ret) = serde_json::from_str(&save_data) { if let Ok(ret) = serde_json::from_str(&save_data) {
@@ -577,12 +527,12 @@ async fn main() {
SaveData::default() SaveData::default()
}; };
let prediction_ctrl = prediction::start_prediction(saved_session).await;
let (audio_ctrl, mic_stream, tts_output) = start_audio_input(&sys_message_sink).await; let (audio_ctrl, mic_stream, tts_output) = start_audio_input(&sys_message_sink).await;
let tts_ctrl = start_tts(tts_output).await; let tts_ctrl = start_tts(tts_output).await;
let (prediction_request_in, mut prediction_out) = prediction::start_prediction(sys_message_src, saved_session.messages, saved_session.scenery).await;
let transcription_ctrl = transcription::start_transcription(mic_stream).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await;
let mut app = App::new(prediction_request_in, audio_ctrl, transcription_ctrl, tts_ctrl, sys_message_sink, saved_session.direction); let mut app = App::new(prediction_ctrl, audio_ctrl, transcription_ctrl, tts_ctrl);
let mut events = EventStream::new(); let mut events = EventStream::new();
let mut last_tick = Instant::now(); let mut last_tick = Instant::now();
@@ -592,7 +542,6 @@ async fn main() {
last_tick = Instant::now(); last_tick = Instant::now();
app.throbber_state.calc_next(); app.throbber_state.calc_next();
} }
app.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
terminal.draw(|frame| { app.draw(frame)}).unwrap(); terminal.draw(|frame| { app.draw(frame)}).unwrap();
let delay = Delay::new(std::time::Duration::from_millis(60)).fuse(); let delay = Delay::new(std::time::Duration::from_millis(60)).fuse();
@@ -600,17 +549,23 @@ async fn main() {
tokio::select! { tokio::select! {
_ = delay => (), _ = delay => (),
_ = prediction_out.changed() => { Some(next_log) = sys_message_src.recv() => {
app.scene = prediction_out.borrow().clone(); app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(next_log))).await;
app.reply_state.select_first(); },
app.is_requesting = false; next_update = app.predictions.changed() => {
match next_update {
SessionUpdate::Thinking(is_thinking) => app.is_requesting = is_thinking,
SessionUpdate::Scene(scene) => {
app.scene = scene;
app.reply_state.select_first();
}
}
}, },
next_volume = app.audio.next() => { next_volume = app.audio.next() => {
app.audio_level = next_volume app.audio_level = next_volume
}, },
transcription_result = app.transcription.next() => { transcription_result = app.transcription.next() => {
app.next_actions.push(ConversationEntry::User(transcription_result)); app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(transcription_result))).await;
app.regenerate_responses();
}, },
maybe_event = event => { maybe_event = event => {
match maybe_event { match maybe_event {
+115 -46
View File
@@ -3,12 +3,12 @@ use std::process::{Command, Stdio};
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
use bandcamp::SearchResultItem; use bandcamp::SearchResultItem;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use color_eyre::eyre::eyre;
use schemars::{JsonSchema, schema_for}; use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Serializer, ser::CompactFormatter}; use serde_json::{Serializer, ser::CompactFormatter};
use tokio::sync::{mpsc, watch};
use crate::{SaveData, scene::{Artifact, ConversationEntry, Scene, Scenery, StageActions, StageDirection}}; use crate::{SaveData, scene::{Artifact, ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}};
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
@@ -31,8 +31,10 @@ struct Session {
header_message: ChatCompletionRequestMessage, header_message: ChatCompletionRequestMessage,
messages: Vec<ChatCompletionRequestMessage>, messages: Vec<ChatCompletionRequestMessage>,
reply_options: GeneratedResponses, reply_options: GeneratedResponses,
direction: StageDirection,
scenery: Scenery, scenery: Scenery,
tokens_consumed: usize tokens_consumed: usize,
activity_notify: watch::Sender<bool>
} }
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
@@ -91,7 +93,7 @@ struct ToolResults {
} }
impl Session { impl Session {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> Self { fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
let mut conversation = vec![]; let mut conversation = vec![];
for msg in &messages { for msg in &messages {
if let Ok(conversation_msg) = msg.clone().try_into() { if let Ok(conversation_msg) = msg.clone().try_into() {
@@ -106,13 +108,9 @@ impl Session {
messages, messages,
reply_options: Default::default(), reply_options: Default::default(),
scenery, scenery,
tokens_consumed: 0 direction,
} tokens_consumed: 0,
} activity_notify,
fn insert_actions(&mut self, actions: &StageActions) {
for addition in &actions.additions {
self.insert_conversation(addition.clone());
} }
} }
@@ -213,9 +211,10 @@ impl Session {
full_conversation full_conversation
} }
async fn regenerate_options(&mut self, direction: &StageDirection) { async fn regenerate_options(&mut self) {
self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }});
loop { loop {
let full_conversation = self.generate_conversation(direction); let full_conversation = self.generate_conversation(&self.direction);
let tools = vec![ let tools = vec![
ChatCompletionTools::Function(ChatCompletionTool { ChatCompletionTools::Function(ChatCompletionTool {
@@ -268,11 +267,11 @@ impl Session {
match message.finish_reason { match message.finish_reason {
Some(FinishReason::ContentFilter) => { Some(FinishReason::ContentFilter) => {
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into())); self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
return; break;
}, },
Some(FinishReason::Length) => { Some(FinishReason::Length) => {
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into())); self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
return; break;
}, },
_ => () _ => ()
} }
@@ -317,7 +316,7 @@ impl Session {
if let Some(content) = message.message.content.as_ref() { if let Some(content) = message.message.content.as_ref() {
if let Ok(options) = serde_json::from_str(content.as_str()) { if let Ok(options) = serde_json::from_str(content.as_str()) {
self.reply_options = options; self.reply_options = options;
return; break;
} else { } else {
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into())); self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
} }
@@ -326,10 +325,11 @@ impl Session {
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into())); self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
} }
} }
self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }});
} }
fn as_scene(&self) -> Scene { fn as_scene(&self) -> Scene {
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed) Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
} }
fn insert_conversation(&mut self, entry: ConversationEntry) { fn insert_conversation(&mut self, entry: ConversationEntry) {
@@ -341,47 +341,116 @@ impl Session {
} }
} }
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) { #[derive(Debug)]
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default()); pub struct SessionControl {
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default()); event_sink: mpsc::Sender<PredictionAction>,
scene_watch: watch::Receiver<Scene>,
activity_watch: watch::Receiver<bool>
}
let mut session = Session::from_initial_messages(initial_messages, scenery); #[derive(Debug)]
pub enum SessionUpdate {
Scene(Scene),
Thinking(bool)
}
impl SessionControl {
pub async fn insert(&self, action: PredictionAction) {
self.event_sink.send(action).await.unwrap();
}
pub async fn log(&self, message: String) {
self.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(message))).await;
}
pub async fn regenerate_options(&self) {
self.insert(PredictionAction::GeneratePredictions).await;
}
pub async fn changed(&mut self) -> SessionUpdate {
tokio::select! {
_ = self.activity_watch.changed() => {
SessionUpdate::Thinking(*self.activity_watch.borrow_and_update())
},
_ = self.scene_watch.changed() => {
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
}
}
}
}
pub async fn start_prediction(saved_session: SaveData) -> SessionControl {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
let (action_sink, mut action_src) = mpsc::channel(5);
let mut session = Session::from_initial_messages(saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
// Send the initial scene to the UI, after we have loaded the session from the first messages. // Send the initial scene to the UI, after we have loaded the session from the first messages.
prediction_in.send(session.as_scene()).unwrap(); prediction_in.send(session.as_scene()).unwrap();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
tokio::select! { if let Some(evt) = action_src.recv().await {
maybe_message = sys_message_src.recv() => { let do_regen = match evt {
if let Some(message) = maybe_message { PredictionAction::ConversationAppend(msg) => {
session.insert_conversation(ConversationEntry::SystemMessage(message)); let do_regen = match msg {
prediction_in.send(session.as_scene()).unwrap(); ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true,
} _ => false
},
maybe_request = prediction_request_out.changed() => {
if maybe_request.is_ok() {
let next_cxt = prediction_request_out.borrow().clone();
session.insert_actions(&next_cxt);
let mut save_data = SaveData {
direction: next_cxt.direction,
messages: session.messages.clone(),
scenery: session.scenery.clone()
}; };
session.insert_conversation(msg);
save_data.save(); do_regen
},
session.regenerate_options(&save_data.direction).await; PredictionAction::SetEpisodeNumber(num) => {
session.direction.episode_number = num;
save_data.messages = session.messages.clone(); if let Err(err) = session.direction.reload_mixxx_playlist() {
save_data.save(); session.insert_conversation(ConversationEntry::SystemMessage(format!("Failed to load mixxx playlist: {:?}.", err).into()));
prediction_in.send(session.as_scene()).unwrap(); } else {
session.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
}
false
},
PredictionAction::GeneratePredictions => {
true
},
PredictionAction::SetNarrative(narrative) => {
session.direction.narrative = narrative;
session.insert_conversation(ConversationEntry::SystemMessage("Updated stage direction narrative".into()));
true
},
PredictionAction::SetShowEndTime(end_time) => {
session.direction.end_time = end_time;
false
} }
};
let save_data = SaveData {
direction: session.direction.clone(),
messages: session.messages.clone(),
scenery: session.scenery.clone()
};
save_data.save();
if do_regen {
session.reply_options.responses.clear();
} }
};
prediction_in.send(session.as_scene()).unwrap();
if do_regen {
session.regenerate_options().await;
prediction_in.send(session.as_scene()).unwrap();
}
}
} }
}); });
(prediction_request_in, prediction_out) SessionControl {
event_sink: action_sink,
scene_watch: prediction_out,
activity_watch: activity_notify_src
}
} }
+34 -18
View File
@@ -1,5 +1,5 @@
use async_openai::types::chat::*; use async_openai::types::chat::*;
use chrono::Duration; use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlite::OpenFlags; use sqlite::OpenFlags;
@@ -48,14 +48,32 @@ impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
type Error = (); type Error = ();
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StageDirection { pub struct StageDirection {
pub episode_number: u32, pub episode_number: u32,
pub time_remaining: Duration, #[serde(skip)]
pub end_time: DateTime<Utc>,
pub narrative: String, pub narrative: String,
pub current_playlist: Vec<PlaylistEntry> pub current_playlist: Vec<PlaylistEntry>
} }
impl StageDirection {
pub fn time_remaining(&self) -> Duration {
self.end_time.signed_duration_since(Utc::now())
}
}
impl Default for StageDirection {
fn default() -> Self {
Self {
episode_number: 0,
end_time: Utc::now() + Duration::hours(2),
narrative: Default::default(),
current_playlist: Default::default(),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Artifact { pub enum Artifact {
Bandcamp(BandcampResult), Bandcamp(BandcampResult),
@@ -81,10 +99,11 @@ impl From<sqlite::Error> for MixxxError {
impl StageDirection { impl StageDirection {
pub fn reload_mixxx_playlist(&mut self) -> Result<(), MixxxError> { pub fn reload_mixxx_playlist(&mut self) -> Result<(), MixxxError> {
self.current_playlist.clear(); self.current_playlist.clear();
let playlist_name = format!("BFF.fm - Episode {}", self.episode_number);
let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only())?; let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only())?;
let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1"; let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1";
let mut statement = connection.prepare(query)?; let mut statement = connection.prepare(query)?;
statement.bind((1, format!("BFF.fm - Episode {}", self.episode_number).as_str()))?; statement.bind((1, playlist_name.as_str()))?;
statement.next()?; statement.next()?;
let latest_id = statement.read::<i64, _>("id").unwrap(); let latest_id = statement.read::<i64, _>("id").unwrap();
@@ -106,10 +125,13 @@ impl StageDirection {
} }
} }
#[derive(Debug, Default, Clone)] #[derive(Debug, Clone)]
pub struct StageActions { pub enum PredictionAction {
pub direction: StageDirection, ConversationAppend(ConversationEntry),
pub additions: Vec<ConversationEntry> SetEpisodeNumber(u32),
GeneratePredictions,
SetNarrative(String),
SetShowEndTime(DateTime<Utc>)
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
@@ -124,17 +146,19 @@ pub struct PlaylistEntry {
pub struct Scene { pub struct Scene {
reply_options: GeneratedResponses, reply_options: GeneratedResponses,
conversation: Vec<ConversationEntry>, conversation: Vec<ConversationEntry>,
pub direction: StageDirection,
pub tokens_consumed: usize, pub tokens_consumed: usize,
scenery: Scenery scenery: Scenery
} }
impl Scene { impl Scene {
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize) -> Self { pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
Self { Self {
reply_options, reply_options,
conversation, conversation,
scenery, scenery,
tokens_consumed tokens_consumed,
direction
} }
} }
@@ -146,15 +170,7 @@ impl Scene {
&self.conversation &self.conversation
} }
pub fn conversation_mut(&mut self) -> &mut Vec<ConversationEntry> {
&mut self.conversation
}
pub fn reply_options(&self) -> &Vec<PossibleResponse> { pub fn reply_options(&self) -> &Vec<PossibleResponse> {
&self.reply_options.responses &self.reply_options.responses
} }
pub fn reply_options_mut(&mut self) -> &mut Vec<PossibleResponse> {
&mut self.reply_options.responses
}
} }