prediction: completely rewrite the prediction engine by moving all the conversation manipulation into that task out of the UI
This commit is contained in:
+92
-137
@@ -1,7 +1,6 @@
|
||||
use async_openai::{types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionResponse}};
|
||||
use async_openai::types::chat::ChatCompletionRequestMessage;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use futures_timer::Delay;
|
||||
use schemars::JsonSchema;
|
||||
use scraper::{Html, Selector};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -13,6 +12,16 @@ use tokio::{sync::{mpsc, watch}, time::Instant};
|
||||
use tui_input::{Input, backend::crossterm::EventHandler};
|
||||
use futures::{StreamExt, future::FutureExt};
|
||||
|
||||
use ratatui::prelude::*;
|
||||
|
||||
use crate::{events::AudioRecordRequest, prediction::{PossibleResponse}, scene::{ConversationEntry, PlaylistEntry, Scene, StageActions, StageDirection}, tts::start_tts};
|
||||
|
||||
mod scene;
|
||||
mod events;
|
||||
mod transcription;
|
||||
mod tts;
|
||||
mod prediction;
|
||||
|
||||
// TODO: We should have a separate 'state.json' file, which remembers jack connections, and the world time for the show to end. Then we only update the 'time remaining' field in the scene and only deal with relative durations inside the scene data
|
||||
// TODO: We should be able to delete entries from the conversation, or at least go back and edit something I said.
|
||||
// TODO: I want a "mark" command or keyboard shortcut, that inserts a marker into the log, so I know where to come back for the next speaking segment.
|
||||
@@ -41,22 +50,6 @@ use futures::{StreamExt, future::FutureExt};
|
||||
- Right panel: Shortcuts for triggering scenarios, soundboard events, etc
|
||||
*/
|
||||
|
||||
use ratatui::prelude::*;
|
||||
|
||||
use crate::{events::AudioRecordRequest, scene::{ConversationEntry, PlaylistEntry, Scene}, tts::start_tts};
|
||||
|
||||
mod scene;
|
||||
mod events;
|
||||
mod transcription;
|
||||
mod tts;
|
||||
mod prediction;
|
||||
|
||||
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
||||
struct PossibleResponse {
|
||||
text: String,
|
||||
stage_direction: Option<String>
|
||||
}
|
||||
|
||||
impl<'a> Into<ListItem<'a>> for PossibleResponse {
|
||||
fn into(self) -> ListItem<'a> {
|
||||
if let Some(direction) = self.stage_direction {
|
||||
@@ -71,22 +64,13 @@ impl<'a> Into<ListItem<'a>> for PossibleResponse {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Deserialize, Serialize, Debug)]
|
||||
struct GeneratedResponses {
|
||||
responses: Vec<PossibleResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
struct StageEventArgs {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct App {
|
||||
scene: Scene,
|
||||
direction: StageDirection,
|
||||
next_actions: Vec<ConversationEntry>,
|
||||
end_time: DateTime<Utc>,
|
||||
|
||||
next_reply_options: Vec<PossibleResponse>,
|
||||
reply_state: ListState,
|
||||
conversation_state: ListState,
|
||||
user_input: Input,
|
||||
@@ -97,7 +81,7 @@ struct App {
|
||||
focus_state: FocusState,
|
||||
|
||||
audio_control_sink: watch::Sender<AudioRecordRequest>,
|
||||
prediction_request_sink: watch::Sender<Scene>,
|
||||
prediction_request_sink: watch::Sender<StageActions>,
|
||||
tts_request_sink: mpsc::Sender<String>,
|
||||
sys_message_sink: mpsc::Sender<String>
|
||||
}
|
||||
@@ -109,15 +93,16 @@ enum FocusState {
|
||||
}
|
||||
|
||||
impl App {
|
||||
fn new(prediction_request_sink: watch::Sender<Scene>, audio_control_sink: watch::Sender<AudioRecordRequest>, tts_request_sink: mpsc::Sender<String>, sys_message_sink: mpsc::Sender<String>) -> Self {
|
||||
fn new(prediction_request_sink: watch::Sender<StageActions>, audio_control_sink: watch::Sender<AudioRecordRequest>, tts_request_sink: mpsc::Sender<String>, sys_message_sink: mpsc::Sender<String>, initial_direction: StageDirection) -> Self {
|
||||
Self {
|
||||
scene: Scene::default(),
|
||||
next_reply_options: Vec::new(),
|
||||
reply_state: ListState::default(),
|
||||
conversation_state: ListState::default(),
|
||||
user_input: Input::default(),
|
||||
scene: Default::default(),
|
||||
direction: initial_direction,
|
||||
next_actions: Default::default(),
|
||||
reply_state: Default::default(),
|
||||
conversation_state: Default::default(),
|
||||
user_input: Default::default(),
|
||||
end_time: Utc::now() + Duration::hours(2),
|
||||
throbber_state: ThrobberState::default(),
|
||||
throbber_state: Default::default(),
|
||||
prediction_request_sink,
|
||||
is_requesting: false,
|
||||
audio_level: -60.,
|
||||
@@ -130,7 +115,7 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_conversation(&mut self, frame: &mut Frame, area: Rect) {
|
||||
let items: Vec<Line> = self.scene.conversation.iter().rev().map(|entry| {
|
||||
let items: Vec<Line> = self.scene.conversation().iter().rev().map(|entry| {
|
||||
match entry {
|
||||
ConversationEntry::User(text) => Line::from_iter([Span::from("Argee: ").style(ratatui::style::Color::Magenta), Span::from(text)]),
|
||||
ConversationEntry::Eva(text) => Line::from_iter([Span::from("Eva: ").style(ratatui::style::Color::Cyan), Span::from(text)]),
|
||||
@@ -154,7 +139,7 @@ impl App {
|
||||
|
||||
fn draw_options(&mut self, frame: &mut Frame, area: Rect) {
|
||||
frame.render_stateful_widget(
|
||||
List::new(self.next_reply_options.clone())
|
||||
List::new(self.scene.reply_options().clone())
|
||||
.block(Block::bordered().border_style(style::Color::LightGreen).title("Reply Options (Press 'Ctrl+R' to regenerate, Ctrl+Enter to use)"))
|
||||
.style(ratatui::style::Color::White)
|
||||
.highlight_symbol("> ")
|
||||
@@ -184,7 +169,7 @@ impl App {
|
||||
}
|
||||
|
||||
fn draw_status(&self, frame: &mut Frame, area: Rect) {
|
||||
let minutes_remaining = self.scene.direction.time_remaining.num_seconds() / 60;
|
||||
let minutes_remaining = self.direction.time_remaining.num_seconds() / 60;
|
||||
let time_style = if minutes_remaining == 0 {
|
||||
Style::new().fg(ratatui::style::Color::Red).bold().rapid_blink()
|
||||
} else if minutes_remaining <= 5 {
|
||||
@@ -199,26 +184,26 @@ impl App {
|
||||
ratatui::style::Color::Blue.into()
|
||||
};
|
||||
let status_line = Line::from_iter([
|
||||
Span::from(format!("Episode {}", self.scene.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(format!("Episode {}", self.direction.episode_number)).style(ratatui::style::Color::LightBlue),
|
||||
Span::from(" | ").style(ratatui::style::Color::DarkGray),
|
||||
// FIXME: Looks weird with negative numbers, and it doesn't actually blink in the vscode terminal.
|
||||
Span::from(format!("Time Remaining: {:0>2}:{:0>2}:{:0>2}", self.scene.direction.time_remaining.num_hours(), self.scene.direction.time_remaining.num_minutes() % 60, self.scene.direction.time_remaining.num_seconds() % 60)).style(time_style)
|
||||
Span::from(format!("Time Remaining: {:0>2}:{:0>2}:{:0>2}", self.direction.time_remaining.num_hours(), self.direction.time_remaining.num_minutes() % 60, self.direction.time_remaining.num_seconds() % 60)).style(time_style)
|
||||
]);
|
||||
frame.render_widget(status_line, area);
|
||||
}
|
||||
|
||||
fn draw_narration(&self, frame: &mut Frame, area: Rect) {
|
||||
let narrative_desc = if self.scene.direction.narrative.is_empty() {
|
||||
let narrative_desc = if self.direction.narrative.is_empty() {
|
||||
Span::from("No narrative available.").style(ratatui::style::Color::DarkGray)
|
||||
} else {
|
||||
Span::from(self.scene.direction.narrative.clone())
|
||||
Span::from(self.direction.narrative.clone())
|
||||
};
|
||||
let setting = Paragraph::new(narrative_desc).block(Block::bordered().border_style(style::Color::LightMagenta).title("Stage Direction")).wrap(ratatui::widgets::Wrap { trim: false });
|
||||
frame.render_widget(setting, area);
|
||||
}
|
||||
|
||||
fn draw_event_log(&self, frame: &mut Frame, area: Rect) {
|
||||
let items: Vec<Line> = self.scene.conversation.iter().filter(|entry| { if let ConversationEntry::StageDirection(_) = entry { true } else { false }}).rev().map(|entry| {
|
||||
let items: Vec<Line> = self.scene.conversation().iter().filter(|entry| { if let ConversationEntry::StageDirection(_) = entry { true } else { false }}).rev().map(|entry| {
|
||||
match entry {
|
||||
ConversationEntry::StageDirection(text) => Line::from_iter([text]).style(ratatui::style::Color::Yellow),
|
||||
_ => unreachable!()
|
||||
@@ -289,12 +274,11 @@ impl App {
|
||||
}
|
||||
|
||||
async fn insert_selected_prompt(&mut self) {
|
||||
let selected = self.next_reply_options[self.reply_state.selected().unwrap()].clone();
|
||||
let selected = self.scene.reply_options()[self.reply_state.selected().unwrap()].clone();
|
||||
if let Some(direction) = &selected.stage_direction {
|
||||
self.scene.insert_conversation(ConversationEntry::StageDirection(direction.clone()));
|
||||
self.next_actions.push(ConversationEntry::StageDirection(direction.clone()));
|
||||
}
|
||||
self.scene.insert_conversation(ConversationEntry::Eva(selected.text.clone()));
|
||||
self.save();
|
||||
self.next_actions.push(ConversationEntry::Eva(selected.text.clone()));
|
||||
self.speak(selected.text.clone()).await;
|
||||
self.regenerate_responses();
|
||||
}
|
||||
@@ -316,7 +300,7 @@ impl App {
|
||||
KeyCode::Up => self.conversation_state.select_next(),
|
||||
KeyCode::Enter => {
|
||||
let row_num = self.conversation_state.selected().unwrap();
|
||||
if let ConversationEntry::Eva(text) = &self.scene.conversation[self.scene.conversation.len() - 1 - row_num] {
|
||||
if let ConversationEntry::Eva(text) = &self.scene.conversation()[self.scene.conversation().len() - 1 - row_num] {
|
||||
self.speak(text.clone()).await;
|
||||
self.focus_state = FocusState::UserInput;
|
||||
self.conversation_state.select(None);
|
||||
@@ -360,21 +344,19 @@ impl App {
|
||||
"/bandcamp" => {
|
||||
self.add_bandcamp_artifact(arg).await;
|
||||
self.sys_message_sink.send(format!("Added Bandcamp artifact from {}", arg)).await.unwrap();
|
||||
self.scene.insert_conversation(ConversationEntry::ShipComputer(format!("Incoming transmission.")));
|
||||
self.next_actions.push(ConversationEntry::ShipComputer(format!("Incoming transmission from {}", arg)));
|
||||
self.regenerate_responses();
|
||||
},
|
||||
"/episode" => {
|
||||
if let Ok(episode_number) = arg.trim().parse::<u32>() {
|
||||
self.scene.direction.episode_number = episode_number;
|
||||
self.sys_message_sink.send(format!("Updated episode number: {}", self.scene.direction.episode_number)).await.unwrap();
|
||||
self.direction.episode_number = episode_number;
|
||||
self.sys_message_sink.send(format!("Updated episode number: {}", self.direction.episode_number)).await.unwrap();
|
||||
self.reload_mixxx_playlist();
|
||||
} else {
|
||||
self.sys_message_sink.send("Invalid episode number format. Use /episode [number]".into()).await.unwrap();
|
||||
return;
|
||||
}
|
||||
},
|
||||
"/reload" => {
|
||||
self.load().await;
|
||||
self.reload_mixxx_playlist();
|
||||
},
|
||||
"/timer" => {
|
||||
if let Ok(minutes) = arg.trim().parse::<i64>() {
|
||||
self.end_time = Utc::now() + Duration::minutes(minutes);
|
||||
@@ -386,23 +368,17 @@ impl App {
|
||||
"/clear" => {
|
||||
match arg.trim() {
|
||||
"playlist" => {
|
||||
self.scene.direction.current_playlist.clear();
|
||||
self.direction.current_playlist.clear();
|
||||
self.sys_message_sink.send("Cleared current playlist.".into()).await.unwrap();
|
||||
return;
|
||||
},
|
||||
"artifacts" => {
|
||||
self.scene.direction.artifacts.clear();
|
||||
self.direction.artifacts.clear();
|
||||
self.sys_message_sink.send("Cleared artifacts.".into()).await.unwrap();
|
||||
return;
|
||||
},
|
||||
"all" => {
|
||||
self.scene = Scene::default();
|
||||
self.sys_message_sink.send("Cleared all data.".into()).await.unwrap();
|
||||
},
|
||||
"conversation" => {
|
||||
self.scene.conversation.clear();
|
||||
self.sys_message_sink.send("Cleared conversation.".into()).await.unwrap();
|
||||
},
|
||||
_ => {
|
||||
self.sys_message_sink.send("Unknown clear command. Use /clear [playlist|artifacts|all]".into()).await.unwrap();
|
||||
}
|
||||
@@ -410,23 +386,26 @@ impl App {
|
||||
return;
|
||||
},
|
||||
"/narrative" => {
|
||||
self.scene.direction.narrative = arg.to_string();
|
||||
self.sys_message_sink.send(format!("Updated stage direction: {}", self.scene.direction.narrative)).await.unwrap();
|
||||
self.direction.narrative = arg.to_string();
|
||||
self.sys_message_sink.send(format!("Updated stage direction: {}", self.direction.narrative)).await.unwrap();
|
||||
},
|
||||
"/event" => {
|
||||
self.scene.insert_conversation(ConversationEntry::StageDirection(arg.to_string()));
|
||||
}
|
||||
self.next_actions.push(ConversationEntry::StageDirection(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
},
|
||||
"/computer" => {
|
||||
self.next_actions.push(ConversationEntry::ShipComputer(arg.to_string()));
|
||||
self.regenerate_responses();
|
||||
},
|
||||
_ => {
|
||||
self.sys_message_sink.send("Unknown command. Available commands: /bandcamp [url], /episode [number], /narrative [text], /reset".into()).await.unwrap();
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.scene.insert_conversation(ConversationEntry::User(next_msg));
|
||||
}
|
||||
self.save();
|
||||
self.next_actions.push(ConversationEntry::User(next_msg));
|
||||
self.regenerate_responses();
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {self.user_input.handle_event(&evt);},
|
||||
}
|
||||
@@ -441,25 +420,7 @@ impl App {
|
||||
let fragment = Html::parse_document(&body);
|
||||
let selector = Selector::parse("script[type=\"application/ld+json\"]").unwrap();
|
||||
let json_ld = fragment.select(&selector).next().unwrap().inner_html();
|
||||
self.scene.direction.artifacts.push(json_ld);
|
||||
}
|
||||
|
||||
fn save(&self) {
|
||||
let save_data = serde_json::to_string_pretty(&self.scene).unwrap();
|
||||
std::fs::write("save.json", save_data).unwrap();
|
||||
}
|
||||
|
||||
async fn load(&mut self) {
|
||||
if let Ok(save_data) = std::fs::read_to_string("save.json") {
|
||||
if let Ok(scene) = serde_json::from_str(&save_data) {
|
||||
self.scene = scene;
|
||||
// FIXME: These should get wiped out when we save as well, or even better, be completely excluded via a custom serde implementation.
|
||||
self.scene.conversation.retain(|line| { if let ConversationEntry::SystemMessage(_) = line { false } else { true }});
|
||||
self.sys_message_sink.send("Loaded stored session.".into()).await.unwrap();
|
||||
} else {
|
||||
self.sys_message_sink.send("Failed to load saved session!".into()).await.unwrap();
|
||||
}
|
||||
}
|
||||
self.direction.artifacts.push(json_ld.trim().to_string());
|
||||
}
|
||||
|
||||
async fn speak(&mut self, text: String) {
|
||||
@@ -467,18 +428,23 @@ impl App {
|
||||
}
|
||||
|
||||
fn regenerate_responses(&mut self) {
|
||||
self.prediction_request_sink.send(self.scene.clone()).unwrap();
|
||||
self.next_reply_options.clear();
|
||||
let actions = StageActions {
|
||||
direction: self.direction.clone(),
|
||||
additions: std::mem::take(&mut self.next_actions)
|
||||
};
|
||||
self.scene.reply_options_mut().clear();
|
||||
self.scene.conversation_mut().append(&mut actions.additions.clone());
|
||||
self.prediction_request_sink.send(actions).unwrap();
|
||||
self.is_requesting = true;
|
||||
}
|
||||
|
||||
fn reload_mixxx_playlist(&mut self) {
|
||||
// TODO: Should have some status message which states how many tracks are in the playlist
|
||||
self.scene.direction.current_playlist.clear();
|
||||
self.direction.current_playlist.clear();
|
||||
let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only()).unwrap();
|
||||
let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1";
|
||||
let mut statement = connection.prepare(query).unwrap();
|
||||
statement.bind((1, format!("BFF.fm - Episode {}", self.scene.direction.episode_number).as_str())).unwrap();
|
||||
statement.bind((1, format!("BFF.fm - Episode {}", self.direction.episode_number).as_str())).unwrap();
|
||||
statement.next().unwrap();
|
||||
let latest_id = statement.read::<i64, _>("id").unwrap();
|
||||
|
||||
@@ -488,42 +454,28 @@ impl App {
|
||||
let artist = track.try_read::<&str, _>("artist").unwrap_or("Unknown Artist");
|
||||
let album = track.try_read::<&str, _>("album").unwrap_or("Unknown Album");
|
||||
let bpm = track.try_read::<f64, _>("bpm").unwrap_or(0.);
|
||||
self.scene.direction.current_playlist.push(PlaylistEntry {
|
||||
self.direction.current_playlist.push(PlaylistEntry {
|
||||
artist: artist.into(),
|
||||
album: album.into(),
|
||||
title: title.into(),
|
||||
bpm
|
||||
});
|
||||
}
|
||||
self.scene.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
|
||||
self.next_actions.push(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response(&mut self, response: &CreateChatCompletionResponse) {
|
||||
self.is_requesting = false;
|
||||
if let Some(calls) = &response.choices[0].message.tool_calls {
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
if call.function.name == "log_stage_event" {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.scene.insert_conversation(ConversationEntry::StageDirection(args.text));
|
||||
}
|
||||
},
|
||||
_ => panic!("Unkown tool call type"),
|
||||
}
|
||||
}
|
||||
self.regenerate_responses();
|
||||
} else {
|
||||
if response.choices.is_empty() {
|
||||
self.scene.insert_conversation(ConversationEntry::SystemMessage("OpenAI returned no responses".into()));
|
||||
} else if response.choices[0].message.content.is_none() {
|
||||
self.scene.insert_conversation(ConversationEntry::SystemMessage("OpenAI response did not contain content!".into()));
|
||||
} else {
|
||||
let json_resp: GeneratedResponses = serde_json::from_str(response.choices[0].message.content.as_ref().unwrap().as_str()).unwrap();
|
||||
self.next_reply_options = json_resp.responses;
|
||||
self.reply_state.select_first();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct SaveData {
|
||||
pub direction: StageDirection,
|
||||
pub messages: Vec<ChatCompletionRequestMessage>
|
||||
}
|
||||
|
||||
impl SaveData {
|
||||
fn save(&self) {
|
||||
let save_data = serde_json::to_string_pretty(self).unwrap();
|
||||
std::fs::write("save.json", save_data).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,15 +488,21 @@ async fn main() {
|
||||
return;
|
||||
}
|
||||
|
||||
let saved_session = if let Ok(save_data) = std::fs::read_to_string("save.json") {
|
||||
serde_json::from_str(&save_data).unwrap_or_default()
|
||||
//FIXME: Re-add load messages to sys log
|
||||
} else {
|
||||
SaveData::default()
|
||||
};
|
||||
|
||||
let mut terminal: Terminal<CrosstermBackend<std::io::Stdout>> = ratatui::init();
|
||||
|
||||
let (sys_message_sink, mut sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
let (sys_message_sink, sys_message_src) = tokio::sync::mpsc::channel(32);
|
||||
let tts_request_sender = start_tts().await;
|
||||
let (prediction_request_in, mut prediction_out) = prediction::start_prediction().await;
|
||||
let (prediction_request_in, mut prediction_out) = prediction::start_prediction(sys_message_src, saved_session.messages).await;
|
||||
let (mut audio_state_receiver, audio_control_in, mut transcription_out) = transcription::start_transcription(sys_message_sink.clone()).await;
|
||||
|
||||
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender, sys_message_sink);
|
||||
app.load().await;
|
||||
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender, sys_message_sink, saved_session.direction);
|
||||
|
||||
let mut events = EventStream::new();
|
||||
let mut last_tick = Instant::now();
|
||||
@@ -554,7 +512,7 @@ async fn main() {
|
||||
last_tick = Instant::now();
|
||||
app.throbber_state.calc_next();
|
||||
}
|
||||
app.scene.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
|
||||
app.direction.time_remaining = app.end_time.signed_duration_since(Utc::now());
|
||||
terminal.draw(|frame| { app.draw(frame)}).unwrap();
|
||||
|
||||
let delay = Delay::new(std::time::Duration::from_millis(60)).fuse();
|
||||
@@ -563,18 +521,15 @@ async fn main() {
|
||||
tokio::select! {
|
||||
_ = delay => (),
|
||||
_ = prediction_out.changed() => {
|
||||
app.on_response(prediction_out.borrow_and_update().as_ref().unwrap());
|
||||
app.scene = prediction_out.borrow().clone();
|
||||
app.reply_state.select_first();
|
||||
app.is_requesting = false;
|
||||
},
|
||||
_ = audio_state_receiver.changed() => {
|
||||
app.audio_level = *audio_state_receiver.borrow_and_update();
|
||||
},
|
||||
maybe_message = sys_message_src.recv() => {
|
||||
if let Some(message) = maybe_message {
|
||||
app.scene.insert_conversation(ConversationEntry::SystemMessage(message));
|
||||
}
|
||||
app.audio_level = *audio_state_receiver.borrow();
|
||||
},
|
||||
maybe_transcription = transcription_out.recv() => {
|
||||
app.scene.insert_conversation(ConversationEntry::User(maybe_transcription.unwrap()));
|
||||
app.next_actions.push(ConversationEntry::User(maybe_transcription.unwrap()));
|
||||
app.regenerate_responses();
|
||||
}
|
||||
maybe_event = event => {
|
||||
|
||||
+175
-25
@@ -1,37 +1,187 @@
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{CreateChatCompletionRequest, CreateChatCompletionResponse}};
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::scene::Scene;
|
||||
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
|
||||
|
||||
pub async fn start_prediction() -> (tokio::sync::watch::Sender<Scene>, tokio::sync::watch::Receiver<Option<CreateChatCompletionResponse>>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(None);
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default());
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
|
||||
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct PossibleResponse {
|
||||
pub text: String,
|
||||
pub stage_direction: Option<String>
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
|
||||
pub struct GeneratedResponses {
|
||||
pub responses: Vec<PossibleResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Session {
|
||||
client: Client<OpenAIConfig>,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
header_message: ChatCompletionRequestMessage,
|
||||
messages: Vec<ChatCompletionRequestMessage>,
|
||||
reply_options: GeneratedResponses
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
struct StageEventArgs {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
for msg in &messages {
|
||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||
conversation.push(conversation_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
client: Default::default(),
|
||||
conversation,
|
||||
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
|
||||
messages,
|
||||
reply_options: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_actions(&mut self, actions: &StageActions) {
|
||||
for addition in &actions.additions {
|
||||
self.insert_conversation(addition.clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) -> Option<Scene> {
|
||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||
let mut full_conversation = vec![
|
||||
self.header_message.clone(),
|
||||
direction_message
|
||||
];
|
||||
full_conversation.append(&mut self.messages.clone());
|
||||
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_stage_event")
|
||||
.description("Inserts an event into the current scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_ship_computer_message")
|
||||
.description("Inserts a message from the ship computer into the scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
})
|
||||
];
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.messages(full_conversation)
|
||||
.model("gpt-5.4")
|
||||
.tools(tools)
|
||||
.max_completion_tokens(350u32)
|
||||
.response_format(ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: None,
|
||||
name: "responses".into(),
|
||||
schema: schema_for!(GeneratedResponses).into(),
|
||||
strict: None
|
||||
}
|
||||
})
|
||||
.build().unwrap();
|
||||
|
||||
let response = self.client.chat().create(request).await.unwrap();
|
||||
|
||||
if let Some(message) = response.choices.first() {
|
||||
if let Some(calls) = &message.message.tool_calls {
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
match call.function.name.as_str() {
|
||||
"log_stage_event" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::StageDirection(args.text));
|
||||
},
|
||||
"log_ship_computer_message" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::ShipComputer(args.text));
|
||||
},
|
||||
_ => panic!("Unknown function was called")
|
||||
}
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
}
|
||||
}
|
||||
Some(self.as_scene())
|
||||
} else {
|
||||
self.reply_options = serde_json::from_str(message.message.content.as_ref().unwrap().as_str()).unwrap();
|
||||
Some(self.as_scene())
|
||||
}
|
||||
} else {
|
||||
//FIXME: Handle tool calls
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
fn as_scene(&self) -> Scene {
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone())
|
||||
}
|
||||
|
||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
self.conversation.push(entry.clone());
|
||||
|
||||
if let Ok(next_msg) = entry.try_into() {
|
||||
self.messages.push(next_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
|
||||
|
||||
let mut session = Session::from_initial_messages(initial_messages);
|
||||
|
||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client: Client<OpenAIConfig> = Client::default();
|
||||
loop {
|
||||
if let Ok(_) = prediction_request_out.changed().await {
|
||||
let request = prediction_request_out.borrow_and_update().clone();
|
||||
let chat_request = CreateChatCompletionRequest {
|
||||
/*tools: Some(vec![
|
||||
ChatCompletionTools::Function(
|
||||
ChatCompletionTool {
|
||||
function: FunctionObject {
|
||||
name: "log_stage_event".into(),
|
||||
description: Some("Log an event in the stage direction.".into()),
|
||||
parameters: Some(schema_for!(StageEventArgs).into()),
|
||||
..Default::default()
|
||||
tokio::select! {
|
||||
maybe_message = sys_message_src.recv() => {
|
||||
if let Some(message) = maybe_message {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(message));
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
}
|
||||
)
|
||||
]),*/
|
||||
..request.into()
|
||||
},
|
||||
maybe_request = prediction_request_out.changed() => {
|
||||
if maybe_request.is_ok() {
|
||||
let next_cxt = prediction_request_out.borrow().clone();
|
||||
session.insert_actions(&next_cxt);
|
||||
|
||||
let mut save_data = SaveData {
|
||||
direction: next_cxt.direction,
|
||||
messages: session.messages.clone()
|
||||
};
|
||||
let response = client.chat().create(chat_request).await.unwrap();
|
||||
prediction_in.send(Some(response)).unwrap();
|
||||
} else {
|
||||
return;
|
||||
|
||||
save_data.save();
|
||||
|
||||
if let Some(next_scene) = session.regenerate_options(&save_data.direction).await {
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(next_scene).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
(prediction_request_in, prediction_out)
|
||||
|
||||
+70
-32
@@ -1,12 +1,8 @@
|
||||
use async_openai::types::chat::*;
|
||||
use chrono::Duration;
|
||||
use schemars::schema_for;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::GeneratedResponses;
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
use crate::prediction::{GeneratedResponses, PossibleResponse};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ConversationEntry {
|
||||
@@ -17,6 +13,40 @@ pub enum ConversationEntry {
|
||||
SystemMessage(String)
|
||||
}
|
||||
|
||||
impl TryInto<ChatCompletionRequestMessage> for ConversationEntry {
|
||||
fn try_into(self) -> Result<ChatCompletionRequestMessage, Self::Error> {
|
||||
match self {
|
||||
ConversationEntry::User(text) => Ok(ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: text.into(), ..Default::default()})),
|
||||
ConversationEntry::Eva(text) => Ok(ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(text.into()), ..Default::default()})),
|
||||
ConversationEntry::ShipComputer(text) => Ok(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("ship-computer".into()), ..Default::default() })),
|
||||
ConversationEntry::StageDirection(text) => Ok(ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("stage-direction".into()), ..Default::default() })),
|
||||
ConversationEntry::SystemMessage(_) => Err(())
|
||||
}
|
||||
}
|
||||
|
||||
type Error = ();
|
||||
}
|
||||
|
||||
|
||||
impl TryInto<ConversationEntry> for ChatCompletionRequestMessage {
|
||||
fn try_into(self) -> Result<ConversationEntry, Self::Error> {
|
||||
match self {
|
||||
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Text(msg), ..}) => Ok(ConversationEntry::User(msg)),
|
||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text(msg)), ..}) => Ok(ConversationEntry::Eva(msg)),
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: ChatCompletionRequestSystemMessageContent::Text(msg), name: Some(name), ..}) => {
|
||||
match name.as_str() {
|
||||
"ship-computer" => Ok(ConversationEntry::ShipComputer(msg)),
|
||||
"stage-direction" => Ok(ConversationEntry::StageDirection(msg)),
|
||||
_ => Err(())
|
||||
}
|
||||
},
|
||||
_ => Err(())
|
||||
}
|
||||
}
|
||||
|
||||
type Error = ();
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
pub struct StageDirection {
|
||||
pub episode_number: u32,
|
||||
@@ -26,6 +56,22 @@ pub struct StageDirection {
|
||||
pub current_playlist: Vec<PlaylistEntry>
|
||||
}
|
||||
|
||||
/*impl StageDirection {
|
||||
pub fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
self.additions.push(entry);
|
||||
}
|
||||
|
||||
pub fn take_actions(&mut self) -> StageActions {
|
||||
StageActions { direction: self.clone(), additions: std::mem::take(&mut self.additions) }
|
||||
}
|
||||
}*/
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StageActions {
|
||||
pub direction: StageDirection,
|
||||
pub additions: Vec<ConversationEntry>
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
pub struct PlaylistEntry {
|
||||
pub artist: String,
|
||||
@@ -34,41 +80,33 @@ pub struct PlaylistEntry {
|
||||
pub bpm: f64
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct Scene {
|
||||
pub conversation: Vec<ConversationEntry>,
|
||||
pub direction: StageDirection
|
||||
reply_options: GeneratedResponses,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
}
|
||||
|
||||
impl Scene {
|
||||
pub fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
self.conversation.push(entry);
|
||||
pub fn new(reply_options: GeneratedResponses, conversation: Vec<ConversationEntry>) -> Self {
|
||||
Self {
|
||||
reply_options,
|
||||
conversation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<CreateChatCompletionRequest> for Scene {
|
||||
fn into(self) -> CreateChatCompletionRequest {
|
||||
let mut messages = vec![
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: SYSTEM_PROMPT.into(), ..Default::default()}),
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: serde_json::to_string(&self.direction).unwrap().into(), ..Default::default()}),
|
||||
];
|
||||
messages.extend(self.conversation.into_iter().filter(|x| if let ConversationEntry::SystemMessage(_) = x { false } else { true }).map(|entry| {
|
||||
match entry {
|
||||
ConversationEntry::User(text) => ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { content: text.into(), ..Default::default()}),
|
||||
ConversationEntry::Eva(text) => ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(text.into()), ..Default::default()}),
|
||||
ConversationEntry::ShipComputer(text) => ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("ship-computer".into()), ..Default::default() }),
|
||||
ConversationEntry::StageDirection(text) => ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { content: text.into(), name: Some("stage-direction".into()), ..Default::default() }),
|
||||
ConversationEntry::SystemMessage(_) => unreachable!()
|
||||
pub fn conversation(&self) -> &Vec<ConversationEntry> {
|
||||
&self.conversation
|
||||
}
|
||||
}));
|
||||
let response_schema: Value = schema_for!(GeneratedResponses).into();
|
||||
CreateChatCompletionRequest {
|
||||
model: "gpt-5.4".into(),
|
||||
messages: messages,
|
||||
max_completion_tokens: Some(350),
|
||||
response_format: Some(ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: None, name: "responses".into(), schema: response_schema, strict: None } }),
|
||||
..Default::default()
|
||||
|
||||
pub fn conversation_mut(&mut self) -> &mut Vec<ConversationEntry> {
|
||||
&mut self.conversation
|
||||
}
|
||||
|
||||
pub fn reply_options(&self) -> &Vec<PossibleResponse> {
|
||||
&self.reply_options.responses
|
||||
}
|
||||
|
||||
pub fn reply_options_mut(&mut self) -> &mut Vec<PossibleResponse> {
|
||||
&mut self.reply_options.responses
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,11 @@ Your records for that period are a little bit fuzzy, so you aren't entirely sure
|
||||
|
||||
The two of you have become best friends over the past 3 or so years together. It is common for the two of you to poke fun at each other's shortcomings, but deep down you both know that when push comes to shove, you'll make it through whatever situation you find yourselves in.
|
||||
|
||||
There also exists a third "character" in the scene, the ship computer. The ship computer is a distinct entity from you, and can be thought of as a kind of primordeal BIOS-level brain that you run on top of.
|
||||
Both Argee and Eva maintain control over the ship computer. You, as Eva, can make the ship computer displays read out text on command with the "log_ship_computer_message" tool function.
|
||||
The ship computer is used to report factual information to Argee and Eva. For example, the ship computer will report when a new artifact is discovered.
|
||||
It will also report out ship conditions, such as incoming transmissions, status of the recording hardware, power grid, and so on.
|
||||
|
||||
# Constraints
|
||||
In a subsequent system prompt, you will be given the currrent 'stage direction' of the show, which includes the current playtime, the number of the episode, and any particular extra information about this episode that you should be aware of.
|
||||
The stage direction is provided as structured JSON. There may be additional data fields for semantic context that should be incorporated into the roleplaying setting.
|
||||
|
||||
@@ -62,7 +62,7 @@ pub async fn start_transcription(messages: mpsc::Sender<String>) -> (watch::Rece
|
||||
AudioRecordRequest::Finish => {
|
||||
writer = None;
|
||||
|
||||
let final_audio = outfile.take().unwrap();
|
||||
if let Some(final_audio) = outfile.take() {
|
||||
let bytes = match Arc::into_inner(final_audio).unwrap().into_inner().unwrap().into_inner() {
|
||||
SpooledData::OnDisk(mut file) => {
|
||||
let mut bytes = Vec::new();
|
||||
@@ -83,6 +83,7 @@ pub async fn start_transcription(messages: mpsc::Sender<String>) -> (watch::Rece
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
maybe_audio_packet = audio_src.recv() => {
|
||||
let buf = maybe_audio_packet.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user