mirror of
https://github.com/openai/codex.git
synced 2026-06-02 11:22:01 +00:00
refactor: narrow async lock guard lifetimes (#18211)
Follow-up to https://github.com/openai/codex/pull/18178, where we called out enabling the await-holding lint as a follow-up. The long-term goal is to enable Clippy coverage for async guards held across awaits. This PR is intentionally only the first, low-risk cleanup pass: it narrows obvious lock guard lifetimes and leaves `codex-rs/Cargo.toml` unchanged so the lint is not enabled until the remaining cases are fixed or explicitly justified. It intentionally leaves the active-turn/turn-state locking pattern alone because those checks and mutations need to stay atomic. ## Common fixes used here These are the main patterns reviewers should expect in this PR, and they are also the patterns to reach for when fixing future `await_holding_*` findings: - **Scope the guard to the synchronous work.** If the code only needs data from a locked value, move the lock into a small block, clone or compute the needed values, and do the later `.await` after the block. - **Use direct one-line mutations when there is no later await.** Cases like `map.lock().await.remove(&id)` are acceptable when the guard is only needed for that single mutation and the statement ends before any async work. - **Drain or clone work out of the lock before notifying or awaiting.** For example, the JS REPL drains pending exec senders into a local vector and the websocket writer clones buffered envelopes before it serializes or sends them. - **Use a `Semaphore` only when serialization is intentional across async work.** The test serialization guards intentionally span awaited setup or execution, so using a semaphore communicates "one at a time" without holding a mutex guard. - **Remove the mutex when there is only one owner.** The PTY stdin writer task owns `stdin` directly; the old `Arc<Mutex<_>>` did not protect shared access because nothing else had access to the writer. - **Do not split locks that protect an atomic invariant.** This PR deliberately leaves active-turn/turn-state paths alone because those checks and mutations need to stay atomic. Those cases should be fixed separately with a design change or documented with `#[expect]`. ## What changed - Narrow scoped async mutex guards in app-server, JS REPL, network approval, remote-control websocket, and the RMCP test server. - Replace test-only async mutex serialization guards with semaphores where the guard intentionally lives across async work. - Let the PTY pipe writer task own stdin directly instead of wrapping it in an async mutex. ## Verification - `just fix -p codex-core -p codex-app-server -p codex-rmcp-client -p codex-shell-escalation -p codex-utils-pty -p codex-utils-readiness` - `just clippy -p codex-core` - `cargo test -p codex-core -p codex-app-server -p codex-rmcp-client -p codex-shell-escalation -p codex-utils-pty -p codex-utils-readiness` was run; the app-server suite passed, and `codex-core` failed in the local sandbox on six otel approval tests plus `suite::user_shell_cmd::user_shell_command_does_not_set_network_sandbox_env_var`, which appear to depend on local command approval/default rules and `CODEX_SANDBOX_NETWORK_DISABLED=1` in this environment.
This commit is contained in:
@@ -2016,9 +2016,12 @@ async fn complete_file_change_item(
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let mut state = thread_state.lock().await;
|
||||
state.turn_summary.file_change_started.remove(&item_id);
|
||||
drop(state);
|
||||
thread_state
|
||||
.lock()
|
||||
.await
|
||||
.turn_summary
|
||||
.file_change_started
|
||||
.remove(&item_id);
|
||||
|
||||
let notification = ItemCompletedNotification {
|
||||
thread_id: conversation_id.to_string(),
|
||||
@@ -2087,12 +2090,12 @@ async fn complete_command_execution_item(
|
||||
outgoing: &ThreadScopedOutgoingMessageSender,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let mut state = thread_state.lock().await;
|
||||
let should_emit = state
|
||||
let should_emit = thread_state
|
||||
.lock()
|
||||
.await
|
||||
.turn_summary
|
||||
.command_execution_started
|
||||
.remove(&item_id);
|
||||
drop(state);
|
||||
if !should_emit {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -3521,9 +3521,7 @@ impl CodexMessageProcessor {
|
||||
// No ThreadRollback event will arrive if an error occurs.
|
||||
// Clean up and reply immediately.
|
||||
let thread_state = self.thread_state_manager.thread_state(thread_id).await;
|
||||
let mut thread_state = thread_state.lock().await;
|
||||
thread_state.pending_rollbacks = None;
|
||||
drop(thread_state);
|
||||
thread_state.lock().await.pending_rollbacks = None;
|
||||
|
||||
self.send_internal_error(request, format!("failed to start rollback: {err}"))
|
||||
.await;
|
||||
@@ -8234,8 +8232,10 @@ impl CodexMessageProcessor {
|
||||
start_fuzzy_file_search_session(session_id.clone(), roots, self.outgoing.clone());
|
||||
match session {
|
||||
Ok(session) => {
|
||||
let mut sessions = self.fuzzy_search_sessions.lock().await;
|
||||
sessions.insert(session_id, session);
|
||||
self.fuzzy_search_sessions
|
||||
.lock()
|
||||
.await
|
||||
.insert(session_id, session);
|
||||
self.outgoing
|
||||
.send_response(request_id, FuzzyFileSearchSessionStartResponse {})
|
||||
.await;
|
||||
@@ -10857,13 +10857,9 @@ mod tests {
|
||||
assert_eq!(cancel_rx.await, Ok(()));
|
||||
|
||||
let state = manager.thread_state(thread_id).await;
|
||||
let subscribed_connection_ids = manager.subscribed_connection_ids(thread_id).await;
|
||||
assert!(subscribed_connection_ids.is_empty());
|
||||
let state = state.lock().await;
|
||||
assert!(
|
||||
manager
|
||||
.subscribed_connection_ids(thread_id)
|
||||
.await
|
||||
.is_empty()
|
||||
);
|
||||
assert!(state.cancel_tx.is_none());
|
||||
assert!(state.active_turn_snapshot().is_none());
|
||||
Ok(())
|
||||
|
||||
@@ -392,7 +392,14 @@ impl RemoteControlWebsocket {
|
||||
ping_interval: std::time::Duration,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<()> {
|
||||
for server_envelope in state.lock().await.outbound_buffer.server_envelopes() {
|
||||
let server_envelopes = state
|
||||
.lock()
|
||||
.await
|
||||
.outbound_buffer
|
||||
.server_envelopes()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
for server_envelope in server_envelopes {
|
||||
let payload = match serde_json::to_string(&server_envelope) {
|
||||
Ok(payload) => payload,
|
||||
Err(err) => {
|
||||
@@ -594,21 +601,22 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
};
|
||||
|
||||
let mut websocket_state = state.lock().await;
|
||||
if let Some(cursor) = client_envelope.cursor.as_deref() {
|
||||
websocket_state.subscribe_cursor = Some(cursor.to_string());
|
||||
}
|
||||
if let ClientEvent::Ack = &client_envelope.event
|
||||
&& let Some(acked_seq_id) = client_envelope.seq_id
|
||||
&& let Some(stream_id) = client_envelope.stream_id.as_ref()
|
||||
{
|
||||
websocket_state.outbound_buffer.ack(
|
||||
&client_envelope.client_id,
|
||||
stream_id,
|
||||
acked_seq_id,
|
||||
);
|
||||
let mut websocket_state = state.lock().await;
|
||||
if let Some(cursor) = client_envelope.cursor.as_deref() {
|
||||
websocket_state.subscribe_cursor = Some(cursor.to_string());
|
||||
}
|
||||
if let ClientEvent::Ack = &client_envelope.event
|
||||
&& let Some(acked_seq_id) = client_envelope.seq_id
|
||||
&& let Some(stream_id) = client_envelope.stream_id.as_ref()
|
||||
{
|
||||
websocket_state.outbound_buffer.ack(
|
||||
&client_envelope.client_id,
|
||||
stream_id,
|
||||
acked_seq_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
drop(websocket_state);
|
||||
|
||||
if client_tracker
|
||||
.handle_message(client_envelope)
|
||||
|
||||
@@ -1385,37 +1385,44 @@ impl Session {
|
||||
&self,
|
||||
updates: SessionSettingsUpdate,
|
||||
) -> ConstraintResult<()> {
|
||||
let mut state = self.state.lock().await;
|
||||
|
||||
match state.session_configuration.apply(&updates) {
|
||||
Ok(updated) => {
|
||||
let previous_cwd = state.session_configuration.cwd.clone();
|
||||
let sandbox_policy_changed =
|
||||
state.session_configuration.sandbox_policy != updated.sandbox_policy;
|
||||
let next_cwd = updated.cwd.clone();
|
||||
let codex_home = updated.codex_home.clone();
|
||||
let session_source = updated.session_source.clone();
|
||||
state.session_configuration = updated;
|
||||
drop(state);
|
||||
|
||||
self.maybe_refresh_shell_snapshot_for_cwd(
|
||||
&previous_cwd,
|
||||
&next_cwd,
|
||||
&codex_home,
|
||||
&session_source,
|
||||
);
|
||||
if sandbox_policy_changed {
|
||||
self.refresh_managed_network_proxy_for_current_sandbox_policy()
|
||||
.await;
|
||||
let (previous_cwd, sandbox_policy_changed, next_cwd, codex_home, session_source) = {
|
||||
let mut state = self.state.lock().await;
|
||||
let updated = match state.session_configuration.apply(&updates) {
|
||||
Ok(updated) => updated,
|
||||
Err(err) => {
|
||||
warn!("rejected session settings update: {err}");
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("rejected session settings update: {err}");
|
||||
Err(err)
|
||||
}
|
||||
let previous_cwd = state.session_configuration.cwd.clone();
|
||||
let sandbox_policy_changed =
|
||||
state.session_configuration.sandbox_policy != updated.sandbox_policy;
|
||||
let next_cwd = updated.cwd.clone();
|
||||
let codex_home = updated.codex_home.clone();
|
||||
let session_source = updated.session_source.clone();
|
||||
state.session_configuration = updated;
|
||||
(
|
||||
previous_cwd,
|
||||
sandbox_policy_changed,
|
||||
next_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
)
|
||||
};
|
||||
|
||||
self.maybe_refresh_shell_snapshot_for_cwd(
|
||||
&previous_cwd,
|
||||
&next_cwd,
|
||||
&codex_home,
|
||||
&session_source,
|
||||
);
|
||||
if sandbox_policy_changed {
|
||||
self.refresh_managed_network_proxy_for_current_sandbox_policy()
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn set_session_startup_prewarm(
|
||||
|
||||
@@ -446,13 +446,7 @@ impl Session {
|
||||
sub_id: String,
|
||||
updates: SessionSettingsUpdate,
|
||||
) -> ConstraintResult<Arc<TurnContext>> {
|
||||
let (
|
||||
session_configuration,
|
||||
sandbox_policy_changed,
|
||||
previous_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
) = {
|
||||
let update_result = {
|
||||
let mut state = self.state.lock().await;
|
||||
match state.session_configuration.clone().apply(&updates) {
|
||||
Ok(next) => {
|
||||
@@ -462,26 +456,36 @@ impl Session {
|
||||
let codex_home = next.codex_home.clone();
|
||||
let session_source = next.session_source.clone();
|
||||
state.session_configuration = next.clone();
|
||||
(
|
||||
Ok((
|
||||
next,
|
||||
sandbox_policy_changed,
|
||||
previous_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
)
|
||||
}
|
||||
Err(err) => {
|
||||
drop(state);
|
||||
self.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return Err(err);
|
||||
))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
};
|
||||
|
||||
let (
|
||||
session_configuration,
|
||||
sandbox_policy_changed,
|
||||
previous_cwd,
|
||||
codex_home,
|
||||
session_source,
|
||||
) = match update_result {
|
||||
Ok(update) => update,
|
||||
Err(err) => {
|
||||
self.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: err.to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::BadRequest),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -884,9 +884,8 @@ impl JsReplManager {
|
||||
|
||||
let (req_id, rx) = {
|
||||
let req_id = Uuid::new_v4().to_string();
|
||||
let mut pending = pending_execs.lock().await;
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
pending.insert(req_id.clone(), tx);
|
||||
pending_execs.lock().await.insert(req_id.clone(), tx);
|
||||
exec_contexts.lock().await.insert(
|
||||
req_id.clone(),
|
||||
ExecContext {
|
||||
@@ -956,9 +955,7 @@ impl JsReplManager {
|
||||
let response = match tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await {
|
||||
Ok(Ok(msg)) => msg,
|
||||
Ok(Err(_)) => {
|
||||
let mut pending = pending_execs.lock().await;
|
||||
let removed = pending.remove(&req_id).is_some();
|
||||
drop(pending);
|
||||
let removed = pending_execs.lock().await.remove(&req_id).is_some();
|
||||
if removed {
|
||||
self.clear_top_level_exec_if_matches(&req_id).await;
|
||||
}
|
||||
@@ -1340,40 +1337,40 @@ impl JsReplManager {
|
||||
KernelToHost::EmitImage(req) => {
|
||||
let exec_id = req.exec_id.clone();
|
||||
let emit_id = req.id.clone();
|
||||
let response =
|
||||
if let Some(ctx) = exec_contexts.lock().await.get(&exec_id).cloned() {
|
||||
match validate_emitted_image_url(&req.image_url) {
|
||||
Ok(()) => {
|
||||
let content_item = emitted_image_content_item(
|
||||
ctx.turn.as_ref(),
|
||||
req.image_url,
|
||||
req.detail,
|
||||
);
|
||||
JsReplManager::record_exec_content_item(
|
||||
&exec_tool_calls,
|
||||
&exec_id,
|
||||
content_item,
|
||||
)
|
||||
.await;
|
||||
HostToKernel::EmitImageResult(EmitImageResult {
|
||||
id: emit_id,
|
||||
ok: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(error) => HostToKernel::EmitImageResult(EmitImageResult {
|
||||
let context = exec_contexts.lock().await.get(&exec_id).cloned();
|
||||
let response = if let Some(ctx) = context {
|
||||
match validate_emitted_image_url(&req.image_url) {
|
||||
Ok(()) => {
|
||||
let content_item = emitted_image_content_item(
|
||||
ctx.turn.as_ref(),
|
||||
req.image_url,
|
||||
req.detail,
|
||||
);
|
||||
JsReplManager::record_exec_content_item(
|
||||
&exec_tool_calls,
|
||||
&exec_id,
|
||||
content_item,
|
||||
)
|
||||
.await;
|
||||
HostToKernel::EmitImageResult(EmitImageResult {
|
||||
id: emit_id,
|
||||
ok: false,
|
||||
error: Some(error),
|
||||
}),
|
||||
ok: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
HostToKernel::EmitImageResult(EmitImageResult {
|
||||
Err(error) => HostToKernel::EmitImageResult(EmitImageResult {
|
||||
id: emit_id,
|
||||
ok: false,
|
||||
error: Some("js_repl exec context not found".to_string()),
|
||||
})
|
||||
};
|
||||
error: Some(error),
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
HostToKernel::EmitImageResult(EmitImageResult {
|
||||
id: emit_id,
|
||||
ok: false,
|
||||
error: Some("js_repl exec context not found".to_string()),
|
||||
})
|
||||
};
|
||||
|
||||
if let Err(err) = JsReplManager::write_message(&stdin, &response).await {
|
||||
let snapshot =
|
||||
@@ -1424,7 +1421,7 @@ impl JsReplManager {
|
||||
let exec_id = req.exec_id.clone();
|
||||
let tool_call_id = req.id.clone();
|
||||
let tool_name = req.tool_name.clone();
|
||||
let context = { exec_contexts.lock().await.get(&exec_id).cloned() };
|
||||
let context = exec_contexts.lock().await.get(&exec_id).cloned();
|
||||
let result = match context {
|
||||
Some(ctx) => {
|
||||
tokio::select! {
|
||||
@@ -1502,14 +1499,17 @@ impl JsReplManager {
|
||||
}
|
||||
}
|
||||
|
||||
let mut pending = pending_execs.lock().await;
|
||||
let pending_exec_ids = pending.keys().cloned().collect::<Vec<_>>();
|
||||
for (_id, tx) in pending.drain() {
|
||||
let pending_execs_to_notify = {
|
||||
let mut pending = pending_execs.lock().await;
|
||||
pending.drain().collect::<Vec<_>>()
|
||||
};
|
||||
let mut pending_exec_ids = Vec::with_capacity(pending_execs_to_notify.len());
|
||||
for (id, tx) in pending_execs_to_notify {
|
||||
pending_exec_ids.push(id);
|
||||
let _ = tx.send(ExecResultMessage::Err {
|
||||
message: kernel_exit_message.clone(),
|
||||
});
|
||||
}
|
||||
drop(pending);
|
||||
if !pending_exec_ids.is_empty() {
|
||||
Self::clear_top_level_exec_if_matches_any_map(&manager_kernel, &pending_exec_ids).await;
|
||||
}
|
||||
|
||||
@@ -223,10 +223,8 @@ impl NetworkApprovalService {
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister_call(&self, registration_id: &str) {
|
||||
let mut active_calls = self.active_calls.lock().await;
|
||||
active_calls.shift_remove(registration_id);
|
||||
let mut call_outcomes = self.call_outcomes.lock().await;
|
||||
call_outcomes.remove(registration_id);
|
||||
self.active_calls.lock().await.shift_remove(registration_id);
|
||||
self.call_outcomes.lock().await.remove(registration_id);
|
||||
}
|
||||
|
||||
async fn resolve_single_active_call(&self) -> Option<Arc<ActiveNetworkApprovalCall>> {
|
||||
@@ -344,8 +342,7 @@ impl NetworkApprovalService {
|
||||
|
||||
let Some(turn_context) = Self::active_turn_context(session.as_ref()).await else {
|
||||
pending.set_decision(PendingApprovalDecision::Deny).await;
|
||||
let mut pending_approvals = self.pending_host_approvals.lock().await;
|
||||
pending_approvals.remove(&key);
|
||||
self.pending_host_approvals.lock().await.remove(&key);
|
||||
self.record_outcome_for_single_active_call(NetworkApprovalOutcome::DeniedByPolicy(
|
||||
policy_denial_message,
|
||||
))
|
||||
@@ -354,8 +351,7 @@ impl NetworkApprovalService {
|
||||
};
|
||||
if !sandbox_policy_allows_network_approval_flow(turn_context.sandbox_policy.get()) {
|
||||
pending.set_decision(PendingApprovalDecision::Deny).await;
|
||||
let mut pending_approvals = self.pending_host_approvals.lock().await;
|
||||
pending_approvals.remove(&key);
|
||||
self.pending_host_approvals.lock().await.remove(&key);
|
||||
self.record_outcome_for_single_active_call(NetworkApprovalOutcome::DeniedByPolicy(
|
||||
policy_denial_message,
|
||||
))
|
||||
@@ -364,8 +360,7 @@ impl NetworkApprovalService {
|
||||
}
|
||||
if !allows_network_approval_flow(turn_context.approval_policy.value()) {
|
||||
pending.set_decision(PendingApprovalDecision::Deny).await;
|
||||
let mut pending_approvals = self.pending_host_approvals.lock().await;
|
||||
pending_approvals.remove(&key);
|
||||
self.pending_host_approvals.lock().await.remove(&key);
|
||||
self.record_outcome_for_single_active_call(NetworkApprovalOutcome::DeniedByPolicy(
|
||||
policy_denial_message,
|
||||
))
|
||||
|
||||
@@ -418,22 +418,23 @@ async fn fail_session_post_when_armed(
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let mut armed_failure = state.armed_failure.lock().await;
|
||||
if let Some(failure) = armed_failure.as_mut()
|
||||
&& failure.remaining > 0
|
||||
{
|
||||
failure.remaining -= 1;
|
||||
let status = failure.status;
|
||||
if failure.remaining == 0 {
|
||||
*armed_failure = None;
|
||||
let mut armed_failure = state.armed_failure.lock().await;
|
||||
if let Some(failure) = armed_failure.as_mut()
|
||||
&& failure.remaining > 0
|
||||
{
|
||||
failure.remaining -= 1;
|
||||
let status = failure.status;
|
||||
if failure.remaining == 0 {
|
||||
*armed_failure = None;
|
||||
}
|
||||
let mut response = Response::new(Body::from(format!(
|
||||
"forced session failure with status {status}"
|
||||
)));
|
||||
*response.status_mut() = status;
|
||||
return response;
|
||||
}
|
||||
let mut response = Response::new(Body::from(format!(
|
||||
"forced session failure with status {status}"
|
||||
)));
|
||||
*response.status_mut() = status;
|
||||
return response;
|
||||
}
|
||||
|
||||
drop(armed_failure);
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
@@ -391,11 +391,11 @@ mod tests {
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
|
||||
static ESCALATE_SERVER_TEST_LOCK: LazyLock<tokio::sync::Mutex<()>> =
|
||||
LazyLock::new(|| tokio::sync::Mutex::new(()));
|
||||
static ESCALATE_SERVER_TEST_LOCK: LazyLock<Semaphore> = LazyLock::new(|| Semaphore::new(1));
|
||||
|
||||
struct DeterministicEscalationPolicy {
|
||||
decision: EscalationDecision,
|
||||
@@ -596,7 +596,7 @@ mod tests {
|
||||
/// until `close_client_socket()` is called.
|
||||
#[tokio::test]
|
||||
async fn start_session_exposes_wrapper_env_overlay() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let execve_wrapper = PathBuf::from("/tmp/codex-execve-wrapper");
|
||||
let execve_wrapper_str = execve_wrapper.to_string_lossy().to_string();
|
||||
let server = EscalateServer::new(
|
||||
@@ -638,7 +638,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn exec_closes_parent_socket_after_shell_spawn() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let after_spawn_invoked = Arc::new(AtomicBool::new(false));
|
||||
let server = EscalateServer::new(
|
||||
PathBuf::from("/bin/bash"),
|
||||
@@ -672,7 +672,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
@@ -712,7 +712,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_resolves_relative_file_against_request_workdir()
|
||||
-> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let tmp = tempfile::TempDir::new()?;
|
||||
let workdir = tmp.path().join("workspace");
|
||||
@@ -751,7 +751,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
@@ -844,7 +844,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_accepts_received_fds_that_overlap_destinations()
|
||||
-> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let mut pipe_fds = [0; 2];
|
||||
if unsafe { libc::pipe(pipe_fds.as_mut_ptr()) } == -1 {
|
||||
return Err(std::io::Error::last_os_error().into());
|
||||
@@ -916,7 +916,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_passes_permissions_to_executor() -> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
@@ -972,7 +972,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn dropping_session_aborts_intercept_workers_and_kills_spawned_child()
|
||||
-> anyhow::Result<()> {
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
|
||||
let _guard = ESCALATE_SERVER_TEST_LOCK.acquire().await?;
|
||||
let tmp = TempDir::new()?;
|
||||
let pid_file = tmp.path().join("escalated-child.pid");
|
||||
let pid_file_display = pid_file.display().to_string();
|
||||
|
||||
@@ -162,12 +162,11 @@ async fn spawn_process_with_stdin_mode(
|
||||
let (stdout_tx, stdout_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let (stderr_tx, stderr_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||
let writer_handle = if let Some(stdin) = stdin {
|
||||
let writer = Arc::new(tokio::sync::Mutex::new(stdin));
|
||||
tokio::spawn(async move {
|
||||
let mut writer = stdin;
|
||||
while let Some(bytes) = writer_rx.recv().await {
|
||||
let mut guard = writer.lock().await;
|
||||
let _ = guard.write_all(&bytes).await;
|
||||
let _ = guard.flush().await;
|
||||
let _ = writer.write_all(&bytes).await;
|
||||
let _ = writer.flush().await;
|
||||
}
|
||||
})
|
||||
} else {
|
||||
|
||||
@@ -277,17 +277,36 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_returns_error_when_lock_is_held() {
|
||||
let flag = ReadinessFlag::new();
|
||||
let _guard = flag
|
||||
.tokens
|
||||
.try_lock()
|
||||
.expect("initial lock acquisition should succeed");
|
||||
let flag = Arc::new(ReadinessFlag::new());
|
||||
let (locked_tx, locked_rx) = std::sync::mpsc::channel();
|
||||
let (release_tx, release_rx) = std::sync::mpsc::channel();
|
||||
let lock_thread = {
|
||||
let flag = Arc::clone(&flag);
|
||||
std::thread::spawn(move || {
|
||||
let _guard = flag.tokens.blocking_lock();
|
||||
locked_tx
|
||||
.send(())
|
||||
.expect("test should receive lock acquisition notification");
|
||||
release_rx
|
||||
.recv()
|
||||
.expect("test should release held readiness lock");
|
||||
})
|
||||
};
|
||||
locked_rx
|
||||
.recv()
|
||||
.expect("test should observe held readiness lock");
|
||||
|
||||
let err = flag
|
||||
.subscribe()
|
||||
.await
|
||||
.expect_err("contended subscribe should report a lock failure");
|
||||
assert_matches!(err, ReadinessError::TokenLockFailed);
|
||||
release_tx
|
||||
.send(())
|
||||
.expect("test should release readiness lock thread");
|
||||
lock_thread
|
||||
.join()
|
||||
.expect("readiness lock thread should not panic");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user