mirror of
https://github.com/openai/codex.git
synced 2026-04-24 06:35:50 +00:00
codex: tighten api provisioning implementation
Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -2240,7 +2240,6 @@ dependencies = [
|
||||
"codex-keyring-store",
|
||||
"codex-protocol",
|
||||
"codex-terminal-detection",
|
||||
"codex-utils-home-dir",
|
||||
"core_test_support",
|
||||
"keyring",
|
||||
"once_cell",
|
||||
|
||||
@@ -17,7 +17,6 @@ codex-config = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-terminal-detection = { workspace = true }
|
||||
codex-utils-home-dir = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
||||
@@ -17,10 +17,10 @@ const PLATFORM_HYDRA_CLIENT_ID: &str = "app_2SKx67EdpoN0G6j64rFvigXD";
|
||||
const PLATFORM_AUDIENCE: &str = "https://api.openai.com/v1";
|
||||
const DEFAULT_API_BASE: &str = "https://api.openai.com";
|
||||
const DEFAULT_CALLBACK_PORT: u16 = 5000;
|
||||
const DEFAULT_CALLBACK_PATH: &str = "/auth/callback";
|
||||
const CALLBACK_PATH: &str = "/auth/callback";
|
||||
const DEFAULT_SCOPE: &str = "openid email profile offline_access";
|
||||
const DEFAULT_APP: &str = "api";
|
||||
const DEFAULT_USER_AGENT: &str = "OpenAI-Onboard-Auth-Script/1.0";
|
||||
const USER_AGENT: &str = "Codex-API-Provision/1.0";
|
||||
const DEFAULT_PROJECT_API_KEY_NAME: &str = "Codex CLI";
|
||||
const DEFAULT_PROJECT_POLL_INTERVAL_SECONDS: u64 = 10;
|
||||
const DEFAULT_PROJECT_POLL_TIMEOUT_SECONDS: u64 = 60;
|
||||
@@ -79,16 +79,12 @@ impl PendingApiProvisioning {
|
||||
self.callback_server.open_browser()
|
||||
}
|
||||
|
||||
pub fn open_browser_or_print(&self) -> bool {
|
||||
self.callback_server.open_browser_or_print()
|
||||
}
|
||||
|
||||
pub async fn finish(self) -> Result<ProvisionedApiKey, HelperError> {
|
||||
pub async fn finish(self) -> Result<ProvisionedApiKey, ApiProvisionError> {
|
||||
let code = self
|
||||
.callback_server
|
||||
.wait_for_code(Duration::from_secs(OAUTH_TIMEOUT_SECONDS))
|
||||
.await
|
||||
.map_err(|err| HelperError::message(err.to_string()))?;
|
||||
.map_err(|err| ApiProvisionError::message(err.to_string()))?;
|
||||
provision_from_authorization_code(
|
||||
&self.client,
|
||||
&self.options,
|
||||
@@ -102,30 +98,28 @@ impl PendingApiProvisioning {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ProvisionedApiKey {
|
||||
pub sensitive_id: String,
|
||||
pub organization_id: String,
|
||||
pub organization_title: Option<String>,
|
||||
pub default_project_id: String,
|
||||
pub default_project_title: Option<String>,
|
||||
pub project_api_key: String,
|
||||
pub access_token: String,
|
||||
}
|
||||
|
||||
pub fn start_api_provisioning(
|
||||
options: ApiProvisionOptions,
|
||||
) -> Result<PendingApiProvisioning, HelperError> {
|
||||
) -> Result<PendingApiProvisioning, ApiProvisionError> {
|
||||
validate_api_provision_options(&options)?;
|
||||
let client = build_http_client()?;
|
||||
let callback_server = start_authorization_code_server(
|
||||
options.callback_port,
|
||||
DEFAULT_CALLBACK_PATH,
|
||||
CALLBACK_PATH,
|
||||
/*force_state*/ None,
|
||||
|redirect_uri, pkce, state| {
|
||||
build_authorize_url(&options, redirect_uri, pkce, state)
|
||||
.map_err(|err| std::io::Error::other(err.to_string()))
|
||||
},
|
||||
)
|
||||
.map_err(|err| HelperError::message(err.to_string()))?;
|
||||
.map_err(|err| ApiProvisionError::message(err.to_string()))?;
|
||||
let redirect_uri = callback_server.redirect_uri.clone();
|
||||
|
||||
Ok(PendingApiProvisioning {
|
||||
@@ -137,19 +131,19 @@ pub fn start_api_provisioning(
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_api_provision_options(options: &ApiProvisionOptions) -> Result<(), HelperError> {
|
||||
fn validate_api_provision_options(options: &ApiProvisionOptions) -> Result<(), ApiProvisionError> {
|
||||
if options.project_poll_interval_seconds == 0 {
|
||||
return Err(HelperError::message(
|
||||
return Err(ApiProvisionError::message(
|
||||
"project_poll_interval_seconds must be greater than 0.".to_string(),
|
||||
));
|
||||
}
|
||||
if options.project_poll_timeout_seconds == 0 {
|
||||
return Err(HelperError::message(
|
||||
return Err(ApiProvisionError::message(
|
||||
"project_poll_timeout_seconds must be greater than 0.".to_string(),
|
||||
));
|
||||
}
|
||||
if options.api_key_name.trim().is_empty() {
|
||||
return Err(HelperError::message(
|
||||
return Err(ApiProvisionError::message(
|
||||
"api_key_name must not be empty.".to_string(),
|
||||
));
|
||||
}
|
||||
@@ -161,12 +155,12 @@ fn build_authorize_url(
|
||||
redirect_uri: &str,
|
||||
pkce: &PkceCodes,
|
||||
state: &str,
|
||||
) -> Result<String, HelperError> {
|
||||
) -> Result<String, ApiProvisionError> {
|
||||
let mut url = Url::parse(&format!(
|
||||
"{}/oauth/authorize",
|
||||
options.issuer.trim_end_matches('/')
|
||||
))
|
||||
.map_err(|err| HelperError::message(format!("invalid issuer URL: {err}")))?;
|
||||
.map_err(|err| ApiProvisionError::message(format!("invalid issuer URL: {err}")))?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("audience", &options.audience)
|
||||
.append_pair("client_id", &options.client_id)
|
||||
@@ -179,11 +173,11 @@ fn build_authorize_url(
|
||||
Ok(url.to_string())
|
||||
}
|
||||
|
||||
fn build_http_client() -> Result<Client, HelperError> {
|
||||
fn build_http_client() -> Result<Client, ApiProvisionError> {
|
||||
build_reqwest_client_with_custom_ca(
|
||||
reqwest::Client::builder().timeout(Duration::from_secs(HTTP_TIMEOUT_SECONDS)),
|
||||
)
|
||||
.map_err(|err| HelperError::message(format!("failed to build HTTP client: {err}")))
|
||||
.map_err(|err| ApiProvisionError::message(format!("failed to build HTTP client: {err}")))
|
||||
}
|
||||
|
||||
async fn provision_from_authorization_code(
|
||||
@@ -192,7 +186,7 @@ async fn provision_from_authorization_code(
|
||||
redirect_uri: &str,
|
||||
code_verifier: &str,
|
||||
code: &str,
|
||||
) -> Result<ProvisionedApiKey, HelperError> {
|
||||
) -> Result<ProvisionedApiKey, ApiProvisionError> {
|
||||
let tokens = exchange_authorization_code_for_tokens(
|
||||
client,
|
||||
&options.issuer,
|
||||
@@ -229,13 +223,11 @@ async fn provision_from_authorization_code(
|
||||
.sensitive_id;
|
||||
|
||||
Ok(ProvisionedApiKey {
|
||||
sensitive_id: login.user.session.sensitive_id,
|
||||
organization_id: target.organization_id,
|
||||
organization_title: target.organization_title,
|
||||
default_project_id: target.project_id,
|
||||
default_project_title: target.project_title,
|
||||
project_api_key: api_key,
|
||||
access_token: tokens.access_token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -246,13 +238,13 @@ async fn exchange_authorization_code_for_tokens(
|
||||
redirect_uri: &str,
|
||||
code_verifier: &str,
|
||||
code: &str,
|
||||
) -> Result<OAuthTokens, HelperError> {
|
||||
) -> Result<OAuthTokens, ApiProvisionError> {
|
||||
let url = format!("{}/oauth/token", issuer.trim_end_matches('/'));
|
||||
execute_json(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.json(&serde_json::json!({
|
||||
"client_id": client_id,
|
||||
"code_verifier": code_verifier,
|
||||
@@ -271,7 +263,7 @@ async fn onboarding_login(
|
||||
api_base: &str,
|
||||
app: &str,
|
||||
access_token: &str,
|
||||
) -> Result<OnboardingLoginResponse, HelperError> {
|
||||
) -> Result<OnboardingLoginResponse, ApiProvisionError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/onboarding/login",
|
||||
api_base.trim_end_matches('/')
|
||||
@@ -280,7 +272,7 @@ async fn onboarding_login(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(access_token)
|
||||
.json(&serde_json::json!({ "app": app })),
|
||||
"POST",
|
||||
@@ -293,13 +285,13 @@ async fn list_organizations(
|
||||
client: &Client,
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
) -> Result<Vec<Organization>, HelperError> {
|
||||
) -> Result<Vec<Organization>, ApiProvisionError> {
|
||||
let url = format!("{}/v1/organizations", api_base.trim_end_matches('/'));
|
||||
let response: DataList<Organization> = execute_json(
|
||||
client
|
||||
.request(Method::GET, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(session_key),
|
||||
"GET",
|
||||
&url,
|
||||
@@ -313,7 +305,7 @@ async fn list_projects(
|
||||
api_base: &str,
|
||||
session_key: &str,
|
||||
organization_id: &str,
|
||||
) -> Result<Vec<Project>, HelperError> {
|
||||
) -> Result<Vec<Project>, ApiProvisionError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/organizations/{}/projects?detail=basic&limit=100",
|
||||
api_base.trim_end_matches('/'),
|
||||
@@ -323,7 +315,7 @@ async fn list_projects(
|
||||
client
|
||||
.request(Method::GET, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.header("openai-organization", organization_id)
|
||||
.bearer_auth(session_key),
|
||||
"GET",
|
||||
@@ -339,7 +331,7 @@ async fn wait_for_default_project(
|
||||
session_key: &str,
|
||||
poll_interval_seconds: u64,
|
||||
timeout_seconds: u64,
|
||||
) -> Result<ProvisioningTarget, HelperError> {
|
||||
) -> Result<ProvisioningTarget, ApiProvisionError> {
|
||||
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_seconds);
|
||||
loop {
|
||||
let organizations = list_organizations(client, api_base, session_key).await?;
|
||||
@@ -363,7 +355,7 @@ async fn wait_for_default_project(
|
||||
};
|
||||
|
||||
if std::time::Instant::now() >= deadline {
|
||||
return Err(HelperError::message(format!(
|
||||
return Err(ApiProvisionError::message(format!(
|
||||
"Timed out waiting for an organization and default project. Last observed state: {last_state}"
|
||||
)));
|
||||
}
|
||||
@@ -371,7 +363,7 @@ async fn wait_for_default_project(
|
||||
.saturating_duration_since(std::time::Instant::now())
|
||||
.as_secs();
|
||||
let sleep_seconds = poll_interval_seconds.min(remaining_seconds.max(1));
|
||||
std::thread::sleep(Duration::from_secs(sleep_seconds));
|
||||
tokio::time::sleep(Duration::from_secs(sleep_seconds)).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,7 +389,7 @@ async fn create_project_api_key(
|
||||
session_key: &str,
|
||||
target: &ProvisioningTarget,
|
||||
key_name: &str,
|
||||
) -> Result<CreateApiKeyResponse, HelperError> {
|
||||
) -> Result<CreateApiKeyResponse, ApiProvisionError> {
|
||||
let url = format!(
|
||||
"{}/dashboard/organizations/{}/projects/{}/api_keys",
|
||||
api_base.trim_end_matches('/'),
|
||||
@@ -408,7 +400,7 @@ async fn create_project_api_key(
|
||||
client
|
||||
.request(Method::POST, &url)
|
||||
.header(reqwest::header::ACCEPT, "application/json")
|
||||
.header(reqwest::header::USER_AGENT, DEFAULT_USER_AGENT)
|
||||
.header(reqwest::header::USER_AGENT, USER_AGENT)
|
||||
.bearer_auth(session_key)
|
||||
.json(&serde_json::json!({
|
||||
"action": "create",
|
||||
@@ -424,26 +416,26 @@ async fn execute_json<T>(
|
||||
request: reqwest::RequestBuilder,
|
||||
method: &str,
|
||||
url: &str,
|
||||
) -> Result<T, HelperError>
|
||||
) -> Result<T, ApiProvisionError>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| HelperError::message(format!("Network error calling {url}: {err}")))?;
|
||||
.map_err(|err| ApiProvisionError::message(format!("Network error calling {url}: {err}")))?;
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
HelperError::message(format!("Failed reading response from {url}: {err}"))
|
||||
ApiProvisionError::message(format!("Failed reading response from {url}: {err}"))
|
||||
})?;
|
||||
if !status.is_success() {
|
||||
return Err(HelperError::api(
|
||||
return Err(ApiProvisionError::api(
|
||||
format!("{method} {url} failed with HTTP {status}"),
|
||||
String::from_utf8_lossy(&body).into_owned(),
|
||||
));
|
||||
}
|
||||
serde_json::from_slice(&body)
|
||||
.map_err(|err| HelperError::message(format!("{url} returned invalid JSON: {err}")))
|
||||
.map_err(|err| ApiProvisionError::message(format!("{url} returned invalid JSON: {err}")))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -512,38 +504,29 @@ struct CreatedApiKey {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HelperError {
|
||||
pub struct ApiProvisionError {
|
||||
message: String,
|
||||
body: Option<String>,
|
||||
}
|
||||
|
||||
impl HelperError {
|
||||
impl ApiProvisionError {
|
||||
fn message(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
body: None,
|
||||
}
|
||||
Self { message }
|
||||
}
|
||||
|
||||
fn api(message: String, body: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
body: Some(body),
|
||||
message: format!("{message}: {body}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn body(&self) -> Option<&str> {
|
||||
self.body.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HelperError {
|
||||
impl std::fmt::Display for ApiProvisionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for HelperError {}
|
||||
impl std::error::Error for ApiProvisionError {}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "api_provision_tests.rs"]
|
||||
|
||||
@@ -154,13 +154,11 @@ async fn provision_from_authorization_code_provisions_api_key() {
|
||||
assert_eq!(
|
||||
output,
|
||||
ProvisionedApiKey {
|
||||
sensitive_id: "session-123".to_string(),
|
||||
organization_id: "org-default".to_string(),
|
||||
organization_title: Some("Default Org".to_string()),
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: Some("Default Project".to_string()),
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
access_token: "oauth-access-123".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ pub use server::ServerOptions;
|
||||
pub use server::ShutdownHandle;
|
||||
pub use server::run_login_server;
|
||||
|
||||
pub use api_provision::ApiProvisionError;
|
||||
pub use api_provision::ApiProvisionOptions;
|
||||
pub use api_provision::HelperError as ApiProvisionError;
|
||||
pub use api_provision::PendingApiProvisioning;
|
||||
pub use api_provision::ProvisionedApiKey;
|
||||
pub use api_provision::start_api_provisioning;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
//! This module therefore keeps the user-facing error path and the structured-log path separate.
|
||||
//! Returned `io::Error` values still carry the detail needed by CLI/browser callers, while
|
||||
//! structured logs only emit explicitly reviewed fields plus redacted URL/error values.
|
||||
use std::future::Future;
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
@@ -125,22 +126,6 @@ impl AuthorizationCodeServer {
|
||||
webbrowser::open(&self.auth_url).is_ok()
|
||||
}
|
||||
|
||||
pub fn open_browser_or_print(&self) -> bool {
|
||||
let opened = self.open_browser();
|
||||
if opened {
|
||||
eprintln!(
|
||||
"Starting local auth callback server.\nIf your browser did not open, navigate to this URL to continue:\n\n{}",
|
||||
self.auth_url
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
"Starting local auth callback server.\nOpen this URL in your browser to continue:\n\n{}",
|
||||
self.auth_url
|
||||
);
|
||||
}
|
||||
opened
|
||||
}
|
||||
|
||||
pub fn code_verifier(&self) -> &str {
|
||||
&self.code_verifier
|
||||
}
|
||||
@@ -197,67 +182,19 @@ where
|
||||
let state = force_state.unwrap_or_else(generate_state);
|
||||
let callback_path = callback_path.to_string();
|
||||
|
||||
let server = bind_server(port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"Unable to determine the server port",
|
||||
));
|
||||
}
|
||||
};
|
||||
let server = Arc::new(server);
|
||||
|
||||
let (server, actual_port, rx) = bind_server_with_request_channel(port)?;
|
||||
let redirect_uri = format!("http://localhost:{actual_port}{callback_path}");
|
||||
let auth_url = auth_url_builder(&redirect_uri, &pkce, &state)?;
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
match tx.blocking_send(request) {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
eprintln!("Failed to send request to channel: {error}");
|
||||
return Err(io::Error::other("Failed to send request to channel"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
break Err(io::Error::other("Authentication was not completed"));
|
||||
}
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
break Err(io::Error::other("Authentication was not completed"));
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response =
|
||||
process_authorization_code_request(&url_raw, &callback_path, &state);
|
||||
|
||||
if let Some(result) = respond_to_request(req, response).await {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
server.unblock();
|
||||
result
|
||||
})
|
||||
};
|
||||
let (server_handle, shutdown_handle) = spawn_callback_server_loop(
|
||||
server,
|
||||
rx,
|
||||
"Authentication was not completed",
|
||||
move |url_raw| {
|
||||
let callback_path = callback_path.clone();
|
||||
let state = state.clone();
|
||||
async move { process_authorization_code_request(&url_raw, &callback_path, &state) }
|
||||
},
|
||||
);
|
||||
|
||||
Ok(AuthorizationCodeServer {
|
||||
auth_url,
|
||||
@@ -265,7 +202,7 @@ where
|
||||
redirect_uri,
|
||||
code_verifier: pkce.code_verifier,
|
||||
server_handle,
|
||||
shutdown_handle: ShutdownHandle { shutdown_notify },
|
||||
shutdown_handle,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -274,18 +211,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
let pkce = generate_pkce();
|
||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||
|
||||
let server = bind_server(opts.port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"Unable to determine the server port",
|
||||
));
|
||||
}
|
||||
};
|
||||
let server = Arc::new(server);
|
||||
|
||||
let (server, actual_port, rx) = bind_server_with_request_channel(opts.port)?;
|
||||
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
|
||||
let auth_url = build_authorize_url(
|
||||
&opts.issuer,
|
||||
@@ -299,63 +225,22 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
|
||||
if opts.open_browser {
|
||||
let _ = webbrowser::open(&auth_url);
|
||||
}
|
||||
|
||||
// Map blocking reads from server.recv() to an async channel.
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
match tx.blocking_send(request) {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
eprintln!("Failed to send request to channel: {error}");
|
||||
return Err(io::Error::other("Failed to send request to channel"));
|
||||
}
|
||||
}
|
||||
let (server_handle, shutdown_handle) =
|
||||
spawn_callback_server_loop(server, rx, "Login was not completed", move |url_raw| {
|
||||
let redirect_uri = redirect_uri.clone();
|
||||
let state = state.clone();
|
||||
let opts = opts.clone();
|
||||
let pkce = pkce.clone();
|
||||
async move {
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
let server = server;
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
break Err(io::Error::other("Login was not completed"));
|
||||
}
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
break Err(io::Error::other("Login was not completed"));
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
if let Some(result) = respond_to_request(req, response).await {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the server is unblocked so the thread dedicated to
|
||||
// running `server.recv()` in a loop exits cleanly.
|
||||
server.unblock();
|
||||
result
|
||||
})
|
||||
};
|
||||
});
|
||||
|
||||
Ok(LoginServer {
|
||||
auth_url,
|
||||
actual_port,
|
||||
server_handle,
|
||||
shutdown_handle: ShutdownHandle { shutdown_notify },
|
||||
shutdown_handle,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -626,6 +511,87 @@ fn html_headers() -> Vec<Header> {
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_server_with_request_channel(
|
||||
port: u16,
|
||||
) -> io::Result<(Arc<Server>, u16, tokio::sync::mpsc::Receiver<Request>)> {
|
||||
let server = bind_server(port)?;
|
||||
let actual_port = match server.server_addr().to_ip() {
|
||||
Some(addr) => addr.port(),
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AddrInUse,
|
||||
"Unable to determine the server port",
|
||||
));
|
||||
}
|
||||
};
|
||||
let server = Arc::new(server);
|
||||
|
||||
// Map blocking reads from server.recv() to an async channel.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
match tx.blocking_send(request) {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
eprintln!("Failed to send request to channel: {error}");
|
||||
return Err(io::Error::other("Failed to send request to channel"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
Ok((server, actual_port, rx))
|
||||
}
|
||||
|
||||
fn spawn_callback_server_loop<T, F, Fut>(
|
||||
server: Arc<Server>,
|
||||
mut rx: tokio::sync::mpsc::Receiver<Request>,
|
||||
incomplete_message: &'static str,
|
||||
mut process_request: F,
|
||||
) -> (tokio::task::JoinHandle<io::Result<T>>, ShutdownHandle)
|
||||
where
|
||||
T: Send + 'static,
|
||||
F: FnMut(String) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = HandledRequest<T>> + Send + 'static,
|
||||
{
|
||||
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
|
||||
let server_handle = {
|
||||
let shutdown_notify = shutdown_notify.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
break Err(io::Error::other(incomplete_message));
|
||||
}
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
break Err(io::Error::other(incomplete_message));
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response = process_request(url_raw).await;
|
||||
|
||||
if let Some(result) = respond_to_request(req, response).await {
|
||||
break result;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure that the server is unblocked so the thread dedicated to
|
||||
// running `server.recv()` in a loop exits cleanly.
|
||||
server.unblock();
|
||||
result
|
||||
})
|
||||
};
|
||||
|
||||
(server_handle, ShutdownHandle { shutdown_notify })
|
||||
}
|
||||
|
||||
async fn respond_to_request<T>(req: Request, response: HandledRequest<T>) -> Option<io::Result<T>> {
|
||||
match response {
|
||||
HandledRequest::Response(response) => {
|
||||
|
||||
@@ -261,13 +261,11 @@ mod tests {
|
||||
fn success_cell_snapshot() {
|
||||
let cell = success_cell(
|
||||
&ProvisionedApiKey {
|
||||
sensitive_id: "session-123".to_string(),
|
||||
organization_id: "org-default".to_string(),
|
||||
organization_title: Some("Default Org".to_string()),
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: Some("Default Project".to_string()),
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
access_token: "oauth-access-123".to_string(),
|
||||
},
|
||||
Path::new("/tmp/workspace/.env.local"),
|
||||
LiveApplyOutcome::Applied,
|
||||
@@ -280,13 +278,11 @@ mod tests {
|
||||
fn success_cell_snapshot_when_live_apply_is_skipped() {
|
||||
let cell = success_cell(
|
||||
&ProvisionedApiKey {
|
||||
sensitive_id: "session-123".to_string(),
|
||||
organization_id: "org-default".to_string(),
|
||||
organization_title: None,
|
||||
default_project_id: "proj-default".to_string(),
|
||||
default_project_title: None,
|
||||
project_api_key: "sk-proj-123".to_string(),
|
||||
access_token: "oauth-access-123".to_string(),
|
||||
},
|
||||
Path::new("/tmp/workspace/.env.local"),
|
||||
LiveApplyOutcome::Skipped(
|
||||
|
||||
Reference in New Issue
Block a user