prediction: also split out the prediction task to another module for future growth
This commit is contained in:
+3
-33
@@ -1,4 +1,4 @@
|
|||||||
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionRequest, CreateChatCompletionResponse}};
|
use async_openai::{types::chat::{ChatCompletionMessageToolCalls, CreateChatCompletionResponse}};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use futures_timer::Delay;
|
use futures_timer::Delay;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
@@ -49,6 +49,7 @@ mod scene;
|
|||||||
mod events;
|
mod events;
|
||||||
mod transcription;
|
mod transcription;
|
||||||
mod tts;
|
mod tts;
|
||||||
|
mod prediction;
|
||||||
|
|
||||||
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
#[derive(JsonSchema, Deserialize, Serialize, Debug, Clone)]
|
||||||
struct PossibleResponse {
|
struct PossibleResponse {
|
||||||
@@ -535,43 +536,12 @@ async fn main() {
|
|||||||
|
|
||||||
let (sys_message_sender, mut sys_message_receiver) = tokio::sync::mpsc::channel(5);
|
let (sys_message_sender, mut sys_message_receiver) = tokio::sync::mpsc::channel(5);
|
||||||
let tts_request_sender = start_tts().await;
|
let tts_request_sender = start_tts().await;
|
||||||
|
let (prediction_request_in, mut prediction_out) = prediction::start_prediction().await;
|
||||||
let (prediction_in, mut prediction_out) = tokio::sync::watch::channel(None);
|
|
||||||
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default());
|
|
||||||
|
|
||||||
let (mut audio_state_receiver, audio_control_in, mut transcription_out) = transcription::start_transcription(sys_message_sender).await;
|
let (mut audio_state_receiver, audio_control_in, mut transcription_out) = transcription::start_transcription(sys_message_sender).await;
|
||||||
|
|
||||||
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender);
|
let mut app = App::new(prediction_request_in, audio_control_in, tts_request_sender);
|
||||||
app.load();
|
app.load();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let client: Client<OpenAIConfig> = Client::default();
|
|
||||||
loop {
|
|
||||||
if let Ok(_) = prediction_request_out.changed().await {
|
|
||||||
let request = prediction_request_out.borrow_and_update().clone();
|
|
||||||
let chat_request = CreateChatCompletionRequest {
|
|
||||||
/*tools: Some(vec![
|
|
||||||
ChatCompletionTools::Function(
|
|
||||||
ChatCompletionTool {
|
|
||||||
function: FunctionObject {
|
|
||||||
name: "log_stage_event".into(),
|
|
||||||
description: Some("Log an event in the stage direction.".into()),
|
|
||||||
parameters: Some(schema_for!(StageEventArgs).into()),
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]),*/
|
|
||||||
..request.into()
|
|
||||||
};
|
|
||||||
let response = client.chat().create(chat_request).await.unwrap();
|
|
||||||
prediction_in.send(Some(response)).unwrap();
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut events = EventStream::new();
|
let mut events = EventStream::new();
|
||||||
let mut last_tick = Instant::now();
|
let mut last_tick = Instant::now();
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
use async_openai::{Client, config::OpenAIConfig, types::chat::{CreateChatCompletionRequest, CreateChatCompletionResponse}};
|
||||||
|
|
||||||
|
use crate::scene::Scene;
|
||||||
|
|
||||||
|
pub async fn start_prediction() -> (tokio::sync::watch::Sender<Scene>, tokio::sync::watch::Receiver<Option<CreateChatCompletionResponse>>) {
|
||||||
|
let (prediction_in, prediction_out) = tokio::sync::watch::channel(None);
|
||||||
|
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(Scene::default());
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let client: Client<OpenAIConfig> = Client::default();
|
||||||
|
loop {
|
||||||
|
if let Ok(_) = prediction_request_out.changed().await {
|
||||||
|
let request = prediction_request_out.borrow_and_update().clone();
|
||||||
|
let chat_request = CreateChatCompletionRequest {
|
||||||
|
/*tools: Some(vec![
|
||||||
|
ChatCompletionTools::Function(
|
||||||
|
ChatCompletionTool {
|
||||||
|
function: FunctionObject {
|
||||||
|
name: "log_stage_event".into(),
|
||||||
|
description: Some("Log an event in the stage direction.".into()),
|
||||||
|
parameters: Some(schema_for!(StageEventArgs).into()),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]),*/
|
||||||
|
..request.into()
|
||||||
|
};
|
||||||
|
let response = client.chat().create(chat_request).await.unwrap();
|
||||||
|
prediction_in.send(Some(response)).unwrap();
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(prediction_request_in, prediction_out)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user