Enhance state direction command with ship computer outputs, and report token burn on the UI

This commit is contained in:
2026-06-09 09:04:03 +02:00
parent 114f1ea4df
commit 88e1f2a62b
3 changed files with 105 additions and 100 deletions
+70 -43
View File
@@ -3,10 +3,12 @@ use std::process::{Command, Stdio};
use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}};
use bandcamp::SearchResultItem;
use chrono::{DateTime, Utc};
use color_eyre::eyre::eyre;
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::{Serializer, ser::CompactFormatter};
use crate::{SaveData, scene::{ConversationEntry, Scene, StageActions, StageDirection}};
use crate::{SaveData, scene::{Artifact, ConversationEntry, Scene, Scenery, StageActions, StageDirection}};
const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt");
@@ -28,12 +30,20 @@ struct Session {
conversation: Vec<ConversationEntry>,
header_message: ChatCompletionRequestMessage,
messages: Vec<ChatCompletionRequestMessage>,
reply_options: GeneratedResponses
reply_options: GeneratedResponses,
scenery: Scenery,
tokens_consumed: usize
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
enum StageEvent {
ShipComputer(String),
StageDirection(String)
}
#[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)]
struct StageEventArgs {
text: String,
event: StageEvent
}
#[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)]
@@ -81,7 +91,7 @@ struct ToolResults {
}
impl Session {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>) -> Self {
fn from_initial_messages(messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> Self {
let mut conversation = vec![];
for msg in &messages {
if let Ok(conversation_msg) = msg.clone().try_into() {
@@ -94,7 +104,9 @@ impl Session {
conversation,
header_message: ChatCompletionRequestSystemMessageArgs::default().content(SYSTEM_PROMPT).build().unwrap().into(),
messages,
reply_options: Default::default()
reply_options: Default::default(),
scenery,
tokens_consumed: 0
}
}
@@ -105,15 +117,12 @@ impl Session {
}
async fn tool_stage_event(&mut self, args: StageEventArgs) -> ToolResults {
let msg = match args.event {
StageEvent::ShipComputer(text) => ConversationEntry::ShipComputer(text),
StageEvent::StageDirection(text) => ConversationEntry::StageDirection(text)
};
ToolResults {
messages: vec![ConversationEntry::StageDirection(args.text)],
..Default::default()
}
}
async fn tool_computer_event(&mut self, args: StageEventArgs) -> ToolResults {
ToolResults {
messages: vec![ConversationEntry::ShipComputer(args.text)],
messages: vec![msg],
..Default::default()
}
}
@@ -127,20 +136,23 @@ impl Session {
match result {
SearchResultItem::Artist(data) => {
let result: BandcampResult = bandcamp::fetch_artist(data.artist_id).await.unwrap().into();
json_results.push(result);
json_results.push(Artifact::Bandcamp(result));
},
SearchResultItem::Album(data) => {
let result: BandcampResult = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into();
json_results.push(result);
json_results.push(Artifact::Bandcamp(result));
}
_ => ()
}
}
}
messages.push(ConversationEntry::ShipComputer(format!("Artifact scan for '{}' complete. {} results.", args.query, json_results.len()).into()));
let artifact_count = json_results.len();
messages.push(ConversationEntry::ShipComputer(format!("Relay scan for '{}' complete. {} artifacts added to the archive.", args.query, artifact_count).into()));
self.scenery.artifacts.append(&mut json_results);
ToolResults {
result: Some(serde_json::to_string(&json_results).unwrap()),
result: Some(format!("{} artifacts were added to the archive.", artifact_count)),
messages
}
}
@@ -164,28 +176,46 @@ impl Session {
if let Some(year) = args.year {
beets_cmd.arg(format!("year:{}", year));
}
let result = if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
if let Ok(output) = beets_cmd.stdout(Stdio::piped()).spawn().unwrap().wait_with_output() {
messages.push(ConversationEntry::ShipComputer(format!("Executing archive query {:?}", beets_cmd)));
Some(minify::json::minify(str::from_utf8(&output.stdout).unwrap()))
self.scenery.artifacts.push(Artifact::BeetsTrack(serde_json::from_str(str::from_utf8(&output.stdout).unwrap()).unwrap()));
} else {
messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into()));
None
};
ToolResults {
result,
result: None,
messages
}
}
fn generate_conversation(&self, direction: &StageDirection) -> Vec<ChatCompletionRequestMessage> {
let mut json_buf = vec![];
let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter);
direction.serialize(&mut ser).unwrap();
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default()
.content(String::from_utf8(json_buf).unwrap())
.build().unwrap().into();
let mut json_buf = vec![];
let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter);
self.scenery.serialize(&mut ser).unwrap();
let scenery_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default()
.content(String::from_utf8(json_buf).unwrap())
.build().unwrap().into();
let mut full_conversation = vec![
self.header_message.clone(),
direction_message,
scenery_message,
];
full_conversation.append(&mut self.messages.clone());
full_conversation
}
async fn regenerate_options(&mut self, direction: &StageDirection) {
loop {
let direction_message: ChatCompletionRequestMessage = ChatCompletionRequestSystemMessageArgs::default().content(serde_json::to_string(&direction).unwrap()).build().unwrap().into();
let mut full_conversation = vec![
self.header_message.clone(),
direction_message
];
full_conversation.append(&mut self.messages.clone());
let full_conversation = self.generate_conversation(direction);
let tools = vec![
ChatCompletionTools::Function(ChatCompletionTool {
@@ -195,13 +225,6 @@ impl Session {
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("log_ship_computer_message")
.description("Inserts a message from the ship computer into the scene script")
.parameters(schema_for!(StageEventArgs))
.build().unwrap()
}),
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("archive_query")
@@ -212,13 +235,13 @@ impl Session {
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("bandcamp_artifact_scan")
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters")
.description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters. To find an artist, provide only the artist name. To find an album, provide the artist and the album.")
.parameters(schema_for!(BandcampQueryArgs))
.build().unwrap()
})
];
let request = CreateChatCompletionRequestArgs::default()
.messages(full_conversation.clone())
.messages(full_conversation)
.model("gpt-5.4-mini")
.tools(tools)
.max_completion_tokens(1024u32)
@@ -233,9 +256,13 @@ impl Session {
.build().unwrap();
let response = self.client.chat().create(request).await.unwrap_or_else(|err| {
panic!("{} {:?}", err, full_conversation);
panic!("OpenAI Panic: {}", err);
});
if let Some(usage) = response.usage {
self.tokens_consumed += usage.total_tokens as usize;
}
if let Some(message) = response.choices.first() {
match message.finish_reason {
@@ -263,7 +290,6 @@ impl Session {
let args = call.function.arguments.as_str();
let tool_result = match func_name {
"log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await,
"log_ship_computer_message" => self.tool_computer_event(serde_json::from_str(args).unwrap()).await,
"bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await,
"archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await,
_ => unreachable!()
@@ -303,7 +329,7 @@ impl Session {
}
fn as_scene(&self) -> Scene {
Scene::new(self.reply_options.clone(), self.conversation.clone())
Scene::new(self.reply_options.clone(), self.conversation.clone(), self.scenery.clone(), self.tokens_consumed)
}
fn insert_conversation(&mut self, entry: ConversationEntry) {
@@ -315,11 +341,11 @@ impl Session {
}
}
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<String>, initial_messages: Vec<ChatCompletionRequestMessage>, scenery: Scenery) -> (tokio::sync::watch::Sender<StageActions>, tokio::sync::watch::Receiver<Scene>) {
let (prediction_in, prediction_out) = tokio::sync::watch::channel(Scene::default());
let (prediction_request_in, mut prediction_request_out) = tokio::sync::watch::channel(StageActions::default());
let mut session = Session::from_initial_messages(initial_messages);
let mut session = Session::from_initial_messages(initial_messages, scenery);
// Send the initial scene to the UI, after we have loaded the session from the first messages.
prediction_in.send(session.as_scene()).unwrap();
@@ -340,7 +366,8 @@ pub async fn start_prediction(mut sys_message_src: tokio::sync::mpsc::Receiver<S
let mut save_data = SaveData {
direction: next_cxt.direction,
messages: session.messages.clone()
messages: session.messages.clone(),
scenery: session.scenery.clone()
};
save_data.save();