Send message by default mid turn. queue messages by tab (#9077)

https://github.com/user-attachments/assets/03838730-4ddc-44df-a2c7-cb8ecda78660
This commit is contained in:
Ahmed Ibrahim
2026-01-12 23:06:35 -08:00
committed by GitHub
parent e726a82c8a
commit cbca43d57a
24 changed files with 875 additions and 342 deletions

View File

@@ -19,6 +19,7 @@ pub struct StreamingSseChunk {
/// Minimal streaming SSE server for tests that need gated per-chunk delivery.
pub struct StreamingSseServer {
uri: String,
requests: Arc<TokioMutex<Vec<Vec<u8>>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
@@ -28,6 +29,10 @@ impl StreamingSseServer {
&self.uri
}
pub async fn requests(&self) -> Vec<Vec<u8>> {
self.requests.lock().await.clone()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
@@ -61,6 +66,8 @@ pub async fn start_streaming_sse_server(
responses: VecDeque::from(responses),
completions: VecDeque::from(completion_senders),
}));
let requests = Arc::new(TokioMutex::new(Vec::new()));
let requests_for_task = Arc::clone(&requests);
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move {
@@ -70,6 +77,7 @@ pub async fn start_streaming_sse_server(
accept_res = listener.accept() => {
let (mut stream, _) = accept_res.expect("accept streaming SSE connection");
let state = Arc::clone(&state);
let requests = Arc::clone(&requests_for_task);
tokio::spawn(async move {
let (request, body_prefix) = read_http_request(&mut stream).await;
let Some((method, path)) = parse_request_line(&request) else {
@@ -78,7 +86,7 @@ pub async fn start_streaming_sse_server(
};
if method == "GET" && path == "/v1/models" {
if drain_request_body(&mut stream, &request, body_prefix)
if read_request_body(&mut stream, &request, body_prefix)
.await
.is_err()
{
@@ -95,13 +103,16 @@ pub async fn start_streaming_sse_server(
}
if method == "POST" && path == "/v1/responses" {
if drain_request_body(&mut stream, &request, body_prefix)
let body = match read_request_body(&mut stream, &request, body_prefix)
.await
.is_err()
{
let _ = write_http_response(&mut stream, 400, "bad request", "text/plain").await;
return;
}
Ok(body) => body,
Err(_) => {
let _ = write_http_response(&mut stream, 400, "bad request", "text/plain").await;
return;
}
};
requests.lock().await.push(body);
let Some((chunks, completion)) = take_next_stream(&state).await else {
let _ = write_http_response(&mut stream, 500, "no responses queued", "text/plain").await;
return;
@@ -137,6 +148,7 @@ pub async fn start_streaming_sse_server(
(
StreamingSseServer {
uri,
requests,
shutdown: shutdown_tx,
task,
},
@@ -202,13 +214,13 @@ fn content_length(headers: &str) -> Option<usize> {
})
}
async fn drain_request_body(
async fn read_request_body(
stream: &mut tokio::net::TcpStream,
headers: &str,
mut body_prefix: Vec<u8>,
) -> std::io::Result<()> {
) -> std::io::Result<Vec<u8>> {
let Some(content_len) = content_length(headers) else {
return Ok(());
return Ok(body_prefix);
};
if body_prefix.len() > content_len {
@@ -217,12 +229,13 @@ async fn drain_request_body(
let remaining = content_len.saturating_sub(body_prefix.len());
if remaining == 0 {
return Ok(());
return Ok(body_prefix);
}
let mut rest = vec![0u8; remaining];
stream.read_exact(&mut rest).await?;
Ok(())
body_prefix.extend_from_slice(&rest);
Ok(body_prefix)
}
async fn write_sse_headers(stream: &mut tokio::net::TcpStream) -> std::io::Result<()> {