Fix realtime websocket URL params and conversation_id naming

This commit is contained in:
Ahmed Ibrahim
2026-02-17 18:13:57 -08:00
parent 71b1d9ff0d
commit bd967b3b00
3 changed files with 45 additions and 12 deletions

View File

@@ -248,7 +248,10 @@ impl RealtimeWebsocketClient {
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
ensure_rustls_crypto_provider();
let ws_url = websocket_url_from_api_url(config.api_url.as_str())?;
let ws_url = websocket_url_from_api_url(
config.api_url.as_str(),
self.provider.query_params.as_ref(),
)?;
let mut request = ws_url
.as_str()
@@ -268,7 +271,7 @@ impl RealtimeWebsocketClient {
let (stream, rx_message) = WebsocketPump::new(stream);
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
connection
.send_session_create(config.prompt, config.session_id)
.send_session_create(config.prompt, config.conversation_id)
.await?;
Ok(connection)
}
@@ -293,7 +296,10 @@ fn websocket_config() -> WebSocketConfig {
WebSocketConfig::default()
}
fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
fn websocket_url_from_api_url(
api_url: &str,
query_params: Option<&std::collections::HashMap<String, String>>,
) -> Result<Url, ApiError> {
let mut url = Url::parse(api_url)
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
@@ -315,7 +321,18 @@ fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
scheme => Err(ApiError::Stream(format!(
"unsupported realtime api_url scheme: {scheme}"
))),
}?;
if let Some(params) = query_params
&& !params.is_empty()
{
let mut url_query = url.query_pairs_mut();
for (key, value) in params {
url_query.append_pair(key, value);
}
}
Ok(url)
}
#[cfg(test)]
@@ -420,16 +437,32 @@ mod tests {
#[test]
fn websocket_url_from_http_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url("http://127.0.0.1:8011").expect("build ws url");
let url = websocket_url_from_api_url("http://127.0.0.1:8011", None).expect("build ws url");
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/ws");
}
#[test]
fn websocket_url_from_ws_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url("wss://example.com").expect("build ws url");
let url = websocket_url_from_api_url("wss://example.com", None).expect("build ws url");
assert_eq!(url.as_str(), "wss://example.com/ws");
}
#[test]
fn websocket_url_includes_provider_query_params() {
let mut query_params = HashMap::new();
query_params.insert("api-version".to_string(), "2024-10-01-preview".to_string());
let url = websocket_url_from_api_url("https://example.com/ws", Some(&query_params))
.expect("build ws url");
let api_version = url
.query_pairs()
.find(|(key, _)| key == "api-version")
.map(|(_, value)| value.into_owned());
assert_eq!(url.scheme(), "wss");
assert_eq!(api_version, Some("2024-10-01-preview".to_string()));
}
#[tokio::test]
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
@@ -534,7 +567,7 @@ mod tests {
RealtimeSessionConfig {
api_url: format!("ws://{addr}"),
prompt: "backend prompt".to_string(),
session_id: Some("conv_1".to_string()),
conversation_id: Some("conv_1".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),
@@ -661,7 +694,7 @@ mod tests {
RealtimeSessionConfig {
api_url: format!("ws://{addr}"),
prompt: "backend prompt".to_string(),
session_id: Some("conv_1".to_string()),
conversation_id: Some("conv_1".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),

View File

@@ -7,7 +7,7 @@ use tracing::debug;
pub struct RealtimeSessionConfig {
pub api_url: String,
pub prompt: String,
pub session_id: Option<String>,
pub conversation_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]

View File

@@ -77,7 +77,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
RealtimeSessionConfig {
api_url,
prompt: "backend prompt".to_string(),
session_id: Some("conv_123".to_string()),
conversation_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),
@@ -168,7 +168,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
RealtimeSessionConfig {
api_url,
prompt: "backend prompt".to_string(),
session_id: Some("conv_123".to_string()),
conversation_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),
@@ -230,7 +230,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
RealtimeSessionConfig {
api_url,
prompt: "backend prompt".to_string(),
session_id: Some("conv_123".to_string()),
conversation_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),
@@ -290,7 +290,7 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
RealtimeSessionConfig {
api_url,
prompt: "backend prompt".to_string(),
session_id: Some("conv_123".to_string()),
conversation_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
HeaderMap::new(),