prediction: completely rewrite the prediction engine by moving all the conversation manipulation into that task out of the UI

This commit is contained in:
2026-06-04 21:34:10 +02:00
parent 57e3ff9b55
commit 49c720fe46
5 changed files with 368 additions and 219 deletions
+178 -28
View File
@@ -1,36 +1,186 @@
use async_openai::{Client, config::OpenAIConfig, types::chat::{CreateChatCompletionRequest, CreateChatCompletionResponse}};
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use crate::scene::Scene;
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
pub async fn start_prediction() -> (tokio::sync::watch::Sender<Scene>, tokio::sync::watch::Receiver<Option<CreateChatCompletionResponse>>) {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(None);
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default());
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
pub struct PossibleResponse {
pub text: String,
pub stage_direction: Option<String>
}
#[derive(Default, Debug, JsonSchema, Deserialize, Serialize, Clone)]
pub struct GeneratedResponses {
pub responses: Vec<PossibleResponse>,
}
#[derive(Debug)]
struct Session {
client: Client<OpenAIConfig>,
conversation: Vec<ConversationEntry>,
header_message: ChatCompletionRequestMessage,
messages: Vec<ChatCompletionRequestMessage>,
reply_options: GeneratedResponses
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
struct StageEventArgs {
text: String,
}
impl Session {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
let mut conversation = vec![];
for msg in &messages {
if let Ok(conversation_msg) = msg.clone().try_into() {
conversation.push(conversation_msg);
}
}
Self {
client: Default::default(),
conversation,
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
messages,
reply_options: Default::default()
}
}
fn insert_actions(&mut self, actions: &StageActions) {
for addition in &actions.additions {
self.insert_conversation(addition.clone());
}
}
async fn regenerate_options(&mut self, direction: &StageDirection) -> Option<Scene> {
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
let mut full_conversation = vec![
self.header_message.clone(),
direction_message
];
full_conversation.append(&mut self.messages.clone());
let tools = vec![
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("log_stage_event")
.description("Inserts an event into the current scene script")
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("log_ship_computer_message")
.description("Inserts a message from the ship computer into the scene script")
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
})
];
let request = CreateChatCompletionRequestArgs::default()
.messages(full_conversation)
.model("gpt-5.4")
.tools(tools)
.max_completion_tokens(350u32)
.response_format(ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: None,
name: "responses".into(),
schema: schema_for!(GeneratedResponses).into(),
strict: None
}
})
.build().unwrap();
let response = self.client.chat().create(request).await.unwrap();
if let Some(message) = response.choices.first() {
if let Some(calls) = &message.message.tool_calls {
for call in calls {
match call {
ChatCompletionMessageToolCalls::Function(call) => {
match call.function.name.as_str() {
"log_stage_event" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
self.insert_conversation(ConversationEntry::StageDirection(args.text));
},
"log_ship_computer_message" => {
let args: StageEventArgs = serde_json::from_str(call.function.arguments.as_str()).unwrap();
self.insert_conversation(ConversationEntry::ShipComputer(args.text));
},
_ => panic!("Unknown function was called")
}
},
_ => panic!("Unknown tool was called")
}
}
Some(self.as_scene())
} else {
self.reply_options = serde_json::from_str(message.message.content.as_ref().unwrap().as_str()).unwrap();
Some(self.as_scene())
}
} else {
//FIXME: Handle tool calls
panic!();
}
}
fn as_scene(&self) -> Scene {
Scene::new(self.reply_options.clone(), self.conversation.clone())
}
fn insert_conversation(&mut self, entry: ConversationEntry) {
self.conversation.push(entry.clone());
if let Ok(next_msg) = entry.try_into() {
self.messages.push(next_msg);
}
}
}
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
let mut session = Session::from_initial_messages(initial_messages);
// Send the initial scene to the UI, after we have loaded the session from the first messages.
prediction_in.send(session.as_scene()).unwrap();
tokio::spawn(async move {
let client: Client<OpenAIConfig> = Client::default();
loop {
if let Ok(_) = prediction_request_out.changed().await {
let request = prediction_request_out.borrow_and_update().clone();
let chat_request = CreateChatCompletionRequest {
/*tools: Some(vec![
ChatCompletionTools::Function(
ChatCompletionTool {
function: FunctionObject {
name: "log_stage_event".into(),
description: Some("Log an event in the stage direction.".into()),
parameters: Some(schema_for!(StageEventArgs).into()),
..Default::default()
}
}
)
]),*/
..request.into()
};
let response = client.chat().create(chat_request).await.unwrap();
prediction_in.send(Some(response)).unwrap();
} else {
return;
}
tokio::select! {
maybe_message = sys_message_src.recv() => {
if let Some(message) = maybe_message {
session.insert_conversation(ConversationEntry::SystemMessage(message));
prediction_in.send(session.as_scene()).unwrap();
}
},
maybe_request = prediction_request_out.changed() => {
if maybe_request.is_ok() {
let next_cxt = prediction_request_out.borrow().clone();
session.insert_actions(&next_cxt);
let mut save_data = SaveData {
direction: next_cxt.direction,
messages: session.messages.clone()
};
save_data.save();
if let Some(next_scene) = session.regenerate_options(&save_data.direction).await {
save_data.messages = session.messages.clone();
save_data.save();
prediction_in.send(next_scene).unwrap();
}
}
}
};
}
});