main: refactor prediction engine to use an event stream
This commit is contained in:
+46
-91
@@ -1,19 +1,19 @@
|
||||
use async_openai::types::chat::ChatCompletionRequestMessage;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use chrono::{Duration, Utc};
|
||||
use futures_timer::Delay;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListItem, ListState, Paragraph, Wrap}};
|
||||
use ratatui::{Frame, layout::{Constraint, Direction, Layout}, widgets::{Block, BorderType, Clear, Gauge, List, ListDirection, ListState, Paragraph, Wrap}};
|
||||
use throbber_widgets_tui::{Throbber, ThrobberState};
|
||||
use crossterm::{event::{self, EventStream, KeyCode, KeyModifiers}};
|
||||
use tokio::{sync::{mpsc, watch}, time::Instant};
|
||||
use tokio::time::Instant;
|
||||
use tui_input::{Input, backend::crossterm::EventHandler};
|
||||
use futures::{StreamExt, future::FutureExt};
|
||||
|
||||
use ratatui::prelude::*;
|
||||
use tui_skeleton::{AnimationMode, SkeletonText};
|
||||
|
||||
use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{BandcampResult, PossibleResponse}, scene::{ConversationEntry, Scene, Scenery, StageActions, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}};
|
||||
use crate::{audio::{AudioInputControl, start_audio_input}, prediction::{SessionControl, SessionUpdate}, scene::{ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}, transcription::TranscriptionControl, tts::{TtsControl, start_tts}};
|
||||
|
||||
mod scene;
|
||||
mod events;
|
||||
@@ -51,26 +51,9 @@ mod audio;
|
||||
- Right panel: Shortcuts for triggering scenarios, soundboard events, etc
|
||||
*/
|
||||
|
||||
impl<'a> Into<ListItem<'a>> for PossibleResponse {
|
||||
fn into(self) -> ListItem<'a> {
|
||||
if let Some(direction) = self.stage_direction {
|
||||
Line::from_iter([
|
||||
//Span::from(format!("({})", direction)).style(ratatui::style::Color::Yellow),
|
||||
//Span::from(" "),
|
||||
Span::from(self.text)
|
||||
]).into()
|
||||
} else {
|
||||
Line::from(self.text).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct App {
|
||||
scene: Scene,
|
||||
direction: StageDirection,
|
||||
next_actions: Vec<ConversationEntry>,
|
||||
end_time: DateTime<Utc>,
|
||||
|
||||
reply_state: ListState,
|
||||
conversation_state: ListState,
|
||||
@@ -82,10 +65,9 @@ struct App {
|
||||
focus_state: FocusState,
|
||||
|
||||
transcription: TranscriptionControl,
|
||||
prediction_request_sink: watch::Sender<StageActions>,
|
||||
audio: AudioInputControl,
|
||||
tts: TtsControl,
|
||||
sys_message_sink: mpsc::Sender<String>
|
||||
predictions: SessionControl
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -95,17 +77,13 @@ enum FocusState {
|
||||
}
|
||||
|
||||
impl App {
|
||||
fn new(prediction_request_sink: watch::Sender<StageActions>, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl, sys_message_sink: mpsc::Sender<String>, initial_direction: StageDirection) -> Self {
|
||||
fn new(predictions: SessionControl, audio: AudioInputControl, transcription: TranscriptionControl, tts: TtsControl) -> Self {
|
||||
Self {
|
||||
scene: Default::default(),
|
||||
direction: initial_direction,
|
||||
next_actions: Default::default(),
|
||||
reply_state: Default::default(),
|
||||
conversation_state: Default::default(),
|
||||
user_input: Default::default(),
|
||||
end_time: Utc::now() + Duration::hours(2),
|
||||
throbber_state: Default::default(),
|
||||
prediction_request_sink,
|
||||
is_requesting: false,
|
||||
audio_level: -60.,
|
||||
audio,
|
||||
@@ -113,7 +91,7 @@ impl App {
|
||||
transcription,
|
||||
focus_state: FocusState::UserInput,
|
||||
tts,
|
||||
sys_message_sink
|
||||
predictions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +161,7 @@ impl App {
|
||||
|
||||
fn draw_options(&mut self, frame: &mut Frame, area: Rect) {
|
||||
let borders = Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)");
|
||||
if self.scene.reply_options().len() == 0 && self.is_requesting {
|
||||
if self.is_requesting {
|
||||
let list = SkeletonText::new(std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64)
|
||||
.braille(true)
|
||||
.line_widths(&[0.25, 0.5, 0.4, 0.6])
|
||||
@@ -242,10 +220,10 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_status(&self, frame: &mut Frame, area: Rect) {
|
||||
let minutes_remaining = self.direction.time_remaining.num_seconds() / 60;
|
||||
|
||||
let negative = self.direction.time_remaining.abs() != self.direction.time_remaining;
|
||||
let time_remaining: Duration = self.scene.direction.time_remaining();
|
||||
let minutes_remaining = time_remaining.num_seconds() / 60;
|
||||
|
||||
let negative = time_remaining.abs() != time_remaining;
|
||||
|
||||
let time_style = if minutes_remaining <= 0 || negative {
|
||||
Style::new().fg(ratatui::style::Color::LightRed).bold()
|
||||
@@ -262,15 +240,15 @@ impl App {
|
||||
};
|
||||
|
||||
let formatted_time = if negative {
|
||||
format!("-{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours().abs(), self.direction.time_remaining.num_minutes().abs()% 60, self.direction.time_remaining.num_seconds().abs() % 60)
|
||||
format!("-{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours().abs(), time_remaining.num_minutes().abs()% 60, time_remaining.num_seconds().abs() % 60)
|
||||
} else {
|
||||
format!("{:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours(), self.direction.time_remaining.num_minutes() % 60, self.direction.time_remaining.num_seconds() % 60)
|
||||
format!("{:0>2}:{:0>2}:{:0>2}", time_remaining.num_hours(), time_remaining.num_minutes() % 60, time_remaining.num_seconds() % 60)
|
||||
};
|
||||
|
||||
let status_line = Line::from_iter([
|
||||
Span::from(format!("Episode {}", self.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(format!("Episode {}", self.scene.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
Span::from(format!("{} tracks", self.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(format!("{} tracks", self.scene.direction.current_playlist.len())).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
Span::from(format!("Time Remaining: {}", formatted_time)).style(time_style),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
@@ -282,10 +260,10 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_narration(&self, frame: &mut Frame, area: Rect) {
|
||||
let narrative_desc = if self.direction.narrative.is_empty() {
|
||||
let narrative_desc = if self.scene.direction.narrative.is_empty() {
|
||||
Span::from("No narrative available.").style(ratatui::style::Color::DarkGray)
|
||||
} else {
|
||||
Span::from(self.direction.narrative.clone())
|
||||
Span::from(self.scene.direction.narrative.clone())
|
||||
};
|
||||
let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false });
|
||||
frame.render_widget(setting, area);
|
||||
@@ -384,11 +362,9 @@ impl App {
|
||||
async fn insert_selected_prompt(&mut self) {
|
||||
let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone();
|
||||
if let Some(direction) = &selected.stage_direction {
|
||||
self.next_actions.push(ConversationEntry::StageDirection(direction.clone()));
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(direction.clone()))).await;
|
||||
}
|
||||
self.next_actions.push(ConversationEntry::Eva(selected.text.clone()));
|
||||
self.tts.speak(selected.text.clone()).await;
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Eva(selected.text.clone()))).await;
|
||||
}
|
||||
|
||||
async fn on_command(&mut self, command: &str) {
|
||||
@@ -398,38 +374,33 @@ impl App {
|
||||
match command {
|
||||
// FIXME: Need some new kind of /bandcamp command to force loading of specific urls
|
||||
"/episode" => {
|
||||
if let Ok(episode_number) = arg.trim().parse::<u32>() {
|
||||
self.direction.episode_number = episode_number;
|
||||
self.sys_message_sink.send(format!("Updated episode number: {}", self.direction.episode_number)).await.unwrap();
|
||||
self.reload_mixxx_playlist();
|
||||
if let Ok(episode_number) = arg.trim().parse() {
|
||||
self.predictions.insert(scene::PredictionAction::SetEpisodeNumber(episode_number)).await;
|
||||
} else {
|
||||
self.sys_message_sink.send("Invalid episode number format. Use /episode [number]".into()).await.unwrap();
|
||||
self.predictions.log("Invalid episode number format. Use /episode [number]".into()).await;
|
||||
return;
|
||||
}
|
||||
},
|
||||
"/timer" => {
|
||||
if let Ok(minutes) = arg.trim().parse::<i64>() {
|
||||
self.end_time = Utc::now() + Duration::minutes(minutes);
|
||||
self.sys_message_sink.send(format!("Set timer for {} minutes.", minutes)).await.unwrap();
|
||||
let end_time = Utc::now() + Duration::minutes(minutes);
|
||||
self.predictions.insert(PredictionAction::SetShowEndTime(end_time)).await;
|
||||
self.predictions.log(format!("Set timer for {} minutes.", minutes)).await;
|
||||
} else {
|
||||
self.sys_message_sink.send("Invalid timer format. Use /timer [minutes]".into()).await.unwrap();
|
||||
self.predictions.log("Invalid timer format. Use /timer [minutes]".into()).await;
|
||||
}
|
||||
},
|
||||
"/narrative" => {
|
||||
self.direction.narrative = arg.to_string();
|
||||
self.sys_message_sink.send(format!("Updated stage direction: {}", self.direction.narrative)).await.unwrap();
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::SetNarrative(arg.to_string())).await;
|
||||
},
|
||||
"/event" => {
|
||||
self.next_actions.push(ConversationEntry::StageDirection(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::StageDirection(arg.to_string()))).await;
|
||||
},
|
||||
"/computer" => {
|
||||
self.next_actions.push(ConversationEntry::ShipComputer(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::ShipComputer(arg.to_string()))).await;
|
||||
},
|
||||
_ => {
|
||||
self.sys_message_sink.send("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await.unwrap();
|
||||
self.predictions.log("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]".into()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -469,12 +440,11 @@ impl App {
|
||||
self.conversation_state.select_first();
|
||||
self.reply_state.select(None);
|
||||
},
|
||||
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.regenerate_responses(),
|
||||
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => self.predictions.regenerate_options().await,
|
||||
KeyCode::Char('x') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
if self.recording_audio {
|
||||
self.recording_audio = false;
|
||||
self.transcription.stop();
|
||||
self.is_requesting = true;
|
||||
} else {
|
||||
self.recording_audio = true;
|
||||
self.transcription.start();
|
||||
@@ -493,8 +463,7 @@ impl App {
|
||||
if next_msg.starts_with("/") {
|
||||
self.on_command(&next_msg).await;
|
||||
} else {
|
||||
self.next_actions.push(ConversationEntry::User(next_msg));
|
||||
self.regenerate_responses();
|
||||
self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(next_msg))).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -504,25 +473,6 @@ impl App {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn regenerate_responses(&mut self) {
|
||||
let actions = StageActions {
|
||||
direction: self.direction.clone(),
|
||||
additions: std::mem::take(&mut self.next_actions)
|
||||
};
|
||||
self.scene.reply_options_mut().clear();
|
||||
self.scene.conversation_mut().append(&mut actions.additions.clone());
|
||||
self.prediction_request_sink.send(actions).unwrap();
|
||||
self.is_requesting = true;
|
||||
}
|
||||
|
||||
fn reload_mixxx_playlist(&mut self) {
|
||||
if let Err(err) = self.direction.reload_mixxx_playlist() {
|
||||
self.next_actions.push(ConversationEntry::SystemMessage(format!("Error while loading mixxx playlist: {:?}", err)));
|
||||
} else {
|
||||
self.next_actions.push(ConversationEntry::SystemMessage(format!("Mixxx playlist reloaded. {} tracks found.", self.direction.current_playlist.len()).into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -562,7 +512,7 @@ async fn main() {
|
||||
|
||||
let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init();
|
||||
|
||||
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
|
||||
let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") {
|
||||
if let Ok(ret) = serde_json::from_str(&save_data) {
|
||||
@@ -577,12 +527,12 @@ async fn main() {
|
||||
SaveData::default()
|
||||
};
|
||||
|
||||
let prediction_ctrl = prediction::start_prediction(saved_session).await;
|
||||
let (audio_ctrl, mic_stream, tts_output) = start_audio_input(&sys_message_sink).await;
|
||||
let tts_ctrl = start_tts(tts_output).await;
|
||||
let (prediction_request_in, mut prediction_out) = prediction::start_prediction(sys_message_src, saved_session.messages, saved_session.scenery).await;
|
||||
let transcription_ctrl = transcription::start_transcription(mic_stream).await;
|
||||
|
||||
let mut app = App::new(prediction_request_in, audio_ctrl, transcription_ctrl, tts_ctrl, sys_message_sink, saved_session.direction);
|
||||
let mut app = App::new(prediction_ctrl, audio_ctrl, transcription_ctrl, tts_ctrl);
|
||||
|
||||
let mut events = EventStream::new();
|
||||
let mut last_tick = Instant::now();
|
||||
@@ -592,7 +542,6 @@ async fn main() {
|
||||
last_tick = Instant::now();
|
||||
app.throbber_state.calc_next();
|
||||
}
|
||||
app.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
|
||||
terminal.draw(|frame| { app.draw(frame)}).unwrap();
|
||||
|
||||
let delay = Delay::new(std::time::Duration::from_millis(60)).fuse();
|
||||
@@ -600,17 +549,23 @@ async fn main() {
|
||||
|
||||
tokio::select! {
|
||||
_ = delay => (),
|
||||
_ = prediction_out.changed() => {
|
||||
app.scene = prediction_out.borrow().clone();
|
||||
Some(next_log) = sys_message_src.recv() => {
|
||||
app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(next_log))).await;
|
||||
},
|
||||
next_update = app.predictions.changed() => {
|
||||
match next_update {
|
||||
SessionUpdate::Thinking(is_thinking) => app.is_requesting = is_thinking,
|
||||
SessionUpdate::Scene(scene) => {
|
||||
app.scene = scene;
|
||||
app.reply_state.select_first();
|
||||
app.is_requesting = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
next_volume = app.audio.next() => {
|
||||
app.audio_level = next_volume
|
||||
},
|
||||
transcription_result = app.transcription.next() => {
|
||||
app.next_actions.push(ConversationEntry::User(transcription_result));
|
||||
app.regenerate_responses();
|
||||
app.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::User(transcription_result))).await;
|
||||
},
|
||||
maybe_event = event => {
|
||||
match maybe_event {
|
||||
|
||||
+108
-39
@@ -3,12 +3,12 @@ use std::process::{Command, Stdio};
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use bandcamp::SearchResultItem;
|
||||
use chrono::{DateTime, Utc};
|
||||
use color_eyre::eyre::eyre;
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Serializer, ser::CompactFormatter};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::{SaveData, scene::{Artifact, ConversationEntry, Scene, Scenery, StageActions, StageDirection}};
|
||||
use crate::{SaveData, scene::{Artifact, ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}};
|
||||
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
@@ -31,8 +31,10 @@ struct Session {
|
||||
header_message: ChatCompletionRequestMessage,
|
||||
messages: Vec<ChatCompletionRequestMessage>,
|
||||
reply_options: GeneratedResponses,
|
||||
direction: StageDirection,
|
||||
scenery: Scenery,
|
||||
tokens_consumed: usize
|
||||
tokens_consumed: usize,
|
||||
activity_notify: watch::Sender<bool>
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
@@ -91,7 +93,7 @@ struct ToolResults {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> Self {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
for msg in &messages {
|
||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||
@@ -106,13 +108,9 @@ impl Session {
|
||||
messages,
|
||||
reply_options: Default::default(),
|
||||
scenery,
|
||||
tokens_consumed: 0
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_actions(&mut self, actions: &StageActions) {
|
||||
for addition in &actions.additions {
|
||||
self.insert_conversation(addition.clone());
|
||||
direction,
|
||||
tokens_consumed: 0,
|
||||
activity_notify,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,9 +211,10 @@ impl Session {
|
||||
full_conversation
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
||||
async fn regenerate_options(&mut self) {
|
||||
self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }});
|
||||
loop {
|
||||
let full_conversation = self.generate_conversation(direction);
|
||||
let full_conversation = self.generate_conversation(&self.direction);
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
@@ -268,11 +267,11 @@ impl Session {
|
||||
match message.finish_reason {
|
||||
Some(FinishReason::ContentFilter) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
|
||||
return;
|
||||
break;
|
||||
},
|
||||
Some(FinishReason::Length) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
|
||||
return;
|
||||
break;
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
@@ -317,7 +316,7 @@ impl Session {
|
||||
if let Some(content) = message.message.content.as_ref() {
|
||||
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
||||
self.reply_options = options;
|
||||
return;
|
||||
break;
|
||||
} else {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
|
||||
}
|
||||
@@ -326,10 +325,11 @@ impl Session {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
|
||||
}
|
||||
}
|
||||
self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }});
|
||||
}
|
||||
|
||||
fn as_scene(&self) -> Scene {
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed)
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
||||
}
|
||||
|
||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
@@ -341,47 +341,116 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
|
||||
#[derive(Debug)]
|
||||
pub struct SessionControl {
|
||||
event_sink: mpsc::Sender<PredictionAction>,
|
||||
scene_watch: watch::Receiver<Scene>,
|
||||
activity_watch: watch::Receiver<bool>
|
||||
}
|
||||
|
||||
let mut session = Session::from_initial_messages(initial_messages, scenery);
|
||||
#[derive(Debug)]
|
||||
pub enum SessionUpdate {
|
||||
Scene(Scene),
|
||||
Thinking(bool)
|
||||
}
|
||||
|
||||
impl SessionControl {
|
||||
pub async fn insert(&self, action: PredictionAction) {
|
||||
self.event_sink.send(action).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn log(&self, message: String) {
|
||||
self.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(message))).await;
|
||||
}
|
||||
|
||||
pub async fn regenerate_options(&self) {
|
||||
self.insert(PredictionAction::GeneratePredictions).await;
|
||||
}
|
||||
|
||||
pub async fn changed(&mut self) -> SessionUpdate {
|
||||
tokio::select! {
|
||||
_ = self.activity_watch.changed() => {
|
||||
SessionUpdate::Thinking(*self.activity_watch.borrow_and_update())
|
||||
},
|
||||
_ = self.scene_watch.changed() => {
|
||||
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(saved_session: SaveData) -> SessionControl {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
||||
|
||||
let (action_sink, mut action_src) = mpsc::channel(5);
|
||||
|
||||
let mut session = Session::from_initial_messages(saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
||||
|
||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = sys_message_src.recv() => {
|
||||
if let Some(message) = maybe_message {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(message));
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
},
|
||||
maybe_request = prediction_request_out.changed() => {
|
||||
if maybe_request.is_ok() {
|
||||
let next_cxt = prediction_request_out.borrow().clone();
|
||||
session.insert_actions(&next_cxt);
|
||||
if let Some(evt) = action_src.recv().await {
|
||||
let do_regen = match evt {
|
||||
PredictionAction::ConversationAppend(msg) => {
|
||||
let do_regen = match msg {
|
||||
ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true,
|
||||
_ => false
|
||||
};
|
||||
session.insert_conversation(msg);
|
||||
|
||||
let mut save_data = SaveData {
|
||||
direction: next_cxt.direction,
|
||||
do_regen
|
||||
},
|
||||
PredictionAction::SetEpisodeNumber(num) => {
|
||||
session.direction.episode_number = num;
|
||||
if let Err(err) = session.direction.reload_mixxx_playlist() {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(format!("Failed to load mixxx playlist: {:?}.", err).into()));
|
||||
} else {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
|
||||
}
|
||||
false
|
||||
},
|
||||
PredictionAction::GeneratePredictions => {
|
||||
true
|
||||
},
|
||||
PredictionAction::SetNarrative(narrative) => {
|
||||
session.direction.narrative = narrative;
|
||||
session.insert_conversation(ConversationEntry::SystemMessage("Updated stage direction narrative".into()));
|
||||
true
|
||||
},
|
||||
PredictionAction::SetShowEndTime(end_time) => {
|
||||
session.direction.end_time = end_time;
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let save_data = SaveData {
|
||||
direction: session.direction.clone(),
|
||||
messages: session.messages.clone(),
|
||||
scenery: session.scenery.clone()
|
||||
};
|
||||
|
||||
save_data.save();
|
||||
|
||||
session.regenerate_options(&save_data.direction).await;
|
||||
if do_regen {
|
||||
session.reply_options.responses.clear();
|
||||
}
|
||||
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
if do_regen {
|
||||
session.regenerate_options().await;
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
(prediction_request_in, prediction_out)
|
||||
SessionControl {
|
||||
event_sink: action_sink,
|
||||
scene_watch: prediction_out,
|
||||
activity_watch: activity_notify_src
|
||||
}
|
||||
}
|
||||
+34
-18
@@ -1,5 +1,5 @@
|
||||
use async_openai::types::chat::*;
|
||||
use chrono::Duration;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlite::OpenFlags;
|
||||
|
||||
@@ -48,14 +48,32 @@ impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
|
||||
type Error = ();
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StageDirection {
|
||||
pub episode_number: u32,
|
||||
pub time_remaining: Duration,
|
||||
#[serde(skip)]
|
||||
pub end_time: DateTime<Utc>,
|
||||
pub narrative: String,
|
||||
pub current_playlist: Vec<PlaylistEntry>
|
||||
}
|
||||
|
||||
impl StageDirection {
|
||||
pub fn time_remaining(&self) -> Duration {
|
||||
self.end_time.signed_duration_since(Utc::now())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StageDirection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
episode_number: 0,
|
||||
end_time: Utc::now() + Duration::hours(2),
|
||||
narrative: Default::default(),
|
||||
current_playlist: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum Artifact {
|
||||
Bandcamp(BandcampResult),
|
||||
@@ -81,10 +99,11 @@ impl From<sqlite::Error> for MixxxError {
|
||||
impl StageDirection {
|
||||
pub fn reload_mixxx_playlist(&mut self) -> Result<(), MixxxError> {
|
||||
self.current_playlist.clear();
|
||||
let playlist_name = format!("BFF.fm - Episode {}", self.episode_number);
|
||||
let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only())?;
|
||||
let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1";
|
||||
let mut statement = connection.prepare(query)?;
|
||||
statement.bind((1, format!("BFF.fm - Episode {}", self.episode_number).as_str()))?;
|
||||
statement.bind((1, playlist_name.as_str()))?;
|
||||
statement.next()?;
|
||||
let latest_id = statement.read::<i64, _>("id").unwrap();
|
||||
|
||||
@@ -106,10 +125,13 @@ impl StageDirection {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StageActions {
|
||||
pub direction: StageDirection,
|
||||
pub additions: Vec<ConversationEntry>
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PredictionAction {
|
||||
ConversationAppend(ConversationEntry),
|
||||
SetEpisodeNumber(u32),
|
||||
GeneratePredictions,
|
||||
SetNarrative(String),
|
||||
SetShowEndTime(DateTime<Utc>)
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
@@ -124,17 +146,19 @@ pub struct PlaylistEntry {
|
||||
pub struct Scene {
|
||||
reply_options: GeneratedResponses,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
pub direction: StageDirection,
|
||||
pub tokens_consumed: usize,
|
||||
scenery: Scenery
|
||||
}
|
||||
|
||||
impl Scene {
|
||||
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize) -> Self {
|
||||
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>, scenery: Scenery, tokens_consumed: usize, direction: StageDirection) -> Self {
|
||||
Self {
|
||||
reply_options,
|
||||
conversation,
|
||||
scenery,
|
||||
tokens_consumed
|
||||
tokens_consumed,
|
||||
direction
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,15 +170,7 @@ impl Scene {
|
||||
&self.conversation
|
||||
}
|
||||
|
||||
pub fn conversation_mut(&mut self) -> &mut Vec<ConversationEntry> {
|
||||
&mut self.conversation
|
||||
}
|
||||
|
||||
pub fn reply_options(&self) -> &Vec<PossibleResponse> {
|
||||
&self.reply_options.responses
|
||||
}
|
||||
|
||||
pub fn reply_options_mut(&mut self) -> &mut Vec<PossibleResponse> {
|
||||
&mut self.reply_options.responses
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user