From 9c2023f6cab0f2c17079c6885cc02192f2b6a4c7 Mon Sep 17 00:00:00 2001 From: Victoria Fischer Date: Wed, 3 Jun 2026 19:30:23 +0200 Subject: [PATCH] prediction: also split out the prediction task to another module for future growth --- src/main.rs | 36 +++--------------------------------- src/prediction.rs | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 33 deletions(-) create mode 100644 src/prediction.rs diff --git a/src/main.rs b/src/main.rs index 22499ed..a5b21a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionRequest, CreateChatCompletionResponse}}; +use async_openai::{types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionResponse}}; use chrono::{DateTime, Duration, Utc}; use futures_timer::Delay; use schemars::JsonSchema; @@ -49,6 +49,7 @@ mod scene; mod events; mod transcription; mod tts; +mod prediction; #[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)] struct PossibleResponse { @@ -535,43 +536,12 @@ async fn main() { let (sys_message_sender, mut sys_message_receiver) = tokio::sync::mpsc::channel(5); let tts_request_sender = start_tts().await; - - let (prediction_in, mut prediction_out) = tokio::sync::watch::channel(None); - let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default()); - + let (prediction_request_in, mut prediction_out) = prediction::start_prediction().await; let (mut audio_state_receiver, audio_control_in, mut transcription_out) = transcription::start_transcription(sys_message_sender).await; let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender); app.load(); - tokio::spawn(async move { - let client: Client = Client::default(); - loop { - if let Ok(_) = prediction_request_out.changed().await { - let request = prediction_request_out.borrow_and_update().clone(); - let chat_request = CreateChatCompletionRequest { - /*tools: Some(vec![ - ChatCompletionTools::Function( - ChatCompletionTool { - function: FunctionObject { - name: "log_stage_event".into(), - description: Some("Log an event in the stage direction.".into()), - parameters: Some(schema_for!(StageEventArgs).into()), - ..Default::default() - } - } - ) - ]),*/ - ..request.into() - }; - let response = client.chat().create(chat_request).await.unwrap(); - prediction_in.send(Some(response)).unwrap(); - } else { - return; - } - } - }); - let mut events = EventStream::new(); let mut last_tick = Instant::now(); diff --git a/src/prediction.rs b/src/prediction.rs new file mode 100644 index 0000000..cf67259 --- /dev/null +++ b/src/prediction.rs @@ -0,0 +1,38 @@ +use async_openai::{Client, config::OpenAIConfig, types::chat::{CreateChatCompletionRequest, CreateChatCompletionResponse}}; + +use crate::scene::Scene; + +pub async fn start_prediction() -> (tokio::sync::watch::Sender, tokio::sync::watch::Receiver>) { + let (prediction_in, prediction_out) = tokio::sync::watch::channel(None); + let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default()); + + tokio::spawn(async move { + let client: Client = Client::default(); + loop { + if let Ok(_) = prediction_request_out.changed().await { + let request = prediction_request_out.borrow_and_update().clone(); + let chat_request = CreateChatCompletionRequest { + /*tools: Some(vec![ + ChatCompletionTools::Function( + ChatCompletionTool { + function: FunctionObject { + name: "log_stage_event".into(), + description: Some("Log an event in the stage direction.".into()), + parameters: Some(schema_for!(StageEventArgs).into()), + ..Default::default() + } + } + ) + ]),*/ + ..request.into() + }; + let response = client.chat().create(chat_request).await.unwrap(); + prediction_in.send(Some(response)).unwrap(); + } else { + return; + } + } + }); + + (prediction_request_in, prediction_out) +} \ No newline at end of file