prediction: completely rewrite the prediction engine by moving all the conversation manipulation into that task out of the UI
This commit is contained in:
+178
-28
@@ -1,36 +1,186 @@
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{CreateChatCompletionRequest, CreateChatCompletionResponse}};
|
||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::scene::Scene;
|
||||
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
|
||||
|
||||
pub async fn start_prediction() -> (tokio::sync::watch::Sender<Scene>, tokio::sync::watch::Receiver<Option<CreateChatCompletionResponse>>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(None);
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default());
|
||||
|
||||
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
|
||||
|
||||
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct PossibleResponse {
|
||||
pub text: String,
|
||||
pub stage_direction: Option<String>
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
|
||||
pub struct GeneratedResponses {
|
||||
pub responses: Vec<PossibleResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Session {
|
||||
client: Client<OpenAIConfig>,
|
||||
conversation: Vec<ConversationEntry>,
|
||||
header_message: ChatCompletionRequestMessage,
|
||||
messages: Vec<ChatCompletionRequestMessage>,
|
||||
reply_options: GeneratedResponses
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
|
||||
struct StageEventArgs {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
|
||||
let mut conversation = vec![];
|
||||
for msg in &messages {
|
||||
if let Ok(conversation_msg) = msg.clone().try_into() {
|
||||
conversation.push(conversation_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
client: Default::default(),
|
||||
conversation,
|
||||
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
|
||||
messages,
|
||||
reply_options: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_actions(&mut self, actions: &StageActions) {
|
||||
for addition in &actions.additions {
|
||||
self.insert_conversation(addition.clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn regenerate_options(&mut self, direction: &StageDirection) -> Option<Scene> {
|
||||
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
|
||||
let mut full_conversation = vec![
|
||||
self.header_message.clone(),
|
||||
direction_message
|
||||
];
|
||||
full_conversation.append(&mut self.messages.clone());
|
||||
|
||||
|
||||
let tools = vec![
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_stage_event")
|
||||
.description("Inserts an event into the current scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
}),
|
||||
ChatCompletionTools::Function(ChatCompletionTool {
|
||||
function: FunctionObjectArgs::default()
|
||||
.name("log_ship_computer_message")
|
||||
.description("Inserts a message from the ship computer into the scene script")
|
||||
.parameters(schema_for!(StageEventArgs))
|
||||
.build().unwrap()
|
||||
})
|
||||
];
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.messages(full_conversation)
|
||||
.model("gpt-5.4")
|
||||
.tools(tools)
|
||||
.max_completion_tokens(350u32)
|
||||
.response_format(ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: None,
|
||||
name: "responses".into(),
|
||||
schema: schema_for!(GeneratedResponses).into(),
|
||||
strict: None
|
||||
}
|
||||
})
|
||||
.build().unwrap();
|
||||
|
||||
let response = self.client.chat().create(request).await.unwrap();
|
||||
|
||||
if let Some(message) = response.choices.first() {
|
||||
if let Some(calls) = &message.message.tool_calls {
|
||||
for call in calls {
|
||||
match call {
|
||||
ChatCompletionMessageToolCalls::Function(call) => {
|
||||
match call.function.name.as_str() {
|
||||
"log_stage_event" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::StageDirection(args.text));
|
||||
},
|
||||
"log_ship_computer_message" => {
|
||||
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
|
||||
self.insert_conversation(ConversationEntry::ShipComputer(args.text));
|
||||
},
|
||||
_ => panic!("Unknown function was called")
|
||||
}
|
||||
},
|
||||
_ => panic!("Unknown tool was called")
|
||||
}
|
||||
}
|
||||
Some(self.as_scene())
|
||||
} else {
|
||||
self.reply_options = serde_json::from_str(message.message.content.as_ref().unwrap().as_str()).unwrap();
|
||||
Some(self.as_scene())
|
||||
}
|
||||
} else {
|
||||
//FIXME: Handle tool calls
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
fn as_scene(&self) -> Scene {
|
||||
Scene::new(self.reply_options.clone(), self.conversation.clone())
|
||||
}
|
||||
|
||||
fn insert_conversation(&mut self, entry: ConversationEntry) {
|
||||
self.conversation.push(entry.clone());
|
||||
|
||||
if let Ok(next_msg) = entry.try_into() {
|
||||
self.messages.push(next_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
|
||||
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
|
||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
|
||||
|
||||
let mut session = Session::from_initial_messages(initial_messages);
|
||||
|
||||
// Send the initial scene to the UI, after we have loaded the session from the first messages.
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client: Client<OpenAIConfig> = Client::default();
|
||||
loop {
|
||||
if let Ok(_) = prediction_request_out.changed().await {
|
||||
let request = prediction_request_out.borrow_and_update().clone();
|
||||
let chat_request = CreateChatCompletionRequest {
|
||||
/*tools: Some(vec![
|
||||
ChatCompletionTools::Function(
|
||||
ChatCompletionTool {
|
||||
function: FunctionObject {
|
||||
name: "log_stage_event".into(),
|
||||
description: Some("Log an event in the stage direction.".into()),
|
||||
parameters: Some(schema_for!(StageEventArgs).into()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
)
|
||||
]),*/
|
||||
..request.into()
|
||||
};
|
||||
let response = client.chat().create(chat_request).await.unwrap();
|
||||
prediction_in.send(Some(response)).unwrap();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
tokio::select! {
|
||||
maybe_message = sys_message_src.recv() => {
|
||||
if let Some(message) = maybe_message {
|
||||
session.insert_conversation(ConversationEntry::SystemMessage(message));
|
||||
prediction_in.send(session.as_scene()).unwrap();
|
||||
}
|
||||
},
|
||||
maybe_request = prediction_request_out.changed() => {
|
||||
if maybe_request.is_ok() {
|
||||
let next_cxt = prediction_request_out.borrow().clone();
|
||||
session.insert_actions(&next_cxt);
|
||||
|
||||
let mut save_data = SaveData {
|
||||
direction: next_cxt.direction,
|
||||
messages: session.messages.clone()
|
||||
};
|
||||
|
||||
save_data.save();
|
||||
|
||||
if let Some(next_scene) = session.regenerate_options(&save_data.direction).await {
|
||||
save_data.messages = session.messages.clone();
|
||||
save_data.save();
|
||||
prediction_in.send(next_scene).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user