mirror of
https://github.com/openai/codex.git
synced 2026-05-23 20:44:50 +00:00
changes
This commit is contained in:
@@ -97,6 +97,11 @@ impl AwsAuthContext {
|
||||
&self.service
|
||||
}
|
||||
|
||||
pub async fn preload_credentials(&self) -> Result<(), AwsAuthError> {
|
||||
let _ = self.credentials_provider.provide_credentials().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
|
||||
self.sign_at(request, SystemTime::now()).await
|
||||
}
|
||||
@@ -202,6 +207,14 @@ mod tests {
|
||||
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preload_credentials_resolves_provider() {
|
||||
test_context(/*session_token*/ None)
|
||||
.preload_credentials()
|
||||
.await
|
||||
.expect("static credentials should resolve");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_provider_failures_are_retryable() {
|
||||
assert!(
|
||||
|
||||
@@ -673,6 +673,15 @@ impl ModelClient {
|
||||
true
|
||||
}
|
||||
|
||||
/// Resolves provider credentials during session startup when the provider requests it.
|
||||
pub(crate) async fn prewarm_provider_auth(&self) -> Result<()> {
|
||||
if !self.state.provider.prewarms_auth_on_startup() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.state.provider.prewarm_auth().await
|
||||
}
|
||||
|
||||
/// Returns auth + provider configuration resolved from the current session auth state.
|
||||
///
|
||||
/// This centralizes setup used by both prewarm and normal request paths so they stay in
|
||||
|
||||
@@ -82,6 +82,7 @@ pub(crate) mod mentions {
|
||||
mod sandbox_tags;
|
||||
pub mod sandboxing;
|
||||
mod session_prefix;
|
||||
mod session_startup_auth_prewarm;
|
||||
mod session_startup_prewarm;
|
||||
mod shell_detect;
|
||||
pub mod skills;
|
||||
|
||||
@@ -980,6 +980,7 @@ impl Session {
|
||||
anyhow::bail!("required MCP servers failed to initialize: {details}");
|
||||
}
|
||||
}
|
||||
sess.schedule_startup_auth_prewarm().await;
|
||||
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
|
||||
.await;
|
||||
let session_start_source = match &initial_history {
|
||||
|
||||
16
codex-rs/core/src/session_startup_auth_prewarm.rs
Normal file
16
codex-rs/core/src/session_startup_auth_prewarm.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::warn;
|
||||
|
||||
use crate::session::session::Session;
|
||||
|
||||
impl Session {
|
||||
pub(crate) async fn schedule_startup_auth_prewarm(self: &Arc<Self>) {
|
||||
let model_client = self.services.model_client.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = model_client.prewarm_provider_auth().await {
|
||||
warn!("startup provider auth prewarm failed: {err:#}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ use super::mantle::region_from_config;
|
||||
const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK";
|
||||
const LEGACY_SESSION_ID_HEADER: &str = "session_id";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) enum BedrockAuthMethod {
|
||||
EnvBearerToken { token: String, region: String },
|
||||
AwsSdkAuth { context: AwsAuthContext },
|
||||
@@ -42,17 +43,25 @@ pub(super) async fn resolve_auth_method(
|
||||
Ok(BedrockAuthMethod::AwsSdkAuth { context })
|
||||
}
|
||||
|
||||
pub(super) async fn resolve_provider_auth(
|
||||
aws: &ModelProviderAwsAuthInfo,
|
||||
) -> Result<SharedAuthProvider> {
|
||||
match resolve_auth_method(aws).await? {
|
||||
BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider {
|
||||
pub(super) async fn prewarm_credentials(auth_method: &BedrockAuthMethod) -> Result<()> {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { .. } => Ok(()),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => context
|
||||
.preload_credentials()
|
||||
.await
|
||||
.map_err(aws_auth_error_to_codex_error),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn provider_auth_from_method(auth_method: BedrockAuthMethod) -> SharedAuthProvider {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { token, .. } => Arc::new(BearerAuthProvider {
|
||||
token: Some(token),
|
||||
account_id: None,
|
||||
is_fedramp_account: false,
|
||||
})),
|
||||
}),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => {
|
||||
Ok(Arc::new(BedrockMantleSigV4AuthProvider::new(context)))
|
||||
Arc::new(BedrockMantleSigV4AuthProvider::new(context))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
|
||||
use super::auth::BedrockAuthMethod;
|
||||
use super::auth::resolve_auth_method;
|
||||
|
||||
const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle";
|
||||
const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [
|
||||
@@ -48,16 +47,15 @@ pub(super) fn base_url(region: &str) -> Result<String> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn runtime_base_url(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
|
||||
let region = resolve_region(aws).await?;
|
||||
base_url(®ion)
|
||||
pub(super) fn region_from_auth_method(auth_method: &BedrockAuthMethod) -> String {
|
||||
match auth_method {
|
||||
BedrockAuthMethod::EnvBearerToken { region, .. } => region.clone(),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => context.region().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
|
||||
match resolve_auth_method(aws).await? {
|
||||
BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region),
|
||||
BedrockAuthMethod::AwsSdkAuth { context } => Ok(context.region().to_string()),
|
||||
}
|
||||
pub(super) fn runtime_base_url_from_auth_method(auth_method: &BedrockAuthMethod) -> Result<String> {
|
||||
base_url(®ion_from_auth_method(auth_method))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -16,20 +16,26 @@ use codex_models_manager::manager::StaticModelsManager;
|
||||
use codex_protocol::account::ProviderAccount;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
use crate::provider::ModelProvider;
|
||||
use crate::provider::ProviderAccountResult;
|
||||
use crate::provider::ProviderAccountState;
|
||||
use crate::provider::ProviderCapabilities;
|
||||
use auth::resolve_provider_auth;
|
||||
use auth::BedrockAuthMethod;
|
||||
use auth::prewarm_credentials;
|
||||
use auth::provider_auth_from_method;
|
||||
use auth::resolve_auth_method;
|
||||
pub(crate) use catalog::static_model_catalog;
|
||||
use mantle::runtime_base_url;
|
||||
use mantle::runtime_base_url_from_auth_method;
|
||||
|
||||
/// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AmazonBedrockModelProvider {
|
||||
pub(crate) info: ModelProviderInfo,
|
||||
pub(crate) aws: ModelProviderAwsAuthInfo,
|
||||
auth_method: Arc<OnceCell<BedrockAuthMethod>>,
|
||||
credentials_prewarmed: Arc<OnceCell<()>>,
|
||||
}
|
||||
|
||||
impl AmazonBedrockModelProvider {
|
||||
@@ -44,8 +50,25 @@ impl AmazonBedrockModelProvider {
|
||||
Self {
|
||||
info: provider_info,
|
||||
aws,
|
||||
auth_method: Arc::new(OnceCell::new()),
|
||||
credentials_prewarmed: Arc::new(OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn auth_method(&self) -> Result<BedrockAuthMethod> {
|
||||
self.auth_method
|
||||
.get_or_try_init(|| resolve_auth_method(&self.aws))
|
||||
.await
|
||||
.cloned()
|
||||
}
|
||||
|
||||
async fn prewarm_bedrock_credentials(&self) -> Result<()> {
|
||||
let auth_method = self.auth_method().await?;
|
||||
self.credentials_prewarmed
|
||||
.get_or_try_init(|| async move { prewarm_credentials(&auth_method).await })
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -70,6 +93,14 @@ impl ModelProvider for AmazonBedrockModelProvider {
|
||||
None
|
||||
}
|
||||
|
||||
fn prewarms_auth_on_startup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn prewarm_auth(&self) -> Result<()> {
|
||||
self.prewarm_bedrock_credentials().await
|
||||
}
|
||||
|
||||
fn account_state(&self) -> ProviderAccountResult {
|
||||
Ok(ProviderAccountState {
|
||||
account: Some(ProviderAccount::AmazonBedrock),
|
||||
@@ -79,16 +110,21 @@ impl ModelProvider for AmazonBedrockModelProvider {
|
||||
|
||||
async fn api_provider(&self) -> Result<Provider> {
|
||||
let mut api_provider_info = self.info.clone();
|
||||
api_provider_info.base_url = Some(runtime_base_url(&self.aws).await?);
|
||||
api_provider_info.base_url = Some(runtime_base_url_from_auth_method(
|
||||
&self.auth_method().await?,
|
||||
)?);
|
||||
api_provider_info.to_api_provider(/*auth_mode*/ None)
|
||||
}
|
||||
|
||||
async fn runtime_base_url(&self) -> Result<Option<String>> {
|
||||
Ok(Some(runtime_base_url(&self.aws).await?))
|
||||
Ok(Some(runtime_base_url_from_auth_method(
|
||||
&self.auth_method().await?,
|
||||
)?))
|
||||
}
|
||||
|
||||
async fn api_auth(&self) -> Result<SharedAuthProvider> {
|
||||
resolve_provider_auth(&self.aws).await
|
||||
self.prewarm_bedrock_credentials().await?;
|
||||
Ok(provider_auth_from_method(self.auth_method().await?))
|
||||
}
|
||||
|
||||
fn models_manager(
|
||||
|
||||
@@ -96,6 +96,16 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
|
||||
/// Returns the current provider-scoped auth value, if one is configured.
|
||||
async fn auth(&self) -> Option<CodexAuth>;
|
||||
|
||||
/// Returns whether this provider should resolve request credentials during session startup.
|
||||
fn prewarms_auth_on_startup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Resolves provider credentials before the first model request when startup prewarm is enabled.
|
||||
async fn prewarm_auth(&self) -> codex_protocol::error::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the current app-visible account state for this provider.
|
||||
fn account_state(&self) -> ProviderAccountResult;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user