Compare commits

..

3 Commits

Author SHA1 Message Date
Celia Chen
49710533fe Merge branch 'main' into dev/cc/exp 2026-03-23 16:11:59 -07:00
Celia Chen
2e30bbbe02 Merge branch 'main' into dev/cc/exp 2026-03-23 14:18:10 -07:00
celia-oai
8e5aeddda4 changes 2026-03-23 12:35:02 -07:00
4 changed files with 89 additions and 61 deletions

View File

@@ -294,13 +294,14 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let stale_refresh = Utc::now() - Duration::days(9);
let fresh_access_token = access_token_with_expiration(Utc::now() + Duration::hours(1));
let initial_tokens = build_tokens(&fresh_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
last_refresh: Some(stale_refresh),
};
ctx.write_auth(&initial_auth)?;
@@ -325,7 +326,7 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
async fn refreshes_token_when_access_token_is_expired() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
@@ -340,13 +341,14 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let stale_refresh = Utc::now() - Duration::days(9);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let fresh_refresh = Utc::now() - Duration::days(1);
let expired_access_token = access_token_with_expiration(Utc::now() - Duration::hours(1));
let initial_tokens = build_tokens(&expired_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(stale_refresh),
last_refresh: Some(fresh_refresh),
};
ctx.write_auth(&initial_auth)?;
@@ -373,7 +375,7 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
.as_ref()
.context("last_refresh should be recorded")?;
assert!(
*refreshed_at >= stale_refresh,
*refreshed_at >= fresh_refresh,
"last_refresh should advance"
);
@@ -867,7 +869,7 @@ impl Drop for EnvGuard {
}
}
fn minimal_jwt() -> String {
fn jwt_with_payload(payload: serde_json::Value) -> String {
#[derive(Serialize)]
struct Header {
alg: &'static str,
@@ -878,7 +880,6 @@ fn minimal_jwt() -> String {
alg: "none",
typ: "JWT",
};
let payload = json!({ "sub": "user-123" });
fn b64(data: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
@@ -898,6 +899,14 @@ fn minimal_jwt() -> String {
format!("{header_b64}.{payload_b64}.{signature_b64}")
}
fn minimal_jwt() -> String {
jwt_with_payload(json!({ "sub": "user-123" }))
}
fn access_token_with_expiration(expires_at: chrono::DateTime<Utc>) -> String {
jwt_with_payload(json!({ "sub": "user-123", "exp": expires_at.timestamp() }))
}
fn build_tokens(access_token: &str, refresh_token: &str) -> TokenData {
let id_token = IdTokenInfo {
raw_jwt: minimal_jwt(),

View File

@@ -28,6 +28,7 @@ use crate::token_data::KnownPlan as InternalKnownPlan;
use crate::token_data::PlanType as InternalPlanType;
use crate::token_data::TokenData;
use crate::token_data::parse_chatgpt_jwt_claims;
use crate::token_data::parse_jwt_expiration;
use codex_client::CodexHttpClient;
use codex_protocol::account::PlanType as AccountPlanType;
use serde_json::Value;
@@ -69,7 +70,6 @@ impl PartialEq for CodexAuth {
}
}
// TODO(pakrym): use token exp field to check for expiration instead
const TOKEN_REFRESH_INTERVAL: i64 = 8;
const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again.";
@@ -1333,6 +1333,11 @@ impl AuthManager {
Some(auth_dot_json) => auth_dot_json,
None => return false,
};
if let Some(tokens) = auth_dot_json.tokens.as_ref()
&& let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token)
{
return expires_at <= Utc::now();
}
let last_refresh = match auth_dot_json.last_refresh {
Some(last_refresh) => last_refresh,
None => return false,

View File

@@ -1,6 +1,9 @@
use base64::Engine;
use chrono::DateTime;
use chrono::Utc;
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use thiserror::Error;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
@@ -117,6 +120,12 @@ struct AuthClaims {
chatgpt_account_id: Option<String>,
}
#[derive(Deserialize)]
struct StandardJwtClaims {
#[serde(default)]
exp: Option<i64>,
}
#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
@@ -127,7 +136,7 @@ pub enum IdTokenInfoError {
Json(#[from] serde_json::Error),
}
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
fn decode_jwt_payload<T: DeserializeOwned>(jwt: &str) -> Result<T, IdTokenInfoError> {
// JWT format: header.payload.signature
let mut parts = jwt.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
@@ -136,7 +145,19 @@ pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoErr
};
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
let claims = serde_json::from_slice(&payload_bytes)?;
Ok(claims)
}
pub fn parse_jwt_expiration(jwt: &str) -> Result<Option<DateTime<Utc>>, IdTokenInfoError> {
let claims: StandardJwtClaims = decode_jwt_payload(jwt)?;
Ok(claims
.exp
.and_then(|exp| DateTime::<Utc>::from_timestamp(exp, 0)))
}
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
let claims: IdClaims = decode_jwt_payload(jwt)?;
let email = claims
.email
.or_else(|| claims.profile.and_then(|profile| profile.email));

View File

@@ -1,9 +1,10 @@
use super::*;
use chrono::TimeZone;
use chrono::Utc;
use pretty_assertions::assert_eq;
use serde::Serialize;
#[test]
fn id_token_info_parses_email_and_plan() {
fn fake_jwt(payload: serde_json::Value) -> String {
#[derive(Serialize)]
struct Header {
alg: &'static str,
@@ -13,12 +14,6 @@ fn id_token_info_parses_email_and_plan() {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
@@ -27,7 +22,17 @@ fn id_token_info_parses_email_and_plan() {
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
format!("{header_b64}.{payload_b64}.{signature_b64}")
}
#[test]
fn id_token_info_parses_email_and_plan() {
let fake_jwt = fake_jwt(serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
}));
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
@@ -36,30 +41,12 @@ fn id_token_info_parses_email_and_plan() {
#[test]
fn id_token_info_parses_go_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
let fake_jwt = fake_jwt(serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "go"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
}));
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
@@ -68,31 +55,37 @@ fn id_token_info_parses_go_plan() {
#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}
#[test]
fn jwt_expiration_parses_exp_claim() {
let fake_jwt = fake_jwt(serde_json::json!({
"exp": 1_700_000_000_i64,
}));
let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
assert_eq!(expires_at, Utc.timestamp_opt(1_700_000_000, 0).single());
}
#[test]
fn jwt_expiration_handles_missing_exp() {
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));
let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
assert_eq!(expires_at, None);
}
#[test]
fn jwt_expiration_rejects_malformed_jwt() {
let err = parse_jwt_expiration("not-a-jwt").expect_err("should fail");
assert_eq!(err.to_string(), "invalid ID token format");
}
#[test]
fn workspace_account_detection_matches_workspace_plans() {
let workspace = IdTokenInfo {