main: refactor prediction engine to use an event stream
This commit is contained in:
+115
-46
@@ -3,12 +3,12 @@ use std::process::{Command, Stdio};
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use bandcamp::SearchResultItem;
|
||||
use chrono::{DateTime, Utc};
|
||||
use color_eyre::eyre::eyre;
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Serializer, ser::CompactFormatter};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::{SaveData, scene::{Artifact, ConversationEntry, Scene, Scenery, StageActions, StageDirection}};
|
||||
use crate::{SaveData, scene::{Artifact, ConversationEntry, PredictionAction, Scene, Scenery, StageDirection}};
|
||||
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
@@ -31,8 +31,10 @@ struct Session {
|
||||
header_message: ChatCompletionRequestMessage,
|
||||
messages: Vec<ChatCompletionRequestMessage>,
|
||||
reply_options: GeneratedResponses,
|
||||
direction: StageDirection,
|
||||
scenery: Scenery,
|
||||
tokens_consumed: usize
|
||||
tokens_consumed: usize,
|
||||
activity_notify: watch::Sender<bool>
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
@@ -91,7 +93,7 @@ struct ToolResults {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> Self {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery, direction: StageDirection, activity_notify: watch::Sender<bool>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
for msg in &messages {
|
||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||
@@ -106,13 +108,9 @@ impl Session {
|
||||
messages,
|
||||
reply_options: Default::default(),
|
||||
scenery,
|
||||
tokens_consumed: 0
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_actions(&mut self, actions: &StageActions) {
|
||||
for addition in &actions.additions {
|
||||
self.insert_conversation(addition.clone());
|
||||
direction,
|
||||
tokens_consumed: 0,
|
||||
activity_notify,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,9 +211,10 @@ impl Session {
|
||||
full_conversation
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) {
|
||||
async fn regenerate_options(&mut self) {
|
||||
self.activity_notify.send_if_modified(|x| { if !*x { *x = true; true } else { false }});
|
||||
loop {
|
||||
let full_conversation = self.generate_conversation(direction);
|
||||
let full_conversation = self.generate_conversation(&self.direction);
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
@@ -268,11 +267,11 @@ impl Session {
|
||||
match message.finish_reason {
|
||||
Some(FinishReason::ContentFilter) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Content filter triggered.".into()));
|
||||
return;
|
||||
break;
|
||||
},
|
||||
Some(FinishReason::Length) => {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Maximum token count exceeded!".into()));
|
||||
return;
|
||||
break;
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
@@ -317,7 +316,7 @@ impl Session {
|
||||
if let Some(content) = message.message.content.as_ref() {
|
||||
if let Ok(options) = serde_json::from_str(content.as_str()) {
|
||||
self.reply_options = options;
|
||||
return;
|
||||
break;
|
||||
} else {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("Received invalid JSON! Trying again.".into()));
|
||||
}
|
||||
@@ -326,10 +325,11 @@ impl Session {
|
||||
self.insert_conversation(ConversationEntry::SystemMessage("No messages were received! Trying again.".into()));
|
||||
}
|
||||
}
|
||||
self.activity_notify.send_if_modified(|x| { if *x { *x = false; true } else { false }});
|
||||
}
|
||||
|
||||
fn as_scene(&self) -> Scene {
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed)
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed, self.direction.clone())
|
||||
}
|
||||
|
||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
@@ -341,47 +341,116 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
|
||||
#[derive(Debug)]
|
||||
pub struct SessionControl {
|
||||
event_sink: mpsc::Sender<PredictionAction>,
|
||||
scene_watch: watch::Receiver<Scene>,
|
||||
activity_watch: watch::Receiver<bool>
|
||||
}
|
||||
|
||||
let mut session = Session::from_initial_messages(initial_messages, scenery);
|
||||
#[derive(Debug)]
|
||||
pub enum SessionUpdate {
|
||||
Scene(Scene),
|
||||
Thinking(bool)
|
||||
}
|
||||
|
||||
impl SessionControl {
|
||||
pub async fn insert(&self, action: PredictionAction) {
|
||||
self.event_sink.send(action).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn log(&self, message: String) {
|
||||
self.insert(PredictionAction::ConversationAppend(ConversationEntry::SystemMessage(message))).await;
|
||||
}
|
||||
|
||||
pub async fn regenerate_options(&self) {
|
||||
self.insert(PredictionAction::GeneratePredictions).await;
|
||||
}
|
||||
|
||||
pub async fn changed(&mut self) -> SessionUpdate {
|
||||
tokio::select! {
|
||||
_ = self.activity_watch.changed() => {
|
||||
SessionUpdate::Thinking(*self.activity_watch.borrow_and_update())
|
||||
},
|
||||
_ = self.scene_watch.changed() => {
|
||||
SessionUpdate::Scene(self.scene_watch.borrow_and_update().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(saved_session: SaveData) -> SessionControl {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (activity_notify_sink, activity_notify_src) = tokio::sync::watch::channel(false);
|
||||
|
||||
let (action_sink, mut action_src) = mpsc::channel(5);
|
||||
|
||||
let mut session = Session::from_initial_messages(saved_session.messages, saved_session.scenery, saved_session.direction, activity_notify_sink);
|
||||
|
||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = sys_message_src.recv() => {
|
||||
if let Some(message) = maybe_message {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(message));
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
},
|
||||
maybe_request = prediction_request_out.changed() => {
|
||||
if maybe_request.is_ok() {
|
||||
let next_cxt = prediction_request_out.borrow().clone();
|
||||
session.insert_actions(&next_cxt);
|
||||
|
||||
let mut save_data = SaveData {
|
||||
direction: next_cxt.direction,
|
||||
messages: session.messages.clone(),
|
||||
scenery: session.scenery.clone()
|
||||
if let Some(evt) = action_src.recv().await {
|
||||
let do_regen = match evt {
|
||||
PredictionAction::ConversationAppend(msg) => {
|
||||
let do_regen = match msg {
|
||||
ConversationEntry::Eva(_) | ConversationEntry::ShipComputer(_) | ConversationEntry::User(_) => true,
|
||||
_ => false
|
||||
};
|
||||
session.insert_conversation(msg);
|
||||
|
||||
save_data.save();
|
||||
|
||||
session.regenerate_options(&save_data.direction).await;
|
||||
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
do_regen
|
||||
},
|
||||
PredictionAction::SetEpisodeNumber(num) => {
|
||||
session.direction.episode_number = num;
|
||||
if let Err(err) = session.direction.reload_mixxx_playlist() {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(format!("Failed to load mixxx playlist: {:?}.", err).into()));
|
||||
} else {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage("Mixxx playlist reloaded.".into()));
|
||||
}
|
||||
false
|
||||
},
|
||||
PredictionAction::GeneratePredictions => {
|
||||
true
|
||||
},
|
||||
PredictionAction::SetNarrative(narrative) => {
|
||||
session.direction.narrative = narrative;
|
||||
session.insert_conversation(ConversationEntry::SystemMessage("Updated stage direction narrative".into()));
|
||||
true
|
||||
},
|
||||
PredictionAction::SetShowEndTime(end_time) => {
|
||||
session.direction.end_time = end_time;
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let save_data = SaveData {
|
||||
direction: session.direction.clone(),
|
||||
messages: session.messages.clone(),
|
||||
scenery: session.scenery.clone()
|
||||
};
|
||||
|
||||
save_data.save();
|
||||
|
||||
if do_regen {
|
||||
session.reply_options.responses.clear();
|
||||
}
|
||||
};
|
||||
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
if do_regen {
|
||||
session.regenerate_options().await;
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(prediction_request_in, prediction_out)
|
||||
SessionControl {
|
||||
event_sink: action_sink,
|
||||
scene_watch: prediction_out,
|
||||
activity_watch: activity_notify_src
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user