mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
Fix realtime websocket URL params and conversation_id naming
This commit is contained in:
@@ -248,7 +248,10 @@ impl RealtimeWebsocketClient {
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
ensure_rustls_crypto_provider();
|
||||
let ws_url = websocket_url_from_api_url(config.api_url.as_str())?;
|
||||
let ws_url = websocket_url_from_api_url(
|
||||
config.api_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
)?;
|
||||
|
||||
let mut request = ws_url
|
||||
.as_str()
|
||||
@@ -268,7 +271,7 @@ impl RealtimeWebsocketClient {
|
||||
let (stream, rx_message) = WebsocketPump::new(stream);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
|
||||
connection
|
||||
.send_session_create(config.prompt, config.session_id)
|
||||
.send_session_create(config.prompt, config.conversation_id)
|
||||
.await?;
|
||||
Ok(connection)
|
||||
}
|
||||
@@ -293,7 +296,10 @@ fn websocket_config() -> WebSocketConfig {
|
||||
WebSocketConfig::default()
|
||||
}
|
||||
|
||||
fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
|
||||
fn websocket_url_from_api_url(
|
||||
api_url: &str,
|
||||
query_params: Option<&std::collections::HashMap<String, String>>,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = Url::parse(api_url)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
|
||||
|
||||
@@ -315,7 +321,18 @@ fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
|
||||
scheme => Err(ApiError::Stream(format!(
|
||||
"unsupported realtime api_url scheme: {scheme}"
|
||||
))),
|
||||
}?;
|
||||
|
||||
if let Some(params) = query_params
|
||||
&& !params.is_empty()
|
||||
{
|
||||
let mut url_query = url.query_pairs_mut();
|
||||
for (key, value) in params {
|
||||
url_query.append_pair(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -420,16 +437,32 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_http_base_defaults_to_ws_path() {
|
||||
let url = websocket_url_from_api_url("http://127.0.0.1:8011").expect("build ws url");
|
||||
let url = websocket_url_from_api_url("http://127.0.0.1:8011", None).expect("build ws url");
|
||||
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/ws");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_ws_base_defaults_to_ws_path() {
|
||||
let url = websocket_url_from_api_url("wss://example.com").expect("build ws url");
|
||||
let url = websocket_url_from_api_url("wss://example.com", None).expect("build ws url");
|
||||
assert_eq!(url.as_str(), "wss://example.com/ws");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_includes_provider_query_params() {
|
||||
let mut query_params = HashMap::new();
|
||||
query_params.insert("api-version".to_string(), "2024-10-01-preview".to_string());
|
||||
|
||||
let url = websocket_url_from_api_url("https://example.com/ws", Some(&query_params))
|
||||
.expect("build ws url");
|
||||
let api_version = url
|
||||
.query_pairs()
|
||||
.find(|(key, _)| key == "api-version")
|
||||
.map(|(_, value)| value.into_owned());
|
||||
|
||||
assert_eq!(url.scheme(), "wss");
|
||||
assert_eq!(api_version, Some("2024-10-01-preview".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -534,7 +567,7 @@ mod tests {
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
conversation_id: Some("conv_1".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -661,7 +694,7 @@ mod tests {
|
||||
RealtimeSessionConfig {
|
||||
api_url: format!("ws://{addr}"),
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
conversation_id: Some("conv_1".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -7,7 +7,7 @@ use tracing::debug;
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub api_url: String,
|
||||
pub prompt: String,
|
||||
pub session_id: Option<String>,
|
||||
pub conversation_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
||||
@@ -77,7 +77,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
|
||||
RealtimeSessionConfig {
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
conversation_id: Some("conv_123".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -168,7 +168,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
|
||||
RealtimeSessionConfig {
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
conversation_id: Some("conv_123".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -230,7 +230,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
|
||||
RealtimeSessionConfig {
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
conversation_id: Some("conv_123".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -290,7 +290,7 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
|
||||
RealtimeSessionConfig {
|
||||
api_url,
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: Some("conv_123".to_string()),
|
||||
conversation_id: Some("conv_123".to_string()),
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
Reference in New Issue
Block a user