From 79908b64a199d3bb46db499f525764a042ec5e6c Mon Sep 17 00:00:00 2001 From: celia-oai Date: Thu, 30 Apr 2026 20:08:54 -0700 Subject: [PATCH] changes --- codex-rs/aws-auth/src/lib.rs | 13 ++++++ codex-rs/core/src/client.rs | 9 ++++ codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/session/session.rs | 1 + .../core/src/session_startup_auth_prewarm.rs | 16 +++++++ .../model-provider/src/amazon_bedrock/auth.rs | 23 +++++++--- .../src/amazon_bedrock/mantle.rs | 16 +++---- .../model-provider/src/amazon_bedrock/mod.rs | 46 +++++++++++++++++-- codex-rs/model-provider/src/provider.rs | 10 ++++ 9 files changed, 114 insertions(+), 21 deletions(-) create mode 100644 codex-rs/core/src/session_startup_auth_prewarm.rs diff --git a/codex-rs/aws-auth/src/lib.rs b/codex-rs/aws-auth/src/lib.rs index 13425f2297..b6c8564746 100644 --- a/codex-rs/aws-auth/src/lib.rs +++ b/codex-rs/aws-auth/src/lib.rs @@ -97,6 +97,11 @@ impl AwsAuthContext { &self.service } + pub async fn preload_credentials(&self) -> Result<(), AwsAuthError> { + let _ = self.credentials_provider.provide_credentials().await?; + Ok(()) + } + pub async fn sign(&self, request: AwsRequestToSign) -> Result { self.sign_at(request, SystemTime::now()).await } @@ -202,6 +207,14 @@ mod tests { assert!(signing::header_value(&signed.headers, "x-amz-date").is_some()); } + #[tokio::test] + async fn preload_credentials_resolves_provider() { + test_context(/*session_token*/ None) + .preload_credentials() + .await + .expect("static credentials should resolve"); + } + #[test] fn credentials_provider_failures_are_retryable() { assert!( diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index ba81b451a7..23922062ac 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -673,6 +673,15 @@ impl ModelClient { true } + /// Resolves provider credentials during session startup when the provider requests it. + pub(crate) async fn prewarm_provider_auth(&self) -> Result<()> { + if !self.state.provider.prewarms_auth_on_startup() { + return Ok(()); + } + + self.state.provider.prewarm_auth().await + } + /// Returns auth + provider configuration resolved from the current session auth state. /// /// This centralizes setup used by both prewarm and normal request paths so they stay in diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 6a61079a3b..7944769094 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -82,6 +82,7 @@ pub(crate) mod mentions { mod sandbox_tags; pub mod sandboxing; mod session_prefix; +mod session_startup_auth_prewarm; mod session_startup_prewarm; mod shell_detect; pub mod skills; diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index 6c6cc22a0e..b7ed859f65 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -980,6 +980,7 @@ impl Session { anyhow::bail!("required MCP servers failed to initialize: {details}"); } } + sess.schedule_startup_auth_prewarm().await; sess.schedule_startup_prewarm(session_configuration.base_instructions.clone()) .await; let session_start_source = match &initial_history { diff --git a/codex-rs/core/src/session_startup_auth_prewarm.rs b/codex-rs/core/src/session_startup_auth_prewarm.rs new file mode 100644 index 0000000000..0e672cdc9f --- /dev/null +++ b/codex-rs/core/src/session_startup_auth_prewarm.rs @@ -0,0 +1,16 @@ +use std::sync::Arc; + +use tracing::warn; + +use crate::session::session::Session; + +impl Session { + pub(crate) async fn schedule_startup_auth_prewarm(self: &Arc) { + let model_client = self.services.model_client.clone(); + tokio::spawn(async move { + if let Err(err) = model_client.prewarm_provider_auth().await { + warn!("startup provider auth prewarm failed: {err:#}"); + } + }); + } +} diff --git a/codex-rs/model-provider/src/amazon_bedrock/auth.rs b/codex-rs/model-provider/src/amazon_bedrock/auth.rs index 96c233207f..c4096879ae 100644 --- a/codex-rs/model-provider/src/amazon_bedrock/auth.rs +++ b/codex-rs/model-provider/src/amazon_bedrock/auth.rs @@ -22,6 +22,7 @@ use super::mantle::region_from_config; const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK"; const LEGACY_SESSION_ID_HEADER: &str = "session_id"; +#[derive(Clone, Debug)] pub(super) enum BedrockAuthMethod { EnvBearerToken { token: String, region: String }, AwsSdkAuth { context: AwsAuthContext }, @@ -42,17 +43,25 @@ pub(super) async fn resolve_auth_method( Ok(BedrockAuthMethod::AwsSdkAuth { context }) } -pub(super) async fn resolve_provider_auth( - aws: &ModelProviderAwsAuthInfo, -) -> Result { - match resolve_auth_method(aws).await? { - BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider { +pub(super) async fn prewarm_credentials(auth_method: &BedrockAuthMethod) -> Result<()> { + match auth_method { + BedrockAuthMethod::EnvBearerToken { .. } => Ok(()), + BedrockAuthMethod::AwsSdkAuth { context } => context + .preload_credentials() + .await + .map_err(aws_auth_error_to_codex_error), + } +} + +pub(super) fn provider_auth_from_method(auth_method: BedrockAuthMethod) -> SharedAuthProvider { + match auth_method { + BedrockAuthMethod::EnvBearerToken { token, .. } => Arc::new(BearerAuthProvider { token: Some(token), account_id: None, is_fedramp_account: false, - })), + }), BedrockAuthMethod::AwsSdkAuth { context } => { - Ok(Arc::new(BedrockMantleSigV4AuthProvider::new(context))) + Arc::new(BedrockMantleSigV4AuthProvider::new(context)) } } } diff --git a/codex-rs/model-provider/src/amazon_bedrock/mantle.rs b/codex-rs/model-provider/src/amazon_bedrock/mantle.rs index 7881845e45..d0e1e89f79 100644 --- a/codex-rs/model-provider/src/amazon_bedrock/mantle.rs +++ b/codex-rs/model-provider/src/amazon_bedrock/mantle.rs @@ -4,7 +4,6 @@ use codex_protocol::error::CodexErr; use codex_protocol::error::Result; use super::auth::BedrockAuthMethod; -use super::auth::resolve_auth_method; const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle"; const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [ @@ -48,16 +47,15 @@ pub(super) fn base_url(region: &str) -> Result { } } -pub(super) async fn runtime_base_url(aws: &ModelProviderAwsAuthInfo) -> Result { - let region = resolve_region(aws).await?; - base_url(®ion) +pub(super) fn region_from_auth_method(auth_method: &BedrockAuthMethod) -> String { + match auth_method { + BedrockAuthMethod::EnvBearerToken { region, .. } => region.clone(), + BedrockAuthMethod::AwsSdkAuth { context } => context.region().to_string(), + } } -async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result { - match resolve_auth_method(aws).await? { - BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region), - BedrockAuthMethod::AwsSdkAuth { context } => Ok(context.region().to_string()), - } +pub(super) fn runtime_base_url_from_auth_method(auth_method: &BedrockAuthMethod) -> Result { + base_url(®ion_from_auth_method(auth_method)) } #[cfg(test)] diff --git a/codex-rs/model-provider/src/amazon_bedrock/mod.rs b/codex-rs/model-provider/src/amazon_bedrock/mod.rs index adca7d7d91..63469d9eab 100644 --- a/codex-rs/model-provider/src/amazon_bedrock/mod.rs +++ b/codex-rs/model-provider/src/amazon_bedrock/mod.rs @@ -16,20 +16,26 @@ use codex_models_manager::manager::StaticModelsManager; use codex_protocol::account::ProviderAccount; use codex_protocol::error::Result; use codex_protocol::openai_models::ModelsResponse; +use tokio::sync::OnceCell; use crate::provider::ModelProvider; use crate::provider::ProviderAccountResult; use crate::provider::ProviderAccountState; use crate::provider::ProviderCapabilities; -use auth::resolve_provider_auth; +use auth::BedrockAuthMethod; +use auth::prewarm_credentials; +use auth::provider_auth_from_method; +use auth::resolve_auth_method; pub(crate) use catalog::static_model_catalog; -use mantle::runtime_base_url; +use mantle::runtime_base_url_from_auth_method; /// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint. #[derive(Clone, Debug)] pub(crate) struct AmazonBedrockModelProvider { pub(crate) info: ModelProviderInfo, pub(crate) aws: ModelProviderAwsAuthInfo, + auth_method: Arc>, + credentials_prewarmed: Arc>, } impl AmazonBedrockModelProvider { @@ -44,8 +50,25 @@ impl AmazonBedrockModelProvider { Self { info: provider_info, aws, + auth_method: Arc::new(OnceCell::new()), + credentials_prewarmed: Arc::new(OnceCell::new()), } } + + async fn auth_method(&self) -> Result { + self.auth_method + .get_or_try_init(|| resolve_auth_method(&self.aws)) + .await + .cloned() + } + + async fn prewarm_bedrock_credentials(&self) -> Result<()> { + let auth_method = self.auth_method().await?; + self.credentials_prewarmed + .get_or_try_init(|| async move { prewarm_credentials(&auth_method).await }) + .await?; + Ok(()) + } } #[async_trait::async_trait] @@ -70,6 +93,14 @@ impl ModelProvider for AmazonBedrockModelProvider { None } + fn prewarms_auth_on_startup(&self) -> bool { + true + } + + async fn prewarm_auth(&self) -> Result<()> { + self.prewarm_bedrock_credentials().await + } + fn account_state(&self) -> ProviderAccountResult { Ok(ProviderAccountState { account: Some(ProviderAccount::AmazonBedrock), @@ -79,16 +110,21 @@ impl ModelProvider for AmazonBedrockModelProvider { async fn api_provider(&self) -> Result { let mut api_provider_info = self.info.clone(); - api_provider_info.base_url = Some(runtime_base_url(&self.aws).await?); + api_provider_info.base_url = Some(runtime_base_url_from_auth_method( + &self.auth_method().await?, + )?); api_provider_info.to_api_provider(/*auth_mode*/ None) } async fn runtime_base_url(&self) -> Result> { - Ok(Some(runtime_base_url(&self.aws).await?)) + Ok(Some(runtime_base_url_from_auth_method( + &self.auth_method().await?, + )?)) } async fn api_auth(&self) -> Result { - resolve_provider_auth(&self.aws).await + self.prewarm_bedrock_credentials().await?; + Ok(provider_auth_from_method(self.auth_method().await?)) } fn models_manager( diff --git a/codex-rs/model-provider/src/provider.rs b/codex-rs/model-provider/src/provider.rs index 0c5e8e0ffe..d61607636d 100644 --- a/codex-rs/model-provider/src/provider.rs +++ b/codex-rs/model-provider/src/provider.rs @@ -96,6 +96,16 @@ pub trait ModelProvider: fmt::Debug + Send + Sync { /// Returns the current provider-scoped auth value, if one is configured. async fn auth(&self) -> Option; + /// Returns whether this provider should resolve request credentials during session startup. + fn prewarms_auth_on_startup(&self) -> bool { + false + } + + /// Resolves provider credentials before the first model request when startup prewarm is enabled. + async fn prewarm_auth(&self) -> codex_protocol::error::Result<()> { + Ok(()) + } + /// Returns the current app-visible account state for this provider. fn account_state(&self) -> ProviderAccountResult;