diff --git a/src/artifacts/musicbrainz.rs b/src/artifacts/musicbrainz.rs index f064458..85a505a 100644 --- a/src/artifacts/musicbrainz.rs +++ b/src/artifacts/musicbrainz.rs @@ -1,4 +1,3 @@ -use log::Record; use musicbrainz_rs::entity::artist_credit::ArtistCredit; use musicbrainz_rs::entity::release::Release; use musicbrainz_rs::{ApiEndpointError, entity::recording::Recording}; diff --git a/src/audio.rs b/src/audio.rs index 9f48e03..84fd212 100644 --- a/src/audio.rs +++ b/src/audio.rs @@ -1,4 +1,6 @@ -use jack::{AudioIn, AudioOut, ClientOptions, NotificationHandler}; +use std::{collections::HashMap, fmt::Display}; + +use jack::{AudioIn, AudioOut, ClientOptions, NotificationHandler, Port, ProcessScope}; use oximedia_metering::vu_meter::VuMeter; use serde::{Deserialize, Serialize}; use tokio::sync::*; @@ -27,29 +29,123 @@ impl AudioInputControl { } } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy, Hash)] +enum Role { + Mic, + Tts, + Sfx +} + +impl Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let val = match self { + Self::Mic => "Microphone Input", + Self::Tts => "TTS Output", + Self::Sfx => "SFX Output" + }; + + f.write_str(val) + } +} + +impl Role { + fn is_input(&self) -> bool { + match self { + Role::Mic => true, + _ => false + } + } +} + #[derive(Debug)] -pub struct MicStream { +pub struct AudioInStream { pub src: mpsc::Receiver>, pub sample_rate: u32 } #[derive(Debug)] -pub struct TtsOutStream { +pub struct AudioOutStream { pub sink: mpsc::Sender>, pub sample_rate: u32 } +struct AudioSource { + port: Port, + sample_sink: mpsc::Sender>, + meter: VuMeter +} + +impl AudioSource { + fn new(client: &jack::Client, name: &str) -> (Self, AudioInStream) { + let (sample_sink, receiver) = mpsc::channel(32); + let port = client.register_port(name, AudioIn::default()).unwrap(); + (AudioSource { + port, + sample_sink, + meter: VuMeter::new(client.sample_rate().into(), 1, None) + }, AudioInStream { + sample_rate: client.sample_rate(), + src: receiver + }) + } + + fn process(&mut self, scope: &ProcessScope) -> Option { + if self.port.connected_count().unwrap() > 0 { + let buf: Vec<_> = self.port.as_slice(scope).iter().copied().collect(); + self.meter.process_interleaved(&buf); + self.sample_sink.blocking_send(buf).unwrap(); + + self.meter.channel_vu(0) + } else { + None + } + } +} + #[derive(Debug)] -pub struct SfxOutStream { - pub sink: mpsc::Sender>, - pub sample_rate: u32 +struct AudioSink { + output_buf: Vec, + port: Port, + sample_src: mpsc::Receiver> +} + +impl AudioSink { + fn new(client: &jack::Client, name: &str) -> (Self, AudioOutStream) { + let (sender, sample_src) = mpsc::channel(32); + let port = client.register_port(name, AudioOut::default()).unwrap(); + (AudioSink { + output_buf: Vec::with_capacity(1024), + port, + sample_src + }, AudioOutStream { + sample_rate: client.sample_rate(), + sink: sender, + }) + } + + fn process(&mut self, scope: &ProcessScope) { + if let Ok(mut next_outbuf) = self.sample_src.try_recv() { + self.output_buf.append(&mut next_outbuf); + } + + if self.port.connected_count().unwrap() > 0 && !self.output_buf.is_empty() { + let outbuf = self.port.as_mut_slice(scope); + let mut next_segment: Vec = self.output_buf.drain(0..(outbuf.len()).min(self.output_buf.len())).collect(); + let underrun = outbuf.len() - next_segment.len(); + if underrun > 0 { + for _ in 0..underrun { + next_segment.push(0.); + } + } + + outbuf.copy_from_slice(&next_segment); + } + } } #[derive(Debug, Default, Clone, Serialize, Deserialize)] struct AudioConfig { - mic_in_connections: Vec, - tts_out_connections: Vec, - sfx_out_connections: Vec, + connections: HashMap> } impl AudioConfig { @@ -65,9 +161,7 @@ impl AudioConfig { #[derive(Debug)] struct Notify { config: AudioConfig, - mic_port: jack::Port, - tts_port: jack::Port, - sfx_port: jack::Port, + ports: HashMap> } impl NotificationHandler for Notify { @@ -78,26 +172,31 @@ impl NotificationHandler for Notify { port_id_b: jack::PortId, are_connected: bool, ) { - let port_a = client.port_by_id(port_id_a).unwrap(); - let port_b = client.port_by_id(port_id_b).unwrap(); + let port_src = client.port_by_id(port_id_a).unwrap(); + let port_dst = client.port_by_id(port_id_b).unwrap(); - let (stream_name, other_port, target_cfg) = if port_b == self.mic_port { - ("Microphone input", port_a, &mut self.config.mic_in_connections) - } else if port_a == self.tts_port { - ("TTS output", port_b, &mut self.config.tts_out_connections) - } else if port_a == self.sfx_port { - ("SFX output", port_b, &mut self.config.sfx_out_connections) - } else { - return; - }; - - if let Ok(port_name) = other_port.name() { - if are_connected { - log::info!("{} connected to {}", stream_name, port_name); - target_cfg.push(port_name); + let port_match = self.ports.iter().filter_map(|(role, local_port)| { + if role.is_input() && *local_port == port_dst { + Some((role, port_src.name())) + } else if *local_port == port_src { + Some((role, port_dst.name())) } else { - log::info!("{} disconnected from {}", stream_name, port_name); - target_cfg.retain(|x| { x != &port_name} ); + None + } + }).next(); + + if let Some((role, Ok(target_port))) = port_match { + if !self.config.connections.contains_key(role) { + self.config.connections.insert(*role, Default::default()); + } + let cfg_slot = self.config.connections.get_mut(role).unwrap(); + + if are_connected { + log::info!("{} connected to {}", role, target_port); + cfg_slot.push(target_port); + } else { + log::info!("{} disconnected from {}", role, target_port); + cfg_slot.retain(|x| { x != &target_port} ); } let save_data = serde_json::to_string_pretty(&self.config).unwrap(); @@ -106,69 +205,51 @@ impl NotificationHandler for Notify { } } -pub async fn start_audio_input() -> (AudioInputControl, MicStream, TtsOutStream, SfxOutStream) { +pub async fn start_audio_input() -> (AudioInputControl, AudioInStream, AudioOutStream, AudioOutStream) { let (exit_tx, exit_rx) = oneshot::channel(); let config = AudioConfig::load(); - let (mic_audio_sink, mic_audio_src) = mpsc::channel(32); - let (sfx_audio_sink, mut sfx_audio_src) = mpsc::channel(32); - let (tts_audio_sink, mut tts_audio_src) = mpsc::channel(32); let (volume_sink, volume_src) = watch::channel(0.); let (client, _status) = jack::Client::new("Eva-Cohost", ClientOptions::default() | ClientOptions::SESSION_ID).unwrap(); - let mic_port = client.register_port("microphone-in", AudioIn::default()).unwrap(); - let mut tts_port = client.register_port("tts-out", AudioOut::default()).unwrap(); - let mut sfx_port = client.register_port("sfx-out", AudioOut::default()).unwrap(); - let rate = client.sample_rate(); - for (port, connections) in [ - (&tts_port, &config.tts_out_connections), - (&sfx_port, &config.sfx_out_connections), - ] { - for peer_name in connections { - if let Some(peer) = client.port_by_name(peer_name) { - if let Err(err) = client.connect_ports(port, &peer) { - log::error!("Failed to reconnect {} to {}", port.name().unwrap(), peer_name); - } else { - log::info!("Reconnected {} to {}", port.name().unwrap(), peer_name); + let (mut tts_sink, tts_stream) = AudioSink::new(&client, "tts-out"); + let (mut sfx_sink, sfx_stream) = AudioSink::new(&client, "sfx-out"); + let (mut mic_src, mic_stream) = AudioSource::new(&client, "microphone-in"); + + let notifier = Notify { + config, + ports: HashMap::from_iter([ + (Role::Mic, mic_src.port.clone_unowned()), + (Role::Tts, tts_sink.port.clone_unowned()), + (Role::Sfx, sfx_sink.port.clone_unowned()) + ]) + }; + + for (role, local_port) in ¬ifier.ports { + if let Some(targets) = notifier.config.connections.get(role) { + for peer_name in targets { + if let Some(peer) = client.port_by_name(peer_name) { + let (src, dst) = if role.is_input() { + (&peer, local_port) + } else { + (local_port, &peer) + }; + if let Err(err) = client.connect_ports(&src, &dst) { + log::error!("Failed to reconnect {} to {}: {:?}", role, peer_name, err); + } else { + log::info!("Reconnected {} to {}", role, peer_name); + } } } } } - for (port, connections) in [ - (&mic_port, &config.mic_in_connections) - ] { - for peer_name in connections { - if let Some(peer) = client.port_by_name(peer_name) { - client.connect_ports(&peer, port).unwrap(); - } - } - } - - let notifier = Notify { - config, - mic_port: mic_port.clone_unowned(), - tts_port: tts_port.clone_unowned(), - sfx_port: sfx_port.clone_unowned() - }; - - let mut meter = VuMeter::new(rate.into(), 1, None); - let mut tts_output_buf = vec![]; - let mut sfx_output_buf = vec![]; - tts_output_buf.reserve(1024); - sfx_output_buf.reserve(1024); - let handler = jack::contrib::ClosureProcessHandler::new(move |_client, scope| { - if mic_port.connected_count().unwrap() > 0 { - let buf: Vec<_> = mic_port.as_slice(scope).iter().copied().collect(); - meter.process_interleaved(&buf); - mic_audio_sink.blocking_send(buf).unwrap(); - + if let Some(next_vu) = mic_src.process(scope) { volume_sink.send_if_modified(|v| { - let next_vu = meter.channel_vu(0).unwrap(); let next_vu = (next_vu * 100.0).round() / 100.0; if *v != next_vu { *v = next_vu; @@ -179,27 +260,8 @@ pub async fn start_audio_input() -> (AudioInputControl, MicStream, TtsOutStream, }); } - for (src, output, port) in [ - (&mut tts_audio_src, &mut tts_output_buf, &mut tts_port), - (&mut sfx_audio_src, &mut sfx_output_buf, &mut sfx_port) - ] { - if let Ok(mut next_outbuf) = src.try_recv() { - output.append(&mut next_outbuf); - } - - if port.connected_count().unwrap() > 0 && !output.is_empty() { - let outbuf = port.as_mut_slice(scope); - let mut next_segment: Vec = output.drain(0..(outbuf.len()).min(output.len())).collect(); - let underrun = outbuf.len() - next_segment.len(); - if underrun > 0 { - for _ in 0..underrun { - next_segment.push(0.); - } - } - - outbuf.copy_from_slice(&next_segment); - } - + for sink in [&mut tts_sink, &mut sfx_sink] { + sink.process(scope); } jack::Control::Continue }); @@ -215,14 +277,5 @@ pub async fn start_audio_input() -> (AudioInputControl, MicStream, TtsOutStream, (AudioInputControl { volume_src, _jack_client: JackClientRef { killswitch: Some(exit_tx) } - }, MicStream { - sample_rate: rate, - src: mic_audio_src - }, TtsOutStream { - sample_rate: rate, - sink: tts_audio_sink - }, SfxOutStream { - sample_rate: rate, - sink: sfx_audio_sink - }) + }, mic_stream, tts_stream, sfx_stream) } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 5edcd57..84d3585 100644 --- a/src/main.rs +++ b/src/main.rs @@ -130,7 +130,7 @@ async fn main() { }; let prediction_ctrl = prediction::start_prediction(saved_session, sys_message_src).await; - let (audio_ctrl, mic_stream, tts_output, sfx_output) = start_audio_input().await; + let (audio_ctrl, mic_stream, tts_output, _sfx_output) = start_audio_input().await; let tts_ctrl = start_tts(tts_output).await; let transcription_ctrl = transcription::start_transcription(mic_stream).await; diff --git a/src/transcription.rs b/src/transcription.rs index a6a7c79..6d48508 100644 --- a/src/transcription.rs +++ b/src/transcription.rs @@ -4,7 +4,7 @@ use async_openai::{Client, config::OpenAIConfig, types::{InputSource, audio::{Au use tempfile::SpooledData; use tokio::sync::{mpsc, watch}; -use crate::{audio::MicStream, events::AudioRecordRequest}; +use crate::{audio::AudioInStream, events::AudioRecordRequest}; #[derive(Debug)] pub struct TranscriptionControl { @@ -44,7 +44,7 @@ impl std::io::Seek for RcFile { } } -pub async fn start_transcription(mut mic_src: MicStream) -> TranscriptionControl { +pub async fn start_transcription(mut mic_src: AudioInStream) -> TranscriptionControl { let (audio_control_in, mut audio_control_out) = watch::channel(AudioRecordRequest::Finish); let (transcription_in, transcription_out) = mpsc::channel(1); diff --git a/src/tts.rs b/src/tts.rs index 699b6bf..e6d67db 100644 --- a/src/tts.rs +++ b/src/tts.rs @@ -1,6 +1,6 @@ use std::process::{Command, Stdio}; -use crate::audio::TtsOutStream; +use crate::audio::AudioOutStream; #[derive(Debug)] pub struct TtsControl { @@ -13,7 +13,7 @@ impl TtsControl { } } -pub async fn start_tts(audio_sink: TtsOutStream) -> TtsControl { +pub async fn start_tts(audio_sink: AudioOutStream) -> TtsControl { let (tts_request_sender, mut tts_request_receiver) = tokio::sync::mpsc::channel(3);