use arc instead of overusing mutex

This commit is contained in:
Ahmed Ibrahim
2025-08-04 19:17:44 -07:00
parent f68bf94db1
commit c1e9083cbd
4 changed files with 76 additions and 57 deletions

View File

@@ -30,6 +30,10 @@ pub(crate) struct Conversation {
session_id: Uuid,
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
state: Mutex<ConversationState>,
}
struct ConversationState {
running: bool,
streaming_enabled: bool,
buffered_events: Vec<CodexEventNotificationParams>,
@@ -42,40 +46,49 @@ impl Conversation {
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
session_id: Uuid,
) -> Arc<Mutex<Self>> {
let conv = Arc::new(Mutex::new(Self {
) -> Arc<Self> {
let conv = Arc::new(Self {
codex,
session_id,
outgoing,
request_id,
running: false,
streaming_enabled: false,
buffered_events: Vec::new(),
pending_elicitations: Vec::new(),
}));
state: Mutex::new(ConversationState {
running: false,
streaming_enabled: false,
buffered_events: Vec::new(),
pending_elicitations: Vec::new(),
}),
});
// Detach a background loop tied to this Conversation
Conversation::spawn_loop(conv.clone());
conv
}
pub(crate) async fn set_streaming(&mut self, enabled: bool) {
pub(crate) async fn set_streaming(&self, enabled: bool) {
if enabled {
self.streaming_enabled = true;
self.emit_initial_state().await;
self.drain_pending_elicitations().await;
let (events_snapshot, pending_snapshot) = {
let mut st = self.state.lock().await;
st.streaming_enabled = true;
(
st.buffered_events.clone(),
std::mem::take(&mut st.pending_elicitations),
)
};
self.emit_initial_state_with(events_snapshot).await;
self.drain_pending_elicitations_from(pending_snapshot).await;
} else {
self.streaming_enabled = false;
let mut st = self.state.lock().await;
st.streaming_enabled = false;
}
}
fn spawn_loop(this: Arc<Mutex<Self>>) {
fn spawn_loop(this: Arc<Self>) {
tokio::spawn(async move {
loop {
// We clone codex to avoid holding the lock while awaiting next_event
let codex = { this.lock().await.codex.clone() };
// Codex can be awaited without locking Conversation
let codex = this.codex.clone();
let res = codex.next_event().await;
let mut guard = this.lock().await;
guard.handle_next_event(res).await;
this.handle_next_event(res).await;
}
});
}
@@ -85,15 +98,17 @@ impl Conversation {
}
pub(crate) async fn try_submit_user_input(
&mut self,
&self,
request_id: RequestId,
items: Vec<codex_core::protocol::InputItem>,
) -> Result<(), String> {
if self.running {
return Err("Session is already running".to_string());
{
let mut st = self.state.lock().await;
if st.running {
return Err("Session is already running".to_string());
}
st.running = true;
}
// Optimistically mark running to avoid races between quick successive submits
self.running = true;
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
@@ -106,23 +121,26 @@ impl Conversation {
})
.await;
if let Err(e) = submit_res {
// Revert running on error
self.running = false;
let mut st = self.state.lock().await;
st.running = false;
return Err(format!("Failed to submit user input: {e}"));
}
Ok(())
}
async fn handle_next_event<E>(&mut self, res: Result<Event, E>)
async fn handle_next_event<E>(&self, res: Result<Event, E>)
where
E: std::fmt::Display,
{
match res {
Ok(event) => {
self.buffered_events.push(CodexEventNotificationParams {
meta: None,
msg: event.msg.clone(),
});
{
let mut st = self.state.lock().await;
st.buffered_events.push(CodexEventNotificationParams {
meta: None,
msg: event.msg.clone(),
});
}
self.stream_event_if_enabled(&event.msg).await;
match event.msg {
@@ -191,15 +209,13 @@ impl Conversation {
}
}
async fn emit_initial_state(&self) {
async fn emit_initial_state_with(&self, events: Vec<CodexEventNotificationParams>) {
let params = InitialStateNotificationParams {
meta: Some(NotificationMeta {
conversation_id: Some(ConversationId(self.session_id)),
request_id: None,
}),
initial_state: InitialStatePayload {
events: self.buffered_events.clone(),
},
initial_state: InitialStatePayload { events },
};
if let Ok(params_val) = serde_json::to_value(&params) {
self.outgoing
@@ -210,8 +226,8 @@ impl Conversation {
}
}
async fn drain_pending_elicitations(&mut self) {
for item in self.pending_elicitations.drain(..) {
async fn drain_pending_elicitations_from(&self, items: Vec<PendingElicitation>) {
for item in items {
match item {
PendingElicitation::ExecRequest(ExecRequest {
command,
@@ -262,13 +278,14 @@ impl Conversation {
}
async fn process_exec_request(
&mut self,
&self,
command: Vec<String>,
cwd: PathBuf,
call_id: String,
event_id: String,
) {
if self.streaming_enabled {
let should_stream = { self.state.lock().await.streaming_enabled };
if should_stream {
handle_exec_approval_request(
command,
cwd,
@@ -284,7 +301,8 @@ impl Conversation {
)
.await;
} else {
self.pending_elicitations
let mut st = self.state.lock().await;
st.pending_elicitations
.push(PendingElicitation::ExecRequest(ExecRequest {
command,
cwd,
@@ -294,7 +312,7 @@ impl Conversation {
}
}
async fn process_patch_request(&mut self, req: PatchRequest) {
async fn process_patch_request(&self, req: PatchRequest) {
let PatchRequest {
call_id,
reason,
@@ -302,7 +320,8 @@ impl Conversation {
changes,
event_id,
} = req;
if self.streaming_enabled {
let should_stream = { self.state.lock().await.streaming_enabled };
if should_stream {
handle_patch_approval_request(
call_id,
reason,
@@ -319,7 +338,8 @@ impl Conversation {
)
.await;
} else {
self.pending_elicitations
let mut st = self.state.lock().await;
st.pending_elicitations
.push(PendingElicitation::PatchRequest(PatchRequest {
call_id,
reason,
@@ -331,7 +351,7 @@ impl Conversation {
}
async fn stream_event_if_enabled(&self, msg: &EventMsg) {
if !self.streaming_enabled {
if !{ self.state.lock().await.streaming_enabled } {
return;
}
let method = msg.to_string();
@@ -348,12 +368,14 @@ impl Conversation {
}
}
async fn handle_task_started(&mut self) {
self.running = true;
async fn handle_task_started(&self) {
let mut st = self.state.lock().await;
st.running = true;
}
async fn handle_task_clear(&mut self) {
self.running = false;
async fn handle_task_clear(&self) {
let mut st = self.state.lock().await;
st.running = false;
}
}

View File

@@ -44,7 +44,7 @@ pub(crate) struct MessageProcessor {
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Mutex<crate::conversation_loop::Conversation>>>>>,
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
/// Track request IDs to the original ToolCallRequestParams for cancellation handling
tool_request_map: Arc<Mutex<HashMap<RequestId, ToolCallRequestParams>>>,
@@ -70,7 +70,7 @@ impl MessageProcessor {
pub(crate) fn conversation_map(
&self,
) -> Arc<Mutex<HashMap<Uuid, Arc<Mutex<crate::conversation_loop::Conversation>>>>> {
) -> Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>> {
self.conversation_map.clone()
}
@@ -632,7 +632,7 @@ impl MessageProcessor {
let codex_arc = {
let sessions_guard = self.conversation_map.lock().await;
match sessions_guard.get(&session_id) {
Some(conv) => conv.lock().await.codex().clone(),
Some(conv) => conv.codex().clone(),
None => {
tracing::warn!(
"Cancel send_message: session not found for session_id: {session_id}"

View File

@@ -55,10 +55,7 @@ pub(crate) async fn handle_send_message(
return;
};
let res = {
let mut guard = conversation.lock().await;
guard.try_submit_user_input(id.clone(), items).await
};
let res = conversation.try_submit_user_input(id.clone(), items).await;
if let Err(e) = res {
message_processor
@@ -86,8 +83,8 @@ pub(crate) async fn handle_send_message(
pub(crate) async fn get_session(
session_id: Uuid,
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Mutex<Conversation>>>>>,
) -> Option<Arc<Mutex<Conversation>>> {
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Conversation>>>>,
) -> Option<Arc<Conversation>> {
let guard = conversation_map.lock().await;
guard.get(&session_id).cloned()
}

View File

@@ -40,7 +40,7 @@ pub(crate) async fn handle_stream_conversation(
if let Some(conv) = conv {
tokio::spawn(async move {
conv.lock().await.set_streaming(true).await;
conv.set_streaming(true).await;
});
}
}
@@ -52,6 +52,6 @@ pub(crate) async fn handle_cancel(
) {
let session_id = args.conversation_id.0;
if let Some(conv) = get_session(session_id, message_processor.conversation_map()).await {
conv.lock().await.set_streaming(false).await;
conv.set_streaming(false).await;
}
}