diff --git a/codex-rs/exec-server/src/posix/escalate_server.rs b/codex-rs/exec-server/src/posix/escalate_server.rs index b71142d5b1..72934607a3 100644 --- a/codex-rs/exec-server/src/posix/escalate_server.rs +++ b/codex-rs/exec-server/src/posix/escalate_server.rs @@ -258,12 +258,18 @@ mod tests { }), )); + let mut env = HashMap::new(); + for i in 0..10 { + let value = "A".repeat(1024); + env.insert(format!("CODEX_TEST_VAR{i}"), value); + } + client .send(EscalateRequest { file: PathBuf::from("/bin/echo"), argv: vec!["echo".to_string()], workdir: PathBuf::from("/tmp"), - env: HashMap::new(), + env, }) .await?; diff --git a/codex-rs/exec-server/src/posix/socket.rs b/codex-rs/exec-server/src/posix/socket.rs index 92c93dcc7d..35292367a6 100644 --- a/codex-rs/exec-server/src/posix/socket.rs +++ b/codex-rs/exec-server/src/posix/socket.rs @@ -171,42 +171,24 @@ async fn read_frame_payload( unreachable!("loop exits only after returning payload") } -fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { - if fds.len() > MAX_FDS_PER_MESSAGE { +fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { + let control = make_control_message(fds)?; + let payload = [IoSlice::new(data)]; + let msg = if control.is_empty() { + MsgHdr::new().with_buffers(&payload) + } else { + MsgHdr::new().with_buffers(&payload).with_control(&control) + }; + let written = socket.sendmsg(&msg, 0)?; + if written != data.len() { return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - format!("too many fds: {}", fds.len()), + std::io::ErrorKind::WriteZero, + format!( + "short datagram write: wrote {written} bytes out of {}", + data.len() + ), )); } - let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len()); - frame.extend_from_slice(&encode_length(data.len())?); - frame.extend_from_slice(data); - - let mut control = vec![0u8; control_space_for_fds(fds.len())]; - unsafe { - let cmsg = control.as_mut_ptr().cast::(); - (*cmsg).cmsg_len = libc::CMSG_LEN(size_of::() as c_uint * fds.len() as c_uint) as _; - (*cmsg).cmsg_level = libc::SOL_SOCKET; - (*cmsg).cmsg_type = libc::SCM_RIGHTS; - let data_ptr = libc::CMSG_DATA(cmsg).cast::(); - for (i, fd) in fds.iter().enumerate() { - data_ptr.add(i).write(fd.as_raw_fd()); - } - } - - let payload = [IoSlice::new(&frame)]; - let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); - let mut sent = socket.sendmsg(&msg, 0)?; - while sent < frame.len() { - let bytes = socket.send(&frame[sent..])?; - if bytes == 0 { - return Err(std::io::Error::new( - std::io::ErrorKind::WriteZero, - "socket closed while sending frame payload", - )); - } - sent += bytes; - } Ok(()) } @@ -220,24 +202,16 @@ fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> { Ok(len_u32.to_le_bytes()) } -pub(crate) fn send_json_message( - socket: &Socket, - msg: T, - fds: &[OwnedFd], -) -> std::io::Result<()> { - let data = serde_json::to_vec(&msg)?; - send_message_bytes(socket, &data, fds) -} - -fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> { +fn make_control_message(fds: &[OwnedFd]) -> std::io::Result> { if fds.len() > MAX_FDS_PER_MESSAGE { - return Err(std::io::Error::new( + Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("too many fds: {}", fds.len()), - )); - } - let mut control = vec![0u8; control_space_for_fds(fds.len())]; - if !fds.is_empty() { + )) + } else if fds.is_empty() { + Ok(Vec::new()) + } else { + let mut control = vec![0u8; control_space_for_fds(fds.len())]; unsafe { let cmsg = control.as_mut_ptr().cast::(); (*cmsg).cmsg_len = @@ -249,20 +223,8 @@ fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io data_ptr.add(i).write(fd.as_raw_fd()); } } + Ok(control) } - let payload = [IoSlice::new(data)]; - let msg = MsgHdr::new().with_buffers(&payload).with_control(&control); - let written = socket.sendmsg(&msg, 0)?; - if written != data.len() { - return Err(std::io::Error::new( - std::io::ErrorKind::WriteZero, - format!( - "short datagram write: wrote {written} bytes out of {}", - data.len() - ), - )); - } - Ok(()) } fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec, Vec)> { @@ -308,11 +270,11 @@ impl AsyncSocket { msg: T, fds: &[OwnedFd], ) -> std::io::Result<()> { - self.inner - .async_io(Interest::WRITABLE, |socket| { - send_json_message(socket, &msg, fds) - }) - .await + let payload = serde_json::to_vec(&msg)?; + let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len()); + frame.extend_from_slice(&encode_length(payload.len())?); + frame.extend_from_slice(&payload); + send_stream_frame(&self.inner, &frame, fds).await } pub async fn receive_with_fds Deserialize<'de>>( @@ -343,6 +305,54 @@ impl AsyncSocket { } } +async fn send_stream_frame( + socket: &AsyncFd, + frame: &[u8], + fds: &[OwnedFd], +) -> std::io::Result<()> { + let mut written = 0; + let mut include_fds = !fds.is_empty(); + while written < frame.len() { + let mut guard = socket.writable().await?; + let result = guard.try_io(|inner| { + send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds) + }); + let bytes_written = match result { + Ok(bytes_written) => bytes_written?, + Err(_would_block) => continue, + }; + if bytes_written == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "socket closed while sending frame payload", + )); + } + written += bytes_written; + include_fds = false; + } + Ok(()) +} + +fn send_stream_chunk( + socket: &Socket, + frame: &[u8], + fds: &[OwnedFd], + include_fds: bool, +) -> std::io::Result { + let control = if include_fds { + make_control_message(fds)? + } else { + Vec::new() + }; + let payload = [IoSlice::new(frame)]; + let msg = if control.is_empty() { + MsgHdr::new().with_buffers(&payload) + } else { + MsgHdr::new().with_buffers(&payload).with_control(&control) + }; + socket.sendmsg(&msg, 0) +} + pub(crate) struct AsyncDatagramSocket { inner: AsyncFd, } @@ -433,6 +443,17 @@ mod tests { Ok(()) } + #[tokio::test] + async fn async_socket_handles_large_payload() -> std::io::Result<()> { + let (server, client) = AsyncSocket::pair()?; + let payload = vec![b'A'; 10_000]; + let receive_task = tokio::spawn(async move { server.receive::>().await }); + client.send(payload.clone()).await?; + let received_payload = receive_task.await.unwrap()?; + assert_eq!(payload, received_payload); + Ok(()) + } + #[tokio::test] async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> { let (server, client) = AsyncDatagramSocket::pair()?; @@ -450,19 +471,19 @@ mod tests { } #[test] - fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; - let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err(); + let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); Ok(()) } #[test] - fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> { + let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; - let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); + let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); Ok(()) }