diff --git a/codex-rs/login/src/create_api_key.rs b/codex-rs/login/src/create_api_key.rs index 1b30ddb457..e587644936 100644 --- a/codex-rs/login/src/create_api_key.rs +++ b/codex-rs/login/src/create_api_key.rs @@ -313,23 +313,29 @@ async fn wait_for_default_project( let deadline = std::time::Instant::now() + Duration::from_secs(timeout_seconds); loop { let organizations = list_organizations(client, api_base, session_key).await?; - let last_state = if let Some(organization) = select_active_organization(&organizations) { - let projects = list_projects(client, api_base, session_key, &organization.id).await?; - if let Some(project) = find_default_project(&projects) { - return Ok(ProvisioningTarget { - organization_id: organization.id.clone(), - organization_title: organization.title.clone(), - project_id: project.id.clone(), - project_title: project.title.clone(), - }); + let last_state = if organizations.is_empty() { + "no organization found".to_string() + } else { + let ordered_organizations = organizations_by_preference(&organizations); + let mut project_count = 0; + for organization in ordered_organizations { + let projects = + list_projects(client, api_base, session_key, &organization.id).await?; + project_count += projects.len(); + if let Some(project) = find_default_project(&projects) { + return Ok(ProvisioningTarget { + organization_id: organization.id.clone(), + organization_title: organization.title.clone(), + project_id: project.id.clone(), + project_title: project.title.clone(), + }); + } } format!( - "organization `{}` exists, but no default project is ready yet (saw {} projects).", - organization.id, - projects.len() + "checked {} organizations and {} projects, but no default project is ready yet.", + organizations.len(), + project_count ) - } else { - "no organization found".to_string() }; if std::time::Instant::now() >= deadline { @@ -345,16 +351,22 @@ async fn wait_for_default_project( } } -fn select_active_organization(organizations: &[Organization]) -> Option<&Organization> { - organizations - .iter() - .find(|organization| organization.is_default) - .or_else(|| { - organizations - .iter() - .find(|organization| organization.personal) - }) - .or_else(|| organizations.first()) +fn organizations_by_preference(organizations: &[Organization]) -> Vec<&Organization> { + let mut ordered_organizations = organizations.iter().enumerate().collect::>(); + ordered_organizations.sort_by_key(|(index, organization)| { + let rank = if organization.is_default { + 0 + } else if organization.personal { + 1 + } else { + 2 + }; + (rank, *index) + }); + ordered_organizations + .into_iter() + .map(|(_, organization)| organization) + .collect() } fn find_default_project(projects: &[Project]) -> Option<&Project> { diff --git a/codex-rs/login/src/create_api_key_tests.rs b/codex-rs/login/src/create_api_key_tests.rs index a1f2d666ca..611b309c84 100644 --- a/codex-rs/login/src/create_api_key_tests.rs +++ b/codex-rs/login/src/create_api_key_tests.rs @@ -11,7 +11,7 @@ use wiremock::matchers::path; use wiremock::matchers::query_param; #[test] -fn select_active_organization_prefers_default_then_personal_then_first() { +fn organizations_by_preference_orders_default_then_personal_then_input_order() { let organizations = vec![ Organization { id: "org-first".to_string(), @@ -33,9 +33,12 @@ fn select_active_organization_prefers_default_then_personal_then_first() { }, ]; - let selected = select_active_organization(&organizations); + let selected = organizations_by_preference(&organizations); - assert_eq!(selected, organizations.get(2)); + assert_eq!( + selected, + vec![&organizations[2], &organizations[1], &organizations[0]] + ); } #[test] @@ -97,6 +100,10 @@ async fn create_api_key_from_authorization_code_creates_api_key() { "id": "org-default", "title": "Default Org", "is_default": true, + }, + { + "id": "org-secondary", + "title": "Secondary Org", } ] }))) @@ -106,6 +113,15 @@ async fn create_api_key_from_authorization_code_creates_api_key() { .and(path("/dashboard/organizations/org-default/projects")) .and(query_param("detail", "basic")) .and(query_param("limit", "100")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "data": [] + }))) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/dashboard/organizations/org-secondary/projects")) + .and(query_param("detail", "basic")) + .and(query_param("limit", "100")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "data": [ { @@ -119,7 +135,7 @@ async fn create_api_key_from_authorization_code_creates_api_key() { .await; Mock::given(method("POST")) .and(path( - "/dashboard/organizations/org-default/projects/proj-default/api_keys", + "/dashboard/organizations/org-secondary/projects/proj-default/api_keys", )) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "key": { @@ -156,8 +172,8 @@ async fn create_api_key_from_authorization_code_creates_api_key() { assert_eq!( output, CreatedApiKey { - organization_id: "org-default".to_string(), - organization_title: Some("Default Org".to_string()), + organization_id: "org-secondary".to_string(), + organization_title: Some("Secondary Org".to_string()), default_project_id: "proj-default".to_string(), default_project_title: Some("Default Project".to_string()), project_api_key: "sk-proj-123".to_string(), diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index c25d138623..6443c4177f 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -17,10 +17,6 @@ pub use server::LoginServer; pub use server::ServerOptions; pub use server::run_login_server; -pub use create_api_key::CreateApiKeyError; -pub use create_api_key::CreatedApiKey; -pub use create_api_key::PendingCreateApiKey; -pub use create_api_key::start_create_api_key; pub use auth::AuthConfig; pub use auth::AuthCredentialsStoreMode; pub use auth::AuthDotJson; @@ -40,4 +36,8 @@ pub use auth::logout; pub use auth::read_openai_api_key_from_env; pub use auth::save_auth; pub use codex_app_server_protocol::AuthMode; +pub use create_api_key::CreateApiKeyError; +pub use create_api_key::CreatedApiKey; +pub use create_api_key::PendingCreateApiKey; +pub use create_api_key::start_create_api_key; pub use token_data::TokenData; diff --git a/codex-rs/login/src/oauth_callback_server.rs b/codex-rs/login/src/oauth_callback_server.rs index f4f4997417..369c880ec5 100644 --- a/codex-rs/login/src/oauth_callback_server.rs +++ b/codex-rs/login/src/oauth_callback_server.rs @@ -385,16 +385,14 @@ fn process_authorization_code_request( parsed_url.query_pairs().into_owned().collect(); if params.get("state").map(String::as_str) != Some(expected_state) { - return HandledRequest::ResponseAndExit { - status: StatusCode(400), - headers: html_headers(), - body: b"

State mismatch

Return to your terminal and try again.

" - .to_vec(), - result: Err(io::Error::new( - io::ErrorKind::PermissionDenied, - "State mismatch in OAuth callback.", - )), - }; + let mut response = Response::from_string( + "

State mismatch

Return to your terminal and try again.

", + ) + .with_status_code(400); + if let Some(header) = html_headers().into_iter().next() { + response = response.with_header(header); + } + return HandledRequest::Response(response); } if let Some(error_code) = params.get("error") { @@ -451,3 +449,24 @@ fn authorization_code_error_message(error_code: &str, error_description: Option< format!("Authentication failed: {error_code}") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn process_authorization_code_request_keeps_server_running_on_state_mismatch() { + let response = process_authorization_code_request( + "/auth/callback?state=wrong-state&code=auth-code", + "/auth/callback", + "expected-state", + ); + + match response { + HandledRequest::Response(_) => {} + HandledRequest::RedirectWithHeader(_) | HandledRequest::ResponseAndExit { .. } => { + panic!("state mismatch should return a response without exiting") + } + } + } +}