Tighten file upload auth and failure checks

This commit is contained in:
Liang-Ting Jiang
2026-04-24 17:01:02 -07:00
parent 0307d8f791
commit 0b2f823e99
2 changed files with 96 additions and 9 deletions

View File

@@ -401,7 +401,10 @@ fn non_empty_string(value: Option<String>) -> Option<String> {
fn is_process_upload_stream_error_event(event: &str) -> bool {
let event_tail = event.rsplit(['.', '_']).next().unwrap_or(event);
matches!(event_tail, "error" | "cancelled" | "unknown")
matches!(
event_tail,
"error" | "failed" | "cancelled" | "canceled" | "unknown"
)
}
fn authorized_request(
@@ -489,10 +492,14 @@ fn should_attach_auth_to_openai_file_url(download_url: &Url, base_url: &str) ->
let Ok(base_url) = Url::parse(base_url) else {
return false;
};
match (download_url.host_str(), base_url.host_str()) {
(Some(download_host), Some(base_host)) => download_host.eq_ignore_ascii_case(base_host),
_ => false,
}
download_url
.scheme()
.eq_ignore_ascii_case(base_url.scheme())
&& download_url.port_or_known_default() == base_url.port_or_known_default()
&& match (download_url.host_str(), base_url.host_str()) {
(Some(download_host), Some(base_host)) => download_host.eq_ignore_ascii_case(base_host),
_ => false,
}
}
fn build_reqwest_client() -> reqwest::Client {
@@ -686,4 +693,84 @@ mod tests {
assert_eq!(uploaded.mime_type, Some("text/plain".to_string()));
assert_eq!(uploaded.library_file_id, Some("library_123".to_string()));
}
#[test]
fn should_attach_auth_only_for_same_origin() {
let base_url = "https://chatgpt.com/backend-api";
assert!(should_attach_auth_to_openai_file_url(
&Url::parse("https://chatgpt.com/files/file_123/content").expect("valid url"),
base_url,
));
assert!(!should_attach_auth_to_openai_file_url(
&Url::parse("http://chatgpt.com/files/file_123/content").expect("valid url"),
base_url,
));
assert!(!should_attach_auth_to_openai_file_url(
&Url::parse("https://chatgpt.com:8443/files/file_123/content").expect("valid url"),
base_url,
));
}
#[tokio::test]
async fn upload_local_file_fails_when_process_upload_stream_reports_failed_event() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/backend-api/files"))
.and(header("chatgpt-account-id", "account_id"))
.and(body_json(serde_json::json!({
"file_name": "hello.txt",
"file_size": 5,
"use_case": "codex",
"store_in_library": true,
})))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"file_id": "file_123", "upload_url": format!("{}/upload/file_123", server.uri())})),
)
.mount(&server)
.await;
Mock::given(method("PUT"))
.and(path("/upload/file_123"))
.and(header("content-length", "5"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/backend-api/files/process_upload_stream"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(
concat!(
"{\"file_id\":\"file_123\",\"event\":\"indexing.completed\",\"message\":\"\",",
"\"extra\":{\"metadata_object_id\":\"library_123\"}}\n",
"{\"file_id\":\"file_123\",\"event\":\"indexing.failed\",",
"\"message\":\"indexing failed\",\"extra\":null}\n",
)
.as_bytes()
.to_vec(),
))
.mount(&server)
.await;
let base_url = base_url_for(&server);
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("hello.txt");
tokio::fs::write(&path, b"hello").await.expect("write file");
let error = upload_local_file(
&base_url,
&chatgpt_auth(),
&path,
&OpenAiFileUploadOptions {
store_in_library: true,
},
)
.await
.expect_err("upload should fail");
assert!(matches!(
error,
OpenAiFileError::UploadFailed { ref file_id, ref message }
if file_id == "file_123" && message == "indexing failed"
));
}
}

View File

@@ -148,10 +148,10 @@ async fn build_uploaded_local_argument_value(
"uri": uploaded.uri,
"file_size_bytes": uploaded.file_size_bytes,
});
if uploaded.library_file_id.is_none() {
if let Some(uploaded_object) = uploaded_value.as_object_mut() {
uploaded_object.remove("library_file_id");
}
if uploaded.library_file_id.is_none()
&& let Some(uploaded_object) = uploaded_value.as_object_mut()
{
uploaded_object.remove("library_file_id");
}
Ok(uploaded_value)
}