prediction: completely rewrite the prediction engine by moving all the conversation manipulation into that task out of the UI

This commit is contained in:
2026-06-04 21:34:10 +02:00
parent 57e3ff9b55
commit 49c720fe46
5 changed files with 368 additions and 219 deletions
+92 -137
View File
@@ -1,7 +1,6 @@
use async_openai::{types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionResponse}};
use async_openai::types::chat::ChatCompletionRequestMessage;
use chrono::{DateTime, Duration, Utc};
use futures_timer::Delay;
use schemars::JsonSchema;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
@@ -13,6 +12,16 @@ use tokio::{sync::{mpsc, watch}, time::Instant};
use tui_input::{Input, backend::crossterm::EventHandler};
use futures::{StreamExt, future::FutureExt};
use ratatui::prelude::*;
use crate::{events::AudioRecordRequest, prediction::{PossibleResponse}, scene::{ConversationEntry, PlaylistEntry, Scene, StageActions, StageDirection}, tts::start_tts};
mod scene;
mod events;
mod transcription;
mod tts;
mod prediction;
// TODO: We should have a separate 'state.json' file, which remembers jack connections, and the world time for the show to end. Then we only update the 'time remaining' field in the scene and only deal with relative durations inside the scene data
// TODO: We should be able to delete entries from the conversation, or at least go back and edit something I said.
// TODO: I want a "mark" command or keyboard shortcut, that inserts a marker into the log, so I know where to come back for the next speaking segment.
@@ -41,22 +50,6 @@ use futures::{StreamExt, future::FutureExt};
- Right panel: Shortcuts for triggering scenarios, soundboard events, etc
*/
use ratatui::prelude::*;
use crate::{events::AudioRecordRequest, scene::{ConversationEntry, PlaylistEntry, Scene}, tts::start_tts};
mod scene;
mod events;
mod transcription;
mod tts;
mod prediction;
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
struct PossibleResponse {
text: String,
stage_direction: Option<String>
}
impl<'a> Into<ListItem<'a>> for PossibleResponse {
fn into(self) -> ListItem<'a> {
if let Some(direction) = self.stage_direction {
@@ -71,22 +64,13 @@ impl<'a> Into<ListItem<'a>> for PossibleResponse {
}
}
#[derive(JsonSchema, Deserialize, Serialize, Debug)]
struct GeneratedResponses {
responses: Vec<PossibleResponse>,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
struct StageEventArgs {
text: String,
}
#[derive(Debug)]
struct App {
scene: Scene,
direction: StageDirection,
next_actions: Vec<ConversationEntry>,
end_time: DateTime<Utc>,
next_reply_options: Vec<PossibleResponse>,
reply_state: ListState,
conversation_state: ListState,
user_input: Input,
@@ -97,7 +81,7 @@ struct App {
focus_state: FocusState,
audio_control_sink: watch::Sender<AudioRecordRequest>,
prediction_request_sink: watch::Sender<Scene>,
prediction_request_sink: watch::Sender<StageActions>,
tts_request_sink: mpsc::Sender<String>,
sys_message_sink: mpsc::Sender<String>
}
@@ -109,15 +93,16 @@ enum FocusState {
}
impl App {
fn new(prediction_request_sink: watch::Sender<Scene>, audio_control_sink: watch::Sender<AudioRecordRequest>, tts_request_sink: mpsc::Sender<String>, sys_message_sink: mpsc::Sender<String>) -> Self {
fn new(prediction_request_sink: watch::Sender<StageActions>, audio_control_sink: watch::Sender<AudioRecordRequest>, tts_request_sink: mpsc::Sender<String>, sys_message_sink: mpsc::Sender<String>, initial_direction: StageDirection) -> Self {
Self {
scene: Scene::default(),
next_reply_options: Vec::new(),
reply_state: ListState::default(),
conversation_state: ListState::default(),
user_input: Input::default(),
scene: Default::default(),
direction: initial_direction,
next_actions: Default::default(),
reply_state: Default::default(),
conversation_state: Default::default(),
user_input: Default::default(),
end_time: Utc::now() + Duration::hours(2),
throbber_state: ThrobberState::default(),
throbber_state: Default::default(),
prediction_request_sink,
is_requesting: false,
audio_level: -60.,
@@ -130,7 +115,7 @@ impl App {
}
fn draw_conversation(&mut self, frame: &mut Frame, area: Rect) {
let items: Vec<Line> = self.scene.conversation.iter().rev().map(|entry| {
let items: Vec<Line> = self.scene.conversation().iter().rev().map(|entry| {
match entry {
ConversationEntry::User(text) => Line::from_iter([Span::from("Argee: ").style(ratatui::style::Color::Magenta), Span::from(text)]),
ConversationEntry::Eva(text) => Line::from_iter([Span::from("Eva: ").style(ratatui::style::Color::Cyan), Span::from(text)]),
@@ -154,7 +139,7 @@ impl App {
fn draw_options(&mut self, frame: &mut Frame, area: Rect) {
frame.render_stateful_widget(
List::new(self.next_reply_options.clone())
List::new(self.scene.reply_options().clone())
.block(Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)"))
.style(ratatui::style::Color::White)
.highlight_symbol("> ")
@@ -184,7 +169,7 @@ impl App {
}
fn draw_status(&self, frame: &mut Frame, area: Rect) {
let minutes_remaining = self.scene.direction.time_remaining.num_seconds() / 60;
let minutes_remaining = self.direction.time_remaining.num_seconds() / 60;
let time_style = if minutes_remaining == 0 {
Style::new().fg(ratatui::style::Color::Red).bold().rapid_blink()
} else if minutes_remaining <= 5 {
@@ -199,26 +184,26 @@ impl App {
ratatui::style::Color::Blue.into()
};
let status_line = Line::from_iter([
Span::from(format!("Episode {}", self.scene.direction.episode_number)).style(ratatui::style::Color::LightBlue),
Span::from(format!("Episode {}", self.direction.episode_number)).style(ratatui::style::Color::LightBlue),
Span::from(" | ").style(ratatui::style::Color::DarkGray),
// FIXME: Looks weird with negative numbers, and it doesn't actually blink in the vscode terminal.
Span::from(format!("Time Remaining: {:0>2}:{:0>2}:{:0>2}", self.scene.direction.time_remaining.num_hours(), self.scene.direction.time_remaining.num_minutes() % 60, self.scene.direction.time_remaining.num_seconds() % 60)).style(time_style)
Span::from(format!("Time Remaining: {:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours(), self.direction.time_remaining.num_minutes() % 60, self.direction.time_remaining.num_seconds() % 60)).style(time_style)
]);
frame.render_widget(status_line, area);
}
fn draw_narration(&self, frame: &mut Frame, area: Rect) {
let narrative_desc = if self.scene.direction.narrative.is_empty() {
let narrative_desc = if self.direction.narrative.is_empty() {
Span::from("No narrative available.").style(ratatui::style::Color::DarkGray)
} else {
Span::from(self.scene.direction.narrative.clone())
Span::from(self.direction.narrative.clone())
};
let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false });
frame.render_widget(setting, area);
}
fn draw_event_log(&self, frame: &mut Frame, area: Rect) {
let items: Vec<Line> = self.scene.conversation.iter().filter(|entry| { if let ConversationEntry::StageDirection(_) = entry { true } else { false }}).rev().map(|entry| {
let items: Vec<Line> = self.scene.conversation().iter().filter(|entry| { if let ConversationEntry::StageDirection(_) = entry { true } else { false }}).rev().map(|entry| {
match entry {
ConversationEntry::StageDirection(text) => Line::from_iter([text]).style(ratatui::style::Color::Yellow),
_ => unreachable!()
@@ -289,12 +274,11 @@ impl App {
}
async fn insert_selected_prompt(&mut self) {
let selected = self.next_reply_options[self.reply_state.selected().unwrap()].clone();
let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone();
if let Some(direction) = &selected.stage_direction {
self.scene.insert_conversation(ConversationEntry::StageDirection(direction.clone()));
self.next_actions.push(ConversationEntry::StageDirection(direction.clone()));
}
self.scene.insert_conversation(ConversationEntry::Eva(selected.text.clone()));
self.save();
self.next_actions.push(ConversationEntry::Eva(selected.text.clone()));
self.speak(selected.text.clone()).await;
self.regenerate_responses();
}
@@ -316,7 +300,7 @@ impl App {
KeyCode::Up => self.conversation_state.select_next(),
KeyCode::Enter => {
let row_num = self.conversation_state.selected().unwrap();
if let ConversationEntry::Eva(text) = &self.scene.conversation[self.scene.conversation.len() - 1 - row_num] {
if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] {
self.speak(text.clone()).await;
self.focus_state = FocusState::UserInput;
self.conversation_state.select(None);
@@ -360,21 +344,19 @@ impl App {
"/bandcamp" => {
self.add_bandcamp_artifact(arg).await;
self.sys_message_sink.send(format!("Added Bandcamp artifact from {}", arg)).await.unwrap();
self.scene.insert_conversation(ConversationEntry::ShipComputer(format!("Incoming transmission.")));
self.next_actions.push(ConversationEntry::ShipComputer(format!("Incoming transmission from {}", arg)));
self.regenerate_responses();
},
"/episode" => {
if let Ok(episode_number) = arg.trim().parse::<u32>() {
self.scene.direction.episode_number = episode_number;
self.sys_message_sink.send(format!("Updated episode number: {}", self.scene.direction.episode_number)).await.unwrap();
self.direction.episode_number = episode_number;
self.sys_message_sink.send(format!("Updated episode number: {}", self.direction.episode_number)).await.unwrap();
self.reload_mixxx_playlist();
} else {
self.sys_message_sink.send("Invalid episode number format. Use /episode [number]".into()).await.unwrap();
return;
}
},
"/reload" => {
self.load().await;
self.reload_mixxx_playlist();
},
"/timer" => {
if let Ok(minutes) = arg.trim().parse::<i64>() {
self.end_time = Utc::now() + Duration::minutes(minutes);
@@ -386,23 +368,17 @@ impl App {
"/clear" => {
match arg.trim() {
"playlist" => {
self.scene.direction.current_playlist.clear();
self.direction.current_playlist.clear();
self.sys_message_sink.send("Cleared current playlist.".into()).await.unwrap();
return;
},
"artifacts" => {
self.scene.direction.artifacts.clear();
self.direction.artifacts.clear();
self.sys_message_sink.send("Cleared artifacts.".into()).await.unwrap();
return;
},
"all" => {
self.scene = Scene::default();
self.sys_message_sink.send("Cleared all data.".into()).await.unwrap();
},
"conversation" => {
self.scene.conversation.clear();
self.sys_message_sink.send("Cleared conversation.".into()).await.unwrap();
},
_ => {
self.sys_message_sink.send("Unknown clear command. Use /clear [playlist|artifacts|all]".into()).await.unwrap();
}
@@ -410,22 +386,25 @@ impl App {
return;
},
"/narrative" => {
self.scene.direction.narrative = arg.to_string();
self.sys_message_sink.send(format!("Updated stage direction: {}", self.scene.direction.narrative)).await.unwrap();
self.direction.narrative = arg.to_string();
self.sys_message_sink.send(format!("Updated stage direction: {}", self.direction.narrative)).await.unwrap();
},
"/event" => {
self.scene.insert_conversation(ConversationEntry::StageDirection(arg.to_string()));
}
self.next_actions.push(ConversationEntry::StageDirection(arg.to_string()));
self.regenerate_responses();
},
"/computer" => {
self.next_actions.push(ConversationEntry::ShipComputer(arg.to_string()));
self.regenerate_responses();
},
_ => {
self.sys_message_sink.send("Unknown command. Available commands: /bandcamp [url], /episode [number], /narrative [text], /reset".into()).await.unwrap();
return;
}
}
} else {
self.scene.insert_conversation(ConversationEntry::User(next_msg));
self.next_actions.push(ConversationEntry::User(next_msg));
self.regenerate_responses();
}
self.save();
self.regenerate_responses();
}
},
_ => {self.user_input.handle_event(&evt);},
@@ -441,25 +420,7 @@ impl App {
let fragment = Html::parse_document(&body);
let selector = Selector::parse("script[type=\"application/ld+json\"]").unwrap();
let json_ld = fragment.select(&selector).next().unwrap().inner_html();
self.scene.direction.artifacts.push(json_ld);
}
fn save(&self) {
let save_data = serde_json::to_string_pretty(&self.scene).unwrap();
std::fs::write("save.json", save_data).unwrap();
}
async fn load(&mut self) {
if let Ok(save_data) = std::fs::read_to_string("save.json") {
if let Ok(scene) = serde_json::from_str(&save_data) {
self.scene = scene;
// FIXME: These should get wiped out when we save as well, or even better, be completely excluded via a custom serde implementation.
self.scene.conversation.retain(|line| { if let ConversationEntry::SystemMessage(_) = line { false } else { true }});
self.sys_message_sink.send("Loaded stored session.".into()).await.unwrap();
} else {
self.sys_message_sink.send("Failed to load saved session!".into()).await.unwrap();
}
}
self.direction.artifacts.push(json_ld.trim().to_string());
}
async fn speak(&mut self, text: String) {
@@ -467,18 +428,23 @@ impl App {
}
fn regenerate_responses(&mut self) {
self.prediction_request_sink.send(self.scene.clone()).unwrap();
self.next_reply_options.clear();
let actions = StageActions {
direction: self.direction.clone(),
additions: std::mem::take(&mut self.next_actions)
};
self.scene.reply_options_mut().clear();
self.scene.conversation_mut().append(&mut actions.additions.clone());
self.prediction_request_sink.send(actions).unwrap();
self.is_requesting = true;
}
fn reload_mixxx_playlist(&mut self) {
// TODO: Should have some status message which states how many tracks are in the playlist
self.scene.direction.current_playlist.clear();
self.direction.current_playlist.clear();
let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only()).unwrap();
let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1";
let mut statement = connection.prepare(query).unwrap();
statement.bind((1, format!("BFF.fm - Episode {}", self.scene.direction.episode_number).as_str())).unwrap();
statement.bind((1, format!("BFF.fm - Episode {}", self.direction.episode_number).as_str())).unwrap();
statement.next().unwrap();
let latest_id = statement.read::<i64, _>("id").unwrap();
@@ -488,42 +454,28 @@ impl App {
let artist = track.try_read::<&str, _>("artist").unwrap_or("Unknown Artist");
let album = track.try_read::<&str, _>("album").unwrap_or("Unknown Album");
let bpm = track.try_read::<f64, _>("bpm").unwrap_or(0.);
self.scene.direction.current_playlist.push(PlaylistEntry {
self.direction.current_playlist.push(PlaylistEntry {
artist: artist.into(),
album: album.into(),
title: title.into(),
bpm
});
}
self.scene.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
self.next_actions.push(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
}
}
fn on_response(&mut self, response: &CreateChatCompletionResponse) {
self.is_requesting = false;
if let Some(calls) = &response.choices[0].message.tool_calls {
for call in calls {
match call {
ChatCompletionMessageToolCalls::Function(call) => {
if call.function.name == "log_stage_event" {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
self.scene.insert_conversation(ConversationEntry::StageDirection(args.text));
}
},
_ => panic!("Unkown tool call type"),
}
}
self.regenerate_responses();
} else {
if response.choices.is_empty() {
self.scene.insert_conversation(ConversationEntry::SystemMessage("OpenAI returned no responses".into()));
} else if response.choices[0].message.content.is_none() {
self.scene.insert_conversation(ConversationEntry::SystemMessage("OpenAI response did not contain content!".into()));
} else {
let json_resp: GeneratedResponses = serde_json::from_str(response.choices[0].message.content.as_ref().unwrap().as_str()).unwrap();
self.next_reply_options = json_resp.responses;
self.reply_state.select_first();
}
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct SaveData {
pub direction: StageDirection,
pub messages: Vec<ChatCompletionRequestMessage>
}
impl SaveData {
fn save(&self) {
let save_data = serde_json::to_string_pretty(self).unwrap();
std::fs::write("save.json", save_data).unwrap();
}
}
@@ -536,15 +488,21 @@ async fn main() {
return;
}
let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") {
serde_json::from_str(&save_data).unwrap_or_default()
//FIXME: Re-add load messages to sys log
} else {
SaveData::default()
};
let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init();
let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32);
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::channel(32);
let tts_request_sender = start_tts().await;
let (prediction_request_in, mut prediction_out) = prediction::start_prediction().await;
let (prediction_request_in, mut prediction_out) = prediction::start_prediction(sys_message_src, saved_session.messages).await;
let (mut audio_state_receiver, audio_control_in, mut transcription_out) = transcription::start_transcription(sys_message_sink.clone()).await;
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender, sys_message_sink);
app.load().await;
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender, sys_message_sink, saved_session.direction);
let mut events = EventStream::new();
let mut last_tick = Instant::now();
@@ -554,7 +512,7 @@ async fn main() {
last_tick = Instant::now();
app.throbber_state.calc_next();
}
app.scene.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
app.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
terminal.draw(|frame| { app.draw(frame)}).unwrap();
let delay = Delay::new(std::time::Duration::from_millis(60)).fuse();
@@ -563,18 +521,15 @@ async fn main() {
tokio::select! {
_ = delay => (),
_ = prediction_out.changed() => {
app.on_response(prediction_out.borrow_and_update().as_ref().unwrap());
app.scene = prediction_out.borrow().clone();
app.reply_state.select_first();
app.is_requesting = false;
},
_ = audio_state_receiver.changed() => {
app.audio_level = *audio_state_receiver.borrow_and_update();
},
maybe_message = sys_message_src.recv() => {
if let Some(message) = maybe_message {
app.scene.insert_conversation(ConversationEntry::SystemMessage(message));
}
app.audio_level = *audio_state_receiver.borrow();
},
maybe_transcription = transcription_out.recv() => {
app.scene.insert_conversation(ConversationEntry::User(maybe_transcription.unwrap()));
app.next_actions.push(ConversationEntry::User(maybe_transcription.unwrap()));
app.regenerate_responses();
}
maybe_event = event => {