diff --git a/codex-rs/app-server/src/in_process.rs b/codex-rs/app-server/src/in_process.rs index 35b92548b8..63d55cd0a8 100644 --- a/codex-rs/app-server/src/in_process.rs +++ b/codex-rs/app-server/src/in_process.rs @@ -444,6 +444,7 @@ async fn start_uninitialized(args: InProcessStartArgs) -> IoResult + Send + Sync + 'static>; -fn configured_thread_config_loader(config: &Config) -> Arc { - match config.experimental_thread_config_endpoint.as_deref() { - Some(endpoint) => Arc::new(RemoteThreadConfigLoader::new(endpoint)), - None => Arc::new(NoopThreadConfigLoader), - } -} - /// Control-plane messages from the processor/transport side to the outbound router task. /// /// `run_main_with_transport` now uses two loops/tasks: @@ -373,15 +366,17 @@ pub enum PluginStartupTasks { Skip, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct AppServerRuntimeOptions { pub plugin_startup_tasks: PluginStartupTasks, + pub core_api_options: CoreApiOptions, } impl Default for AppServerRuntimeOptions { fn default() -> Self { Self { plugin_startup_tasks: PluginStartupTasks::Start, + core_api_options: CoreApiOptions::default(), } } } @@ -456,7 +451,10 @@ pub async fn run_main_with_transport_options( .await { Ok(config) => { - let discovered_thread_config_loader = configured_thread_config_loader(&config); + let discovered_thread_config_loader = thread_config_loader_from_config_with_options( + &config, + &runtime_options.core_api_options, + ); config_manager .replace_thread_config_loader(Arc::clone(&discovered_thread_config_loader)); let auth_manager = @@ -770,6 +768,7 @@ pub async fn run_main_with_transport_options( rpc_transport: analytics_rpc_transport(&transport), remote_control_handle: Some(remote_control_handle.clone()), plugin_startup_tasks: runtime_options.plugin_startup_tasks, + core_api_options: runtime_options.core_api_options.clone(), })); let mut thread_created_rx = processor.thread_created_receiver(); let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count(); diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index a3c3877fdc..10c6f198c0 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -60,10 +60,11 @@ use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::experimental_required_message; use codex_arg0::Arg0DispatchPaths; use codex_chatgpt::workspace_settings; +use codex_core::CoreApiOptions; use codex_core::ThreadManager; use codex_core::agent_graph_store_from_state_db; use codex_core::config::Config; -use codex_core::thread_store_from_config; +use codex_core::thread_store_from_config_with_options; use codex_exec_server::EnvironmentManager; use codex_feedback::CodexFeedback; use codex_login::AuthManager; @@ -263,6 +264,7 @@ pub(crate) struct MessageProcessorArgs { pub(crate) rpc_transport: AppServerRpcTransport, pub(crate) remote_control_handle: Option, pub(crate) plugin_startup_tasks: crate::PluginStartupTasks, + pub(crate) core_api_options: CoreApiOptions, } impl MessageProcessor { @@ -286,6 +288,7 @@ impl MessageProcessor { rpc_transport, remote_control_handle, plugin_startup_tasks, + core_api_options, } = args; auth_manager.set_external_auth(Arc::new(ExternalAuthRefreshBridge { outgoing: outgoing.clone(), @@ -293,7 +296,11 @@ impl MessageProcessor { // The thread store is intentionally process-scoped. Config reloads can // affect per-thread behavior, but they must not move newly started, // resumed, or forked threads to a different persistence backend/root. - let thread_store = thread_store_from_config(config.as_ref(), state_db.clone()); + let thread_store = thread_store_from_config_with_options( + config.as_ref(), + state_db.clone(), + &core_api_options, + ); let agent_graph_store = agent_graph_store_from_state_db(state_db.clone()); let thread_manager = Arc::new(ThreadManager::new( config.as_ref(), diff --git a/codex-rs/app-server/src/message_processor_tracing_tests.rs b/codex-rs/app-server/src/message_processor_tracing_tests.rs index 27e2c2f473..3581a5ee53 100644 --- a/codex-rs/app-server/src/message_processor_tracing_tests.rs +++ b/codex-rs/app-server/src/message_processor_tracing_tests.rs @@ -302,6 +302,7 @@ async fn build_test_processor( rpc_transport: AppServerRpcTransport::Stdio, remote_control_handle: None, plugin_startup_tasks: crate::PluginStartupTasks::Start, + core_api_options: codex_core::CoreApiOptions::default(), })); (processor, outgoing_rx) } diff --git a/codex-rs/config/src/lib.rs b/codex-rs/config/src/lib.rs index e88c736db0..c28fb40488 100644 --- a/codex-rs/config/src/lib.rs +++ b/codex-rs/config/src/lib.rs @@ -31,6 +31,8 @@ pub use cloud_requirements::CloudRequirementsLoadError; pub use cloud_requirements::CloudRequirementsLoadErrorCode; pub use cloud_requirements::CloudRequirementsLoader; pub use codex_app_server_protocol::ConfigLayerSource; +pub use codex_protocol::CODEX_CORE_IDENTITY_HEADER; +pub use codex_protocol::OpaqueIdentity; pub use codex_utils_absolute_path::AbsolutePathBuf; pub use config_requirements::AppRequirementToml; pub use config_requirements::AppsRequirementsToml; diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index 7b7feacec5..0868441406 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -6,6 +6,8 @@ use std::time::Duration; use async_trait::async_trait; use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::WireApi; +use codex_protocol::CODEX_CORE_IDENTITY_HEADER; +use codex_protocol::OpaqueIdentity; use codex_protocol::config_types::ModelProviderAuthInfo; use codex_utils_absolute_path::AbsolutePathBuf; @@ -17,29 +19,75 @@ use super::ThreadConfigLoader; use super::ThreadConfigSource; use super::UserThreadConfig; use proto::thread_config_loader_client::ThreadConfigLoaderClient; +use tonic::codegen::InterceptedService; +use tonic::metadata::BinaryMetadataValue; +use tonic::service::Interceptor; +use tonic::transport::Channel; +use tonic::transport::Endpoint; #[path = "proto/codex.thread_config.v1.rs"] mod proto; const REMOTE_THREAD_CONFIG_LOAD_TIMEOUT: Duration = Duration::from_secs(5); +#[derive(Clone, Debug)] +struct IdentityInterceptor { + identity: Option, +} + +impl Interceptor for IdentityInterceptor { + fn call( + &mut self, + mut request: tonic::Request<()>, + ) -> Result, tonic::Status> { + if let Some(identity) = &self.identity { + request.metadata_mut().insert_bin( + CODEX_CORE_IDENTITY_HEADER, + BinaryMetadataValue::from_bytes(identity.as_bytes()), + ); + } + Ok(request) + } +} + +type RemoteThreadConfigLoaderClient = + ThreadConfigLoaderClient>; + /// gRPC-backed [`ThreadConfigLoader`] implementation. #[derive(Clone, Debug)] pub struct RemoteThreadConfigLoader { endpoint: String, + identity: Option, } impl RemoteThreadConfigLoader { pub fn new(endpoint: impl Into) -> Self { Self { endpoint: endpoint.into(), + identity: None, } } - async fn client( - &self, - ) -> Result, ThreadConfigLoadError> { - ThreadConfigLoaderClient::connect(self.endpoint.clone()) + pub fn new_with_identity( + endpoint: impl Into, + identity: Option, + ) -> Self { + Self { + endpoint: endpoint.into(), + identity, + } + } + + async fn client(&self) -> Result { + let channel = Endpoint::new(self.endpoint.clone()) + .map_err(|err| { + ThreadConfigLoadError::new( + ThreadConfigLoadErrorCode::RequestFailed, + /*status_code*/ None, + format!("invalid remote thread config loader endpoint: {err}"), + ) + })? + .connect() .await .map_err(|err| { ThreadConfigLoadError::new( @@ -47,7 +95,13 @@ impl RemoteThreadConfigLoader { /*status_code*/ None, format!("failed to connect to remote thread config loader: {err}"), ) - }) + })?; + Ok(ThreadConfigLoaderClient::with_interceptor( + channel, + IdentityInterceptor { + identity: self.identity.clone(), + }, + )) } } @@ -299,9 +353,13 @@ mod tests { use std::collections::BTreeMap; use std::collections::HashMap; use std::num::NonZeroU64; + use std::sync::Arc; + use std::sync::Mutex; use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::WireApi; + use codex_protocol::CODEX_CORE_IDENTITY_HEADER; + use codex_protocol::OpaqueIdentity; use codex_protocol::config_types::ModelProviderAuthInfo; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; @@ -319,6 +377,7 @@ mod tests { struct TestServer { sources: Vec, expected_cwd: String, + captured_identity: Option>>>>, } #[tonic::async_trait] @@ -327,6 +386,16 @@ mod tests { &self, request: Request, ) -> Result, Status> { + if let Some(captured_identity) = &self.captured_identity { + let identity = request + .metadata() + .get_bin(CODEX_CORE_IDENTITY_HEADER) + .and_then(|value| value.to_bytes().ok()) + .map(|value| value.to_vec()); + *captured_identity + .lock() + .expect("captured identity mutex poisoned") = identity; + } assert_eq!( request.into_inner(), proto::LoadThreadConfigRequest { @@ -355,6 +424,7 @@ mod tests { .add_service(ThreadConfigLoaderServer::new(TestServer { sources: proto_sources(), expected_cwd, + captured_identity: None, })) .serve_with_incoming_shutdown( tokio_stream::wrappers::TcpListenerStream::new(listener), @@ -379,6 +449,57 @@ mod tests { assert_eq!(loaded.expect("load thread config"), expected_sources()); } + #[tokio::test] + async fn load_thread_config_forwards_identity_as_metadata() { + let cwd = workspace_dir().join("project"); + let expected_cwd = cwd.to_string_lossy().into_owned(); + let captured_identity = Arc::new(Mutex::new(None)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("test server addr"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server_identity = captured_identity.clone(); + let server = tokio::spawn(async move { + Server::builder() + .add_service(ThreadConfigLoaderServer::new(TestServer { + sources: proto_sources(), + expected_cwd, + captured_identity: Some(server_identity), + })) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + async { + let _ = shutdown_rx.await; + }, + ) + .await + }); + + let loader = RemoteThreadConfigLoader::new_with_identity( + format!("http://{addr}"), + Some(OpaqueIdentity::from_bytes(b"tenant-key-\x00\xff".to_vec())), + ); + let loaded = loader + .load(ThreadConfigContext { + thread_id: Some("thread-1".to_string()), + cwd: Some(cwd), + }) + .await; + + let _ = shutdown_tx.send(()); + server.await.expect("join server").expect("server"); + + assert_eq!(loaded.expect("load thread config"), expected_sources()); + assert_eq!( + captured_identity + .lock() + .expect("captured identity mutex poisoned") + .as_deref(), + Some(&b"tenant-key-\x00\xff"[..]) + ); + } + #[test] fn load_thread_config_request_sets_timeout() { let request = load_thread_config_request(ThreadConfigContext::default()); diff --git a/codex-rs/core-api/src/lib.rs b/codex-rs/core-api/src/lib.rs index 9af459830a..4367a52fa9 100644 --- a/codex-rs/core-api/src/lib.rs +++ b/codex-rs/core-api/src/lib.rs @@ -8,6 +8,9 @@ pub use codex_app_server_protocol::item_event_to_server_notification; pub use codex_arg0::Arg0DispatchPaths; pub use codex_arg0::arg0_dispatch_or_else; pub use codex_config::ConfigLayerStack; +pub use codex_config::NoopThreadConfigLoader; +pub use codex_config::RemoteThreadConfigLoader; +pub use codex_config::ThreadConfigLoader; pub use codex_config::config_toml::ProjectConfig; pub use codex_config::config_toml::RealtimeAudioConfig; pub use codex_config::config_toml::RealtimeConfig; @@ -24,6 +27,7 @@ pub use codex_config::types::TuiKeymap; pub use codex_config::types::TuiNotificationSettings; pub use codex_config::types::UriBasedFileOpener; pub use codex_core::CodexThread; +pub use codex_core::CoreApiOptions; pub use codex_core::ForkSnapshot; pub use codex_core::McpManager; pub use codex_core::NewThread; @@ -44,7 +48,10 @@ pub use codex_core::init_state_db; pub use codex_core::init_state_db_from_config; pub use codex_core::resolve_installation_id; pub use codex_core::skills::SkillsManager; +pub use codex_core::thread_config_loader_from_config; +pub use codex_core::thread_config_loader_from_config_with_options; pub use codex_core::thread_store_from_config; +pub use codex_core::thread_store_from_config_with_options; pub use codex_exec_server::EnvironmentManager; pub use codex_exec_server::EnvironmentManagerArgs; pub use codex_exec_server::ExecServerRuntimePaths; @@ -56,6 +63,8 @@ pub use codex_model_provider_info::OPENAI_PROVIDER_ID; pub use codex_model_provider_info::built_in_model_providers; pub use codex_models_manager::manager::RefreshStrategy; pub use codex_models_manager::manager::SharedModelsManager; +pub use codex_protocol::CODEX_CORE_IDENTITY_HEADER; +pub use codex_protocol::OpaqueIdentity; pub use codex_protocol::ThreadId; pub use codex_protocol::config_types::AltScreenMode; pub use codex_protocol::config_types::ApprovalsReviewer; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 4cdfc5ea23..41930527b3 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -117,6 +117,7 @@ pub mod review_prompts; mod thread_manager; pub(crate) mod web_search; pub(crate) mod windows_sandbox_read_grants; +pub use thread_manager::CoreApiOptions; pub use thread_manager::ForkSnapshot; pub use thread_manager::NewThread; pub use thread_manager::StartThreadOptions; @@ -125,7 +126,10 @@ pub use thread_manager::ThreadShutdownReport; pub use thread_manager::agent_graph_store_from_state_db; pub use thread_manager::build_models_manager; pub use thread_manager::init_state_db_from_config; +pub use thread_manager::thread_config_loader_from_config; +pub use thread_manager::thread_config_loader_from_config_with_options; pub use thread_manager::thread_store_from_config; +pub use thread_manager::thread_store_from_config_with_options; pub use web_search::web_search_action_detail; pub use web_search::web_search_detail; pub use windows_sandbox_read_grants::grant_read_root_non_elevated; diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 331ed3ca15..f3b84cf0db 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -24,6 +24,9 @@ use codex_agent_graph_store::LocalAgentGraphStore; use codex_analytics::AnalyticsEventsClient; use codex_app_server_protocol::ThreadHistoryBuilder; use codex_app_server_protocol::TurnStatus; +use codex_config::NoopThreadConfigLoader; +use codex_config::RemoteThreadConfigLoader; +use codex_config::ThreadConfigLoader; use codex_core_plugins::PluginsManager; use codex_exec_server::EnvironmentManager; use codex_login::AuthManager; @@ -33,6 +36,7 @@ use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::OPENAI_PROVIDER_ID; use codex_models_manager::manager::RefreshStrategy; use codex_models_manager::manager::SharedModelsManager; +use codex_protocol::OpaqueIdentity; use codex_protocol::ThreadId; use codex_protocol::config_types::CollaborationModeMask; use codex_protocol::error::CodexErr; @@ -228,6 +232,12 @@ pub struct StartThreadOptions { pub environments: Vec, } +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct CoreApiOptions { + /// Opaque caller identity forwarded to remote core contract implementations. + pub opaque_identity: Option, +} + pub(crate) struct ResumeThreadWithHistoryOptions { pub(crate) config: Config, pub(crate) initial_history: InitialHistory, @@ -276,16 +286,44 @@ pub async fn init_state_db_from_config(config: &Config) -> Option } pub fn thread_store_from_config(config: &Config, state_db: StateDbHandle) -> Arc { + thread_store_from_config_with_options(config, state_db, &CoreApiOptions::default()) +} + +pub fn thread_store_from_config_with_options( + config: &Config, + state_db: StateDbHandle, + options: &CoreApiOptions, +) -> Arc { match &config.experimental_thread_store { ThreadStoreConfig::Local => Arc::new(LocalThreadStore::new( LocalThreadStoreConfig::from_config(config), state_db, )), - ThreadStoreConfig::Remote { endpoint } => Arc::new(RemoteThreadStore::new(endpoint)), + ThreadStoreConfig::Remote { endpoint } => Arc::new(RemoteThreadStore::new_with_identity( + endpoint, + options.opaque_identity.clone(), + )), ThreadStoreConfig::InMemory { id } => InMemoryThreadStore::for_id(id), } } +pub fn thread_config_loader_from_config(config: &Config) -> Arc { + thread_config_loader_from_config_with_options(config, &CoreApiOptions::default()) +} + +pub fn thread_config_loader_from_config_with_options( + config: &Config, + options: &CoreApiOptions, +) -> Arc { + match config.experimental_thread_config_endpoint.as_deref() { + Some(endpoint) => Arc::new(RemoteThreadConfigLoader::new_with_identity( + endpoint, + options.opaque_identity.clone(), + )), + None => Arc::new(NoopThreadConfigLoader), + } +} + pub fn agent_graph_store_from_state_db(state_db: StateDbHandle) -> Arc { Arc::new(LocalAgentGraphStore::new(state_db)) } diff --git a/codex-rs/protocol/src/lib.rs b/codex-rs/protocol/src/lib.rs index a945b1a927..1e973b6a23 100644 --- a/codex-rs/protocol/src/lib.rs +++ b/codex-rs/protocol/src/lib.rs @@ -1,10 +1,13 @@ pub mod account; mod agent_path; pub mod auth; +mod opaque_identity; mod session_id; mod thread_id; mod tool_name; pub use agent_path::AgentPath; +pub use opaque_identity::CODEX_CORE_IDENTITY_HEADER; +pub use opaque_identity::OpaqueIdentity; pub use session_id::SessionId; pub use thread_id::ThreadId; pub use tool_name::ToolName; diff --git a/codex-rs/protocol/src/opaque_identity.rs b/codex-rs/protocol/src/opaque_identity.rs new file mode 100644 index 0000000000..dd0d7aeb72 --- /dev/null +++ b/codex-rs/protocol/src/opaque_identity.rs @@ -0,0 +1,71 @@ +use std::ffi::OsString; + +/// Binary gRPC metadata key used to forward the caller's opaque identity to +/// remote core contract implementations. +pub const CODEX_CORE_IDENTITY_HEADER: &str = "x-codex-core-identity-bin"; + +/// Opaque identity supplied by a Codex core API caller. +/// +/// Codex core treats this as uninterpreted bytes and only forwards it to remote +/// contract implementations that need to perform their own authorization. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct OpaqueIdentity { + bytes: Vec, +} + +impl OpaqueIdentity { + pub fn from_bytes(bytes: impl Into>) -> Self { + Self { + bytes: bytes.into(), + } + } + + pub fn from_os_string(value: OsString) -> Self { + Self::from_bytes(os_string_to_bytes(value)) + } + + pub fn as_bytes(&self) -> &[u8] { + &self.bytes + } + + pub fn into_bytes(self) -> Vec { + self.bytes + } +} + +#[cfg(unix)] +fn os_string_to_bytes(value: OsString) -> Vec { + use std::os::unix::ffi::OsStrExt; + + value.as_os_str().as_bytes().to_vec() +} + +#[cfg(not(unix))] +fn os_string_to_bytes(value: OsString) -> Vec { + value.to_string_lossy().into_owned().into_bytes() +} + +#[cfg(test)] +mod tests { + use super::OpaqueIdentity; + use pretty_assertions::assert_eq; + + #[test] + fn opaque_identity_preserves_bytes() { + let identity = OpaqueIdentity::from_bytes(b"tenant-key-\x00\xff".to_vec()); + + assert_eq!(identity.as_bytes(), &b"tenant-key-\x00\xff"[..]); + } + + #[cfg(unix)] + #[test] + fn opaque_identity_preserves_unix_argv_bytes() { + use std::ffi::OsString; + use std::os::unix::ffi::OsStringExt; + + let identity = + OpaqueIdentity::from_os_string(OsString::from_vec(b"tenant-key-\xff".to_vec())); + + assert_eq!(identity.as_bytes(), &b"tenant-key-\xff"[..]); + } +} diff --git a/codex-rs/thread-store/src/lib.rs b/codex-rs/thread-store/src/lib.rs index 52b7f5ea1f..9d54666679 100644 --- a/codex-rs/thread-store/src/lib.rs +++ b/codex-rs/thread-store/src/lib.rs @@ -12,6 +12,8 @@ mod remote; mod store; mod types; +pub use codex_protocol::CODEX_CORE_IDENTITY_HEADER; +pub use codex_protocol::OpaqueIdentity; pub use error::ThreadStoreError; pub use error::ThreadStoreResult; pub use in_memory::InMemoryThreadStore; diff --git a/codex-rs/thread-store/src/remote/list_threads.rs b/codex-rs/thread-store/src/remote/list_threads.rs index cf562497f4..95d856c169 100644 --- a/codex-rs/thread-store/src/remote/list_threads.rs +++ b/codex-rs/thread-store/src/remote/list_threads.rs @@ -66,7 +66,11 @@ pub(super) async fn list_threads( #[cfg(test)] mod tests { use std::path::PathBuf; + use std::sync::Arc; + use std::sync::Mutex; + use codex_protocol::CODEX_CORE_IDENTITY_HEADER; + use codex_protocol::OpaqueIdentity; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::protocol::SessionSource; use pretty_assertions::assert_eq; @@ -83,7 +87,9 @@ mod tests { use crate::ThreadStore; #[derive(Default)] - struct TestServer; + struct TestServer { + captured_identity: Option>>>>, + } #[tonic::async_trait] impl thread_store_server::ThreadStore for TestServer { @@ -91,6 +97,16 @@ mod tests { &self, request: Request, ) -> Result, Status> { + if let Some(captured_identity) = &self.captured_identity { + let identity = request + .metadata() + .get_bin(CODEX_CORE_IDENTITY_HEADER) + .and_then(|value| value.to_bytes().ok()) + .map(|value| value.to_vec()); + *captured_identity + .lock() + .expect("captured identity mutex poisoned") = identity; + } let request = request.into_inner(); assert_eq!(request.page_size, 2); assert_eq!(request.cursor.as_deref(), Some("cursor-1")); @@ -171,7 +187,7 @@ mod tests { let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); let server = tokio::spawn(async move { Server::builder() - .add_service(ThreadStoreServer::new(TestServer)) + .add_service(ThreadStoreServer::new(TestServer::default())) .serve_with_incoming_shutdown( tokio_stream::wrappers::TcpListenerStream::new(listener), async { @@ -226,6 +242,61 @@ mod tests { server.await.expect("join server").expect("server"); } + #[tokio::test] + async fn list_threads_forwards_identity_as_metadata() { + let captured_identity = Arc::new(Mutex::new(None)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("test server addr"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server_identity = captured_identity.clone(); + let server = tokio::spawn(async move { + Server::builder() + .add_service(ThreadStoreServer::new(TestServer { + captured_identity: Some(server_identity), + })) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + async { + let _ = shutdown_rx.await; + }, + ) + .await + }); + + let store = RemoteThreadStore::new_with_identity( + format!("http://{addr}"), + Some(OpaqueIdentity::from_bytes(b"tenant-key-\x00\xff".to_vec())), + ); + store + .list_threads(ListThreadsParams { + page_size: 2, + cursor: Some("cursor-1".to_string()), + sort_key: ThreadSortKey::UpdatedAt, + sort_direction: crate::SortDirection::Desc, + allowed_sources: vec![SessionSource::Cli], + model_providers: Some(vec!["openai".to_string()]), + cwd_filters: Some(vec![PathBuf::from("/workspace")]), + archived: true, + search_term: Some("needle".to_string()), + use_state_db_only: true, + }) + .await + .expect("list threads"); + + assert_eq!( + captured_identity + .lock() + .expect("captured identity mutex poisoned") + .as_deref(), + Some(&b"tenant-key-\x00\xff"[..]) + ); + + let _ = shutdown_tx.send(()); + server.await.expect("join server").expect("server"); + } + #[test] fn stored_thread_proto_roundtrips_through_domain_type() { let thread = proto::StoredThread { diff --git a/codex-rs/thread-store/src/remote/mod.rs b/codex-rs/thread-store/src/remote/mod.rs index 013b74c933..c29945a447 100644 --- a/codex-rs/thread-store/src/remote/mod.rs +++ b/codex-rs/thread-store/src/remote/mod.rs @@ -2,6 +2,8 @@ mod helpers; mod list_threads; use async_trait::async_trait; +use codex_protocol::CODEX_CORE_IDENTITY_HEADER; +use codex_protocol::OpaqueIdentity; use codex_protocol::ThreadId; use crate::AppendThreadItemsParams; @@ -20,10 +22,37 @@ use crate::ThreadStoreError; use crate::ThreadStoreResult; use crate::UpdateThreadMetadataParams; use proto::thread_store_client::ThreadStoreClient; +use tonic::codegen::InterceptedService; +use tonic::metadata::BinaryMetadataValue; +use tonic::service::Interceptor; +use tonic::transport::Channel; +use tonic::transport::Endpoint; #[path = "proto/codex.thread_store.v1.rs"] mod proto; +#[derive(Clone, Debug)] +struct IdentityInterceptor { + identity: Option, +} + +impl Interceptor for IdentityInterceptor { + fn call( + &mut self, + mut request: tonic::Request<()>, + ) -> Result, tonic::Status> { + if let Some(identity) = &self.identity { + request.metadata_mut().insert_bin( + CODEX_CORE_IDENTITY_HEADER, + BinaryMetadataValue::from_bytes(identity.as_bytes()), + ); + } + Ok(request) + } +} + +type RemoteThreadStoreClient = ThreadStoreClient>; + /// gRPC-backed [`ThreadStore`] implementation for deployments whose durable thread data lives /// outside the app-server process. /// @@ -33,21 +62,43 @@ mod proto; #[derive(Clone, Debug)] pub struct RemoteThreadStore { endpoint: String, + identity: Option, } impl RemoteThreadStore { pub fn new(endpoint: impl Into) -> Self { Self { endpoint: endpoint.into(), + identity: None, } } - async fn client(&self) -> ThreadStoreResult> { - ThreadStoreClient::connect(self.endpoint.clone()) + pub fn new_with_identity( + endpoint: impl Into, + identity: Option, + ) -> Self { + Self { + endpoint: endpoint.into(), + identity, + } + } + + async fn client(&self) -> ThreadStoreResult { + let channel = Endpoint::new(self.endpoint.clone()) + .map_err(|err| ThreadStoreError::InvalidRequest { + message: format!("invalid remote thread store endpoint: {err}"), + })? + .connect() .await .map_err(|err| ThreadStoreError::Internal { message: format!("failed to connect to remote thread store: {err}"), - }) + })?; + Ok(ThreadStoreClient::with_interceptor( + channel, + IdentityInterceptor { + identity: self.identity.clone(), + }, + )) } }