From 3a8130d785d4795876ca5c9f07785db1faa20436 Mon Sep 17 00:00:00 2001 From: Victoria Fischer Date: Wed, 17 Jun 2026 11:09:50 +0200 Subject: [PATCH] artifacts: rewrite the entire artifact querying layer to create modular 'tools' and 'datasource's --- Cargo.lock | 6 +- Cargo.toml | 1 + src/artifacts/archive.rs | 153 +++++++++++++++++++++++++++++++ src/artifacts/bandcamp.rs | 76 +++++++++++++--- src/artifacts/beets.rs | 87 +++++++++++++----- src/artifacts/mixxx.rs | 73 +++++++++------ src/artifacts/mod.rs | 170 ++++++++++++++++++++++++++--------- src/artifacts/musicbrainz.rs | 168 +++++++++++++++++++++++++--------- src/artifacts/tools.rs | 50 +++++++++++ src/prediction.rs | 139 +++++++--------------------- src/scene/mod.rs | 6 +- 11 files changed, 672 insertions(+), 257 deletions(-) create mode 100644 src/artifacts/archive.rs create mode 100644 src/artifacts/tools.rs diff --git a/Cargo.lock b/Cargo.lock index caf871b..3b3c476 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1216,6 +1216,7 @@ dependencies = [ "tokio", "tui-input", "tui-skeleton", + "uuid", ] [[package]] @@ -5190,13 +5191,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "atomic", "getrandom 0.4.2", "js-sys", + "serde_core", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 9fdef58..c46b3fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,4 @@ throbber-widgets-tui = "0.11.0" tokio = { version = "1.52.3", features = ["full"] } tui-input = "0.15.3" tui-skeleton = "0.3.0" +uuid = { version = "1.23.3", features = ["serde", "v4"] } diff --git a/src/artifacts/archive.rs b/src/artifacts/archive.rs new file mode 100644 index 0000000..2b9166a --- /dev/null +++ b/src/artifacts/archive.rs @@ -0,0 +1,153 @@ +use std::{collections::HashMap, ops::{Deref, DerefMut}}; +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::artifacts::{Artifact, Merge, SourceID, beets::BeetsDB, mixxx::MixxxDB, musicbrainz::MBQuery, tools::DataSource}; + +pub struct ArtifactRef<'a> { + id: Uuid, + archive: &'a Archive +} + + +pub struct ArtifactRefMut<'a> { + id: Uuid, + archive: &'a mut Archive +} + +impl<'a> ArtifactRefMut<'a> { + pub fn downgrade(self) -> ArtifactRef<'a> { + ArtifactRef { id: self.id, archive: self.archive } + } +} + +impl<'a> Deref for ArtifactRef<'a> { + type Target = Artifact; + fn deref(&self) -> &Self::Target { + self.archive.contents.get(&self.id).unwrap() + } +} + +impl<'a> Deref for ArtifactRefMut<'a> { + type Target = Artifact; + fn deref(&self) -> &Self::Target { + self.archive.contents.get(&self.id).unwrap() + } +} + +impl<'a> DerefMut for ArtifactRefMut<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.archive.contents.get_mut(&self.id).unwrap() + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct Archive { + #[serde(flatten)] + contents: HashMap +} + +impl Archive { + + pub fn len(&self) -> usize { + self.contents.len() + } + + pub fn get<'a>(&'a self, id: &Uuid) -> Option> { + if self.contents.get(id).is_some() { + Some(ArtifactRef { id: id.clone(), archive: self }) + } else { + None + } + } + + pub fn get_mut<'a>(&'a mut self, id: &Uuid) -> Option> { + if self.contents.get(id).is_some() { + Some(ArtifactRefMut { id: id.clone(), archive: self }) + } else { + None + } + } + + pub async fn data_sync(&mut self, datasrc: &mut Src, source: SourceID) -> usize where Src::Error: Debug { + let mut count = 0; + + let mut new_artifacts = vec![]; + + let pending = self.contents.iter_mut().filter_map(|(_, artifact)| { + if !artifact.sources.contains(&source) { + Some(artifact) + } else { + None + } + }); + + for artifact in pending { + match datasrc.synchronize(artifact).await { + Ok(mut new_pending) => { + count += new_pending.len() + 1; + new_artifacts.append(&mut new_pending); + }, + Err(err) => { + log::error!("Failed to synchronize {:?}: {:?}", artifact, err); + } + } + } + + count + } + + pub async fn synchronize(&mut self) -> usize { + log::info!("Synchronizing records"); + let mut count = 0; + + count += self.data_sync(&mut MixxxDB, SourceID::Mixxx).await; + count += self.data_sync(&mut BeetsDB, SourceID::Beets).await; + count += self.data_sync(&mut MBQuery, SourceID::Musicbrainz).await; + + log::info!("Updated {} records", count); + + count + } + + pub fn insert<'a>(&'a mut self, artifact: Artifact) -> ArtifactRef<'a> { + // If we are inserting a new artifact with a complete MBID... + if let Some(mbid) = artifact.mbid.clone() { + let search_id = mbid; + // And that one already exists... + if let Some(existing) = self.contents.get_mut(&search_id) { + // Update the data + existing.merge(artifact); + ArtifactRef { id: search_id, archive: self } + } else { + // Otherwise, we have a valid ID from some source, but it isn't in the system yet, so lets just fill it up + self.contents.insert(search_id.clone(), artifact); + ArtifactRef { id: search_id, archive: self } + } + } else { + // Otherwise, we attempt to merge it in. In the end, there will somehow still be a record with this mbid + let mut targets: Vec<(Uuid, Artifact)> = self.contents.extract_if(|_, v| { *v == artifact }).collect(); + if let Some((target_id, mut target)) = targets.pop() { + let next_id = if let Some(ref mbid) = artifact.mbid { + // If the new artifact has an mbid, we start using that as the archive key + mbid.clone() + } else { + // Otherwise, why regenerate a new one? + target_id + }; + + for (_, next) in targets { + target.merge(next); + } + target.merge(artifact); + ArtifactRef { id: next_id, archive: self } + } else { + let new_id = Uuid::new_v4(); + self.contents.insert(new_id.clone(), artifact); + ArtifactRef { id: new_id, archive: self } + } + } + } +} \ No newline at end of file diff --git a/src/artifacts/bandcamp.rs b/src/artifacts/bandcamp.rs index abd5e42..956a4ca 100644 --- a/src/artifacts/bandcamp.rs +++ b/src/artifacts/bandcamp.rs @@ -1,9 +1,8 @@ -use std::collections::HashSet; - +use bandcamp::SearchResultItem; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use crate::artifacts::{Album, Artifact, Artist, SourceID}; +use crate::artifacts::{Album, Artifact, ArtifactBuilder, Artist, SourceID, Track, tools::{DataSource, ToolDescription}}; #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] pub struct BandcampQueryArgs { @@ -12,19 +11,72 @@ pub struct BandcampQueryArgs { impl Into for bandcamp::Artist { fn into(self) -> Artifact { - Artifact::Artist(Artist { name: self.name, bio: self.bio, location: self.location, sources: HashSet::from([SourceID::Bandcamp(self.id)])}) + ArtifactBuilder::new(SourceID::Bandcamp).contents(Artist { name: self.name, bio: self.bio, location: self.location }).build() } } impl Into for bandcamp::Album { fn into(self) -> Artifact { - Artifact::Album(Album { - about: self.about, - title: self.title, - artist: self.band.name, - credits: self.credits, - release_date: Some(self.release_date), - sources: HashSet::from([SourceID::Bandcamp(self.id)]) - }) + ArtifactBuilder::new(SourceID::Bandcamp) + .contents(Album { + about: self.about, + title: self.title, + artist: self.band.name, + credits: self.credits, + release_date: Some(self.release_date) + }).build() + } +} + +pub struct BandcampSource; + +impl DataSource for BandcampSource { + type Args = BandcampQueryArgs; + type Error = (); + + async fn synchronize(&mut self, _artifact: &mut Artifact) -> Result, Self::Error> { + todo!() + } + + async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + log::debug!("Fetching artifacts from Bandcamp with {:?}", args); + let mut json_results = vec![]; + if let Ok(results) = bandcamp::search(args.query.as_str()).await { + for result in results { + log::debug!("Result: {:?}", result); + match result { + SearchResultItem::Artist(data) => { + let result = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); + json_results.push(result); + }, + SearchResultItem::Album(data) => { + let result = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); + json_results.push(result); + }, + SearchResultItem::Track(data) => { + let result = ArtifactBuilder::new(SourceID::Bandcamp) + .contents(Track { + title: data.name, + artist: Some(data.band_name), + album: data.album_name, + ..Default::default() + }).build(); + json_results.push(result); + } + _ => () + } + } + } + Ok(json_results) + } +} + +impl ToolDescription for BandcampSource { + fn description(&self) -> &str { + "Scans Bandcamp to find artifacts to use in the scene that match the given search parameters. To find an artist, provide only the artist name. To find an album, provide the artist and the album." + } + + fn name(&self) -> &str { + "query_bandcamp" } } \ No newline at end of file diff --git a/src/artifacts/beets.rs b/src/artifacts/beets.rs index d5ed760..f1477d3 100644 --- a/src/artifacts/beets.rs +++ b/src/artifacts/beets.rs @@ -1,18 +1,19 @@ -use std::{collections::HashSet, process::{Command, Stdio}}; +use std::process::{Command, Stdio}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use uuid::Uuid; -use crate::artifacts::{Artifact, SourceID, Track}; +use crate::artifacts::{Artifact, ArtifactBuilder, Contents, Merge, SourceID, Track, tools::{DataSource, ToolDescription}}; #[derive(Debug, Default, Serialize, Deserialize, Clone, JsonSchema)] pub struct BeatsQueryArgs { - artist: Option, - album: Option, - genre: Option, - title: Option, - year: Option + pub artist: Option, + pub album: Option, + pub genre: Option, + pub title: Option, + pub year: Option } #[derive(Debug, Default, Deserialize)] @@ -28,12 +29,7 @@ struct BeetsTrack { impl Into for BeetsTrack { fn into(self) -> Artifact { - let sources = if let Some(mbid) = self.mb_trackid { - HashSet::from([SourceID::Beets, SourceID::Musicbrainz(mbid)]) - } else { - HashSet::from([SourceID::Beets]) - }; - Artifact::Track(Track { + let track_data = Track { title: self.title, label: self.label, year: Some(self.year), @@ -41,33 +37,70 @@ impl Into for BeetsTrack { album: Some(self.album), artist: Some(self.artist), bpm: None, - sources - }) + }; + let builder = ArtifactBuilder::new(SourceID::Beets) + .contents(track_data); + if let Some(mbid) = self.mb_trackid { + builder.mbid(Uuid::parse_str(&mbid).unwrap()).build() + } else { + builder.build() + } } } -impl BeatsQueryArgs { - pub fn execute(self) -> Result, ()> { +pub struct BeetsDB; + +impl DataSource for BeetsDB { + type Args = BeatsQueryArgs; + + type Error = (); + + async fn synchronize(&mut self, artifact: &mut Artifact) -> Result, Self::Error> { + match artifact.contents { + Contents::Track(ref mut target_track) => { + let args = BeatsQueryArgs { + title: Some(target_track.title.clone()), + artist: target_track.artist.clone(), + album: target_track.album.clone(), + ..Default::default() + }; + + let results = self.query(&args).await.unwrap(); + + if let Some(first) = results.first() { + artifact.merge(first.clone()); + } else { + log::error!("Beets could not find {:?}", target_track); + } + + }, + _ => () + } + + Ok(vec![]) + } + + async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { let mut beets_cmd = Command::new("beet"); beets_cmd.args(["export", "-f", "json", "-i", "title,label,year,genres,album,artist,mb_trackid"]); let mut valid = false; - if let Some(artist) = self.artist { + if let Some(ref artist) = args.artist { beets_cmd.arg(format!("artist:{}", artist)); valid = true; } - if let Some(genre) = self.genre { + if let Some(ref genre) = args.genre { beets_cmd.arg(format!("genre:{}", genre)); valid = true; } - if let Some(album) = self.album { + if let Some(ref album) = args.album { beets_cmd.arg(format!("album:{}", album)); valid = true; } - if let Some(title) = self.title { + if let Some(ref title) = args.title { beets_cmd.arg(format!("title:{}", title)); valid = true; } - if let Some(year) = self.year { + if let Some(year) = args.year { beets_cmd.arg(format!("year:{}", year)); valid = true; } @@ -91,4 +124,14 @@ impl BeatsQueryArgs { Err(()) } } +} + +impl ToolDescription for BeetsDB { + fn description(&self) -> &str { + "Queries the ship's musical artifact archives for tracks matching the given search parameters" + } + + fn name(&self) -> &str { + "query_beets" + } } \ No newline at end of file diff --git a/src/artifacts/mixxx.rs b/src/artifacts/mixxx.rs index b2d933c..209ce20 100644 --- a/src/artifacts/mixxx.rs +++ b/src/artifacts/mixxx.rs @@ -1,8 +1,8 @@ -use std::collections::HashSet; - +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use sqlite::OpenFlags; -use crate::artifacts::{Album, Artifact, Artist, SourceID, Track}; +use crate::artifacts::{Album, Artifact, ArtifactBuilder, Artist, SourceID, Track, tools::{DataSource, ToolDescription}}; #[derive(Debug)] #[allow(unused)] @@ -16,11 +16,24 @@ impl From for MixxxError { } } -pub struct MixxxDB(()); +pub struct MixxxDB; -impl MixxxDB { - pub fn load(playlist_name: &str) -> Result, MixxxError> { +#[derive(Serialize, Deserialize, Debug, Default, JsonSchema)] +pub struct MixxxQuery { + pub playlist_name: String +} + +impl DataSource for MixxxDB { + type Args = MixxxQuery; + type Error = MixxxError; + + async fn synchronize(&mut self, _artifact: &mut Artifact) -> Result, Self::Error> { + Ok(vec![]) + } + + async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { let mut ret = vec![]; + let playlist_name = args.playlist_name.as_str(); log::info!("Loading Mixxx playlist {}", playlist_name); let connection = sqlite::Connection::open_thread_safe_with_flags("mixxxdb.sqlite", OpenFlags::new().with_read_only())?; let query = "SELECT id FROM Playlists WHERE name = ? ORDER BY id DESC LIMIT 1"; @@ -36,29 +49,39 @@ impl MixxxDB { let artist = track.try_read::<&str, _>("artist").unwrap_or("Unknown Artist"); let album = track.try_read::<&str, _>("album").unwrap_or("Unknown Album"); let bpm = track.try_read::("bpm").unwrap_or(0.); - ret.push(Artifact::Track(Track { - artist: Some(artist.into()), - album: Some(album.into()), - title: title.into(), - bpm: Some(bpm), - sources: HashSet::from([SourceID::Mixxx]), - ..Default::default() - })); + ret.push(ArtifactBuilder::new(SourceID::Mixxx) + .contents(Track { + artist: Some(artist.into()), + album: Some(album.into()), + title: title.into(), + bpm: Some(bpm), + ..Default::default() + }).build()); - ret.push(Artifact::Album(Album { - artist: artist.into(), - title: album.into(), - sources: HashSet::from([SourceID::Mixxx]), - ..Default::default() - })); + ret.push(ArtifactBuilder::new(SourceID::Mixxx) + .contents(Album { + artist: artist.into(), + title: album.into(), + ..Default::default() + }).build()); - ret.push(Artifact::Artist(Artist { - name: artist.into(), - sources: HashSet::from([SourceID::Mixxx]), - ..Default::default() - })); + ret.push(ArtifactBuilder::new(SourceID::Mixxx) + .contents(Artist { + name: artist.into(), + ..Default::default() + }).build()); } Ok(ret) } +} + +impl ToolDescription for MixxxDB { + fn description(&self) -> &str { + "Loads artifacts from a given Mixxx playlist name" + } + + fn name(&self) -> &str { + "query_mixxx" + } } \ No newline at end of file diff --git a/src/artifacts/mod.rs b/src/artifacts/mod.rs index 17608fe..ec71bc2 100644 --- a/src/artifacts/mod.rs +++ b/src/artifacts/mod.rs @@ -1,17 +1,21 @@ -use std::collections::{HashMap, HashSet}; +use std::{collections::HashSet, fmt::Debug, }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use uuid::Uuid; + pub mod bandcamp; pub mod mixxx; pub mod beets; pub mod musicbrainz; +pub mod archive; +pub mod tools; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum SourceID { - Bandcamp(u64), - Musicbrainz(String), + Bandcamp, + Musicbrainz, Mixxx, Beets } @@ -23,8 +27,6 @@ pub struct Artist { pub bio: Option, #[serde(skip_serializing_if = "Option::is_none")] pub location: Option, - - pub sources: HashSet } impl PartialEq for Artist { @@ -50,9 +52,7 @@ pub struct Album { #[serde(skip_serializing_if = "Option::is_none")] pub credits: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub release_date: Option>, - - pub sources: HashSet + pub release_date: Option> } #[derive(Debug, Serialize, Deserialize, Clone, Default)] @@ -70,9 +70,7 @@ pub struct Track { #[serde(skip_serializing_if = "Option::is_none")] pub artist: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub bpm: Option, - - pub sources: HashSet + pub bpm: Option } impl PartialEq for Track { @@ -97,48 +95,29 @@ impl PartialEq for Track { } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub enum Artifact { +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(tag = "type")] +pub enum Contents { Artist(Artist), Album(Album), Track(Track) } -macro_rules! merge_fields { - ($this:expr, $that:expr, $field:tt) => { - if $this.$field.is_none() { - $this.$field = $that.$field; - } - }; - ($this:tt, $that:tt, $($fields:tt),+) => { - $( - merge_fields!($this, $that, $fields); - - for src in &$that.sources { - $this.sources.insert(src.clone()); - } - )+ +impl From for Contents { + fn from(value: Artist) -> Self { + Self::Artist(value) } } -impl Merge for Artifact { - fn merge(&mut self, other: Self) { - if *self != other { - return; - } +impl From for Contents { + fn from(value: Album) -> Self { + Self::Album(value) + } +} - match (self, other) { - (Self::Track(this_track), Self::Track(that_track)) => { - merge_fields!(this_track, that_track, album, label, year, artist, bpm); - }, - (Self::Album(this_album), Self::Album(that_album)) => { - merge_fields!(this_album, that_album, about, credits, release_date); - }, - (Self::Artist(this_artist), Self::Artist(that_artist)) => { - merge_fields!(this_artist, that_artist, bio, location); - }, - _ => () - } +impl From for Contents { + fn from(value: Track) -> Self { + Self::Track(value) } } @@ -156,4 +135,107 @@ impl Merge for Vec { pub trait Merge { fn merge(&mut self, other: Self); +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct Artifact { + #[serde(skip_serializing_if = "Option::is_none")] + mbid: Option, + #[serde(flatten)] + contents: Contents, + + sources: HashSet, +} + +#[derive(Debug)] +pub struct ArtifactBuilder { + contents: Option, + mbid: Option, + source: SourceID, +} + +impl ArtifactBuilder { + pub fn new(source: SourceID) -> Self { + Self { + contents: None, + mbid: None, + source, + } + } + + pub fn contents>(mut self, contents: T) -> Self { + self.contents = Some(contents.into()); + self + } + + pub fn mbid>(mut self, mbid: T) -> Self { + self.mbid = Some(mbid.into()); + self + } + + pub fn build(self) -> Artifact { + Artifact { + mbid: self.mbid, + contents: self.contents.unwrap(), + sources: HashSet::from_iter([self.source]), + } + } +} + +impl Artifact { + pub fn contents(&self) -> &Contents { + &self.contents + } +} + +impl Merge for Artifact { + fn merge(&mut self, other: Self) { + self.contents.merge(other.contents); + if self.mbid.is_none() { + self.mbid = other.mbid; + } + for src in other.sources { + self.sources.insert(src); + } + } +} + +macro_rules! merge_fields { + ($this:expr, $that:expr, $field:tt) => { + if $this.$field.is_none() { + $this.$field = $that.$field; + } + }; + ($this:tt, $that:tt, $($fields:tt),+) => { + $( + merge_fields!($this, $that, $fields); + )+ + } +} + +impl Merge for Contents { + fn merge(&mut self, other: Self) { + if *self != other { + return; + } + + match (self, other) { + (Self::Track(this_track), Self::Track(that_track)) => { + this_track.merge(that_track); + }, + (Self::Album(this_album), Self::Album(that_album)) => { + merge_fields!(this_album, that_album, about, credits, release_date); + }, + (Self::Artist(this_artist), Self::Artist(that_artist)) => { + merge_fields!(this_artist, that_artist, bio, location); + }, + _ => () + } + } +} + +impl Merge for Track { + fn merge(&mut self, other: Self) { + merge_fields!(self, other, album, label, year, artist, bpm); + } } \ No newline at end of file diff --git a/src/artifacts/musicbrainz.rs b/src/artifacts/musicbrainz.rs index 7b1bd7b..c87dcc5 100644 --- a/src/artifacts/musicbrainz.rs +++ b/src/artifacts/musicbrainz.rs @@ -1,50 +1,136 @@ -use std::collections::HashSet; - -use musicbrainz_rs::entity::recording::Recording; +use musicbrainz_rs::{ApiEndpointError, entity::recording::Recording}; use musicbrainz_rs::prelude::*; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use uuid::Uuid; -use crate::artifacts::{Album, Artifact, Artist, SourceID, Track}; +use crate::artifacts::tools::{DataSource, ToolDescription}; +use crate::artifacts::{Album, Artifact, ArtifactBuilder, Artist, Contents, Merge, SourceID, Track}; + +pub struct MBQuery; + +impl DataSource for MBQuery { + type Error = ApiEndpointError; + type Args = MusicbrainzQueryArgs; + + async fn synchronize(&mut self, artifact: &mut Artifact) -> Result, Self::Error> { + let mut new_artifacts = vec![]; + if artifact.mbid.is_none() { + return Ok(new_artifacts); + } + let artifact_id = artifact.mbid.clone().unwrap(); + log::debug!("Synchronizing {} with musicbrainz", artifact_id); + match artifact.contents { + Contents::Track(ref mut target_track) => { + let mb_track = Recording::fetch() + .id(&artifact_id.to_string()) + .with_releases().with_artists().with_annotations().execute_async().await; + + let track = match mb_track { + Ok(track) => track, + Err(err) => { + log::error!("Failed to grab musicbrainz data: {:?}", err); + return Err(err); + } + }; + + artifact.sources.insert(SourceID::Musicbrainz); + + for release in track.releases.unwrap_or_default() { + log::debug!("Found new release: {:?}", release); + let first_artist = release.artist_credit.unwrap_or_default().first().unwrap().clone(); + new_artifacts.push(ArtifactBuilder::new(SourceID::Musicbrainz) + .contents(Album { + title: release.title.clone(), + artist: first_artist.name.clone(), + about: release.annotation, + ..Default::default() + }) + .mbid(Uuid::parse_str(&release.id).unwrap()).build()); + target_track.merge(Track { + album: Some(release.title), + title: track.title.clone(), + artist: Some(first_artist.artist.name.clone()), + ..Default::default() + }); + new_artifacts.push(ArtifactBuilder::new(SourceID::Musicbrainz) + .contents(Artist { + name: first_artist.name, + bio: first_artist.artist.annotation, + location: first_artist.artist.area.and_then(|area| { Some(area.name) }), + ..Default::default() + }) + .mbid(Uuid::parse_str(&first_artist.artist.id).unwrap()).build()); + } + }, + _ => () + } + + Ok(new_artifacts) + } + + async fn query(&mut self, args: &Self::Args) -> Result, Self::Error> { + let mut ret = vec![]; + + for mbid in &args.mb_ids { + log::debug!("Fetching recording id {}", mbid); + let track = Recording::fetch() + .id(&mbid) + .with_releases().with_artists().with_annotations().execute_async().await; + + let track = match track { + Ok(track) => track, + Err(err) => { + log::error!("Failed to grab musicbrainz data: {:?}", err); + continue; + } + }; + + for release in track.releases.unwrap_or_default() { + log::debug!("Found new release: {:?}", release); + let first_artist = release.artist_credit.unwrap_or_default().first().unwrap().clone(); + ret.push(ArtifactBuilder::new(SourceID::Musicbrainz) + .contents(Album { + title: release.title.clone(), + artist: first_artist.name.clone(), + about: release.annotation, + ..Default::default() + }) + .mbid(Uuid::parse_str(&release.id).unwrap()).build()); + ret.push(ArtifactBuilder::new(SourceID::Musicbrainz) + .contents(Track { + album: Some(release.title), + title: track.title.clone(), + artist: Some(first_artist.artist.name.clone()), + ..Default::default() + }) + .mbid(Uuid::parse_str(&mbid).unwrap()).build()); + ret.push(ArtifactBuilder::new(SourceID::Musicbrainz) + .contents(Artist { + name: first_artist.name, + bio: first_artist.artist.annotation, + location: first_artist.artist.area.and_then(|area| { Some(area.name) }), + ..Default::default() + }) + .mbid(Uuid::parse_str(&first_artist.artist.id).unwrap()).build()); + } + } + + Ok(ret) + } +} + +impl ToolDescription for MBQuery { + fn description(&self) -> &str { + "Fetches artifacts from Musicbrainz" + } + + fn name(&self) -> &str { + "query_musicbrainz" + } +} #[derive(Debug, Default, Deserialize, Serialize, JsonSchema)] pub struct MusicbrainzQueryArgs { pub mb_ids: Vec -} - -pub async fn search_artifacts(query: MusicbrainzQueryArgs) -> Result, musicbrainz_rs::ApiEndpointError> { - let mut ret = vec![]; - for mbid in query.mb_ids { - let track = Recording::fetch() - .id(&mbid) - .with_releases().with_artists().with_annotations().execute_async().await?; - - for release in track.releases.unwrap_or_default() { - log::debug!("Found new release: {:?}", release); - let first_artist = release.artist_credit.unwrap_or_default().first().unwrap().clone(); - ret.push(Artifact::Album(Album { - title: release.title.clone(), - artist: first_artist.name.clone(), - about: release.annotation, - sources: HashSet::from([SourceID::Musicbrainz(release.id.clone())]), - ..Default::default() - })); - ret.push(Artifact::Track(Track { - album: Some(release.title), - title: track.title.clone(), - artist: Some(first_artist.artist.name.clone()), - sources: HashSet::from([SourceID::Musicbrainz(release.id.clone())]), - ..Default::default() - })); - ret.push(Artifact::Artist(Artist { - name: first_artist.name, - bio: first_artist.artist.annotation, - location: first_artist.artist.area.and_then(|area| { Some(area.name) }), - sources: HashSet::from([SourceID::Musicbrainz(release.id)]), - ..Default::default() - })) - } - } - - Ok(ret) } \ No newline at end of file diff --git a/src/artifacts/tools.rs b/src/artifacts/tools.rs new file mode 100644 index 0000000..79854ca --- /dev/null +++ b/src/artifacts/tools.rs @@ -0,0 +1,50 @@ +use async_openai::types::chat::{ChatCompletionTool, ChatCompletionTools, FunctionObjectArgs}; +use schemars::{JsonSchema, Schema, schema_for}; +use serde::de::DeserializeOwned; + +use crate::artifacts::Artifact; + +pub trait DataSource: ToolDescription { + type Args: JsonSchema + DeserializeOwned; + type Error; + fn synchronize(&mut self, artifact: &mut Artifact) -> impl Future, Self::Error>>; + fn query(&mut self, args: &Self::Args) -> impl Future, Self::Error>>; +} + +pub trait ToolDescription { + fn description(&self) -> &str; + fn name(&self) -> &str; +} + +pub struct Tool { + pub name: String, + pub description: String, + pub schema: Schema +} + +impl Tool { + pub fn from_datasource(src: &T) -> Self { + Self { + name: src.name().to_string(), + description: src.description().to_string(), + schema: schema_for!(T::Args) + } + } +} + +impl Into for Tool { + fn into(self) -> ChatCompletionTool { + ChatCompletionTool { + function: FunctionObjectArgs::default() + .name(self.name) + .description(self.description) + .parameters(self.schema).build().unwrap() + } + } +} + +impl Into for Tool { + fn into(self) -> ChatCompletionTools { + ChatCompletionTools::Function(self.into()) + } +} \ No newline at end of file diff --git a/src/prediction.rs b/src/prediction.rs index 0b53aa4..d6e8304 100644 --- a/src/prediction.rs +++ b/src/prediction.rs @@ -1,14 +1,13 @@ -use std::{collections::HashSet, sync::Arc}; +use std::{fmt::Debug, sync::Arc}; -use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs, FinishReason, FunctionObjectArgs, ResponseFormat, ResponseFormatJsonSchema}}; -use bandcamp::SearchResultItem; +use async_openai::{Client, config::OpenAIConfig, types::chat::{ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, CreateChatCompletionRequestArgs, FinishReason, ResponseFormat, ResponseFormatJsonSchema}}; use chrono::{DateTime, Utc}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; use serde_json::{Serializer, ser::CompactFormatter}; use tokio::sync::{RwLock, mpsc, watch}; -use crate::{SaveData, artifacts::{self, Album, Artifact, Artist, Merge, SourceID, Track, bandcamp::BandcampQueryArgs, beets::BeatsQueryArgs, mixxx::MixxxDB, musicbrainz::{MusicbrainzQueryArgs, search_artifacts}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}}; +use crate::{SaveData, artifacts::{Contents, bandcamp::BandcampSource, beets::BeetsDB, mixxx::{MixxxDB, MixxxQuery}, musicbrainz::MBQuery, tools::{DataSource, Tool}}, scene::{Scene, Scenery, StageDirection, conversation::ConversationEntry}}; const SYSTEM_PROMPT: &str = include_str!("system-prompt.txt"); @@ -99,63 +98,16 @@ impl Session { } } - async fn tool_bandcamp_scan(&mut self, args: BandcampQueryArgs) -> ToolResults { + async fn tool_artifact_query(&mut self, src: &mut Src, json_args: &str) -> ToolResults where Src::Args: Debug { + let args: Src::Args = serde_json::from_str(json_args).unwrap(); let mut messages = vec![]; - log::debug!("Fetching artifacts from Bandcamp with {:?}", args); - let mut json_results = vec![]; - if let Ok(results) = bandcamp::search(args.query.as_str()).await { - for result in results { - log::debug!("Result: {:?}", result); - match result { - SearchResultItem::Artist(data) => { - /*let result = Artifact::Artist(Artist { - name: data.name, - location: data.location, - ..Default::default() - });*/ - let result = bandcamp::fetch_artist(data.artist_id).await.unwrap().into(); - json_results.push(result); - }, - SearchResultItem::Album(data) => { - let result = bandcamp::fetch_album(data.band_id, data.album_id).await.unwrap().into(); - /*let result = Artifact::Album(Album { - title: data.name, - artist: data.band_name, - ..Default::default() - });*/ - json_results.push(result); - }, - SearchResultItem::Track(data) => { - let result = Artifact::Track(Track { - title: data.name, - artist: Some(data.band_name), - album: data.album_name, - sources: HashSet::from([SourceID::Bandcamp(data.track_id)]), - ..Default::default() - }); - json_results.push(result); - } - _ => () - } - } - } - let artifact_count = json_results.len(); - messages.push(ConversationEntry::ShipComputer(format!("Bandcamp relay scan for '{}' complete. {} artifacts added to the archive.", args.query, artifact_count).into())); - - self.scenery.artifacts.merge(json_results); - - ToolResults { - result: Some(format!("{} artifacts were added to the archive.", artifact_count)), - messages - } - } - - async fn tool_artifact_query(&mut self, args: BeatsQueryArgs) -> ToolResults { - let mut messages = vec![]; - log::debug!("Executing beets query {:?}", args); - if let Ok(output) = args.clone().execute() { + log::debug!("Executing query {:?}", args); + if let Ok(output) = src.query(&args).await { messages.push(ConversationEntry::ShipComputer(format!("Found {} artifacts with archive query {:?}", output.len(), args))); - self.scenery.artifacts.merge(output); + for result in output { + self.scenery.artifacts.insert(result); + } + self.scenery.artifacts.synchronize().await; } else { messages.push(ConversationEntry::ShipComputer("Unable to execute query!".into())); }; @@ -166,20 +118,6 @@ impl Session { } } - async fn tool_musicbrainz_fetch_tracks(&mut self, args: MusicbrainzQueryArgs) -> ToolResults { - log::debug!("Executing musicbrainz fetch for {:?}", args); - let results = search_artifacts(args).await.unwrap(); - - let msg = format!("Found {} results via Musicbrainz relay search.", results.len()); - - self.scenery.artifacts.merge(results); - - ToolResults { - result: Some(msg.clone()), - messages: vec![ConversationEntry::ShipComputer(msg)] - } - } - fn generate_conversation(&self, direction: &StageDirection) -> Vec { let mut json_buf = vec![]; let mut ser = Serializer::with_formatter(&mut json_buf, CompactFormatter); @@ -208,38 +146,15 @@ impl Session { let full_conversation = self.generate_conversation(&self.direction); let tools = vec![ - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("log_stage_event") - .description("Inserts an event into the current scene script") - .parameters(schema_for!(StageEventArgs)) - .build().unwrap() - }), + Tool { name: "log_stage_event".into(), description: "Inserts an event into the current scene script".into(), schema: schema_for!(StageEventArgs)}.into(), // TODO: There should only be two queries, one against the ship's onboard archive, and another against the relay network, or whatever we call it. Both should be structured with the same arguments schema // TODO: A relay search should try to grab first from beets, then musicbrainz, then from bandcamp. // TODO: A query should specify what parts of metadata are sufficient for the result, so we don't always have to hit all the layers of data. beets can of course, ignore this. // TODO: A query should be hierarchical somehow? eg, "I already know about artist X, but I want to know everything about track Y from album Z" or "I don't know anything about artist X/album Y, please give me an overview" - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("archive_query") - .description("Queries the ship's musical artifact archives for tracks matching the given search parameters") - .parameters(schema_for!(BeatsQueryArgs)) - .build().unwrap() - }), - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("bandcamp_artifact_scan") - .description("Scans Bandcamp to find artifacts to use in the scene that match the given search parameters. To find an artist, provide only the artist name. To find an album, provide the artist and the album.") - .parameters(schema_for!(BandcampQueryArgs)) - .build().unwrap() - }), - ChatCompletionTools::Function(ChatCompletionTool { - function: FunctionObjectArgs::default() - .name("musicbrainz_track_search") - .description("Fetches metadata from bandcamp for the given musicbrainz recording IDs (mbid)") - .parameters(schema_for!(MusicbrainzQueryArgs)) - .build().unwrap() - }) + Tool::from_datasource(&MBQuery).into(), + Tool::from_datasource(&BandcampSource).into(), + Tool::from_datasource(&BeetsDB).into(), + Tool::from_datasource(&MixxxDB).into(), // TODO: We should be able to have eva update lore memories with a function call, and this lore is somehow fed into the show? but only the relevant bits? or maybe eva even queries it directly // TODO: The memory should also be able to remember facts about artists, albums, tracks we've had in the past, and those could be pulled up when there are hits in the playlist. ]; @@ -295,9 +210,10 @@ impl Session { let args = call.function.arguments.as_str(); let tool_result = match func_name { "log_stage_event" => self.tool_stage_event(serde_json::from_str(args).unwrap()).await, - "bandcamp_artifact_scan" => self.tool_bandcamp_scan(serde_json::from_str(args).unwrap()).await, - "archive_query" => self.tool_artifact_query(serde_json::from_str(args).unwrap()).await, - "musicbrainz_track_search" => self.tool_musicbrainz_fetch_tracks(serde_json::from_str(args).unwrap()).await, + "query_bandcamp" => self.tool_artifact_query(&mut BandcampSource, args).await, + "query_beets" => self.tool_artifact_query(&mut BeetsDB, args).await, + "query_musicbrainz" => self.tool_artifact_query(&mut MBQuery, args).await, + "query_mixxx" => self.tool_artifact_query(&mut MixxxDB, args).await, _ => unreachable!() }; results.push((&call.id, tool_result)); @@ -429,12 +345,19 @@ pub async fn start_prediction(saved_session: SaveData, mut messages: tokio::sync do_regen }, PredictionAction::SetPlaylist(playlist_name) => { - match MixxxDB::load(&playlist_name) { + let args = MixxxQuery { playlist_name }; + match MixxxDB.query(&args).await { Err(err) => log::info!("Failed to load mixxx playlist: {:?}.", err), Ok(playlist) => { - session.scenery.artifacts.merge(playlist.clone()); - session.scenery.current_playlist = playlist; - session.direction.playlist = playlist_name; + session.scenery.current_playlist = vec![]; + for item in playlist.clone() { + if let Contents::Track(as_track) = item.contents() { + session.scenery.current_playlist.push(as_track.clone()); + } + session.scenery.artifacts.insert(item); + } + session.scenery.artifacts.synchronize().await; + session.direction.playlist = args.playlist_name; log::info!("Mixxx playlist reloaded."); } } diff --git a/src/scene/mod.rs b/src/scene/mod.rs index f4f6263..87659e2 100644 --- a/src/scene/mod.rs +++ b/src/scene/mod.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Duration, Utc}; use serde::{Deserialize, Serialize}; -use crate::{artifacts::Artifact, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry}; +use crate::{artifacts::{Track, archive::Archive}, prediction::{GeneratedResponses, PossibleResponse}, scene::conversation::ConversationEntry}; pub mod conversation; @@ -36,8 +36,8 @@ impl Default for StageDirection { #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct Scenery { - pub artifacts: Vec, - pub current_playlist: Vec + pub artifacts: Archive, + pub current_playlist: Vec } #[derive(Debug, Default, Clone, Serialize, Deserialize)]