codex: tighten api provisioning implementation

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Michael Fan
2026-03-25 14:18:43 -04:00
parent e044807690
commit 4ceea3e3c2
7 changed files with 146 additions and 205 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -2240,7 +2240,6 @@ dependencies = [
"codex-keyring-store",
"codex-protocol",
"codex-terminal-detection",
"codex-utils-home-dir",
"core_test_support",
"keyring",
"once_cell",

View File

@@ -17,7 +17,6 @@ codex-config = { workspace = true }
codex-keyring-store = { workspace = true }
codex-protocol = { workspace = true }
codex-terminal-detection = { workspace = true }
codex-utils-home-dir = { workspace = true }
once_cell = { workspace = true }
os_info = { workspace = true }
rand = { workspace = true }

View File

@@ -17,10 +17,10 @@ const PLATFORM_HYDRA_CLIENT_ID: &str = "app_2SKx67EdpoN0G6j64rFvigXD";
const PLATFORM_AUDIENCE: &str = "https://api.openai.com/v1";
const DEFAULT_API_BASE: &str = "https://api.openai.com";
const DEFAULT_CALLBACK_PORT: u16 = 5000;
const DEFAULT_CALLBACK_PATH: &str = "/auth/callback";
const CALLBACK_PATH: &str = "/auth/callback";
const DEFAULT_SCOPE: &str = "openid email profile offline_access";
const DEFAULT_APP: &str = "api";
const DEFAULT_USER_AGENT: &str = "OpenAI-Onboard-Auth-Script/1.0";
const USER_AGENT: &str = "Codex-API-Provision/1.0";
const DEFAULT_PROJECT_API_KEY_NAME: &str = "Codex CLI";
const DEFAULT_PROJECT_POLL_INTERVAL_SECONDS: u64 = 10;
const DEFAULT_PROJECT_POLL_TIMEOUT_SECONDS: u64 = 60;
@@ -79,16 +79,12 @@ impl PendingApiProvisioning {
self.callback_server.open_browser()
}
pub fn open_browser_or_print(&self) -> bool {
self.callback_server.open_browser_or_print()
}
pub async fn finish(self) -> Result<ProvisionedApiKey, HelperError> {
pub async fn finish(self) -> Result<ProvisionedApiKey, ApiProvisionError> {
let code = self
.callback_server
.wait_for_code(Duration::from_secs(OAUTH_TIMEOUT_SECONDS))
.await
.map_err(|err| HelperError::message(err.to_string()))?;
.map_err(|err| ApiProvisionError::message(err.to_string()))?;
provision_from_authorization_code(
&self.client,
&self.options,
@@ -102,30 +98,28 @@ impl PendingApiProvisioning {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProvisionedApiKey {
pub sensitive_id: String,
pub organization_id: String,
pub organization_title: Option<String>,
pub default_project_id: String,
pub default_project_title: Option<String>,
pub project_api_key: String,
pub access_token: String,
}
pub fn start_api_provisioning(
options: ApiProvisionOptions,
) -> Result<PendingApiProvisioning, HelperError> {
) -> Result<PendingApiProvisioning, ApiProvisionError> {
validate_api_provision_options(&options)?;
let client = build_http_client()?;
let callback_server = start_authorization_code_server(
options.callback_port,
DEFAULT_CALLBACK_PATH,
CALLBACK_PATH,
/*force_state*/ None,
|redirect_uri, pkce, state| {
build_authorize_url(&options, redirect_uri, pkce, state)
.map_err(|err| std::io::Error::other(err.to_string()))
},
)
.map_err(|err| HelperError::message(err.to_string()))?;
.map_err(|err| ApiProvisionError::message(err.to_string()))?;
let redirect_uri = callback_server.redirect_uri.clone();
Ok(PendingApiProvisioning {
@@ -137,19 +131,19 @@ pub fn start_api_provisioning(
})
}
fn validate_api_provision_options(options: &ApiProvisionOptions) -> Result<(), HelperError> {
fn validate_api_provision_options(options: &ApiProvisionOptions) -> Result<(), ApiProvisionError> {
if options.project_poll_interval_seconds == 0 {
return Err(HelperError::message(
return Err(ApiProvisionError::message(
"project_poll_interval_seconds must be greater than 0.".to_string(),
));
}
if options.project_poll_timeout_seconds == 0 {
return Err(HelperError::message(
return Err(ApiProvisionError::message(
"project_poll_timeout_seconds must be greater than 0.".to_string(),
));
}
if options.api_key_name.trim().is_empty() {
return Err(HelperError::message(
return Err(ApiProvisionError::message(
"api_key_name must not be empty.".to_string(),
));
}
@@ -161,12 +155,12 @@ fn build_authorize_url(
redirect_uri: &str,
pkce: &PkceCodes,
state: &str,
) -> Result<String, HelperError> {
) -> Result<String, ApiProvisionError> {
let mut url = Url::parse(&format!(
"{}/oauth/authorize",
options.issuer.trim_end_matches('/')
))
.map_err(|err| HelperError::message(format!("invalid issuer URL: {err}")))?;
.map_err(|err| ApiProvisionError::message(format!("invalid issuer URL: {err}")))?;
url.query_pairs_mut()
.append_pair("audience", &options.audience)
.append_pair("client_id", &options.client_id)
@@ -179,11 +173,11 @@ fn build_authorize_url(
Ok(url.to_string())
}
fn build_http_client() -> Result<Client, HelperError> {
fn build_http_client() -> Result<Client, ApiProvisionError> {
build_reqwest_client_with_custom_ca(
reqwest::Client::builder().timeout(Duration::from_secs(HTTP_TIMEOUT_SECONDS)),
)
.map_err(|err| HelperError::message(format!("failed to build HTTP client: {err}")))
.map_err(|err| ApiProvisionError::message(format!("failed to build HTTP client: {err}")))
}
async fn provision_from_authorization_code(
@@ -192,7 +186,7 @@ async fn provision_from_authorization_code(
redirect_uri: &str,
code_verifier: &str,
code: &str,
) -> Result<ProvisionedApiKey, HelperError> {
) -> Result<ProvisionedApiKey, ApiProvisionError> {
let tokens = exchange_authorization_code_for_tokens(
client,
&options.issuer,
@@ -229,13 +223,11 @@ async fn provision_from_authorization_code(
.sensitive_id;
Ok(ProvisionedApiKey {
sensitive_id: login.user.session.sensitive_id,
organization_id: target.organization_id,
organization_title: target.organization_title,
default_project_id: target.project_id,
default_project_title: target.project_title,
project_api_key: api_key,
access_token: tokens.access_token,
})
}
@@ -246,13 +238,13 @@ async fn exchange_authorization_code_for_tokens(
redirect_uri: &str,
code_verifier: &str,
code: &str,
) -> Result<OAuthTokens, HelperError> {
) -> Result<OAuthTokens, ApiProvisionError> {
let url = format!("{}/oauth/token", issuer.trim_end_matches('/'));
execute_json(
client
.request(Method::POST, &url)
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
.header(reqwest::header::USER_AGENT, USER_AGENT)
.json(&serde_json::json!({
"client_id": client_id,
"code_verifier": code_verifier,
@@ -271,7 +263,7 @@ async fn onboarding_login(
api_base: &str,
app: &str,
access_token: &str,
) -> Result<OnboardingLoginResponse, HelperError> {
) -> Result<OnboardingLoginResponse, ApiProvisionError> {
let url = format!(
"{}/dashboard/onboarding/login",
api_base.trim_end_matches('/')
@@ -280,7 +272,7 @@ async fn onboarding_login(
client
.request(Method::POST, &url)
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
.header(reqwest::header::USER_AGENT, USER_AGENT)
.bearer_auth(access_token)
.json(&serde_json::json!({ "app": app })),
"POST",
@@ -293,13 +285,13 @@ async fn list_organizations(
client: &Client,
api_base: &str,
session_key: &str,
) -> Result<Vec<Organization>, HelperError> {
) -> Result<Vec<Organization>, ApiProvisionError> {
let url = format!("{}/v1/organizations", api_base.trim_end_matches('/'));
let response: DataList<Organization> = execute_json(
client
.request(Method::GET, &url)
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
.header(reqwest::header::USER_AGENT, USER_AGENT)
.bearer_auth(session_key),
"GET",
&url,
@@ -313,7 +305,7 @@ async fn list_projects(
api_base: &str,
session_key: &str,
organization_id: &str,
) -> Result<Vec<Project>, HelperError> {
) -> Result<Vec<Project>, ApiProvisionError> {
let url = format!(
"{}/dashboard/organizations/{}/projects?detail=basic&limit=100",
api_base.trim_end_matches('/'),
@@ -323,7 +315,7 @@ async fn list_projects(
client
.request(Method::GET, &url)
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
.header(reqwest::header::USER_AGENT, USER_AGENT)
.header("openai-organization", organization_id)
.bearer_auth(session_key),
"GET",
@@ -339,7 +331,7 @@ async fn wait_for_default_project(
session_key: &str,
poll_interval_seconds: u64,
timeout_seconds: u64,
) -> Result<ProvisioningTarget, HelperError> {
) -> Result<ProvisioningTarget, ApiProvisionError> {
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_seconds);
loop {
let organizations = list_organizations(client, api_base, session_key).await?;
@@ -363,7 +355,7 @@ async fn wait_for_default_project(
};
if std::time::Instant::now() >= deadline {
return Err(HelperError::message(format!(
return Err(ApiProvisionError::message(format!(
"Timed out waiting for an organization and default project. Last observed state: {last_state}"
)));
}
@@ -371,7 +363,7 @@ async fn wait_for_default_project(
.saturating_duration_since(std::time::Instant::now())
.as_secs();
let sleep_seconds = poll_interval_seconds.min(remaining_seconds.max(1));
std::thread::sleep(Duration::from_secs(sleep_seconds));
tokio::time::sleep(Duration::from_secs(sleep_seconds)).await;
}
}
@@ -397,7 +389,7 @@ async fn create_project_api_key(
session_key: &str,
target: &ProvisioningTarget,
key_name: &str,
) -> Result<CreateApiKeyResponse, HelperError> {
) -> Result<CreateApiKeyResponse, ApiProvisionError> {
let url = format!(
"{}/dashboard/organizations/{}/projects/{}/api_keys",
api_base.trim_end_matches('/'),
@@ -408,7 +400,7 @@ async fn create_project_api_key(
client
.request(Method::POST, &url)
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
.header(reqwest::header::USER_AGENT, USER_AGENT)
.bearer_auth(session_key)
.json(&serde_json::json!({
"action": "create",
@@ -424,26 +416,26 @@ async fn execute_json<T>(
request: reqwest::RequestBuilder,
method: &str,
url: &str,
) -> Result<T, HelperError>
) -> Result<T, ApiProvisionError>
where
T: for<'de> Deserialize<'de>,
{
let response = request
.send()
.await
.map_err(|err| HelperError::message(format!("Network error calling {url}: {err}")))?;
.map_err(|err| ApiProvisionError::message(format!("Network error calling {url}: {err}")))?;
let status = response.status();
let body = response.bytes().await.map_err(|err| {
HelperError::message(format!("Failed reading response from {url}: {err}"))
ApiProvisionError::message(format!("Failed reading response from {url}: {err}"))
})?;
if !status.is_success() {
return Err(HelperError::api(
return Err(ApiProvisionError::api(
format!("{method} {url} failed with HTTP {status}"),
String::from_utf8_lossy(&body).into_owned(),
));
}
serde_json::from_slice(&body)
.map_err(|err| HelperError::message(format!("{url} returned invalid JSON: {err}")))
.map_err(|err| ApiProvisionError::message(format!("{url} returned invalid JSON: {err}")))
}
#[derive(Debug, Deserialize)]
@@ -512,38 +504,29 @@ struct CreatedApiKey {
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HelperError {
pub struct ApiProvisionError {
message: String,
body: Option<String>,
}
impl HelperError {
impl ApiProvisionError {
fn message(message: String) -> Self {
Self {
message,
body: None,
}
Self { message }
}
fn api(message: String, body: String) -> Self {
Self {
message,
body: Some(body),
message: format!("{message}: {body}"),
}
}
pub fn body(&self) -> Option<&str> {
self.body.as_deref()
}
}
impl std::fmt::Display for HelperError {
impl std::fmt::Display for ApiProvisionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for HelperError {}
impl std::error::Error for ApiProvisionError {}
#[cfg(test)]
#[path = "api_provision_tests.rs"]

View File

@@ -154,13 +154,11 @@ async fn provision_from_authorization_code_provisions_api_key() {
assert_eq!(
output,
ProvisionedApiKey {
sensitive_id: "session-123".to_string(),
organization_id: "org-default".to_string(),
organization_title: Some("Default Org".to_string()),
default_project_id: "proj-default".to_string(),
default_project_title: Some("Default Project".to_string()),
project_api_key: "sk-proj-123".to_string(),
access_token: "oauth-access-123".to_string(),
}
);
}

View File

@@ -16,8 +16,8 @@ pub use server::ServerOptions;
pub use server::ShutdownHandle;
pub use server::run_login_server;
pub use api_provision::ApiProvisionError;
pub use api_provision::ApiProvisionOptions;
pub use api_provision::HelperError as ApiProvisionError;
pub use api_provision::PendingApiProvisioning;
pub use api_provision::ProvisionedApiKey;
pub use api_provision::start_api_provisioning;

View File

@@ -11,6 +11,7 @@
//! This module therefore keeps the user-facing error path and the structured-log path separate.
//! Returned `io::Error` values still carry the detail needed by CLI/browser callers, while
//! structured logs only emit explicitly reviewed fields plus redacted URL/error values.
use std::future::Future;
use std::io::Cursor;
use std::io::Read;
use std::io::Write;
@@ -125,22 +126,6 @@ impl AuthorizationCodeServer {
webbrowser::open(&self.auth_url).is_ok()
}
pub fn open_browser_or_print(&self) -> bool {
let opened = self.open_browser();
if opened {
eprintln!(
"Starting local auth callback server.\nIf your browser did not open, navigate to this URL to continue:\n\n{}",
self.auth_url
);
} else {
eprintln!(
"Starting local auth callback server.\nOpen this URL in your browser to continue:\n\n{}",
self.auth_url
);
}
opened
}
pub fn code_verifier(&self) -> &str {
&self.code_verifier
}
@@ -197,67 +182,19 @@ where
let state = force_state.unwrap_or_else(generate_state);
let callback_path = callback_path.to_string();
let server = bind_server(port)?;
let actual_port = match server.server_addr().to_ip() {
Some(addr) => addr.port(),
None => {
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
"Unable to determine the server port",
));
}
};
let server = Arc::new(server);
let (server, actual_port, rx) = bind_server_with_request_channel(port)?;
let redirect_uri = format!("http://localhost:{actual_port}{callback_path}");
let auth_url = auth_url_builder(&redirect_uri, &pkce, &state)?;
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
let _server_handle = {
let server = server.clone();
thread::spawn(move || -> io::Result<()> {
while let Ok(request) = server.recv() {
match tx.blocking_send(request) {
Ok(()) => {}
Err(error) => {
eprintln!("Failed to send request to channel: {error}");
return Err(io::Error::other("Failed to send request to channel"));
}
}
}
Ok(())
})
};
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
let server_handle = {
let shutdown_notify = shutdown_notify.clone();
tokio::spawn(async move {
let result = loop {
tokio::select! {
_ = shutdown_notify.notified() => {
break Err(io::Error::other("Authentication was not completed"));
}
maybe_req = rx.recv() => {
let Some(req) = maybe_req else {
break Err(io::Error::other("Authentication was not completed"));
};
let url_raw = req.url().to_string();
let response =
process_authorization_code_request(&url_raw, &callback_path, &state);
if let Some(result) = respond_to_request(req, response).await {
break result;
}
}
}
};
server.unblock();
result
})
};
let (server_handle, shutdown_handle) = spawn_callback_server_loop(
server,
rx,
"Authentication was not completed",
move |url_raw| {
let callback_path = callback_path.clone();
let state = state.clone();
async move { process_authorization_code_request(&url_raw, &callback_path, &state) }
},
);
Ok(AuthorizationCodeServer {
auth_url,
@@ -265,7 +202,7 @@ where
redirect_uri,
code_verifier: pkce.code_verifier,
server_handle,
shutdown_handle: ShutdownHandle { shutdown_notify },
shutdown_handle,
})
}
@@ -274,18 +211,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
let pkce = generate_pkce();
let state = opts.force_state.clone().unwrap_or_else(generate_state);
let server = bind_server(opts.port)?;
let actual_port = match server.server_addr().to_ip() {
Some(addr) => addr.port(),
None => {
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
"Unable to determine the server port",
));
}
};
let server = Arc::new(server);
let (server, actual_port, rx) = bind_server_with_request_channel(opts.port)?;
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
let auth_url = build_authorize_url(
&opts.issuer,
@@ -299,63 +225,22 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
if opts.open_browser {
let _ = webbrowser::open(&auth_url);
}
// Map blocking reads from server.recv() to an async channel.
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
let _server_handle = {
let server = server.clone();
thread::spawn(move || -> io::Result<()> {
while let Ok(request) = server.recv() {
match tx.blocking_send(request) {
Ok(()) => {}
Err(error) => {
eprintln!("Failed to send request to channel: {error}");
return Err(io::Error::other("Failed to send request to channel"));
}
}
let (server_handle, shutdown_handle) =
spawn_callback_server_loop(server, rx, "Login was not completed", move |url_raw| {
let redirect_uri = redirect_uri.clone();
let state = state.clone();
let opts = opts.clone();
let pkce = pkce.clone();
async move {
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await
}
Ok(())
})
};
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
let server_handle = {
let shutdown_notify = shutdown_notify.clone();
let server = server;
tokio::spawn(async move {
let result = loop {
tokio::select! {
_ = shutdown_notify.notified() => {
break Err(io::Error::other("Login was not completed"));
}
maybe_req = rx.recv() => {
let Some(req) = maybe_req else {
break Err(io::Error::other("Login was not completed"));
};
let url_raw = req.url().to_string();
let response =
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
if let Some(result) = respond_to_request(req, response).await {
break result;
}
}
}
};
// Ensure that the server is unblocked so the thread dedicated to
// running `server.recv()` in a loop exits cleanly.
server.unblock();
result
})
};
});
Ok(LoginServer {
auth_url,
actual_port,
server_handle,
shutdown_handle: ShutdownHandle { shutdown_notify },
shutdown_handle,
})
}
@@ -626,6 +511,87 @@ fn html_headers() -> Vec<Header> {
}
}
fn bind_server_with_request_channel(
port: u16,
) -> io::Result<(Arc<Server>, u16, tokio::sync::mpsc::Receiver<Request>)> {
let server = bind_server(port)?;
let actual_port = match server.server_addr().to_ip() {
Some(addr) => addr.port(),
None => {
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
"Unable to determine the server port",
));
}
};
let server = Arc::new(server);
// Map blocking reads from server.recv() to an async channel.
let (tx, rx) = tokio::sync::mpsc::channel::<Request>(16);
let _server_handle = {
let server = server.clone();
thread::spawn(move || -> io::Result<()> {
while let Ok(request) = server.recv() {
match tx.blocking_send(request) {
Ok(()) => {}
Err(error) => {
eprintln!("Failed to send request to channel: {error}");
return Err(io::Error::other("Failed to send request to channel"));
}
}
}
Ok(())
})
};
Ok((server, actual_port, rx))
}
fn spawn_callback_server_loop<T, F, Fut>(
server: Arc<Server>,
mut rx: tokio::sync::mpsc::Receiver<Request>,
incomplete_message: &'static str,
mut process_request: F,
) -> (tokio::task::JoinHandle<io::Result<T>>, ShutdownHandle)
where
T: Send + 'static,
F: FnMut(String) -> Fut + Send + 'static,
Fut: Future<Output = HandledRequest<T>> + Send + 'static,
{
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
let server_handle = {
let shutdown_notify = shutdown_notify.clone();
tokio::spawn(async move {
let result = loop {
tokio::select! {
_ = shutdown_notify.notified() => {
break Err(io::Error::other(incomplete_message));
}
maybe_req = rx.recv() => {
let Some(req) = maybe_req else {
break Err(io::Error::other(incomplete_message));
};
let url_raw = req.url().to_string();
let response = process_request(url_raw).await;
if let Some(result) = respond_to_request(req, response).await {
break result;
}
}
}
};
// Ensure that the server is unblocked so the thread dedicated to
// running `server.recv()` in a loop exits cleanly.
server.unblock();
result
})
};
(server_handle, ShutdownHandle { shutdown_notify })
}
async fn respond_to_request<T>(req: Request, response: HandledRequest<T>) -> Option<io::Result<T>> {
match response {
HandledRequest::Response(response) => {

View File

@@ -261,13 +261,11 @@ mod tests {
fn success_cell_snapshot() {
let cell = success_cell(
&ProvisionedApiKey {
sensitive_id: "session-123".to_string(),
organization_id: "org-default".to_string(),
organization_title: Some("Default Org".to_string()),
default_project_id: "proj-default".to_string(),
default_project_title: Some("Default Project".to_string()),
project_api_key: "sk-proj-123".to_string(),
access_token: "oauth-access-123".to_string(),
},
Path::new("/tmp/workspace/.env.local"),
LiveApplyOutcome::Applied,
@@ -280,13 +278,11 @@ mod tests {
fn success_cell_snapshot_when_live_apply_is_skipped() {
let cell = success_cell(
&ProvisionedApiKey {
sensitive_id: "session-123".to_string(),
organization_id: "org-default".to_string(),
organization_title: None,
default_project_id: "proj-default".to_string(),
default_project_title: None,
project_api_key: "sk-proj-123".to_string(),
access_token: "oauth-access-123".to_string(),
},
Path::new("/tmp/workspace/.env.local"),
LiveApplyOutcome::Skipped(