diff --git a/src/artifacts/archive.rs b/src/artifacts/archive.rs index 96c1a5c..2f88801 100644 --- a/src/artifacts/archive.rs +++ b/src/artifacts/archive.rs @@ -56,7 +56,7 @@ impl Archive { // track, album, artist pub fn stats(&self) -> (usize, usize, usize) { - self.contents.iter().map(|(_id, artifact)| { + self.contents.values().map(|artifact| { match artifact.contents() { Contents::Track(_) => (1, 0, 0), Contents::Album(_) => (0, 1, 0), @@ -68,16 +68,16 @@ impl Archive { } pub fn get<'a>(&'a self, id: &Uuid) -> Option> { - if self.contents.get(id).is_some() { - Some(ArtifactRef { id: id.clone(), archive: self }) + if self.contents.contains_key(id) { + Some(ArtifactRef { id: *id, archive: self }) } else { None } } pub fn get_mut<'a>(&'a mut self, id: &Uuid) -> Option> { - if self.contents.get(id).is_some() { - Some(ArtifactRefMut { id: id.clone(), archive: self }) + if self.contents.contains_key(id) { + Some(ArtifactRefMut { id: *id, archive: self }) } else { None } @@ -136,7 +136,7 @@ impl Archive { pub fn insert<'a>(&'a mut self, artifact: Artifact) -> ArtifactRef<'a> { // If we are inserting a new artifact with a complete MBID... - if let Some(mbid) = artifact.mbid.clone() { + if let Some(mbid) = artifact.mbid { let search_id = mbid; // If an entry already exists keyed by this MBID, merge into it if let Some(existing) = self.contents.get_mut(&search_id) { @@ -145,7 +145,7 @@ impl Archive { } else { // Otherwise, attempt to find existing artifacts with the same contents (but no MBID) let mut targets: Vec<(Uuid, Artifact)> = self.contents.extract_if(|_, v| { v.contents == artifact.contents }).collect(); - if let Some((target_id, mut target)) = targets.pop() { + if let Some((_target_id, mut target)) = targets.pop() { // Merge any other extracted targets into the primary one for (_, next) in targets { target.merge(next); @@ -153,11 +153,11 @@ impl Archive { // Merge the incoming artifact into the merged target target.merge(artifact); // Insert merged target under the canonical MBID key - self.contents.insert(search_id.clone(), target); + self.contents.insert(search_id, target); ArtifactRef { id: search_id, archive: self } } else { // No matching content found: insert under the MBID key - self.contents.insert(search_id.clone(), artifact); + self.contents.insert(search_id, artifact); ArtifactRef { id: search_id, archive: self } } } @@ -165,9 +165,9 @@ impl Archive { // Otherwise, we attempt to merge it in. In the end, there will somehow still be a record with this mbid let mut targets: Vec<(Uuid, Artifact)> = self.contents.extract_if(|_, v| { v.contents == artifact.contents }).collect(); if let Some((target_id, mut target)) = targets.pop() { - let next_id = if let Some(ref mbid) = artifact.mbid { + let next_id = if let Some(mbid) = artifact.mbid { // If the new artifact has an mbid, we start using that as the archive key - mbid.clone() + mbid } else { // Otherwise, why regenerate a new one? target_id @@ -178,11 +178,11 @@ impl Archive { } target.merge(artifact); // Re-insert the merged target back into the archive under the chosen id - self.contents.insert(next_id.clone(), target); + self.contents.insert(next_id, target); ArtifactRef { id: next_id, archive: self } } else { let new_id = Uuid::new_v4(); - self.contents.insert(new_id.clone(), artifact); + self.contents.insert(new_id, artifact); ArtifactRef { id: new_id, archive: self } } } diff --git a/src/artifacts/bandcamp.rs b/src/artifacts/bandcamp.rs index a034b41..12073d4 100644 --- a/src/artifacts/bandcamp.rs +++ b/src/artifacts/bandcamp.rs @@ -9,21 +9,21 @@ pub struct BandcampQueryArgs { pub query: String } -impl Into for bandcamp::Artist { - fn into(self) -> Artifact { - ArtifactBuilder::new(SourceID::Bandcamp).contents(Artist { name: self.name, bio: self.bio, location: self.location }).build() +impl From for Artifact { + fn from(val: bandcamp::Artist) -> Self { + ArtifactBuilder::new(SourceID::Bandcamp).contents(Artist { name: val.name, bio: val.bio, location: val.location }).build() } } -impl Into for bandcamp::Album { - fn into(self) -> Artifact { +impl From for Artifact { + fn from(val: bandcamp::Album) -> Self { ArtifactBuilder::new(SourceID::Bandcamp) .contents(Album { - about: self.about, - title: self.title, - artist: self.band.name, - credits: self.credits, - release_date: Some(self.release_date) + about: val.about, + title: val.title, + artist: val.band.name, + credits: val.credits, + release_date: Some(val.release_date) }).build() } } diff --git a/src/artifacts/beets.rs b/src/artifacts/beets.rs index 4d4319c..bbf8572 100644 --- a/src/artifacts/beets.rs +++ b/src/artifacts/beets.rs @@ -35,20 +35,20 @@ struct BeetsTrack { mb_trackid: Option } -impl Into for BeetsTrack { - fn into(self) -> Artifact { +impl From for Artifact { + fn from(val: BeetsTrack) -> Self { let track_data = Track { - title: self.title, - label: self.label, - year: Some(self.year), - genres: self.genres.unwrap_or_default(), - album: Some(self.album), - artist: Some(self.artist), + title: val.title, + label: val.label, + year: Some(val.year), + genres: val.genres.unwrap_or_default(), + album: Some(val.album), + artist: Some(val.artist), bpm: None, }; let builder = ArtifactBuilder::new(SourceID::Beets) .contents(track_data); - if let Ok(mbid) = Uuid::parse_str(&self.mb_trackid.unwrap_or_default()) { + if let Ok(mbid) = Uuid::parse_str(&val.mb_trackid.unwrap_or_default()) { builder.mbid(mbid).build() } else { builder.build() @@ -136,25 +136,21 @@ impl DataSource for BeetsDB { type Error = BeetsError; async fn synchronize(&self, artifact: &mut Artifact) -> Result, Self::Error> { - match artifact.contents { - Contents::Track(ref mut target_track) => { - let args = BeatsQueryArgs { - title: Some(target_track.title.clone()), - artist: target_track.artist.clone(), - album: target_track.album.clone(), - ..Default::default() - }; + if let Contents::Track(ref mut target_track) = artifact.contents { + let args = BeatsQueryArgs { + title: Some(target_track.title.clone()), + artist: target_track.artist.clone(), + album: target_track.album.clone(), + ..Default::default() + }; - let results = self.query(&BeatsQueryMultiArgs { args: vec![args] }).await?; + let results = self.query(&BeatsQueryMultiArgs { args: vec![args] }).await?; - if let Some(first) = results.first() { - artifact.merge(first.clone()); - } else { - log::debug!("Beets could not find {:?}", target_track); - } - - }, - _ => () + if let Some(first) = results.first() { + artifact.merge(first.clone()); + } else { + log::debug!("Beets could not find {:?}", target_track); + } } Ok(vec![]) diff --git a/src/artifacts/mod.rs b/src/artifacts/mod.rs index 83e5aa7..fa83374 100644 --- a/src/artifacts/mod.rs +++ b/src/artifacts/mod.rs @@ -37,10 +37,6 @@ impl PartialEq for Artist { true } - - fn ne(&self, other: &Self) -> bool { - !self.eq(other) - } } #[derive(Debug, Serialize, Deserialize, Clone, Default)] @@ -63,10 +59,6 @@ impl PartialEq for Album { true } - - fn ne(&self, other: &Self) -> bool { - !self.eq(other) - } } #[derive(Debug, Serialize, Deserialize, Clone, Default)] @@ -103,10 +95,6 @@ impl PartialEq for Track { true } - - fn ne(&self, other: &Self) -> bool { - !self.eq(other) - } } #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] diff --git a/src/artifacts/musicbrainz.rs b/src/artifacts/musicbrainz.rs index 0dd0e62..9460fae 100644 --- a/src/artifacts/musicbrainz.rs +++ b/src/artifacts/musicbrainz.rs @@ -11,17 +11,8 @@ use crate::artifacts::{Album, Artifact, ArtifactBuilder, Artist, Contents, Merge impl From for Track { fn from(value: Recording) -> Self { - let artist = if let Some(artist) = value.artist_credit.unwrap_or_default().first() { - Some(artist.name.clone()) - } else { - None - }; - - let album = if let Some(album) = value.releases.unwrap_or_default().first() { - Some(album.title.clone()) - } else { - None - }; + let artist = value.artist_credit.unwrap_or_default().first().map(|x| x.name.clone() ); + let album = value.releases.unwrap_or_default().first().map(|x| x.title.clone() ); Self { title: value.title, @@ -34,16 +25,12 @@ impl From for Track { impl From for Album { fn from(value: Release) -> Self { - let artist = if let Some(artist) = value.artist_credit.unwrap_or_default().first() { - Some(artist.name.clone()) - } else { - None - }.unwrap_or_default(); + let artist = value.artist_credit.unwrap_or_default().first().map(|x| x.name.clone() ); Self { about: value.annotation, title: value.title, - artist, + artist: artist.unwrap_or_default(), ..Default::default() } } @@ -55,7 +42,6 @@ impl From for Artist { bio: value.artist.annotation, location: value.artist.country, name: value.name, - ..Default::default() } } } @@ -98,29 +84,27 @@ impl DataSource for MBQuery { if artifact.mbid.is_none() { return Ok(ret); } - let artifact_id = artifact.mbid.clone().unwrap(); + let artifact_id = artifact.mbid.unwrap(); log::debug!("Synchronizing {} with musicbrainz", artifact_id); - match artifact.contents { - Contents::Track(_) => { - let mb_track = Recording::fetch() - .id(&artifact_id.to_string()) - .with_releases().with_artists().with_annotations().execute_async().await; + // FIXME: Need to also synchronize albums and artists + if let Contents::Track(_) = artifact.contents { + let mb_track = Recording::fetch() + .id(&artifact_id.to_string()) + .with_releases().with_artists().with_annotations().execute_async().await; - let track = match mb_track { - Ok(track) => track, - Err(err) => { - log::error!("Failed to grab musicbrainz data: {:?}", err); - return Err(err); - } - }; + let track = match mb_track { + Ok(track) => track, + Err(err) => { + log::error!("Failed to grab musicbrainz data: {:?}", err); + return Err(err); + } + }; - let (track, mut new_artifacts) = Self::extract_recording_data(track); + let (track, mut new_artifacts) = Self::extract_recording_data(track); - ret.push(track.clone()); - ret.append(&mut new_artifacts); - artifact.merge(track); - }, - _ => () + ret.push(track.clone()); + ret.append(&mut new_artifacts); + artifact.merge(track); } Ok(ret) diff --git a/src/artifacts/tools.rs b/src/artifacts/tools.rs index 9971a85..e7c0e58 100644 --- a/src/artifacts/tools.rs +++ b/src/artifacts/tools.rs @@ -32,19 +32,19 @@ impl Tool { } } -impl Into for Tool { - fn into(self) -> ChatCompletionTool { +impl From for ChatCompletionTool { + fn from(val: Tool) -> Self { ChatCompletionTool { function: FunctionObjectArgs::default() - .name(self.name) - .description(self.description) - .parameters(self.schema).build().unwrap() + .name(val.name) + .description(val.description) + .parameters(val.schema).build().unwrap() } } } -impl Into for Tool { - fn into(self) -> ChatCompletionTools { - ChatCompletionTools::Function(self.into()) +impl From for ChatCompletionTools { + fn from(val: Tool) -> Self { + ChatCompletionTools::Function(val.into()) } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 46262aa..41f8121 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,7 +78,7 @@ impl log::Log for SysMessageLogger { fn log(&self, record: &log::Record) { let msg = format!("{}", record.args()); - write!(self.1.lock().unwrap(), "{}\n", msg).unwrap(); + writeln!(self.1.lock().unwrap(), "{}", msg).unwrap(); if record.level() <= LevelFilter::Info { self.0.send(ConversationEntry::SystemMessage(msg)).unwrap(); } diff --git a/src/prediction/character.rs b/src/prediction/character.rs index 8b80125..9c51e02 100644 --- a/src/prediction/character.rs +++ b/src/prediction/character.rs @@ -48,7 +48,7 @@ pub struct Character { } impl Character { - pub async fn regenerate<'a, T: Toolbox>(&mut self, client: &mut Client, context: ChatCompletionRequestMessage, toolbox: &mut T, output: &mut mpsc::UnboundedSender, schema: &Value) -> Result<(usize, Option), CharacterError> { + pub async fn regenerate(&mut self, client: &mut Client, context: ChatCompletionRequestMessage, toolbox: &mut T, output: &mut mpsc::UnboundedSender, schema: &Value) -> Result<(usize, Option), CharacterError> { let mut full_conversation = vec![ self.header_message.clone(), context @@ -106,15 +106,21 @@ impl Character { ChatCompletionMessageToolCalls::Function(call) => { log::debug!("Tool {} {}/{}", call.function.name, idx+1, calls.len()); log::debug!("Args {}", call.function.arguments); - if let Some(tool_result) = toolbox.execute_tool(call).await { - // Push tool output messages directly into the conversation as fast as we can - for message in &tool_result.messages { - output.send(PredictionAction::ConversationAppend(message.clone())).unwrap(); + match toolbox.execute_tool(call).await { + Ok(tool_result) => { + // Push tool output messages directly into the conversation as fast as we can + for message in &tool_result.messages { + output.send(PredictionAction::ConversationAppend(message.clone())).unwrap(); + } + results.push((&call.id, tool_result)); + }, + Err(err) => { + results.push((&call.id, ToolResults { + result: Some(format!("Error while calling tool: {:?}", err)), + ..Default::default() + })); + log::error!("Attemped to call {:?}, but got an error instead: {:?}", call, err); } - results.push((&call.id, tool_result)); - } else { - results.push((&call.id, ToolResults::default())); - log::error!("Attemped to call {:?}, but no result was returned.", call); } }, _ => panic!("Unknown tool was called") @@ -141,13 +147,12 @@ impl Character { } if let Some(content) = message.message.content.as_ref() { let options = serde_json::from_str(content.as_str())?; - return Ok((tokens_used as usize, Some(options))); + Ok((tokens_used as usize, Some(options))) } else { - return Ok((tokens_used as usize, None)); + Ok((tokens_used as usize, None)) } } else { - log::info!("No messages were received!"); - return Err(CharacterError::NoOutput); + Err(CharacterError::NoOutput) } } diff --git a/src/prediction/mod.rs b/src/prediction/mod.rs index 47d90ab..4514321 100644 --- a/src/prediction/mod.rs +++ b/src/prediction/mod.rs @@ -164,7 +164,7 @@ impl Conversation { self.insert(ConversationEntry::Spoken(Speaker::ShipComputer, response.message)).await; if response.finished.unwrap_or_default() { self.characters.get_mut(&Speaker::ShipComputer).unwrap().forget().await; - if self.computer_todo.lock().await.iter().filter(|(_, is_finished)| !*is_finished).next().is_none() { + if !self.computer_todo.lock().await.iter().any(|(_, is_finished)| !*is_finished) { self.insert(ConversationEntry::StageDirection("The ship computer goes idle.".into())).await; self.event_sink.send(SessionUpdate::Thinking(Speaker::ShipComputer, false)).unwrap(); } else { @@ -290,11 +290,7 @@ pub async fn conversation_task(save_data: SaveData, sys_log_messages: tokio::syn }; let backlog: Vec<_> = save_data.messages.iter().filter_map(|msg| { - if let Ok(entry) = msg.clone().try_into() { - Some(entry) - } else { - None - } + msg.clone().try_into().ok() }).collect(); event_sink.send(SessionUpdate::Conversation(backlog.clone())).unwrap(); diff --git a/src/prediction/toolbox.rs b/src/prediction/toolbox.rs index 460b324..a5151dc 100644 --- a/src/prediction/toolbox.rs +++ b/src/prediction/toolbox.rs @@ -5,11 +5,11 @@ use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; -use crate::{artifacts::{archive::Archive, bandcamp::BandcampSource, beets::BeetsDB, mixxx::MixxxDB, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::conversation::{ConversationEntry, Speaker}}; +use crate::{artifacts::{archive::Archive, beets::BeetsDB, mixxx::MixxxDB, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::conversation::{ConversationEntry, Speaker}}; pub trait Toolbox { fn tools(&self) -> Vec; - fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> impl Future> + Send; + fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> impl Future> + Send; } pub struct StageToolbox; @@ -27,6 +27,18 @@ impl StageToolbox { } } +#[derive(Debug)] +pub enum ToolError { + InvalidToolName, + Json(serde_json::Error) +} + +impl From for ToolError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} + impl Toolbox for StageToolbox { fn tools(&self) -> Vec { vec![ @@ -34,13 +46,13 @@ impl Toolbox for StageToolbox { ] } - async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Option { + async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Result { let func_name = call.function.name.as_str(); let args = call.function.arguments.as_str(); - Some(match func_name { - "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, - _ => return None - }) + match func_name { + "log_stage_event" => Ok(self.tool_stage_event(serde_json::from_str(args)?).await), + _ => Err(ToolError::InvalidToolName) + } } } @@ -82,24 +94,24 @@ impl Toolbox for ArchiveToolbox { ] } - async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Option { + async fn execute_tool(&mut self, call: &ChatCompletionMessageToolCall) -> Result { let func_name = call.function.name.as_str(); let args = call.function.arguments.as_str(); - Some(match func_name { - "query_bandcamp" => ToolResults { result: None, messages: vec![] }, + match func_name { + "query_bandcamp" => Ok(ToolResults { result: None, messages: vec![] }), "query_beets" => self.tool_artifact_query(&mut BeetsDB, args).await, "query_musicbrainz" => self.tool_artifact_query(&mut MBQuery, args).await, "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, "synchronize_artifacts" => self.synchronize_artifacts().await, "task_list" => self.tasklist_operation(args).await, - _ => return None - }) + _ => Err(ToolError::InvalidToolName) + } } } impl ArchiveToolbox { - async fn tasklist_operation(&mut self, json_args: &str) -> ToolResults { - let args: TaskListArgs = serde_json::from_str(json_args).unwrap_or_default(); + async fn tasklist_operation(&mut self, json_args: &str) -> Result { + let args: TaskListArgs = serde_json::from_str(json_args)?; let mut locked = self.todo_list.lock().await; @@ -117,22 +129,22 @@ impl ArchiveToolbox { } } - ToolResults { + Ok(ToolResults { ..Default::default() - } + }) } - async fn synchronize_artifacts(&mut self) -> ToolResults { + async fn synchronize_artifacts(&mut self) -> Result { let updated_count = self.archive.lock().await.synchronize().await; - ToolResults { + Ok(ToolResults { messages: vec![ConversationEntry::Spoken(Speaker::ShipComputer, format!("Synchronized {} items", updated_count))], ..Default::default() - } + }) } - async fn tool_artifact_query(&mut self, src: &mut Src, json_args: &str) -> ToolResults where Src::Args: core::fmt::Debug, Src::Error: core::fmt::Debug { - let args: Src::Args = serde_json::from_str(json_args).unwrap(); + async fn tool_artifact_query(&mut self, src: &mut Src, json_args: &str) -> Result where Src::Args: core::fmt::Debug, Src::Error: core::fmt::Debug { + let args: Src::Args = serde_json::from_str(json_args)?; log::debug!("Executing query {:?}", args); let result; match src.query(&args).await { @@ -147,10 +159,10 @@ impl ArchiveToolbox { } } - ToolResults { + Ok(ToolResults { result: Some(result), messages: vec![] - } + }) } } diff --git a/src/ui.rs b/src/ui.rs index 764e77b..d2a3c77 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -91,7 +91,7 @@ impl Ui { let throbber = Throbber::default().throbber_style(Style::new().cyan()); frame.render_stateful_widget(throbber, throb_area, &mut self.throbber_state); } else { - frame.render_widget(Clear::default(), throb_area); + frame.render_widget(Clear, throb_area); } } @@ -101,7 +101,7 @@ impl Ui { let throbber = Throbber::default().throbber_style(Style::new().red()).throbber_set(throbber_widgets_tui::WHITE_SQUARE); frame.render_stateful_widget(throbber, throb_area, &mut self.throbber_state); } else { - frame.render_widget(Clear::default(), throb_area); + frame.render_widget(Clear, throb_area); } } @@ -170,7 +170,6 @@ impl Ui { self.predictions.insert(PredictionAction::SetPlaylist(playlist_name)).await; } else { log::error!("Invalid episode number format. Use /episode [number]"); - return; } }, "/playlist" => { @@ -193,7 +192,6 @@ impl Ui { }, "/computer" => { self.predictions.insert(PredictionAction::ComputerCommand(arg.to_string())).await; - return; }, _ => { log::error!("Unknown command. Available commands: /episode [number], /narrative [text], /event [text], /computer [text], /timer [minutes]"); @@ -221,7 +219,7 @@ impl Ui { KeyCode::Enter => { let row_num = self.conversation_state.selected().unwrap(); if let ConversationEntry::Spoken(Speaker::Eva, text) = &self.conversation[self.conversation.len() - 1 - row_num] { - self.tts.speak(text.clone()).await; + self.tts.speak(text.clone()).await.unwrap(); self.focus_state = FocusState::UserInput; self.conversation_state.select(None); self.reply_state.select_first(); @@ -241,10 +239,10 @@ impl Ui { KeyCode::Char('x') if key.modifiers.contains(KeyModifiers::CONTROL) => { if self.recording_audio { self.recording_audio = false; - self.transcription.stop(); + self.transcription.stop().unwrap(); } else { self.recording_audio = true; - self.transcription.start(); + self.transcription.start().unwrap(); } }, KeyCode::Down => self.reply_state.select_next(), @@ -256,12 +254,10 @@ impl Ui { let next_msg = self.user_input.value_and_reset(); if next_msg.trim().is_empty() { self.insert_selected_prompt().await; + } else if next_msg.starts_with("/") { + self.on_command(&next_msg).await; } else { - if next_msg.starts_with("/") { - self.on_command(&next_msg).await; - } else { - self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Spoken(Speaker::User, next_msg))).await; - } + self.predictions.insert(PredictionAction::ConversationAppend(ConversationEntry::Spoken(Speaker::User, next_msg))).await; } }, _ => {self.user_input.handle_event(&evt);}, diff --git a/src/widgets.rs b/src/widgets.rs index ad42efc..b673643 100644 --- a/src/widgets.rs +++ b/src/widgets.rs @@ -182,7 +182,7 @@ impl Widget for StatusBar<'_> { let negative = time_remaining.abs() != time_remaining; let time_style = if minutes_remaining <= 0 || negative { - Style::new().fg(ratatui::style::Color::LightRed).bold() + Style::new().fg(ratatui::style::Color::LightRed).underlined() } else if minutes_remaining < 5 { Style::new().fg(ratatui::style::Color::LightRed).bold() } else if minutes_remaining < 10 {