main: refactor prediction engine to use an event stream
This commit is contained in:
+47
-92
@@ -1,19 +1,19 @@
|
||||
use async_openai::types::chat::ChatCompletionRequestMessage;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use chrono::{Duration, Utc};
|
||||
use futures_timer::Delay;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListItem, ListState, Paragraph, Wrap}};
|
||||
use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListState, Paragraph, Wrap}};
|
||||
use throbber_widgets_tui::{Throbber, ThrobberState};
|
||||
use crossterm::{event::{self, EventStream, KeyCode, KeyModifiers}};
|
||||
use tokio::{sync::{mpsc, watch}, time::Instant};
|
||||
use tokio::time::Instant;
|
||||
use tui_input::{Input, backend::crossterm::EventHandler};
|
||||
use futures::{StreamExt, future::FutureExt};
|
||||
|
||||
use ratatui::prelude::*;
|
||||
use tui_skeleton::{AnimationMode, SkeletonText};
|
||||
|
||||
use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{BandcampResult, PossibleResponse}, scene::{ConversationEntry, Scene, Scenery, StageActions, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}};
|
||||
use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{SessionControl, SessionUpdate}, scene::{ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}};
|
||||
|
||||
mod scene;
|
||||
mod events;
|
||||
@@ -51,26 +51,9 @@ mod audio;
|
||||
- Right panel: Shortcuts for triggering scenarios, soundboard events, etc
|
||||
*/
|
||||
|
||||
impl<'a> Into<ListItem<'a>> for PossibleResponse {
|
||||
fn into(self) -> ListItem<'a> {
|
||||
if let Some(direction) = self.stage_direction {
|
||||
Line::from_iter([
|
||||
//Span::from(format!("({})", direction)).style(ratatui::style::Color::Yellow),
|
||||
//Span::from(" "),
|
||||
Span::from(self.text)
|
||||
]).into()
|
||||
} else {
|
||||
Line::from(self.text).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct App {
|
||||
scene: Scene,
|
||||
direction: StageDirection,
|
||||
next_actions: Vec<ConversationEntry>,
|
||||
end_time: DateTime<Utc>,
|
||||
|
||||
reply_state: ListState,
|
||||
conversation_state: ListState,
|
||||
@@ -82,10 +65,9 @@ struct App {
|
||||
focus_state: FocusState,
|
||||
|
||||
transcription: TranscriptionControl,
|
||||
prediction_request_sink: watch::Sender<StageActions>,
|
||||
audio: AudioInputControl,
|
||||
tts: TtsControl,
|
||||
sys_message_sink: mpsc::Sender<String>
|
||||
predictions: SessionControl
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -95,17 +77,13 @@ enum FocusState {
|
||||
}
|
||||
|
||||
impl App {
|
||||
fn new(prediction_request_sink: watch::Sender<StageActions>, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl, sys_message_sink: mpsc::Sender<String>, initial_direction: StageDirection) -> Self {
|
||||
fn new(predictions: SessionControl, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl) -> Self {
|
||||
Self {
|
||||
scene: Default::default(),
|
||||
direction: initial_direction,
|
||||
next_actions: Default::default(),
|
||||
reply_state: Default::default(),
|
||||
conversation_state: Default::default(),
|
||||
user_input: Default::default(),
|
||||
end_time: Utc::now() + Duration::hours(2),
|
||||
throbber_state: Default::default(),
|
||||
prediction_request_sink,
|
||||
is_requesting: false,
|
||||
audio_level: -60.,
|
||||
audio,
|
||||
@@ -113,7 +91,7 @@ impl App {
|
||||
transcription,
|
||||
focus_state: FocusState::UserInput,
|
||||
tts,
|
||||
sys_message_sink
|
||||
predictions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +161,7 @@ impl App {
|
||||
|
||||
fn draw_options(&mut self, frame: &mut Frame, area: Rect) {
|
||||
let borders = Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)");
|
||||
if self.scene.reply_options().len() == 0 && self.is_requesting {
|
||||
if self.is_requesting {
|
||||
let list = SkeletonText::new(std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64)
|
||||
.braille(true)
|
||||
.line_widths(&[0.25, 0.5, 0.4, 0.6])
|
||||
@@ -242,10 +220,10 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_status(&self, frame: &mut Frame, area: Rect) {
|
||||
let minutes_remaining = self.direction.time_remaining.num_seconds() / 60;
|
||||
|
||||
let negative = self.direction.time_remaining.abs() != self.direction.time_remaining;
|
||||
let time_remaining: Duration = self.scene.direction.time_remaining();
|
||||
let minutes_remaining = time_remaining.num_seconds() / 60;
|
||||
|
||||
let negative = time_remaining.abs() != time_remaining;
|
||||
|
||||
let time_style = if minutes_remaining <= 0 || negative {
|
||||
Style::new().fg(ratatui::style::Color::LightRed).bold()
|
||||
@@ -262,15 +240,15 @@ impl App {
|
||||
};
|
||||
|
||||
let formatted_time = if negative {
|
||||
format!("-{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours().abs(), self.direction.time_remaining.num_minutes().abs()% 60, self.direction.time_remaining.num_seconds().abs() % 60)
|
||||
format!("-{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours().abs(), time_remaining.num_minutes().abs()% 60, time_remaining.num_seconds().abs() % 60)
|
||||
} else {
|
||||
format!("{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours(), self.direction.time_remaining.num_minutes() % 60, self.direction.time_remaining.num_seconds() % 60)
|
||||
format!("{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours(), time_remaining.num_minutes() % 60, time_remaining.num_seconds() % 60)
|
||||
};
|
||||
|
||||
let status_line = Line::from_iter([
|
||||
Span::from(format!("Episode {}", self.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(format!("Episode {}", self.scene.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
Span::from(format!("{} tracks", self.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(format!("{} tracks", self.scene.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
Span::from(format!("Time Remaining: {}", formatted_time)).style(time_style),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
@@ -282,10 +260,10 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_narration(&self, frame: &mut Frame, area: Rect) {
|
||||
let narrative_desc = if self.direction.narrative.is_empty() {
|
||||
let narrative_desc = if self.scene.direction.narrative.is_empty() {
|
||||
Span::from("No narrative available.").style(ratatui::style::Color::DarkGray)
|
||||
} else {
|
||||
Span::from(self.direction.narrative.clone())
|
||||
Span::from(self.scene.direction.narrative.clone())
|
||||
};
|
||||
let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false });
|
||||
frame.render_widget(setting, area);
|
||||
@@ -384,11 +362,9 @@ impl App {
|
||||
async fn insert_selected_prompt(&mut self) {
|
||||
let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone();
|
||||
if let Some(direction) = &selected.stage_direction {
|
||||
self.next_actions.push(ConversationEntry::StageDirection(direction.clone()));
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(direction.clone()))).await;
|
||||
}
|
||||
self.next_actions.push(ConversationEntry::Eva(selected.text.clone()));
|
||||
self.tts.speak(selected.text.clone()).await;
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Eva(selected.text.clone()))).await;
|
||||
}
|
||||
|
||||
async fn on_command(&mut self, command: &str) {
|
||||
@@ -398,38 +374,33 @@ impl App {
|
||||
match command {
|
||||
// FIXME: Need some new kind of /bandcamp command to force loading of specific urls
|
||||
"/episode" => {
|
||||
if let Ok(episode_number) = arg.trim().parse::<u32>() {
|
||||
self.direction.episode_number = episode_number;
|
||||
self.sys_message_sink.send(format!("Updated episode number: {}", self.direction.episode_number)).await.unwrap();
|
||||
self.reload_mixxx_playlist();
|
||||
if let Ok(episode_number) = arg.trim().parse() {
|
||||
self.predictions.insert(scene::PredictionAction::SetEpisodeNumber(episode_number)).await;
|
||||
} else {
|
||||
self.sys_message_sink.send("Invalid episode number format. Use /episode [number]".into()).await.unwrap();
|
||||
self.predictions.log("Invalid episode number format. Use /episode [number]".into()).await;
|
||||
return;
|
||||
}
|
||||
},
|
||||
"/timer" => {
|
||||
if let Ok(minutes) = arg.trim().parse::<i64>() {
|
||||
self.end_time = Utc::now() + Duration::minutes(minutes);
|
||||
self.sys_message_sink.send(format!("Set timer for {} minutes.", minutes)).await.unwrap();
|
||||
let end_time = Utc::now() + Duration::minutes(minutes);
|
||||
self.predictions.insert(PredictionAction::SetShowEndTime(end_time)).await;
|
||||
self.predictions.log(format!("Set timer for {} minutes.", minutes)).await;
|
||||
} else {
|
||||
self.sys_message_sink.send("Invalid timer format. Use /timer [minutes]".into()).await.unwrap();
|
||||
self.predictions.log("Invalid timer format. Use /timer [minutes]".into()).await;
|
||||
}
|
||||
},
|
||||
"/narrative" => {
|
||||
self.direction.narrative = arg.to_string();
|
||||
self.sys_message_sink.send(format!("Updated stage direction: {}", self.direction.narrative)).await.unwrap();
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::SetNarrative(arg.to_string())).await;
|
||||
},
|
||||
"/event" => {
|
||||
self.next_actions.push(ConversationEntry::StageDirection(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(arg.to_string()))).await;
|
||||
},
|
||||
"/computer" => {
|
||||
self.next_actions.push(ConversationEntry::ShipComputer(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::ShipComputer(arg.to_string()))).await;
|
||||
},
|
||||
_ => {
|
||||
self.sys_message_sink.send("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await.unwrap();
|
||||
self.predictions.log("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -469,12 +440,11 @@ impl App {
|
||||
self.conversation_state.select_first();
|
||||
self.reply_state.select(None);
|
||||
},
|
||||
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.regenerate_responses(),
|
||||
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.predictions.regenerate_options().await,
|
||||
KeyCode::Char('x') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
if self.recording_audio {
|
||||
self.recording_audio = false;
|
||||
self.transcription.stop();
|
||||
self.is_requesting = true;
|
||||
} else {
|
||||
self.recording_audio = true;
|
||||
self.transcription.start();
|
||||
@@ -493,8 +463,7 @@ impl App {
|
||||
if next_msg.starts_with("/") {
|
||||
self.on_command(&next_msg).await;
|
||||
} else {
|
||||
self.next_actions.push(ConversationEntry::User(next_msg));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(next_msg))).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -504,25 +473,6 @@ impl App {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn regenerate_responses(&mut self) {
|
||||
let actions = StageActions {
|
||||
direction: self.direction.clone(),
|
||||
additions: std::mem::take(&mut self.next_actions)
|
||||
};
|
||||
self.scene.reply_options_mut().clear();
|
||||
self.scene.conversation_mut().append(&mut actions.additions.clone());
|
||||
self.prediction_request_sink.send(actions).unwrap();
|
||||
self.is_requesting = true;
|
||||
}
|
||||
|
||||
fn reload_mixxx_playlist(&mut self) {
|
||||
if let Err(err) = self.direction.reload_mixxx_playlist() {
|
||||
self.next_actions.push(ConversationEntry::SystemMessage(format!("Error while loading mixxx playlist: {:?}", err)));
|
||||
} else {
|
||||
self.next_actions.push(ConversationEntry::SystemMessage(format!("Mixxx playlist reloaded. {} tracks found.", self.direction.current_playlist.len()).into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -562,7 +512,7 @@ async fn main() {
|
||||
|
||||
let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init();
|
||||
|
||||
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
|
||||
let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") {
|
||||
if let Ok(ret) = serde_json::from_str(&save_data) {
|
||||
@@ -577,12 +527,12 @@ async fn main() {
|
||||
SaveData::default()
|
||||
};
|
||||
|
||||
let prediction_ctrl = prediction::start_prediction(saved_session).await;
|
||||
let (audio_ctrl, mic_stream, tts_output) = start_audio_input(&sys_message_sink).await;
|
||||
let tts_ctrl = start_tts(tts_output).await;
|
||||
let (prediction_request_in, mut prediction_out) = prediction::start_prediction(sys_message_src, saved_session.messages, saved_session.scenery).await;
|
||||
let transcription_ctrl = transcription::start_transcription(mic_stream).await;
|
||||
|
||||
let mut app = App::new(prediction_request_in, audio_ctrl, transcription_ctrl, tts_ctrl, sys_message_sink, saved_session.direction);
|
||||
let mut app = App::new(prediction_ctrl, audio_ctrl, transcription_ctrl, tts_ctrl);
|
||||
|
||||
let mut events = EventStream::new();
|
||||
let mut last_tick = Instant::now();
|
||||
@@ -592,7 +542,6 @@ async fn main() {
|
||||
last_tick = Instant::now();
|
||||
app.throbber_state.calc_next();
|
||||
}
|
||||
app.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
|
||||
terminal.draw(|frame| { app.draw(frame)}).unwrap();
|
||||
|
||||
let delay = Delay::new(std::time::Duration::from_millis(60)).fuse();
|
||||
@@ -600,17 +549,23 @@ async fn main() {
|
||||
|
||||
tokio::select! {
|
||||
_ = delay => (),
|
||||
_ = prediction_out.changed() => {
|
||||
app.scene = prediction_out.borrow().clone();
|
||||
app.reply_state.select_first();
|
||||
app.is_requesting = false;
|
||||
Some(next_log) = sys_message_src.recv() => {
|
||||
app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(next_log))).await;
|
||||
},
|
||||
next_update = app.predictions.changed() => {
|
||||
match next_update {
|
||||
SessionUpdate::Thinking(is_thinking) => app.is_requesting = is_thinking,
|
||||
SessionUpdate::Scene(scene) => {
|
||||
app.scene = scene;
|
||||
app.reply_state.select_first();
|
||||
}
|
||||
}
|
||||
},
|
||||
next_volume = app.audio.next() => {
|
||||
app.audio_level = next_volume
|
||||
},
|
||||
transcription_result = app.transcription.next() => {
|
||||
app.next_actions.push(ConversationEntry::User(transcription_result));
|
||||
app.regenerate_responses();
|
||||
app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(transcription_result))).await;
|
||||
},
|
||||
maybe_event = event => {
|
||||
match maybe_event {
|
||||
|
||||
Reference in New Issue
Block a user