Route TUI realtime audio through shared echo cancellation

Add a shared capture/render processor for realtime local audio and wire microphone and speaker streams through it.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-04 12:45:18 -07:00
parent 5462954edd
commit 95885a8572
7 changed files with 549 additions and 328 deletions

41
codex-rs/Cargo.lock generated
View File

@@ -801,6 +801,15 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "autotools"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef941527c41b0fc0dd48511a8154cd5fc7e29200a0ff8b7203c5d777dbc795cf"
dependencies = [
"cc",
]
[[package]]
name = "aws-lc-rs"
version = "1.16.2"
@@ -2954,6 +2963,7 @@ dependencies = [
"uuid",
"vt100",
"webbrowser",
"webrtc-audio-processing",
"which 8.0.0",
"windows-sys 0.52.0",
"winsplit",
@@ -11888,6 +11898,37 @@ dependencies = [
"webrtc-util",
]
[[package]]
name = "webrtc-audio-processing"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e31cc7cbf143d9c3de985b113ddc3fddb4a60c7746635a69114f94a5b5af8f42"
dependencies = [
"webrtc-audio-processing-config",
"webrtc-audio-processing-sys",
]
[[package]]
name = "webrtc-audio-processing-config"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb5a1fcf911c54bf3d0e022020f1626ae99add5728c65d345bacfbf3adbf2ef2"
[[package]]
name = "webrtc-audio-processing-sys"
version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbd9fcfbedf79b0c7bc0f627d4e9cd5efcb648dcfc271c15503194482b39b0e8"
dependencies = [
"anyhow",
"autotools",
"bindgen",
"cc",
"fs_extra",
"pkg-config",
"regex",
]
[[package]]
name = "webrtc-data"
version = "0.17.1"

View File

@@ -352,6 +352,7 @@ vt100 = "0.16.2"
walkdir = "2.5.0"
webbrowser = "1.0"
webrtc = "0.17.1"
webrtc-audio-processing = { version = "~2.0", features = ["bundled"] }
which = "8"
wildmatch = "2.6.1"
zip = "2.4.2"

View File

@@ -113,6 +113,7 @@ tokio-util = { workspace = true, features = ["time"] }
[target.'cfg(not(target_os = "linux"))'.dependencies]
cpal = "0.15"
webrtc-audio-processing = { workspace = true }
[target.'cfg(unix)'.dependencies]
libc = { workspace = true }

View File

@@ -1,4 +1,6 @@
use super::*;
#[cfg(not(target_os = "linux"))]
use crate::realtime_audio_processing::RealtimeAudioProcessor;
use codex_protocol::protocol::ConversationStartParams;
use codex_protocol::protocol::RealtimeAudioFrame;
use codex_protocol::protocol::RealtimeConversationClosedEvent;
@@ -30,6 +32,8 @@ pub(super) struct RealtimeConversationUiState {
#[cfg(not(target_os = "linux"))]
capture_stop_flag: Option<Arc<AtomicBool>>,
#[cfg(not(target_os = "linux"))]
audio_processor: Option<RealtimeAudioProcessor>,
#[cfg(not(target_os = "linux"))]
capture: Option<crate::voice::VoiceCapture>,
#[cfg(not(target_os = "linux"))]
audio_player: Option<crate::voice::RealtimeAudioPlayer>,
@@ -331,9 +335,26 @@ impl ChatWidget {
fn enqueue_realtime_audio_out(&mut self, frame: &RealtimeAudioFrame) {
#[cfg(not(target_os = "linux"))]
{
if !self.realtime_conversation.is_active() {
return;
}
if self.realtime_conversation.audio_player.is_none() {
self.realtime_conversation.audio_player =
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
let Some(audio_processor) = self.realtime_conversation.audio_processor.clone()
else {
self.fail_realtime_conversation(
"Realtime audio processor was unavailable".to_string(),
);
return;
};
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor) {
Ok(player) => self.realtime_conversation.audio_player = Some(player),
Err(err) => {
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
return;
}
}
}
if let Some(player) = &self.realtime_conversation.audio_player
&& let Err(err) = player.enqueue_frame(frame)
@@ -367,12 +388,42 @@ impl ChatWidget {
self.realtime_conversation.meter_placeholder_id = Some(placeholder_id.clone());
self.request_redraw();
let audio_processor = match RealtimeAudioProcessor::new() {
Ok(audio_processor) => audio_processor,
Err(err) => {
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
"Failed to start realtime audio processor: {err}"
));
return;
}
};
self.realtime_conversation.audio_processor = Some(audio_processor.clone());
let audio_player =
match crate::voice::RealtimeAudioPlayer::start(&self.config, audio_processor.clone()) {
Ok(player) => player,
Err(err) => {
self.realtime_conversation.audio_processor = None;
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
return;
}
};
let capture = match crate::voice::VoiceCapture::start_realtime(
&self.config,
self.app_event_tx.clone(),
audio_processor,
) {
Ok(capture) => capture,
Err(err) => {
drop(audio_player);
self.realtime_conversation.audio_processor = None;
self.realtime_conversation.meter_placeholder_id = None;
self.remove_recording_meter_placeholder(&placeholder_id);
self.fail_realtime_conversation(format!(
@@ -389,10 +440,7 @@ impl ChatWidget {
self.realtime_conversation.capture_stop_flag = Some(stop_flag.clone());
self.realtime_conversation.capture = Some(capture);
if self.realtime_conversation.audio_player.is_none() {
self.realtime_conversation.audio_player =
crate::voice::RealtimeAudioPlayer::start(&self.config).ok();
}
self.realtime_conversation.audio_player = Some(audio_player);
std::thread::spawn(move || {
let mut meter = crate::voice::RecordingMeterState::new();
@@ -423,23 +471,10 @@ impl ChatWidget {
}
match kind {
RealtimeAudioDeviceKind::Microphone => {
self.stop_realtime_microphone();
RealtimeAudioDeviceKind::Microphone | RealtimeAudioDeviceKind::Speaker => {
self.stop_realtime_local_audio();
self.start_realtime_local_audio();
}
RealtimeAudioDeviceKind::Speaker => {
self.stop_realtime_speaker();
match crate::voice::RealtimeAudioPlayer::start(&self.config) {
Ok(player) => {
self.realtime_conversation.audio_player = Some(player);
}
Err(err) => {
self.fail_realtime_conversation(format!(
"Failed to start speaker output: {err}"
));
}
}
}
}
self.request_redraw();
}
@@ -453,6 +488,7 @@ impl ChatWidget {
fn stop_realtime_local_audio(&mut self) {
self.stop_realtime_microphone();
self.stop_realtime_speaker();
self.realtime_conversation.audio_processor = None;
}
#[cfg(target_os = "linux")]

View File

@@ -130,6 +130,8 @@ pub mod onboarding;
mod oss_selection;
mod pager_overlay;
pub mod public_widgets;
#[cfg(not(target_os = "linux"))]
mod realtime_audio_processing;
mod render;
mod resume_picker;
mod selection_list;
@@ -577,42 +579,15 @@ fn latest_session_lookup_params(
source_kinds: (!include_non_interactive)
.then_some(vec![ThreadSourceKind::Cli, ThreadSourceKind::VsCode]),
archived: Some(false),
cwd: cwd_filter.map(|cwd| cwd.to_string_lossy().to_string()),
cwd: if is_remote {
None
} else {
cwd_filter.map(|cwd| cwd.to_string_lossy().to_string())
},
search_term: None,
}
}
fn config_cwd_for_app_server_target(
cwd: Option<&Path>,
app_server_target: &AppServerTarget,
) -> std::io::Result<AbsolutePathBuf> {
if matches!(app_server_target, AppServerTarget::Remote { .. }) {
return AbsolutePathBuf::current_dir();
}
match cwd {
Some(path) => AbsolutePathBuf::from_absolute_path(path.canonicalize()?),
None => AbsolutePathBuf::current_dir(),
}
}
fn latest_session_cwd_filter<'a>(
remote_mode: bool,
remote_cwd_override: Option<&'a Path>,
config: &'a Config,
show_all: bool,
) -> Option<&'a Path> {
if show_all {
return None;
}
if remote_mode {
remote_cwd_override
} else {
Some(config.cwd.as_path())
}
}
pub async fn run_main(
mut cli: Cli,
arg0_paths: Arg0DispatchPaths,
@@ -631,10 +606,6 @@ pub async fn run_main(
auth_token: remote_auth_token.clone(),
})
.unwrap_or(AppServerTarget::Embedded);
let remote_cwd_override = cli
.cwd
.clone()
.filter(|_| matches!(app_server_target, AppServerTarget::Remote { .. }));
let (sandbox_mode, approval_policy) = if cli.full_auto {
(
Some(SandboxMode::WorkspaceWrite),
@@ -685,7 +656,10 @@ pub async fn run_main(
};
let cwd = cli.cwd.clone();
let config_cwd = config_cwd_for_app_server_target(cwd.as_deref(), &app_server_target)?;
let config_cwd = match cwd.as_deref() {
Some(path) => AbsolutePathBuf::from_absolute_path(path.canonicalize()?)?,
None => AbsolutePathBuf::current_dir()?,
};
#[allow(clippy::print_stderr)]
let config_toml = match load_config_as_toml_with_cli_overrides(
@@ -773,11 +747,7 @@ pub async fn run_main(
model,
approval_policy,
sandbox_mode,
cwd: if matches!(app_server_target, AppServerTarget::Remote { .. }) {
None
} else {
cwd
},
cwd,
model_provider: model_provider_override.clone(),
config_profile: cli.config_profile.clone(),
codex_self_exe: arg0_paths.codex_self_exe.clone(),
@@ -939,7 +909,6 @@ pub async fn run_main(
arg0_paths,
loader_overrides,
app_server_target,
remote_cwd_override,
config,
overrides,
cli_kv_overrides,
@@ -958,7 +927,6 @@ async fn run_ratatui_app(
arg0_paths: Arg0DispatchPaths,
loader_overrides: LoaderOverrides,
app_server_target: AppServerTarget,
remote_cwd_override: Option<PathBuf>,
initial_config: Config,
overrides: ConfigOverrides,
cli_kv_overrides: Vec<(String, toml::Value)>,
@@ -1017,21 +985,18 @@ async fn run_ratatui_app(
let needs_onboarding_app_server =
should_show_trust_screen_flag || initial_config.model_provider.requires_openai_auth;
let mut onboarding_app_server = if needs_onboarding_app_server {
Some(
AppServerSession::new(
start_app_server(
&app_server_target,
arg0_paths.clone(),
initial_config.clone(),
cli_kv_overrides.clone(),
loader_overrides.clone(),
cloud_requirements.clone(),
feedback.clone(),
)
.await?,
Some(AppServerSession::new(
start_app_server(
&app_server_target,
arg0_paths.clone(),
initial_config.clone(),
cli_kv_overrides.clone(),
loader_overrides.clone(),
cloud_requirements.clone(),
feedback.clone(),
)
.with_remote_cwd_override(remote_cwd_override.clone()),
)
.await?,
))
} else {
None
};
@@ -1134,21 +1099,18 @@ async fn run_ratatui_app(
|| cli.resume_picker
|| cli.fork_picker;
let mut session_lookup_app_server = if needs_app_server_session_lookup {
Some(
AppServerSession::new(
start_app_server(
&app_server_target,
arg0_paths.clone(),
config.clone(),
cli_kv_overrides.clone(),
loader_overrides.clone(),
cloud_requirements.clone(),
feedback.clone(),
)
.await?,
Some(AppServerSession::new(
start_app_server(
&app_server_target,
arg0_paths.clone(),
config.clone(),
cli_kv_overrides.clone(),
loader_overrides.clone(),
cloud_requirements.clone(),
feedback.clone(),
)
.with_remote_cwd_override(remote_cwd_override.clone()),
)
.await?,
))
} else {
None
};
@@ -1167,21 +1129,12 @@ async fn run_ratatui_app(
}
}
} else if cli.fork_last {
let filter_cwd = if remote_mode {
latest_session_cwd_filter(
remote_mode,
remote_cwd_override.as_deref(),
&config,
cli.fork_show_all,
)
} else {
None
};
let Some(app_server) = session_lookup_app_server.as_mut() else {
unreachable!("session lookup app server should be initialized for --fork --last");
};
match lookup_latest_session_target_with_app_server(
app_server, &config, filter_cwd, /*include_non_interactive*/ false,
app_server, &config, /*cwd_filter*/ None,
/*include_non_interactive*/ false,
)
.await?
{
@@ -1228,12 +1181,11 @@ async fn run_ratatui_app(
}
}
} else if cli.resume_last {
let filter_cwd = latest_session_cwd_filter(
remote_mode,
remote_cwd_override.as_deref(),
&config,
cli.resume_show_all,
);
let filter_cwd = if cli.resume_show_all {
None
} else {
Some(config.cwd.as_path())
};
let Some(app_server) = session_lookup_app_server.as_mut() else {
unreachable!("session lookup app server should be initialized for --resume --last");
};
@@ -1384,7 +1336,7 @@ async fn run_ratatui_app(
let app_result = App::run(
&mut tui,
AppServerSession::new(app_server).with_remote_cwd_override(remote_cwd_override),
AppServerSession::new(app_server),
config,
cli_kv_overrides.clone(),
overrides.clone(),
@@ -1845,9 +1797,12 @@ mod tests {
-> std::io::Result<()> {
let temp_dir = TempDir::new()?;
let config = build_config(&temp_dir).await?;
let cwd = temp_dir.path().join("project");
let params = latest_session_lookup_params(
/*is_remote*/ true, &config, /*cwd_filter*/ None,
/*is_remote*/ true,
&config,
Some(cwd.as_path()),
/*include_non_interactive*/ false,
);
@@ -1856,58 +1811,6 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn latest_session_lookup_params_keep_explicit_cwd_filter_for_remote_sessions()
-> std::io::Result<()> {
let temp_dir = TempDir::new()?;
let config = build_config(&temp_dir).await?;
let cwd = Path::new("repo/on/server");
let params = latest_session_lookup_params(
/*is_remote*/ true,
&config,
Some(cwd),
/*include_non_interactive*/ false,
);
assert_eq!(params.model_providers, None);
assert_eq!(params.cwd.as_deref(), Some("repo/on/server"));
Ok(())
}
#[test]
fn config_cwd_for_app_server_target_uses_current_dir_for_remote_sessions() -> std::io::Result<()>
{
let remote_only_cwd = if cfg!(windows) {
Path::new(r"C:\definitely\not\local\to\this\test")
} else {
Path::new("/definitely/not/local/to/this/test")
};
let target = AppServerTarget::Remote {
websocket_url: "ws://127.0.0.1:1234/".to_string(),
auth_token: None,
};
let config_cwd = config_cwd_for_app_server_target(Some(remote_only_cwd), &target)?;
assert_eq!(config_cwd, AbsolutePathBuf::current_dir()?);
Ok(())
}
#[test]
fn config_cwd_for_app_server_target_canonicalizes_embedded_cli_cwd() -> std::io::Result<()> {
let temp_dir = TempDir::new()?;
let target = AppServerTarget::Embedded;
let config_cwd = config_cwd_for_app_server_target(Some(temp_dir.path()), &target)?;
assert_eq!(
config_cwd,
AbsolutePathBuf::from_absolute_path(temp_dir.path().canonicalize()?)?
);
Ok(())
}
#[tokio::test]
async fn read_session_cwd_returns_none_without_sqlite_or_rollout_path() -> std::io::Result<()> {
let temp_dir = TempDir::new()?;

View File

@@ -0,0 +1,237 @@
use std::collections::VecDeque;
use std::sync::Arc;
use tracing::warn;
use webrtc_audio_processing::Config as AudioProcessingConfig;
use webrtc_audio_processing::Processor;
use webrtc_audio_processing::config::EchoCanceller;
use webrtc_audio_processing::config::GainController;
use webrtc_audio_processing::config::HighPassFilter;
use webrtc_audio_processing::config::NoiseSuppression;
use webrtc_audio_processing::config::NoiseSuppressionLevel;
use webrtc_audio_processing::config::Pipeline;
pub(crate) const AUDIO_PROCESSING_SAMPLE_RATE: u32 = 24_000;
pub(crate) const AUDIO_PROCESSING_CHANNELS: u16 = 1;
#[derive(Clone)]
pub(crate) struct RealtimeAudioProcessor {
processor: Arc<Processor>,
}
impl RealtimeAudioProcessor {
pub(crate) fn new() -> Result<Self, String> {
let processor = Processor::new(AUDIO_PROCESSING_SAMPLE_RATE)
.map_err(|err| format!("failed to initialize realtime audio processor: {err}"))?;
processor.set_config(AudioProcessingConfig {
pipeline: Pipeline {
multi_channel_capture: false,
multi_channel_render: false,
..Default::default()
},
echo_canceller: Some(EchoCanceller::Full {
stream_delay_ms: None,
}),
noise_suppression: Some(NoiseSuppression {
level: NoiseSuppressionLevel::High,
..Default::default()
}),
gain_controller: Some(GainController::GainController2(Default::default())),
high_pass_filter: Some(HighPassFilter::default()),
..Default::default()
});
processor.set_output_will_be_muted(true);
Ok(Self {
processor: Arc::new(processor),
})
}
pub(crate) fn capture_stage(
&self,
input_sample_rate: u32,
input_channels: u16,
) -> RealtimeCaptureAudioProcessor {
RealtimeCaptureAudioProcessor {
processor: self.processor.clone(),
input_sample_rate,
input_channels,
pending_samples: VecDeque::new(),
}
}
pub(crate) fn render_stage(
&self,
output_sample_rate: u32,
output_channels: u16,
) -> RealtimeRenderAudioProcessor {
RealtimeRenderAudioProcessor {
processor: self.processor.clone(),
output_sample_rate,
output_channels,
pending_samples: VecDeque::new(),
}
}
pub(crate) fn set_output_will_be_muted(&self, muted: bool) {
self.processor.set_output_will_be_muted(muted);
}
}
pub(crate) struct RealtimeCaptureAudioProcessor {
processor: Arc<Processor>,
input_sample_rate: u32,
input_channels: u16,
pending_samples: VecDeque<i16>,
}
impl RealtimeCaptureAudioProcessor {
pub(crate) fn process_samples(&mut self, samples: &[i16]) -> Vec<i16> {
let converted = convert_pcm16(
samples,
self.input_sample_rate,
self.input_channels,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
self.pending_samples.extend(converted);
let mut processed = Vec::new();
while self.pending_samples.len() >= self.processor.num_samples_per_frame() {
let mut frame = self.pop_pending_frame();
if let Err(err) = self.processor.process_capture_frame([frame.as_mut_slice()]) {
warn!("failed to process realtime capture audio: {err}");
continue;
}
processed.extend(frame.into_iter().map(f32_to_i16));
}
processed
}
fn pop_pending_frame(&mut self) -> Vec<f32> {
self.pending_samples
.drain(..self.processor.num_samples_per_frame())
.map(i16_to_f32)
.collect()
}
}
pub(crate) struct RealtimeRenderAudioProcessor {
processor: Arc<Processor>,
output_sample_rate: u32,
output_channels: u16,
pending_samples: VecDeque<i16>,
}
impl RealtimeRenderAudioProcessor {
pub(crate) fn process_samples(&mut self, samples: &[i16]) {
self.processor
.set_output_will_be_muted(samples.iter().all(|sample| *sample == 0));
let converted = convert_pcm16(
samples,
self.output_sample_rate,
self.output_channels,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
self.pending_samples.extend(converted);
while self.pending_samples.len() >= self.processor.num_samples_per_frame() {
let mut frame = self.pop_pending_frame();
if let Err(err) = self.processor.process_render_frame([frame.as_mut_slice()]) {
warn!("failed to process realtime render audio: {err}");
}
}
}
fn pop_pending_frame(&mut self) -> Vec<f32> {
self.pending_samples
.drain(..self.processor.num_samples_per_frame())
.map(i16_to_f32)
.collect()
}
}
pub(crate) fn convert_pcm16(
input: &[i16],
input_sample_rate: u32,
input_channels: u16,
output_sample_rate: u32,
output_channels: u16,
) -> Vec<i16> {
if input.is_empty() || input_channels == 0 || output_channels == 0 {
return Vec::new();
}
let in_channels = input_channels as usize;
let out_channels = output_channels as usize;
let in_frames = input.len() / in_channels;
if in_frames == 0 {
return Vec::new();
}
let out_frames = if input_sample_rate == output_sample_rate {
in_frames
} else {
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
as usize
};
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
for out_frame_idx in 0..out_frames {
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
0
} else {
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
};
let src_start = src_frame_idx.saturating_mul(in_channels);
let src = &input[src_start..src_start + in_channels];
match (in_channels, out_channels) {
(1, 1) => out.push(src[0]),
(1, n) => {
for _ in 0..n {
out.push(src[0]);
}
}
(n, 1) if n >= 2 => {
let sum: i32 = src.iter().map(|s| *s as i32).sum();
out.push((sum / (n as i32)) as i16);
}
(n, m) if n == m => out.extend_from_slice(src),
(n, m) if n > m => out.extend_from_slice(&src[..m]),
(n, m) => {
out.extend_from_slice(src);
let last = *src.last().unwrap_or(&0);
for _ in n..m {
out.push(last);
}
}
}
}
out
}
#[inline]
fn i16_to_f32(sample: i16) -> f32 {
(sample as f32) / (i16::MAX as f32)
}
#[inline]
fn f32_to_i16(sample: f32) -> i16 {
(sample.clamp(-1.0, 1.0) * i16::MAX as f32) as i16
}
#[cfg(test)]
mod tests {
use super::convert_pcm16;
use pretty_assertions::assert_eq;
#[test]
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
let converted = convert_pcm16(
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
);
assert_eq!(converted, vec![200, 700]);
}
}

View File

@@ -1,4 +1,10 @@
use crate::app_event_sender::AppEventSender;
use crate::realtime_audio_processing::AUDIO_PROCESSING_CHANNELS;
use crate::realtime_audio_processing::AUDIO_PROCESSING_SAMPLE_RATE;
use crate::realtime_audio_processing::RealtimeAudioProcessor;
use crate::realtime_audio_processing::RealtimeCaptureAudioProcessor;
use crate::realtime_audio_processing::RealtimeRenderAudioProcessor;
use crate::realtime_audio_processing::convert_pcm16;
use base64::Engine;
use codex_core::config::Config;
use codex_protocol::protocol::ConversationAudioParams;
@@ -23,7 +29,11 @@ pub struct VoiceCapture {
}
impl VoiceCapture {
pub fn start_realtime(config: &Config, tx: AppEventSender) -> Result<Self, String> {
pub fn start_realtime(
config: &Config,
tx: AppEventSender,
audio_processor: RealtimeAudioProcessor,
) -> Result<Self, String> {
let (device, config) = select_realtime_input_device_and_config(config)?;
let sample_rate = config.sample_rate().0;
@@ -34,9 +44,8 @@ impl VoiceCapture {
let stream = build_realtime_input_stream(
&device,
&config,
sample_rate,
channels,
tx,
audio_processor.capture_stage(sample_rate, channels),
last_peak.clone(),
)?;
stream
@@ -138,50 +147,76 @@ fn select_realtime_input_device_and_config(
fn build_realtime_input_stream(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
sample_rate: u32,
channels: u16,
tx: AppEventSender,
capture_processor: RealtimeCaptureAudioProcessor,
last_peak: Arc<AtomicU16>,
) -> Result<cpal::Stream, String> {
match config.sample_format() {
cpal::SampleFormat::F32 => device
.build_input_stream(
&config.clone().into(),
move |input: &[f32], _| {
let peak = peak_f32(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::I16 => device
.build_input_stream(
&config.clone().into(),
move |input: &[i16], _| {
let peak = peak_i16(input);
last_peak.store(peak, Ordering::Relaxed);
send_realtime_audio_chunk(&tx, input.to_vec(), sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::U16 => device
.build_input_stream(
&config.clone().into(),
move |input: &[u16], _| {
let mut samples = Vec::with_capacity(input.len());
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
last_peak.store(peak, Ordering::Relaxed);
send_realtime_audio_chunk(&tx, samples, sample_rate, channels);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}")),
cpal::SampleFormat::F32 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[f32], _| {
let peak = peak_f32(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = input.iter().copied().map(f32_to_i16).collect::<Vec<_>>();
let samples = capture_processor.process_samples(&samples);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
cpal::SampleFormat::I16 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[i16], _| {
let peak = peak_i16(input);
last_peak.store(peak, Ordering::Relaxed);
let samples = capture_processor.process_samples(input);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
cpal::SampleFormat::U16 => {
let mut capture_processor = capture_processor;
device
.build_input_stream(
&config.clone().into(),
move |input: &[u16], _| {
let mut samples = Vec::with_capacity(input.len());
let peak = convert_u16_to_i16_and_peak(input, &mut samples);
last_peak.store(peak, Ordering::Relaxed);
let samples = capture_processor.process_samples(&samples);
send_realtime_audio_chunk(
&tx,
samples,
AUDIO_PROCESSING_SAMPLE_RATE,
AUDIO_PROCESSING_CHANNELS,
);
},
move |err| error!("audio input error: {err}"),
None,
)
.map_err(|e| format!("failed to build input stream: {e}"))
}
_ => Err("unsupported input sample format".to_string()),
}
}
@@ -283,24 +318,34 @@ fn convert_u16_to_i16_and_peak(input: &[u16], out: &mut Vec<i16>) -> u16 {
pub(crate) struct RealtimeAudioPlayer {
_stream: cpal::Stream,
queue: Arc<Mutex<VecDeque<i16>>>,
audio_processor: RealtimeAudioProcessor,
output_sample_rate: u32,
output_channels: u16,
}
impl RealtimeAudioPlayer {
pub(crate) fn start(config: &Config) -> Result<Self, String> {
pub(crate) fn start(
config: &Config,
audio_processor: RealtimeAudioProcessor,
) -> Result<Self, String> {
let (device, config) =
crate::audio_device::select_configured_output_device_and_config(config)?;
let output_sample_rate = config.sample_rate().0;
let output_channels = config.channels();
let queue = Arc::new(Mutex::new(VecDeque::new()));
let stream = build_output_stream(&device, &config, Arc::clone(&queue))?;
let stream = build_output_stream(
&device,
&config,
Arc::clone(&queue),
audio_processor.render_stage(output_sample_rate, output_channels),
)?;
stream
.play()
.map_err(|e| format!("failed to start output stream: {e}"))?;
Ok(Self {
_stream: stream,
queue,
audio_processor,
output_sample_rate,
output_channels,
})
@@ -336,6 +381,8 @@ impl RealtimeAudioPlayer {
.map_err(|_| "failed to lock output audio queue".to_string())?;
// TODO(aibrahim): Cap or trim this queue if we observe producer bursts outrunning playback.
guard.extend(converted);
drop(guard);
self.audio_processor.set_output_will_be_muted(false);
Ok(())
}
@@ -343,6 +390,7 @@ impl RealtimeAudioPlayer {
if let Ok(mut guard) = self.queue.lock() {
guard.clear();
}
self.audio_processor.set_output_will_be_muted(true);
}
}
@@ -350,140 +398,94 @@ fn build_output_stream(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
queue: Arc<Mutex<VecDeque<i16>>>,
render_processor: RealtimeRenderAudioProcessor,
) -> Result<cpal::Stream, String> {
let config_any: cpal::StreamConfig = config.clone().into();
match config.sample_format() {
cpal::SampleFormat::F32 => device
.build_output_stream(
&config_any,
move |output: &mut [f32], _| fill_output_f32(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build f32 output stream: {e}")),
cpal::SampleFormat::I16 => device
.build_output_stream(
&config_any,
move |output: &mut [i16], _| fill_output_i16(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build i16 output stream: {e}")),
cpal::SampleFormat::U16 => device
.build_output_stream(
&config_any,
move |output: &mut [u16], _| fill_output_u16(output, &queue),
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build u16 output stream: {e}")),
cpal::SampleFormat::F32 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [f32], _| {
fill_output_f32(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build f32 output stream: {e}"))
}
cpal::SampleFormat::I16 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [i16], _| {
fill_output_i16(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build i16 output stream: {e}"))
}
cpal::SampleFormat::U16 => {
let mut render_processor = render_processor;
device
.build_output_stream(
&config_any,
move |output: &mut [u16], _| {
fill_output_u16(output, &queue, &mut render_processor)
},
move |err| error!("audio output error: {err}"),
None,
)
.map_err(|e| format!("failed to build u16 output stream: {e}"))
}
other => Err(format!("unsupported output sample format: {other:?}")),
}
}
fn fill_output_i16(output: &mut [i16], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
*sample = guard.pop_front().unwrap_or(0);
}
return;
}
output.fill(0);
fn fill_output_i16(
output: &mut [i16],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
output.copy_from_slice(&samples);
render_processor.process_samples(output);
}
fn fill_output_f32(output: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
let v = guard.pop_front().unwrap_or(0);
*sample = (v as f32) / (i16::MAX as f32);
}
return;
fn fill_output_f32(
output: &mut [f32],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
*output_sample = (*sample as f32) / (i16::MAX as f32);
}
output.fill(0.0);
render_processor.process_samples(&samples);
}
fn fill_output_u16(output: &mut [u16], queue: &Arc<Mutex<VecDeque<i16>>>) {
if let Ok(mut guard) = queue.lock() {
for sample in output {
let v = guard.pop_front().unwrap_or(0);
*sample = (v as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
}
return;
fn fill_output_u16(
output: &mut [u16],
queue: &Arc<Mutex<VecDeque<i16>>>,
render_processor: &mut RealtimeRenderAudioProcessor,
) {
let samples = drain_output_samples(output.len(), queue);
for (output_sample, sample) in output.iter_mut().zip(samples.iter()) {
*output_sample = (*sample as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
}
output.fill(32768);
render_processor.process_samples(&samples);
}
fn convert_pcm16(
input: &[i16],
input_sample_rate: u32,
input_channels: u16,
output_sample_rate: u32,
output_channels: u16,
) -> Vec<i16> {
if input.is_empty() || input_channels == 0 || output_channels == 0 {
return Vec::new();
}
let in_channels = input_channels as usize;
let out_channels = output_channels as usize;
let in_frames = input.len() / in_channels;
if in_frames == 0 {
return Vec::new();
}
let out_frames = if input_sample_rate == output_sample_rate {
in_frames
} else {
(((in_frames as u64) * (output_sample_rate as u64)) / (input_sample_rate as u64)).max(1)
as usize
fn drain_output_samples(output_len: usize, queue: &Arc<Mutex<VecDeque<i16>>>) -> Vec<i16> {
let mut samples = vec![0; output_len];
let Ok(mut guard) = queue.lock() else {
return samples;
};
let mut out = Vec::with_capacity(out_frames.saturating_mul(out_channels));
for out_frame_idx in 0..out_frames {
let src_frame_idx = if out_frames <= 1 || in_frames <= 1 {
0
} else {
((out_frame_idx as u64) * ((in_frames - 1) as u64) / ((out_frames - 1) as u64)) as usize
};
let src_start = src_frame_idx.saturating_mul(in_channels);
let src = &input[src_start..src_start + in_channels];
match (in_channels, out_channels) {
(1, 1) => out.push(src[0]),
(1, n) => {
for _ in 0..n {
out.push(src[0]);
}
}
(n, 1) if n >= 2 => {
let sum: i32 = src.iter().map(|s| *s as i32).sum();
out.push((sum / (n as i32)) as i16);
}
(n, m) if n == m => out.extend_from_slice(src),
(n, m) if n > m => out.extend_from_slice(&src[..m]),
(n, m) => {
out.extend_from_slice(src);
let last = *src.last().unwrap_or(&0);
for _ in n..m {
out.push(last);
}
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::convert_pcm16;
use pretty_assertions::assert_eq;
#[test]
fn convert_pcm16_downmixes_and_resamples_for_model_input() {
let input = vec![100, 300, 200, 400, 500, 700, 600, 800];
let converted = convert_pcm16(
&input, /*input_sample_rate*/ 48_000, /*input_channels*/ 2,
/*output_sample_rate*/ 24_000, /*output_channels*/ 1,
);
assert_eq!(converted, vec![200, 700]);
for sample in &mut samples {
*sample = guard.pop_front().unwrap_or(0);
}
samples
}