mirror of
https://github.com/openai/codex.git
synced 2026-03-26 16:43:58 +00:00
Compare commits
31 Commits
dev/imalch
...
codex/api-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c3bfb9462 | ||
|
|
93e362994e | ||
|
|
99448163d7 | ||
|
|
6327dc755c | ||
|
|
bcfcdbc552 | ||
|
|
62b3bacd9d | ||
|
|
a9b4f1e327 | ||
|
|
1e173fc832 | ||
|
|
ee431c25fe | ||
|
|
b3e7bfdf1c | ||
|
|
af79effb1f | ||
|
|
86ae9fc582 | ||
|
|
9906a1a0de | ||
|
|
cc7bfdd5e4 | ||
|
|
e8ed25dd38 | ||
|
|
76bb4d670f | ||
|
|
20e96a78da | ||
|
|
4ceea3e3c2 | ||
|
|
e044807690 | ||
|
|
0ef284aa6e | ||
|
|
d114d21272 | ||
|
|
46c5a066b6 | ||
|
|
792d78750f | ||
|
|
c91d265ff6 | ||
|
|
cb6ade2f1c | ||
|
|
be42df8246 | ||
|
|
8243031353 | ||
|
|
9473126c29 | ||
|
|
abe9c7a984 | ||
|
|
89b60b9f09 | ||
|
|
c12c688998 |
@@ -22,6 +22,7 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
@@ -112,6 +113,10 @@ impl CodexThread {
|
||||
self.codex.agent_status().await
|
||||
}
|
||||
|
||||
pub async fn set_dependency_env(&self, values: HashMap<String, String>) {
|
||||
self.codex.session.set_dependency_env(values).await;
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_status(&self) -> watch::Receiver<AgentStatus> {
|
||||
self.codex.agent_status.clone()
|
||||
}
|
||||
|
||||
527
codex-rs/login/src/create_api_key.rs
Normal file
527
codex-rs/login/src/create_api_key.rs
Normal file
@@ -0,0 +1,527 @@
|
||||
//! Browser-based OAuth flow for creating OpenAI project API keys.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use reqwest::Client;
|
||||
use reqwest::Method;
|
||||
use serde::Deserialize;
|
||||
use url::Url;
|
||||
|
||||
use crate::oauth_callback_server::AuthorizationCodeServer;
|
||||
use crate::oauth_callback_server::PortConflictStrategy;
|
||||
use crate::oauth_callback_server::start_authorization_code_server;
|
||||
use crate::pkce::PkceCodes;
|
||||
|
||||
const AUTH_ISSUER: &str = "https://auth.openai.com";
|
||||
const PLATFORM_HYDRA_CLIENT_ID: &str = "app_2SKx67EdpoN0G6j64rFvigXD";
|
||||
const PLATFORM_AUDIENCE: &str = "https://api.openai.com/v1";
|
||||
const API_BASE: &str = "https://api.openai.com";
|
||||
// This client is registered with Hydra for http://localhost:5000/auth/callback,
|
||||
// so the browser redirect must stay on port 5000.
|
||||
const CALLBACK_PORT: u16 = 5000;
|
||||
const CALLBACK_PATH: &str = "/auth/callback";
|
||||
const SCOPE: &str = "openid email profile offline_access";
|
||||
const APP: &str = "api";
|
||||
const USER_AGENT: &str = "Codex-Create-API-Key/1.0";
|
||||
const PROJECT_API_KEY_NAME: &str = "Codex CLI";
|
||||
const PROJECT_POLL_INTERVAL_SECONDS: u64 = 10;
|
||||
const PROJECT_POLL_TIMEOUT_SECONDS: u64 = 60;
|
||||
const OAUTH_TIMEOUT_SECONDS: u64 = 15 * 60;
|
||||
const HTTP_TIMEOUT_SECONDS: u64 = 30;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct CreateApiKeyOptions {
|
||||
issuer: String,
|
||||
client_id: String,
|
||||
audience: String,
|
||||
api_base: String,
|
||||
app: String,
|
||||
callback_port: u16,
|
||||
scope: String,
|
||||
api_key_name: String,
|
||||
project_poll_interval_seconds: u64,
|
||||
project_poll_timeout_seconds: u64,
|
||||
}
|
||||
|
||||
pub struct PendingCreateApiKey {
|
||||
client: Client,
|
||||
options: CreateApiKeyOptions,
|
||||
redirect_uri: String,
|
||||
code_verifier: String,
|
||||
callback_server: AuthorizationCodeServer,
|
||||
}
|
||||
|
||||
impl PendingCreateApiKey {
|
||||
pub fn auth_url(&self) -> &str {
|
||||
&self.callback_server.auth_url
|
||||
}
|
||||
|
||||
pub fn callback_port(&self) -> u16 {
|
||||
self.callback_server.actual_port
|
||||
}
|
||||
|
||||
pub fn open_browser(&self) -> bool {
|
||||
self.callback_server.open_browser()
|
||||
}
|
||||
|
||||
pub async fn finish(self) -> Result<CreatedApiKey, CreateApiKeyError> {
|
||||
let code = self
|
||||
.callback_server
|
||||
.wait_for_code(Duration::from_secs(OAUTH_TIMEOUT_SECONDS))
|
||||
.await
|
||||
.map_err(|err| CreateApiKeyError::message(err.to_string()))?;
|
||||
create_api_key_from_authorization_code(
|
||||
&self.client,
|
||||
&self.options,
|
||||
&self.redirect_uri,
|
||||
&self.code_verifier,
|
||||
&code,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CreatedApiKey {
|
||||
pub organization_id: String,
|
||||
pub organization_title: Option<String>,
|
||||
pub default_project_id: String,
|
||||
pub default_project_title: Option<String>,
|
||||
pub project_api_key: String,
|
||||
}
|
||||
|
||||
pub fn start_create_api_key() -> Result<PendingCreateApiKey, CreateApiKeyError> {
|
||||
let options = CreateApiKeyOptions {
|
||||
issuer: AUTH_ISSUER.to_string(),
|
||||
client_id: PLATFORM_HYDRA_CLIENT_ID.to_string(),
|
||||
audience: PLATFORM_AUDIENCE.to_string(),
|
||||
api_base: API_BASE.to_string(),
|
||||
app: APP.to_string(),
|
||||
callback_port: CALLBACK_PORT,
|
||||
scope: SCOPE.to_string(),
|
||||
api_key_name: PROJECT_API_KEY_NAME.to_string(),
|
||||
project_poll_interval_seconds: PROJECT_POLL_INTERVAL_SECONDS,
|
||||
project_poll_timeout_seconds: PROJECT_POLL_TIMEOUT_SECONDS,
|
||||
};
|
||||
let client = build_http_client()?;
|
||||
let callback_server = start_authorization_code_server(
|
||||
options.callback_port,
|
||||
PortConflictStrategy::Fail,
|
||||
CALLBACK_PATH,
|
||||
/*force_state*/ None,
|
||||
|redirect_uri, pkce, state| {
|
||||
build_authorize_url(&options, redirect_uri, pkce, state)
|
||||
.map_err(|err| std::io::Error::other(err.to_string()))
|
||||
},
|
||||
)
|
||||
.map_err(|err| CreateApiKeyError::message(err.to_string()))?;
|
||||
let redirect_uri = callback_server.redirect_uri.clone();
|
||||
|
||||
Ok(PendingCreateApiKey {
|
||||
client,
|
||||
options,
|
||||
redirect_uri,
|
||||
code_verifier: callback_server.code_verifier().to_string(),
|
||||
callback_server,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_authorize_url(
|
||||
options: &CreateApiKeyOptions,
|
||||
redirect_uri: &str,
|
||||
pkce: &PkceCodes,
|
||||
state: &str,
|
||||
) -> Result<String, CreateApiKeyError> {
|
||||
let mut url = Url::parse(&format!(
|
||||
"{}/oauth/authorize",
|
||||
options.issuer.trim_end_matches('/')
|
||||
))
|
||||
.map_err(|err| CreateApiKeyError::message(format!("invalid issuer URL: {err}")))?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("audience", &options.audience)
|
||||
.append_pair("client_id", &options.client_id)
|
||||
.append_pair("code_challenge_method", "S256")
|
||||
.append_pair("code_challenge", &pkce.code_challenge)
|
||||
.append_pair("redirect_uri", redirect_uri)
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("scope", &options.scope)
|
||||
.append_pair("state", state);
|
||||
Ok(url.to_string())
|
||||
}
|
||||
|
||||
fn build_http_client() -> Result<Client, CreateApiKeyError> {
|
||||
build_reqwest_client_with_custom_ca(
|
||||
reqwest::Client::builder().timeout(Duration::from_secs(HTTP_TIMEOUT_SECONDS)),
|
||||
)
|
||||
.map_err(|err| CreateApiKeyError::message(format!("failed to build HTTP client: {err}")))
|
||||
}
|
||||
|
||||
async fn create_api_key_from_authorization_code(
|
||||
client: &Client,
|
||||
options: &CreateApiKeyOptions,
|
||||
redirect_uri: &str,
|
||||
code_verifier: &str,
|
||||
code: &str,
|
||||
) -> Result<CreatedApiKey, CreateApiKeyError> {
|
||||
let tokens = exchange_authorization_code_for_tokens(
|
||||
client,
|
||||
&options.issuer,
|
||||
&options.client_id,
|
||||
redirect_uri,
|
||||
code_verifier,
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
let login = onboarding_login(
|
||||
client,
|
||||
&options.api_base,
|
||||
&options.app,
|
||||
&tokens.access_token,
|
||||
)
|
||||
.await?;
|
||||
let target = wait_for_default_project(
|
||||
client,
|
||||
&options.api_base,
|
||||
&login.user.session.sensitive_id,
|
||||
options.project_poll_interval_seconds,
|
||||
options.project_poll_timeout_seconds,
|
||||
)
|
||||
.await?;
|
||||
let api_key = create_project_api_key(
|
||||
client,
|
||||
&options.api_base,
|
||||
&login.user.session.sensitive_id,
|
||||
&target,
|
||||
&options.api_key_name,
|
||||
)
|
||||
.await?
|
||||
.key
|
||||
.sensitive_id;
|
||||
|
||||
Ok(CreatedApiKey {
|
||||
organization_id: target.organization_id,
|
||||
organization_title: target.organization_title,
|
||||
default_project_id: target.project_id,
|
||||
default_project_title: target.project_title,
|
||||
project_api_key: api_key,
|
||||
})
|
||||
}
|
||||
|
||||
async fn exchange_authorization_code_for_tokens(
|
||||
client: &Client,
|
||||
issuer: &str,
|
||||
client_id: &str,
|
||||
redirect_uri: &str,
|
||||
code_verifier: &str,
|
||||
code: &str,
|
||||
) -> Result<OAuthTokens, CreateApiKeyError> {
|
||||
let url = format!("{}/oauth/token", issuer.trim_end_matches('/'));
|
||||
execute_json(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.body(format!(
|
||||
"client_id={}&code_verifier={}&code={}&grant_type={}&redirect_uri={}",
|
||||
urlencoding::encode(client_id),
|
||||
urlencoding::encode(code_verifier),
|
||||
urlencoding::encode(code),
|
||||
urlencoding::encode("authorization_code"),
|
||||
urlencoding::encode(redirect_uri)
|
||||
)),
|
||||
"POST",
|
||||
&url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn onboarding_login(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
app: &str,
|
||||
access_token: &str,
|
||||
) -> Result<OnboardingLoginResponse, CreateApiKeyError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/onboarding/login",
|
||||
api_base.trim_end_matches('/')
|
||||
);
|
||||
execute_json(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(access_token)
|
||||
.json(&serde_json::json!({ "app": app })),
|
||||
"POST",
|
||||
&url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_organizations(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
) -> Result<Vec<Organization>, CreateApiKeyError> {
|
||||
let url = format!("{}/v1/organizations", api_base.trim_end_matches('/'));
|
||||
let response: DataList<Organization> = execute_json(
|
||||
client
|
||||
.request(Method::GET, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(session_key),
|
||||
"GET",
|
||||
&url,
|
||||
)
|
||||
.await?;
|
||||
Ok(response.data)
|
||||
}
|
||||
|
||||
async fn list_projects(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
organization_id: &str,
|
||||
) -> Result<Vec<Project>, CreateApiKeyError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/organizations/{}/projects?detail=basic&limit=100",
|
||||
api_base.trim_end_matches('/'),
|
||||
urlencoding::encode(organization_id)
|
||||
);
|
||||
let response: DataList<Project> = execute_json(
|
||||
client
|
||||
.request(Method::GET, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.header("openai-organization", organization_id)
|
||||
.bearer_auth(session_key),
|
||||
"GET",
|
||||
&url,
|
||||
)
|
||||
.await?;
|
||||
Ok(response.data)
|
||||
}
|
||||
|
||||
async fn wait_for_default_project(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
poll_interval_seconds: u64,
|
||||
timeout_seconds: u64,
|
||||
) -> Result<ProjectApiKeyTarget, CreateApiKeyError> {
|
||||
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_seconds);
|
||||
loop {
|
||||
let organizations = list_organizations(client, api_base, session_key).await?;
|
||||
let last_state = if organizations.is_empty() {
|
||||
"no organization found".to_string()
|
||||
} else {
|
||||
let ordered_organizations = organizations_by_preference(&organizations);
|
||||
let mut project_count = 0;
|
||||
for organization in ordered_organizations {
|
||||
let projects =
|
||||
list_projects(client, api_base, session_key, &organization.id).await?;
|
||||
project_count += projects.len();
|
||||
if let Some(project) = find_default_project(&projects) {
|
||||
return Ok(ProjectApiKeyTarget {
|
||||
organization_id: organization.id.clone(),
|
||||
organization_title: organization.title.clone(),
|
||||
project_id: project.id.clone(),
|
||||
project_title: project.title.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
format!(
|
||||
"checked {} organizations and {} projects, but no default project is ready yet.",
|
||||
organizations.len(),
|
||||
project_count
|
||||
)
|
||||
};
|
||||
|
||||
if std::time::Instant::now() >= deadline {
|
||||
return Err(CreateApiKeyError::message(format!(
|
||||
"Timed out waiting for an organization and default project. Last observed state: {last_state}"
|
||||
)));
|
||||
}
|
||||
let remaining_seconds = deadline
|
||||
.saturating_duration_since(std::time::Instant::now())
|
||||
.as_secs();
|
||||
let sleep_seconds = poll_interval_seconds.min(remaining_seconds.max(1));
|
||||
tokio::time::sleep(Duration::from_secs(sleep_seconds)).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn organizations_by_preference(organizations: &[Organization]) -> Vec<&Organization> {
|
||||
let mut ordered_organizations = organizations.iter().enumerate().collect::<Vec<_>>();
|
||||
ordered_organizations.sort_by_key(|(index, organization)| {
|
||||
let rank = if organization.is_default {
|
||||
0
|
||||
} else if organization.personal {
|
||||
1
|
||||
} else {
|
||||
2
|
||||
};
|
||||
(rank, *index)
|
||||
});
|
||||
ordered_organizations
|
||||
.into_iter()
|
||||
.map(|(_, organization)| organization)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn find_default_project(projects: &[Project]) -> Option<&Project> {
|
||||
projects.iter().find(|project| project.is_initial)
|
||||
}
|
||||
|
||||
async fn create_project_api_key(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
target: &ProjectApiKeyTarget,
|
||||
key_name: &str,
|
||||
) -> Result<CreateProjectApiKeyResponse, CreateApiKeyError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/organizations/{}/projects/{}/api_keys",
|
||||
api_base.trim_end_matches('/'),
|
||||
urlencoding::encode(&target.organization_id),
|
||||
urlencoding::encode(&target.project_id)
|
||||
);
|
||||
execute_json(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(session_key)
|
||||
.json(&serde_json::json!({
|
||||
"action": "create",
|
||||
"name": key_name,
|
||||
})),
|
||||
"POST",
|
||||
&url,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_json<T>(
|
||||
request: reqwest::RequestBuilder,
|
||||
method: &str,
|
||||
url: &str,
|
||||
) -> Result<T, CreateApiKeyError>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| CreateApiKeyError::message(format!("Network error calling {url}: {err}")))?;
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
CreateApiKeyError::message(format!("Failed reading response from {url}: {err}"))
|
||||
})?;
|
||||
if !status.is_success() {
|
||||
return Err(CreateApiKeyError::api(
|
||||
format!("{method} {url} failed with HTTP {status}"),
|
||||
String::from_utf8_lossy(&body).into_owned(),
|
||||
));
|
||||
}
|
||||
serde_json::from_slice(&body)
|
||||
.map_err(|err| CreateApiKeyError::message(format!("{url} returned invalid JSON: {err}")))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthTokens {
|
||||
#[serde(rename = "id_token")]
|
||||
_id_token: String,
|
||||
access_token: String,
|
||||
#[serde(rename = "refresh_token")]
|
||||
_refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OnboardingLoginResponse {
|
||||
user: OnboardingUser,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OnboardingUser {
|
||||
session: OnboardingSession,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OnboardingSession {
|
||||
sensitive_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DataList<T> {
|
||||
data: Vec<T>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||
struct Organization {
|
||||
id: String,
|
||||
title: Option<String>,
|
||||
#[serde(default)]
|
||||
is_default: bool,
|
||||
#[serde(default)]
|
||||
personal: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||
struct Project {
|
||||
id: String,
|
||||
title: Option<String>,
|
||||
#[serde(default)]
|
||||
is_initial: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct ProjectApiKeyTarget {
|
||||
organization_id: String,
|
||||
organization_title: Option<String>,
|
||||
project_id: String,
|
||||
project_title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreateProjectApiKeyResponse {
|
||||
key: CreatedProjectApiKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreatedProjectApiKey {
|
||||
sensitive_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CreateApiKeyError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl CreateApiKeyError {
|
||||
fn message(message: String) -> Self {
|
||||
Self { message }
|
||||
}
|
||||
|
||||
fn api(message: String, body: String) -> Self {
|
||||
Self {
|
||||
message: format!("{message}: {body}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CreateApiKeyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CreateApiKeyError {}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "create_api_key_tests.rs"]
|
||||
mod tests;
|
||||
182
codex-rs/login/src/create_api_key_tests.rs
Normal file
182
codex-rs/login/src/create_api_key_tests.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
use wiremock::matchers::header;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
use wiremock::matchers::query_param;
|
||||
|
||||
#[test]
|
||||
fn organizations_by_preference_orders_default_then_personal_then_input_order() {
|
||||
let organizations = vec![
|
||||
Organization {
|
||||
id: "org-first".to_string(),
|
||||
title: Some("First".to_string()),
|
||||
is_default: false,
|
||||
personal: false,
|
||||
},
|
||||
Organization {
|
||||
id: "org-personal".to_string(),
|
||||
title: Some("Personal".to_string()),
|
||||
is_default: false,
|
||||
personal: true,
|
||||
},
|
||||
Organization {
|
||||
id: "org-default".to_string(),
|
||||
title: Some("Default".to_string()),
|
||||
is_default: true,
|
||||
personal: false,
|
||||
},
|
||||
];
|
||||
|
||||
let selected = organizations_by_preference(&organizations);
|
||||
|
||||
assert_eq!(
|
||||
selected,
|
||||
vec![&organizations[2], &organizations[1], &organizations[0]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_default_project_returns_initial_project() {
|
||||
let projects = vec![
|
||||
Project {
|
||||
id: "proj-secondary".to_string(),
|
||||
title: Some("Secondary".to_string()),
|
||||
is_initial: false,
|
||||
},
|
||||
Project {
|
||||
id: "proj-default".to_string(),
|
||||
title: Some("Default".to_string()),
|
||||
is_initial: true,
|
||||
},
|
||||
];
|
||||
|
||||
let selected = find_default_project(&projects);
|
||||
|
||||
assert_eq!(selected, projects.get(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_api_key_from_authorization_code_creates_api_key() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/oauth/token"))
|
||||
.and(header("content-type", "application/x-www-form-urlencoded"))
|
||||
.and(body_string_contains("client_id=client-123"))
|
||||
.and(body_string_contains("code_verifier=verifier-123"))
|
||||
.and(body_string_contains("code=auth-code-123"))
|
||||
.and(body_string_contains("grant_type=authorization_code"))
|
||||
.and(body_string_contains(
|
||||
"redirect_uri=http%3A%2F%2Flocalhost%3A5000%2Fauth%2Fcallback",
|
||||
))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"id_token": "id-token-123",
|
||||
"access_token": "oauth-access-123",
|
||||
"refresh_token": "oauth-refresh-123",
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/dashboard/onboarding/login"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"user": {
|
||||
"session": {
|
||||
"sensitive_id": "session-123",
|
||||
}
|
||||
}
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/v1/organizations"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"data": [
|
||||
{
|
||||
"id": "org-default",
|
||||
"title": "Default Org",
|
||||
"is_default": true,
|
||||
},
|
||||
{
|
||||
"id": "org-secondary",
|
||||
"title": "Secondary Org",
|
||||
}
|
||||
]
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/dashboard/organizations/org-default/projects"))
|
||||
.and(query_param("detail", "basic"))
|
||||
.and(query_param("limit", "100"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"data": []
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/dashboard/organizations/org-secondary/projects"))
|
||||
.and(query_param("detail", "basic"))
|
||||
.and(query_param("limit", "100"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"data": [
|
||||
{
|
||||
"id": "proj-default",
|
||||
"title": "Default Project",
|
||||
"is_initial": true,
|
||||
}
|
||||
]
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path(
|
||||
"/dashboard/organizations/org-secondary/projects/proj-default/api_keys",
|
||||
))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
||||
"key": {
|
||||
"sensitive_id": "sk-proj-123",
|
||||
}
|
||||
})))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let options = CreateApiKeyOptions {
|
||||
issuer: server.uri(),
|
||||
client_id: "client-123".to_string(),
|
||||
audience: PLATFORM_AUDIENCE.to_string(),
|
||||
api_base: server.uri(),
|
||||
app: APP.to_string(),
|
||||
callback_port: CALLBACK_PORT,
|
||||
scope: SCOPE.to_string(),
|
||||
api_key_name: PROJECT_API_KEY_NAME.to_string(),
|
||||
project_poll_interval_seconds: 1,
|
||||
project_poll_timeout_seconds: 5,
|
||||
};
|
||||
let client = build_http_client().expect("client");
|
||||
|
||||
let output = create_api_key_from_authorization_code(
|
||||
&client,
|
||||
&options,
|
||||
"http://localhost:5000/auth/callback",
|
||||
"verifier-123",
|
||||
"auth-code-123",
|
||||
)
|
||||
.await
|
||||
.expect("provision");
|
||||
|
||||
assert_eq!(
|
||||
output,
|
||||
CreatedApiKey {
|
||||
organization_id: "org-secondary".to_string(),
|
||||
organization_title: Some("Secondary Org".to_string()),
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: Some("Default Project".to_string()),
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
pub mod auth;
|
||||
pub mod token_data;
|
||||
|
||||
mod create_api_key;
|
||||
mod device_code_auth;
|
||||
mod oauth_callback_server;
|
||||
mod pkce;
|
||||
mod server;
|
||||
|
||||
@@ -10,9 +12,9 @@ pub use device_code_auth::DeviceCode;
|
||||
pub use device_code_auth::complete_device_code_login;
|
||||
pub use device_code_auth::request_device_code;
|
||||
pub use device_code_auth::run_device_code_login;
|
||||
pub use oauth_callback_server::ShutdownHandle;
|
||||
pub use server::LoginServer;
|
||||
pub use server::ServerOptions;
|
||||
pub use server::ShutdownHandle;
|
||||
pub use server::run_login_server;
|
||||
|
||||
pub use auth::AuthConfig;
|
||||
@@ -34,4 +36,8 @@ pub use auth::logout;
|
||||
pub use auth::read_openai_api_key_from_env;
|
||||
pub use auth::save_auth;
|
||||
pub use codex_app_server_protocol::AuthMode;
|
||||
pub use create_api_key::CreateApiKeyError;
|
||||
pub use create_api_key::CreatedApiKey;
|
||||
pub use create_api_key::PendingCreateApiKey;
|
||||
pub use create_api_key::start_create_api_key;
|
||||
pub use token_data::TokenData;
|
||||
|
||||
504
codex-rs/login/src/oauth_callback_server.rs
Normal file
504
codex-rs/login/src/oauth_callback_server.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! Shared localhost OAuth callback server machinery.
|
||||
//!
|
||||
//! This module owns the reusable bind/listen/response loop used by OAuth-style browser flows.
|
||||
|
||||
use std::future::Future;
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use rand::RngCore;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Request;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tiny_http::StatusCode;
|
||||
|
||||
use crate::pkce::PkceCodes;
|
||||
use crate::pkce::generate_pkce;
|
||||
|
||||
/// Strategy for handling a callback port that is already in use.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum PortConflictStrategy {
|
||||
/// Attempt to cancel a previous callback server on the same port and retry.
|
||||
CancelPrevious,
|
||||
/// Return an error immediately without sending any request to the occupied port.
|
||||
Fail,
|
||||
}
|
||||
|
||||
/// Handle used to signal the callback server loop to exit.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ShutdownHandle {
|
||||
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl ShutdownHandle {
|
||||
/// Signals the server loop to terminate.
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown_notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle for a running authorization-code callback server.
|
||||
pub(crate) struct AuthorizationCodeServer {
|
||||
pub auth_url: String,
|
||||
pub actual_port: u16,
|
||||
pub redirect_uri: String,
|
||||
code_verifier: String,
|
||||
server_handle: tokio::task::JoinHandle<io::Result<String>>,
|
||||
shutdown_handle: ShutdownHandle,
|
||||
}
|
||||
|
||||
impl AuthorizationCodeServer {
|
||||
pub fn open_browser(&self) -> bool {
|
||||
webbrowser::open(&self.auth_url).is_ok()
|
||||
}
|
||||
|
||||
pub fn code_verifier(&self) -> &str {
|
||||
&self.code_verifier
|
||||
}
|
||||
|
||||
pub async fn wait_for_code(self, timeout: Duration) -> io::Result<String> {
|
||||
let AuthorizationCodeServer {
|
||||
server_handle,
|
||||
shutdown_handle,
|
||||
..
|
||||
} = self;
|
||||
let server_handle = server_handle;
|
||||
tokio::pin!(server_handle);
|
||||
|
||||
tokio::select! {
|
||||
result = &mut server_handle => {
|
||||
result
|
||||
.map_err(|err| io::Error::other(format!("authorization-code server thread panicked: {err:?}")))?
|
||||
}
|
||||
_ = tokio::time::sleep(timeout) => {
|
||||
shutdown_handle.shutdown();
|
||||
let _ = server_handle.await;
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"OAuth flow timed out waiting for the browser callback.",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn start_authorization_code_server<F>(
|
||||
port: u16,
|
||||
port_conflict_strategy: PortConflictStrategy,
|
||||
callback_path: &str,
|
||||
force_state: Option<String>,
|
||||
auth_url_builder: F,
|
||||
) -> io::Result<AuthorizationCodeServer>
|
||||
where
|
||||
F: FnOnce(&str, &PkceCodes, &str) -> io::Result<String>,
|
||||
{
|
||||
let pkce = generate_pkce();
|
||||
let state = force_state.unwrap_or_else(generate_state);
|
||||
let callback_path = callback_path.to_string();
|
||||
|
||||
let (server, actual_port, rx) = bind_server_with_request_channel(port, port_conflict_strategy)?;
|
||||
let redirect_uri = format!("http://localhost:{actual_port}{callback_path}");
|
||||
let auth_url = match auth_url_builder(&redirect_uri, &pkce, &state) {
|
||||
Ok(auth_url) => auth_url,
|
||||
Err(err) => {
|
||||
server.unblock();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
let (server_handle, shutdown_handle) = spawn_callback_server_loop(
|
||||
server,
|
||||
rx,
|
||||
"Authentication was not completed",
|
||||
move |url_raw| {
|
||||
let callback_path = callback_path.clone();
|
||||
let state = state.clone();
|
||||
async move { process_authorization_code_request(&url_raw, &callback_path, &state) }
|
||||
},
|
||||
);
|
||||
|
||||
Ok(AuthorizationCodeServer {
|
||||
auth_url,
|
||||
actual_port,
|
||||
redirect_uri,
|
||||
code_verifier: pkce.code_verifier,
|
||||
server_handle,
|
||||
shutdown_handle,
|
||||
})
|
||||
}
|
||||
|
||||
/// Internal callback handling outcome.
|
||||
pub(crate) enum HandledRequest<T> {
|
||||
Response(Response<Cursor<Vec<u8>>>),
|
||||
RedirectWithHeader(Header),
|
||||
ResponseAndExit {
|
||||
status: StatusCode,
|
||||
headers: Vec<Header>,
|
||||
body: Vec<u8>,
|
||||
result: io::Result<T>,
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) fn bind_server_with_request_channel(
|
||||
port: u16,
|
||||
port_conflict_strategy: PortConflictStrategy,
|
||||
) -> io::Result<(Arc<Server>, u16, tokio::sync::mpsc::Receiver<Request>)> {
|
||||
let server = bind_server(port, port_conflict_strategy)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"Unable to determine the server port",
|
||||
));
|
||||
}
|
||||
};
|
||||
let server = Arc::new(server);
|
||||
|
||||
// Map blocking reads from server.recv() to an async channel.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
match tx.blocking_send(request) {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
eprintln!("Failed to send request to channel: {error}");
|
||||
return Err(io::Error::other("Failed to send request to channel"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
Ok((server, actual_port, rx))
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_callback_server_loop<T, F, Fut>(
|
||||
server: Arc<Server>,
|
||||
mut rx: tokio::sync::mpsc::Receiver<Request>,
|
||||
incomplete_message: &'static str,
|
||||
mut process_request: F,
|
||||
) -> (tokio::task::JoinHandle<io::Result<T>>, ShutdownHandle)
|
||||
where
|
||||
T: Send + 'static,
|
||||
F: FnMut(String) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = HandledRequest<T>> + Send + 'static,
|
||||
{
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
break Err(io::Error::other(incomplete_message));
|
||||
}
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
break Err(io::Error::other(incomplete_message));
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response = process_request(url_raw).await;
|
||||
|
||||
if let Some(result) = respond_to_request(req, response).await {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the server is unblocked so the thread dedicated to
|
||||
// running `server.recv()` in a loop exits cleanly.
|
||||
server.unblock();
|
||||
result
|
||||
})
|
||||
};
|
||||
|
||||
(server_handle, ShutdownHandle { shutdown_notify })
|
||||
}
|
||||
|
||||
async fn respond_to_request<T>(req: Request, response: HandledRequest<T>) -> Option<io::Result<T>> {
|
||||
match response {
|
||||
HandledRequest::Response(response) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
None
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
None
|
||||
}
|
||||
HandledRequest::ResponseAndExit {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
result,
|
||||
} => {
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
send_response_with_disconnect(req, status, headers, body)
|
||||
})
|
||||
.await;
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// tiny_http filters `Connection` headers out of `Response` objects, so using
|
||||
/// `req.respond` never informs the client (or the library) that a keep-alive
|
||||
/// socket should be closed. That leaves the per-connection worker parked in a
|
||||
/// loop waiting for more requests, which in turn causes the next login attempt
|
||||
/// to hang on the old connection. This helper bypasses tiny_http’s response
|
||||
/// machinery: it extracts the raw writer, prints the HTTP response manually,
|
||||
/// and always appends `Connection: close`, ensuring the socket is closed from
|
||||
/// the server side. Ideally, tiny_http would provide an API to control
|
||||
/// server-side connection persistence, but it does not.
|
||||
fn send_response_with_disconnect(
|
||||
req: Request,
|
||||
status: StatusCode,
|
||||
mut headers: Vec<Header>,
|
||||
body: Vec<u8>,
|
||||
) -> io::Result<()> {
|
||||
let mut writer = req.into_writer();
|
||||
let reason = status.default_reason_phrase();
|
||||
write!(writer, "HTTP/1.1 {} {}\r\n", status.0, reason)?;
|
||||
headers.retain(|h| !h.field.equiv("Connection"));
|
||||
if let Ok(close_header) = Header::from_bytes(&b"Connection"[..], &b"close"[..]) {
|
||||
headers.push(close_header);
|
||||
}
|
||||
|
||||
let content_length_value = format!("{}", body.len());
|
||||
if let Ok(content_length_header) =
|
||||
Header::from_bytes(&b"Content-Length"[..], content_length_value.as_bytes())
|
||||
{
|
||||
headers.push(content_length_header);
|
||||
}
|
||||
|
||||
for header in headers {
|
||||
write!(
|
||||
writer,
|
||||
"{}: {}\r\n",
|
||||
header.field.as_str(),
|
||||
header.value.as_str()
|
||||
)?;
|
||||
}
|
||||
|
||||
writer.write_all(b"\r\n")?;
|
||||
writer.write_all(&body)?;
|
||||
writer.flush()
|
||||
}
|
||||
|
||||
pub(crate) fn generate_state() -> String {
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill_bytes(&mut bytes);
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
fn send_cancel_request(port: u16) -> io::Result<()> {
|
||||
let addr: SocketAddr = format!("127.0.0.1:{port}")
|
||||
.parse()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(2)))?;
|
||||
|
||||
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
|
||||
stream.write_all(format!("Host: 127.0.0.1:{port}\r\n").as_bytes())?;
|
||||
stream.write_all(b"Connection: close\r\n\r\n")?;
|
||||
|
||||
let mut buf = [0u8; 64];
|
||||
let _ = stream.read(&mut buf);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_server(port: u16, port_conflict_strategy: PortConflictStrategy) -> io::Result<Server> {
|
||||
let bind_address = format!("127.0.0.1:{port}");
|
||||
let mut cancel_attempted = false;
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: u32 = 10;
|
||||
const RETRY_DELAY: Duration = Duration::from_millis(200);
|
||||
|
||||
loop {
|
||||
match Server::http(&bind_address) {
|
||||
Ok(server) => return Ok(server),
|
||||
Err(err) => {
|
||||
attempts += 1;
|
||||
let is_addr_in_use = err
|
||||
.downcast_ref::<io::Error>()
|
||||
.map(|io_err| io_err.kind() == io::ErrorKind::AddrInUse)
|
||||
.unwrap_or(false);
|
||||
|
||||
// If the address is in use, there is probably another instance of the callback
|
||||
// server running. Attempt to cancel it and retry.
|
||||
if is_addr_in_use {
|
||||
if port_conflict_strategy == PortConflictStrategy::Fail {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
format!("Port {bind_address} is already in use"),
|
||||
));
|
||||
}
|
||||
|
||||
if !cancel_attempted {
|
||||
cancel_attempted = true;
|
||||
if let Err(cancel_err) = send_cancel_request(port) {
|
||||
eprintln!("Failed to cancel previous callback server: {cancel_err}");
|
||||
}
|
||||
}
|
||||
|
||||
thread::sleep(RETRY_DELAY);
|
||||
|
||||
if attempts >= MAX_ATTEMPTS {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
format!("Port {bind_address} is already in use"),
|
||||
));
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_authorization_code_request(
|
||||
url_raw: &str,
|
||||
callback_path: &str,
|
||||
expected_state: &str,
|
||||
) -> HandledRequest<String> {
|
||||
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
||||
Ok(u) => u,
|
||||
Err(err) => {
|
||||
return HandledRequest::Response(
|
||||
Response::from_string(format!("Bad Request: {err}")).with_status_code(400),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
match parsed_url.path() {
|
||||
"/cancel" => HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(200),
|
||||
headers: Vec::new(),
|
||||
body: b"Authentication cancelled".to_vec(),
|
||||
result: Err(io::Error::new(
|
||||
io::ErrorKind::Interrupted,
|
||||
"Authentication cancelled",
|
||||
)),
|
||||
},
|
||||
path if path == callback_path => {
|
||||
let params: std::collections::HashMap<String, String> =
|
||||
parsed_url.query_pairs().into_owned().collect();
|
||||
|
||||
if params.get("state").map(String::as_str) != Some(expected_state) {
|
||||
let mut response = Response::from_string(
|
||||
"<h1>State mismatch</h1><p>Return to your terminal and try again.</p>",
|
||||
)
|
||||
.with_status_code(400);
|
||||
if let Some(header) = html_headers().into_iter().next() {
|
||||
response = response.with_header(header);
|
||||
}
|
||||
return HandledRequest::Response(response);
|
||||
}
|
||||
|
||||
if let Some(error_code) = params.get("error") {
|
||||
let message = authorization_code_error_message(
|
||||
error_code,
|
||||
params.get("error_description").map(String::as_str),
|
||||
);
|
||||
return HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(403),
|
||||
headers: html_headers(),
|
||||
body: b"<h1>Authentication failed</h1><p>Return to your terminal.</p>".to_vec(),
|
||||
result: Err(io::Error::new(io::ErrorKind::PermissionDenied, message)),
|
||||
};
|
||||
}
|
||||
|
||||
match params.get("code") {
|
||||
Some(code) if !code.is_empty() => HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(200),
|
||||
headers: html_headers(),
|
||||
body:
|
||||
b"<h1>Authentication complete</h1><p>You can return to your terminal.</p>"
|
||||
.to_vec(),
|
||||
result: Ok(code.clone()),
|
||||
},
|
||||
_ => HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(400),
|
||||
headers: html_headers(),
|
||||
body: b"<h1>Missing authorization code</h1><p>Return to your terminal.</p>"
|
||||
.to_vec(),
|
||||
result: Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Missing authorization code. Authentication could not be completed.",
|
||||
)),
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
||||
}
|
||||
}
|
||||
|
||||
fn html_headers() -> Vec<Header> {
|
||||
match Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]) {
|
||||
Ok(header) => vec![header],
|
||||
Err(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn authorization_code_error_message(error_code: &str, error_description: Option<&str>) -> String {
|
||||
if let Some(description) = error_description
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
return format!("Authentication failed: {description}");
|
||||
}
|
||||
|
||||
format!("Authentication failed: {error_code}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bind_server_fails_without_canceling_when_port_conflict_strategy_is_fail() {
|
||||
let listener =
|
||||
std::net::TcpListener::bind("127.0.0.1:0").expect("bind ephemeral test listener");
|
||||
let port = listener.local_addr().expect("read local addr").port();
|
||||
|
||||
let error = match bind_server(port, PortConflictStrategy::Fail) {
|
||||
Ok(_) => panic!("expected occupied port to fail immediately"),
|
||||
Err(error) => error,
|
||||
};
|
||||
|
||||
assert_eq!(error.kind(), io::ErrorKind::AddrInUse);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_authorization_code_request_keeps_server_running_on_state_mismatch() {
|
||||
let response = process_authorization_code_request(
|
||||
"/auth/callback?state=wrong-state&code=auth-code",
|
||||
"/auth/callback",
|
||||
"expected-state",
|
||||
);
|
||||
|
||||
match response {
|
||||
HandledRequest::Response(_) => {}
|
||||
HandledRequest::RedirectWithHeader(_) | HandledRequest::ResponseAndExit { .. } => {
|
||||
panic!("state mismatch should return a response without exiting")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,17 +11,9 @@
|
||||
//! This module therefore keeps the user-facing error path and the structured-log path separate.
|
||||
//! Returned `io::Error` values still carry the detail needed by CLI/browser callers, while
|
||||
//! structured logs only emit explicitly reviewed fields plus redacted URL/error values.
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::auth::AuthCredentialsStoreMode;
|
||||
use crate::auth::AuthDotJson;
|
||||
@@ -35,17 +27,21 @@ use base64::Engine;
|
||||
use chrono::Utc;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use rand::RngCore;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Request;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tiny_http::StatusCode;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::oauth_callback_server::HandledRequest;
|
||||
use crate::oauth_callback_server::PortConflictStrategy;
|
||||
use crate::oauth_callback_server::ShutdownHandle;
|
||||
use crate::oauth_callback_server::bind_server_with_request_channel;
|
||||
use crate::oauth_callback_server::generate_state;
|
||||
use crate::oauth_callback_server::spawn_callback_server_loop;
|
||||
|
||||
const DEFAULT_ISSUER: &str = "https://auth.openai.com";
|
||||
const DEFAULT_PORT: u16 = 1455;
|
||||
|
||||
@@ -110,36 +106,13 @@ impl LoginServer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle used to signal the login server loop to exit.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ShutdownHandle {
|
||||
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl ShutdownHandle {
|
||||
/// Signals the login loop to terminate.
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown_notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a local callback server and returns the browser auth URL.
|
||||
pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let pkce = generate_pkce();
|
||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||
|
||||
let server = bind_server(opts.port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"Unable to determine the server port",
|
||||
));
|
||||
}
|
||||
};
|
||||
let server = Arc::new(server);
|
||||
|
||||
let (server, actual_port, rx) =
|
||||
bind_server_with_request_channel(opts.port, PortConflictStrategy::CancelPrevious)?;
|
||||
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
|
||||
let auth_url = build_authorize_url(
|
||||
&opts.issuer,
|
||||
@@ -153,100 +126,25 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
if opts.open_browser {
|
||||
let _ = webbrowser::open(&auth_url);
|
||||
}
|
||||
|
||||
// Map blocking reads from server.recv() to an async channel.
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
match tx.blocking_send(request) {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
eprintln!("Failed to send request to channel: {error}");
|
||||
return Err(io::Error::other("Failed to send request to channel"));
|
||||
}
|
||||
}
|
||||
let (server_handle, shutdown_handle) =
|
||||
spawn_callback_server_loop(server, rx, "Login was not completed", move |url_raw| {
|
||||
let redirect_uri = redirect_uri.clone();
|
||||
let state = state.clone();
|
||||
let opts = opts.clone();
|
||||
let pkce = pkce.clone();
|
||||
async move {
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
let server = server;
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
break Err(io::Error::other("Login was not completed"));
|
||||
}
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
break Err(io::Error::other("Login was not completed"));
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
let exit_result = match response {
|
||||
HandledRequest::Response(response) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(response)).await;
|
||||
None
|
||||
}
|
||||
HandledRequest::ResponseAndExit {
|
||||
headers,
|
||||
body,
|
||||
result,
|
||||
} => {
|
||||
let _ = tokio::task::spawn_blocking(move || {
|
||||
send_response_with_disconnect(req, headers, body)
|
||||
})
|
||||
.await;
|
||||
Some(result)
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(result) = exit_result {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the server is unblocked so the thread dedicated to
|
||||
// running `server.recv()` in a loop exits cleanly.
|
||||
server.unblock();
|
||||
result
|
||||
})
|
||||
};
|
||||
});
|
||||
|
||||
Ok(LoginServer {
|
||||
auth_url,
|
||||
actual_port,
|
||||
server_handle,
|
||||
shutdown_handle: ShutdownHandle { shutdown_notify },
|
||||
shutdown_handle,
|
||||
})
|
||||
}
|
||||
|
||||
/// Internal callback handling outcome.
|
||||
enum HandledRequest {
|
||||
Response(Response<Cursor<Vec<u8>>>),
|
||||
RedirectWithHeader(Header),
|
||||
ResponseAndExit {
|
||||
headers: Vec<Header>,
|
||||
body: Vec<u8>,
|
||||
result: io::Result<()>,
|
||||
},
|
||||
}
|
||||
|
||||
async fn process_request(
|
||||
url_raw: &str,
|
||||
opts: &ServerOptions,
|
||||
@@ -254,7 +152,7 @@ async fn process_request(
|
||||
pkce: &PkceCodes,
|
||||
actual_port: u16,
|
||||
state: &str,
|
||||
) -> HandledRequest {
|
||||
) -> HandledRequest<()> {
|
||||
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
@@ -392,6 +290,7 @@ async fn process_request(
|
||||
"/success" => {
|
||||
let body = include_str!("assets/success.html");
|
||||
HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(200),
|
||||
headers: match Header::from_bytes(
|
||||
&b"Content-Type"[..],
|
||||
&b"text/html; charset=utf-8"[..],
|
||||
@@ -404,6 +303,7 @@ async fn process_request(
|
||||
}
|
||||
}
|
||||
"/cancel" => HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(200),
|
||||
headers: Vec::new(),
|
||||
body: b"Login cancelled".to_vec(),
|
||||
result: Err(io::Error::new(
|
||||
@@ -415,50 +315,6 @@ async fn process_request(
|
||||
}
|
||||
}
|
||||
|
||||
/// tiny_http filters `Connection` headers out of `Response` objects, so using
|
||||
/// `req.respond` never informs the client (or the library) that a keep-alive
|
||||
/// socket should be closed. That leaves the per-connection worker parked in a
|
||||
/// loop waiting for more requests, which in turn causes the next login attempt
|
||||
/// to hang on the old connection. This helper bypasses tiny_http’s response
|
||||
/// machinery: it extracts the raw writer, prints the HTTP response manually,
|
||||
/// and always appends `Connection: close`, ensuring the socket is closed from
|
||||
/// the server side. Ideally, tiny_http would provide an API to control
|
||||
/// server-side connection persistence, but it does not.
|
||||
fn send_response_with_disconnect(
|
||||
req: Request,
|
||||
mut headers: Vec<Header>,
|
||||
body: Vec<u8>,
|
||||
) -> io::Result<()> {
|
||||
let status = StatusCode(200);
|
||||
let mut writer = req.into_writer();
|
||||
let reason = status.default_reason_phrase();
|
||||
write!(writer, "HTTP/1.1 {} {}\r\n", status.0, reason)?;
|
||||
headers.retain(|h| !h.field.equiv("Connection"));
|
||||
if let Ok(close_header) = Header::from_bytes(&b"Connection"[..], &b"close"[..]) {
|
||||
headers.push(close_header);
|
||||
}
|
||||
|
||||
let content_length_value = format!("{}", body.len());
|
||||
if let Ok(content_length_header) =
|
||||
Header::from_bytes(&b"Content-Length"[..], content_length_value.as_bytes())
|
||||
{
|
||||
headers.push(content_length_header);
|
||||
}
|
||||
|
||||
for header in headers {
|
||||
write!(
|
||||
writer,
|
||||
"{}: {}\r\n",
|
||||
header.field.as_str(),
|
||||
header.value.as_str()
|
||||
)?;
|
||||
}
|
||||
|
||||
writer.write_all(b"\r\n")?;
|
||||
writer.write_all(&body)?;
|
||||
writer.flush()
|
||||
}
|
||||
|
||||
fn build_authorize_url(
|
||||
issuer: &str,
|
||||
client_id: &str,
|
||||
@@ -497,74 +353,6 @@ fn build_authorize_url(
|
||||
format!("{issuer}/oauth/authorize?{qs}")
|
||||
}
|
||||
|
||||
fn generate_state() -> String {
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill_bytes(&mut bytes);
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
fn send_cancel_request(port: u16) -> io::Result<()> {
|
||||
let addr: SocketAddr = format!("127.0.0.1:{port}")
|
||||
.parse()
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
|
||||
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(2)))?;
|
||||
|
||||
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
|
||||
stream.write_all(format!("Host: 127.0.0.1:{port}\r\n").as_bytes())?;
|
||||
stream.write_all(b"Connection: close\r\n\r\n")?;
|
||||
|
||||
let mut buf = [0u8; 64];
|
||||
let _ = stream.read(&mut buf);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_server(port: u16) -> io::Result<Server> {
|
||||
let bind_address = format!("127.0.0.1:{port}");
|
||||
let mut cancel_attempted = false;
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: u32 = 10;
|
||||
const RETRY_DELAY: Duration = Duration::from_millis(200);
|
||||
|
||||
loop {
|
||||
match Server::http(&bind_address) {
|
||||
Ok(server) => return Ok(server),
|
||||
Err(err) => {
|
||||
attempts += 1;
|
||||
let is_addr_in_use = err
|
||||
.downcast_ref::<io::Error>()
|
||||
.map(|io_err| io_err.kind() == io::ErrorKind::AddrInUse)
|
||||
.unwrap_or(false);
|
||||
|
||||
// If the address is in use, there is probably another instance of the login server
|
||||
// running. Attempt to cancel it and retry.
|
||||
if is_addr_in_use {
|
||||
if !cancel_attempted {
|
||||
cancel_attempted = true;
|
||||
if let Err(cancel_err) = send_cancel_request(port) {
|
||||
eprintln!("Failed to cancel previous login server: {cancel_err}");
|
||||
}
|
||||
}
|
||||
|
||||
thread::sleep(RETRY_DELAY);
|
||||
|
||||
if attempts >= MAX_ATTEMPTS {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
format!("Port {bind_address} is already in use"),
|
||||
));
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokens returned by the OAuth authorization-code exchange.
|
||||
pub(crate) struct ExchangedTokens {
|
||||
pub id_token: String,
|
||||
@@ -888,13 +676,14 @@ fn login_error_response(
|
||||
kind: io::ErrorKind,
|
||||
error_code: Option<&str>,
|
||||
error_description: Option<&str>,
|
||||
) -> HandledRequest {
|
||||
) -> HandledRequest<()> {
|
||||
let mut headers = Vec::new();
|
||||
if let Ok(header) = Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]) {
|
||||
headers.push(header);
|
||||
}
|
||||
let body = render_login_error_page(message, error_code, error_description);
|
||||
HandledRequest::ResponseAndExit {
|
||||
status: StatusCode(200),
|
||||
headers,
|
||||
body,
|
||||
result: Err(io::Error::new(kind, message.to_string())),
|
||||
|
||||
@@ -3113,6 +3113,22 @@ impl App {
|
||||
AppEvent::RefreshConnectors { force_refetch } => {
|
||||
self.chat_widget.refresh_connectors(force_refetch);
|
||||
}
|
||||
AppEvent::SetDependencyEnv {
|
||||
thread_id,
|
||||
values,
|
||||
result_tx,
|
||||
} => {
|
||||
let result = match self.server.get_thread(thread_id).await {
|
||||
Ok(thread) => {
|
||||
thread.set_dependency_env(values).await;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => Err(format!(
|
||||
"failed to load Codex thread {thread_id} for dependency env update: {err}"
|
||||
)),
|
||||
};
|
||||
let _ = result_tx.send(result);
|
||||
}
|
||||
AppEvent::PluginInstallAuthAdvance { refresh_connectors } => {
|
||||
if refresh_connectors {
|
||||
self.chat_widget.refresh_connectors(/*force_refetch*/ true);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
//! Exit is modelled explicitly via `AppEvent::Exit(ExitMode)` so callers can request shutdown-first
|
||||
//! quits without reaching into the app loop or coupling to shutdown/exit sequencing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_app_server_protocol::PluginInstallResponse;
|
||||
@@ -37,6 +38,7 @@ use codex_protocol::config_types::ServiceTier;
|
||||
use codex_protocol::openai_models::ReasoningEffort;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum RealtimeAudioDeviceKind {
|
||||
@@ -168,6 +170,13 @@ pub(crate) enum AppEvent {
|
||||
force_refetch: bool,
|
||||
},
|
||||
|
||||
/// Add environment variables to the specified thread's dependency env override.
|
||||
SetDependencyEnv {
|
||||
thread_id: ThreadId,
|
||||
values: HashMap<String, String>,
|
||||
result_tx: oneshot::Sender<Result<(), String>>,
|
||||
},
|
||||
|
||||
/// Fetch plugin marketplace state for the provided working directory.
|
||||
FetchPluginsList {
|
||||
cwd: PathBuf,
|
||||
|
||||
@@ -291,6 +291,7 @@ use crate::status_indicator_widget::STATUS_DETAILS_DEFAULT_MAX_LINES;
|
||||
use crate::status_indicator_widget::StatusDetailsCapitalization;
|
||||
use crate::text_formatting::truncate_text;
|
||||
use crate::tui::FrameRequester;
|
||||
mod create_api_key;
|
||||
mod interrupts;
|
||||
use self::interrupts::InterruptManager;
|
||||
mod agent;
|
||||
@@ -4733,6 +4734,9 @@ impl ChatWidget {
|
||||
tx.send(AppEvent::DiffResult(text));
|
||||
});
|
||||
}
|
||||
SlashCommand::CreateApiKey => {
|
||||
self.start_create_api_key();
|
||||
}
|
||||
SlashCommand::Copy => {
|
||||
let Some(text) = self.last_copyable_output.as_deref() else {
|
||||
self.add_info_message(
|
||||
|
||||
307
codex-rs/tui/src/chatwidget/create_api_key.rs
Normal file
307
codex-rs/tui/src/chatwidget/create_api_key.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_core::auth::read_openai_api_key_from_env;
|
||||
use codex_login::CreatedApiKey;
|
||||
use codex_login::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_login::PendingCreateApiKey;
|
||||
use codex_login::start_create_api_key as start_create_api_key_flow;
|
||||
use codex_protocol::ThreadId;
|
||||
use ratatui::style::Stylize;
|
||||
use ratatui::text::Line;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use super::ChatWidget;
|
||||
use crate::app_event::AppEvent;
|
||||
use crate::app_event_sender::AppEventSender;
|
||||
use crate::clipboard_text;
|
||||
use crate::history_cell;
|
||||
use crate::history_cell::PlainHistoryCell;
|
||||
|
||||
impl ChatWidget {
|
||||
pub(crate) fn start_create_api_key(&mut self) {
|
||||
match start_create_api_key_command(self.thread_id(), self.app_event_tx.clone()) {
|
||||
Ok(start_message) => {
|
||||
self.add_to_history(start_message);
|
||||
self.request_redraw();
|
||||
}
|
||||
Err(err) => {
|
||||
self.add_error_message(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn start_create_api_key_command(
|
||||
thread_id: Option<ThreadId>,
|
||||
app_event_tx: AppEventSender,
|
||||
) -> Result<PlainHistoryCell, String> {
|
||||
let thread_id =
|
||||
thread_id.ok_or_else(|| "No active Codex thread for API key creation.".to_string())?;
|
||||
|
||||
if read_openai_api_key_from_env().is_some() {
|
||||
return Ok(existing_shell_api_key_message());
|
||||
}
|
||||
|
||||
let session = start_create_api_key_flow()
|
||||
.map_err(|err| format!("Failed to start API key creation: {err}"))?;
|
||||
let browser_opened = session.open_browser();
|
||||
let start_message =
|
||||
continue_in_browser_message(session.auth_url(), session.callback_port(), browser_opened);
|
||||
|
||||
let app_event_tx_for_task = app_event_tx;
|
||||
tokio::spawn(async move {
|
||||
let cell = complete_command(session, thread_id, app_event_tx_for_task.clone()).await;
|
||||
app_event_tx_for_task.send(AppEvent::InsertHistoryCell(Box::new(cell)));
|
||||
});
|
||||
|
||||
Ok(start_message)
|
||||
}
|
||||
|
||||
fn existing_shell_api_key_message() -> PlainHistoryCell {
|
||||
history_cell::new_info_event(
|
||||
format!(
|
||||
"{OPENAI_API_KEY_ENV_VAR} is already set in this Codex session; skipping API key creation."
|
||||
),
|
||||
Some(format!(
|
||||
"Unset {OPENAI_API_KEY_ENV_VAR} and run /create-api-key again if you want Codex to create a different key."
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
fn continue_in_browser_message(
|
||||
auth_url: &str,
|
||||
callback_port: u16,
|
||||
browser_opened: bool,
|
||||
) -> PlainHistoryCell {
|
||||
let mut lines = vec![
|
||||
vec![
|
||||
"• ".dim(),
|
||||
"Finish API key creation via your browser.".into(),
|
||||
]
|
||||
.into(),
|
||||
"".into(),
|
||||
];
|
||||
|
||||
if browser_opened {
|
||||
lines.push(
|
||||
" Codex tried to open this link for you."
|
||||
.dark_gray()
|
||||
.into(),
|
||||
);
|
||||
} else {
|
||||
lines.push(
|
||||
" Codex couldn't auto-open your browser, but the API key creation flow is still waiting."
|
||||
.dark_gray()
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
lines.push("".into());
|
||||
lines.push(" Open the following link to authenticate:".into());
|
||||
lines.push("".into());
|
||||
lines.push(Line::from(vec![
|
||||
" ".into(),
|
||||
auth_url.to_string().cyan().underlined(),
|
||||
]));
|
||||
lines.push("".into());
|
||||
lines.push(
|
||||
format!(" Codex will display the new {OPENAI_API_KEY_ENV_VAR} here and copy it to your clipboard.")
|
||||
.dark_gray()
|
||||
.into(),
|
||||
);
|
||||
lines.push("".into());
|
||||
lines.push(
|
||||
format!(
|
||||
" On a remote or headless machine, forward localhost:{callback_port} back to this Codex host before opening the link."
|
||||
)
|
||||
.dark_gray()
|
||||
.into(),
|
||||
);
|
||||
|
||||
PlainHistoryCell::new(lines)
|
||||
}
|
||||
|
||||
async fn complete_command(
|
||||
session: PendingCreateApiKey,
|
||||
thread_id: ThreadId,
|
||||
app_event_tx: AppEventSender,
|
||||
) -> PlainHistoryCell {
|
||||
let provisioned = match session.finish().await {
|
||||
Ok(provisioned) => provisioned,
|
||||
Err(err) => {
|
||||
return history_cell::new_error_event(format!("API key creation failed: {err}"));
|
||||
}
|
||||
};
|
||||
let copy_result = clipboard_text::copy_text_to_clipboard(&provisioned.project_api_key);
|
||||
let session_env_result =
|
||||
apply_api_key_to_current_session(&provisioned.project_api_key, thread_id, app_event_tx)
|
||||
.await;
|
||||
|
||||
success_cell(&provisioned, copy_result, session_env_result)
|
||||
}
|
||||
|
||||
async fn apply_api_key_to_current_session(
|
||||
api_key: &str,
|
||||
thread_id: ThreadId,
|
||||
app_event_tx: AppEventSender,
|
||||
) -> Result<(), String> {
|
||||
set_current_process_api_key(api_key);
|
||||
|
||||
let (result_tx, result_rx) = oneshot::channel();
|
||||
app_event_tx.send(AppEvent::SetDependencyEnv {
|
||||
thread_id,
|
||||
values: HashMap::from([(OPENAI_API_KEY_ENV_VAR.to_string(), api_key.to_string())]),
|
||||
result_tx,
|
||||
});
|
||||
|
||||
match result_rx.await {
|
||||
Ok(result) => result,
|
||||
Err(err) => Err(format!(
|
||||
"dependency env update response channel closed before completion: {err}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_current_process_api_key(api_key: &str) {
|
||||
// SAFETY: `/create-api-key` intentionally mutates process-global environment so the running
|
||||
// Codex session can observe `OPENAI_API_KEY` immediately. This is scoped to a single
|
||||
// user-triggered command, and spawned tool environments are updated separately through the
|
||||
// session dependency env override.
|
||||
unsafe {
|
||||
std::env::set_var(OPENAI_API_KEY_ENV_VAR, api_key);
|
||||
}
|
||||
}
|
||||
|
||||
fn success_cell(
|
||||
provisioned: &CreatedApiKey,
|
||||
copy_result: Result<(), String>,
|
||||
session_env_result: Result<(), String>,
|
||||
) -> PlainHistoryCell {
|
||||
let organization = provisioned
|
||||
.organization_title
|
||||
.clone()
|
||||
.unwrap_or_else(|| provisioned.organization_id.clone());
|
||||
let project = provisioned
|
||||
.default_project_title
|
||||
.clone()
|
||||
.unwrap_or_else(|| provisioned.default_project_id.clone());
|
||||
let masked_api_key = mask_api_key(&provisioned.project_api_key);
|
||||
let copy_status = match copy_result {
|
||||
Ok(()) => "Copied the full key to your clipboard.".to_string(),
|
||||
Err(err) => format!("Could not copy the key to your clipboard: {err}"),
|
||||
};
|
||||
let session_env_status = match session_env_result {
|
||||
Ok(()) => {
|
||||
format!("Set {OPENAI_API_KEY_ENV_VAR} in this Codex session for spawned commands.")
|
||||
}
|
||||
Err(err) => {
|
||||
format!("Could not set {OPENAI_API_KEY_ENV_VAR} in this Codex session: {err}")
|
||||
}
|
||||
};
|
||||
let hint = Some(format!("{copy_status} {session_env_status}"));
|
||||
|
||||
history_cell::new_info_event(
|
||||
format!("Created an API key for {organization} / {project}: {masked_api_key}"),
|
||||
hint,
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_api_key(api_key: &str) -> String {
|
||||
const UNMASKED_PREFIX_LEN: usize = 8;
|
||||
const UNMASKED_SUFFIX_LEN: usize = 4;
|
||||
|
||||
if api_key.len() <= UNMASKED_PREFIX_LEN + UNMASKED_SUFFIX_LEN {
|
||||
return api_key.to_string();
|
||||
}
|
||||
|
||||
format!(
|
||||
"{}...{}",
|
||||
&api_key[..UNMASKED_PREFIX_LEN],
|
||||
&api_key[api_key.len() - UNMASKED_SUFFIX_LEN..]
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::history_cell::HistoryCell;
|
||||
use insta::assert_snapshot;
|
||||
|
||||
#[test]
|
||||
fn success_cell_snapshot() {
|
||||
let cell = success_cell(
|
||||
&CreatedApiKey {
|
||||
organization_id: "org-default".to_string(),
|
||||
organization_title: Some("Default Org".to_string()),
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: Some("Default Project".to_string()),
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
},
|
||||
Ok(()),
|
||||
Ok(()),
|
||||
);
|
||||
|
||||
assert_snapshot!(render_cell(&cell));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn success_cell_snapshot_when_clipboard_copy_fails() {
|
||||
let cell = success_cell(
|
||||
&CreatedApiKey {
|
||||
organization_id: "org-default".to_string(),
|
||||
organization_title: None,
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: None,
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
},
|
||||
Err("clipboard unavailable".to_string()),
|
||||
Err("dependency env unavailable".to_string()),
|
||||
);
|
||||
|
||||
assert_snapshot!(render_cell(&cell));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn continue_in_browser_message_snapshot() {
|
||||
let cell = continue_in_browser_message(
|
||||
"https://auth.openai.com/oauth/authorize?client_id=abc",
|
||||
/*callback_port*/ 5000,
|
||||
/*browser_opened*/ false,
|
||||
);
|
||||
|
||||
assert_snapshot!(render_cell(&cell));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn existing_shell_api_key_message_mentions_openai_api_key() {
|
||||
let cell = existing_shell_api_key_message();
|
||||
|
||||
assert_eq!(
|
||||
render_cell(&cell),
|
||||
"• OPENAI_API_KEY is already set in this Codex session; skipping API key creation. Unset OPENAI_API_KEY and run /create-api-key again if you want Codex to create a different key."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn continue_in_browser_message_always_includes_the_auth_url() {
|
||||
let cell = continue_in_browser_message(
|
||||
"https://auth.example.com/oauth/authorize?state=abc",
|
||||
5000,
|
||||
/*browser_opened*/ false,
|
||||
);
|
||||
|
||||
assert!(render_cell(&cell).contains("https://auth.example.com/oauth/authorize?state=abc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_api_key_preserves_prefix_and_suffix() {
|
||||
assert_eq!(mask_api_key("sk-proj-1234567890"), "sk-proj-...7890");
|
||||
}
|
||||
|
||||
fn render_cell(cell: &PlainHistoryCell) -> String {
|
||||
cell.display_lines(120)
|
||||
.into_iter()
|
||||
.map(|line| line.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
---
|
||||
source: tui/src/chatwidget/create_api_key.rs
|
||||
assertion_line: 260
|
||||
expression: render_cell(&cell)
|
||||
---
|
||||
• Finish API key creation via your browser.
|
||||
|
||||
Codex couldn't auto-open your browser, but the API key creation flow is still waiting.
|
||||
|
||||
Open the following link to authenticate:
|
||||
|
||||
https://auth.openai.com/oauth/authorize?client_id=abc
|
||||
|
||||
Codex will display the new OPENAI_API_KEY here and copy it to your clipboard.
|
||||
|
||||
On a remote or headless machine, forward localhost:5000 back to this Codex host before opening the link.
|
||||
@@ -0,0 +1,6 @@
|
||||
---
|
||||
source: tui/src/chatwidget/create_api_key.rs
|
||||
assertion_line: 232
|
||||
expression: render_cell(&cell)
|
||||
---
|
||||
• Created an API key for Default Org / Default Project: sk-proj-123 Copied the full key to your clipboard. Set OPENAI_API_KEY in this Codex session for spawned commands.
|
||||
@@ -0,0 +1,6 @@
|
||||
---
|
||||
source: tui/src/chatwidget/create_api_key.rs
|
||||
assertion_line: 249
|
||||
expression: render_cell(&cell)
|
||||
---
|
||||
• Created an API key for org-default / proj-default: sk-proj-123 Could not copy the key to your clipboard: clipboard unavailable Could not set OPENAI_API_KEY in this Codex session: dependency env unavailable
|
||||
@@ -325,7 +325,7 @@ impl ChatWidget {
|
||||
})
|
||||
}
|
||||
|
||||
fn status_line_cwd(&self) -> &Path {
|
||||
pub(super) fn status_line_cwd(&self) -> &Path {
|
||||
self.current_cwd
|
||||
.as_deref()
|
||||
.unwrap_or(self.config.cwd.as_path())
|
||||
|
||||
@@ -34,6 +34,7 @@ pub enum SlashCommand {
|
||||
Agent,
|
||||
// Undo,
|
||||
Diff,
|
||||
CreateApiKey,
|
||||
Copy,
|
||||
Mention,
|
||||
Status,
|
||||
@@ -82,6 +83,7 @@ impl SlashCommand {
|
||||
// SlashCommand::Undo => "ask Codex to undo a turn",
|
||||
SlashCommand::Quit | SlashCommand::Exit => "exit Codex",
|
||||
SlashCommand::Diff => "show git diff (including untracked files)",
|
||||
SlashCommand::CreateApiKey => "create an API key and copy it to your clipboard",
|
||||
SlashCommand::Copy => "copy the latest Codex output to your clipboard",
|
||||
SlashCommand::Mention => "mention a file",
|
||||
SlashCommand::Skills => "use skills to improve how Codex performs specific tasks",
|
||||
@@ -155,6 +157,7 @@ impl SlashCommand {
|
||||
| SlashCommand::Experimental
|
||||
| SlashCommand::Review
|
||||
| SlashCommand::Plan
|
||||
| SlashCommand::CreateApiKey
|
||||
| SlashCommand::Clear
|
||||
| SlashCommand::Logout
|
||||
| SlashCommand::MemoryDrop
|
||||
@@ -220,4 +223,15 @@ mod tests {
|
||||
fn clean_alias_parses_to_stop_command() {
|
||||
assert_eq!(SlashCommand::from_str("clean"), Ok(SlashCommand::Stop));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_api_key_command_metadata() {
|
||||
assert_eq!(
|
||||
SlashCommand::from_str("create-api-key"),
|
||||
Ok(SlashCommand::CreateApiKey)
|
||||
);
|
||||
assert_eq!(SlashCommand::CreateApiKey.command(), "create-api-key");
|
||||
assert!(!SlashCommand::CreateApiKey.supports_inline_args());
|
||||
assert!(!SlashCommand::CreateApiKey.available_during_task());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user