fix: exec-server stream was erroring for large requests (#7654)

Previous to this change, large `EscalateRequest` payloads exceeded the
kernel send buffer, causing our single `sendmsg(2)` call (with attached
FDs) to be split and retried without proper control handling; this led
to `EINVAL`/broken pipe in the
`handle_escalate_session_respects_run_in_sandbox_decision()` test when
using an `env` with large contents.

**Before:** `AsyncSocket::send_with_fds()` called `send_json_message()`,
which called `send_message_bytes()`, which made one `socket.sendmsg()`
call followed by additional `socket.send()` calls, as necessary:


2e4a402521/codex-rs/exec-server/src/posix/socket.rs (L198-L209)

**After:** `AsyncSocket::send_with_fds()` now calls
`send_stream_frame()`, which calls `send_stream_chunk()` one or more
times. Each call to `send_stream_chunk()` calls `socket.sendmsg()`.

In the previous implementation, the subsequent `socket.send()` writes
had no control information associated with them, whereas in the new
`send_stream_chunk()` implementation, a fresh `MsgHdr` (using
`with_control()`, as appropriate) is created for `socket.sendmsg()` each
time.

Additionally, with this PR, stream sending attaches `SCM_RIGHTS` only on
the first chunk, and omits control data when there are no FDs, allowing
oversized payloads to deliver correctly while preserving FD limits and
error checks.
This commit is contained in:
Michael Bolin
2025-12-06 10:16:47 -08:00
committed by GitHub
parent f521d29726
commit 82090803d9
2 changed files with 100 additions and 73 deletions

View File

@@ -258,12 +258,18 @@ mod tests {
}),
));
let mut env = HashMap::new();
for i in 0..10 {
let value = "A".repeat(1024);
env.insert(format!("CODEX_TEST_VAR{i}"), value);
}
client
.send(EscalateRequest {
file: PathBuf::from("/bin/echo"),
argv: vec!["echo".to_string()],
workdir: PathBuf::from("/tmp"),
env: HashMap::new(),
env,
})
.await?;

View File

@@ -171,42 +171,24 @@ async fn read_frame_payload(
unreachable!("loop exits only after returning payload")
}
fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
if fds.len() > MAX_FDS_PER_MESSAGE {
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
let control = make_control_message(fds)?;
let payload = [IoSlice::new(data)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
let written = socket.sendmsg(&msg, 0)?;
if written != data.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
std::io::ErrorKind::WriteZero,
format!(
"short datagram write: wrote {written} bytes out of {}",
data.len()
),
));
}
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len());
frame.extend_from_slice(&encode_length(data.len())?);
frame.extend_from_slice(data);
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len = libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
for (i, fd) in fds.iter().enumerate() {
data_ptr.add(i).write(fd.as_raw_fd());
}
}
let payload = [IoSlice::new(&frame)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let mut sent = socket.sendmsg(&msg, 0)?;
while sent < frame.len() {
let bytes = socket.send(&frame[sent..])?;
if bytes == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"socket closed while sending frame payload",
));
}
sent += bytes;
}
Ok(())
}
@@ -220,24 +202,16 @@ fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
Ok(len_u32.to_le_bytes())
}
pub(crate) fn send_json_message<T: Serialize>(
socket: &Socket,
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
let data = serde_json::to_vec(&msg)?;
send_message_bytes(socket, &data, fds)
}
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
fn make_control_message(fds: &[OwnedFd]) -> std::io::Result<Vec<u8>> {
if fds.len() > MAX_FDS_PER_MESSAGE {
return Err(std::io::Error::new(
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("too many fds: {}", fds.len()),
));
}
let mut control = vec![0u8; control_space_for_fds(fds.len())];
if !fds.is_empty() {
))
} else if fds.is_empty() {
Ok(Vec::new())
} else {
let mut control = vec![0u8; control_space_for_fds(fds.len())];
unsafe {
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
(*cmsg).cmsg_len =
@@ -249,20 +223,8 @@ fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io
data_ptr.add(i).write(fd.as_raw_fd());
}
}
Ok(control)
}
let payload = [IoSlice::new(data)];
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
let written = socket.sendmsg(&msg, 0)?;
if written != data.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
format!(
"short datagram write: wrote {written} bytes out of {}",
data.len()
),
));
}
Ok(())
}
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
@@ -308,11 +270,11 @@ impl AsyncSocket {
msg: T,
fds: &[OwnedFd],
) -> std::io::Result<()> {
self.inner
.async_io(Interest::WRITABLE, |socket| {
send_json_message(socket, &msg, fds)
})
.await
let payload = serde_json::to_vec(&msg)?;
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
frame.extend_from_slice(&encode_length(payload.len())?);
frame.extend_from_slice(&payload);
send_stream_frame(&self.inner, &frame, fds).await
}
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
@@ -343,6 +305,54 @@ impl AsyncSocket {
}
}
async fn send_stream_frame(
socket: &AsyncFd<Socket>,
frame: &[u8],
fds: &[OwnedFd],
) -> std::io::Result<()> {
let mut written = 0;
let mut include_fds = !fds.is_empty();
while written < frame.len() {
let mut guard = socket.writable().await?;
let result = guard.try_io(|inner| {
send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)
});
let bytes_written = match result {
Ok(bytes_written) => bytes_written?,
Err(_would_block) => continue,
};
if bytes_written == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"socket closed while sending frame payload",
));
}
written += bytes_written;
include_fds = false;
}
Ok(())
}
fn send_stream_chunk(
socket: &Socket,
frame: &[u8],
fds: &[OwnedFd],
include_fds: bool,
) -> std::io::Result<usize> {
let control = if include_fds {
make_control_message(fds)?
} else {
Vec::new()
};
let payload = [IoSlice::new(frame)];
let msg = if control.is_empty() {
MsgHdr::new().with_buffers(&payload)
} else {
MsgHdr::new().with_buffers(&payload).with_control(&control)
};
socket.sendmsg(&msg, 0)
}
pub(crate) struct AsyncDatagramSocket {
inner: AsyncFd<Socket>,
}
@@ -433,6 +443,17 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn async_socket_handles_large_payload() -> std::io::Result<()> {
let (server, client) = AsyncSocket::pair()?;
let payload = vec![b'A'; 10_000];
let receive_task = tokio::spawn(async move { server.receive::<Vec<u8>>().await });
client.send(payload.clone()).await?;
let received_payload = receive_task.await.unwrap()?;
assert_eq!(payload, received_payload);
Ok(())
}
#[tokio::test]
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
let (server, client) = AsyncDatagramSocket::pair()?;
@@ -450,19 +471,19 @@ mod tests {
}
#[test]
fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err();
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}
#[test]
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
Ok(())
}