This commit is contained in:
jimmyfraiture
2025-09-07 17:45:26 -07:00
parent 80eea492ff
commit 85b505afb8
2 changed files with 126 additions and 52 deletions

View File

@@ -114,8 +114,6 @@ impl UnifiedExecSessionManager {
&self,
request: UnifiedExecRequest<'_>,
) -> Result<UnifiedExecResult, UnifiedExecError> {
tracing::error!("In the exec");
// todo update the errors
let timeout_ms = request.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS);
let mut new_session: Option<ManagedUnifiedExecSession> = None;
@@ -125,9 +123,15 @@ impl UnifiedExecSessionManager {
let output_notify;
if let Some(existing_id) = request.session_id {
let sessions = self.sessions.lock().await;
let mut sessions = self.sessions.lock().await;
match sessions.get(&existing_id) {
Some(session) => {
if session.has_exited() {
sessions.remove(&existing_id);
return Err(UnifiedExecError::UnknownSessionId {
session_id: existing_id,
});
}
let (buffer, notify) = session.output_handles();
session_id = existing_id;
writer_tx = session.writer_sender();
@@ -140,6 +144,7 @@ impl UnifiedExecSessionManager {
});
}
}
drop(sessions);
} else {
let command = command_from_chunks(request.input_chunks)?;
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
@@ -203,6 +208,18 @@ impl UnifiedExecSessionManager {
let should_store_session = if let Some(session) = new_session.as_ref() {
!session.has_exited()
} else if request.session_id.is_some() {
let mut sessions = self.sessions.lock().await;
if let Some(existing) = sessions.get(&session_id) {
if existing.has_exited() {
sessions.remove(&session_id);
false
} else {
true
}
} else {
false
}
} else {
true
};
@@ -322,8 +339,8 @@ async fn create_unified_exec_session(
#[cfg(test)]
mod tests {
use super::*;
use super::path::parse_command_line;
use super::*;
#[test]
fn parse_command_line_splits_words() {
@@ -355,7 +372,7 @@ mod tests {
let open_shell = manager
.handle_request(UnifiedExecRequest {
session_id: None,
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
input_chunks: &["bash".to_string(), "-i".to_string()],
timeout_ms: Some(1_500),
})
.await?;
@@ -492,13 +509,78 @@ mod tests {
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn correct_path_resolution() -> Result<(), UnifiedExecError> {
let manager = UnifiedExecSessionManager::default();
let result = manager
.handle_request(UnifiedExecRequest {
session_id: None,
input_chunks: &["echo".to_string(), "codex".to_string()],
timeout_ms: Some(1_500),
})
.await?;
assert!(result.session_id.is_none());
assert!(result.output.contains("codex"));
assert!(manager.sessions.lock().await.is_empty());
Ok(())
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
let manager = UnifiedExecSessionManager::default();
let open_shell = manager
.handle_request(UnifiedExecRequest {
session_id: None,
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
timeout_ms: Some(1_500),
})
.await?;
let session_id = open_shell.session_id.expect("expected session id");
manager
.handle_request(UnifiedExecRequest {
session_id: Some(session_id),
input_chunks: &["exit\n".to_string()],
timeout_ms: Some(1_500),
})
.await?;
tokio::time::sleep(Duration::from_millis(200)).await;
let err = manager
.handle_request(UnifiedExecRequest {
session_id: Some(session_id),
input_chunks: &[],
timeout_ms: Some(100),
})
.await
.expect_err("expected unknown session error");
match err {
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
assert_eq!(err_id, session_id);
}
other => panic!("expected UnknownSessionId, got {other:?}"),
}
assert!(!manager.sessions.lock().await.contains_key(&session_id));
Ok(())
}
#[test]
fn truncate_middle_no_newlines_fallback() {
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let max_bytes = 16;
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
let max_bytes = 32;
let (out, original) = truncate_middle(s, max_bytes);
assert_eq!(out, "…16 tokens truncated…");
assert_eq!(original, Some(16));
assert_eq!(out, "abcdef[TRUNCATED CONTENT]UVWXYZ*");
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
}
#[test]
@@ -510,12 +592,23 @@ mod tests {
assert_eq!(s.len(), 80);
let max_bytes = 64;
assert_eq!(
truncate_middle(&s, max_bytes),
(
"001\n002\n003\n004\n…12 tokens truncated…\n017\n018\n019\n020\n".to_string(),
Some(20)
)
);
let (out, tokens) = truncate_middle(&s, max_bytes);
assert!(out.starts_with("001\n002\n003\n004\n"));
assert!(out.contains("[TRUNCATED CONTENT]"));
assert!(out.ends_with("017\n018\n019\n020\n"));
assert_eq!(tokens, Some(20));
}
#[test]
// Ensure truncation is resilient against multi-bytes chars split.
fn truncate_middle_handles_utf8_content() {
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
let max_bytes = 32;
let (out, tokens) = truncate_middle(s, max_bytes);
assert!(out.contains("[TRUNCATED CONTENT]"));
assert!(out.chars().any(|c| c == '😀'));
assert!(!out.contains('\u{fffd}'));
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
}
}

View File

@@ -1,10 +1,13 @@
const TRUNCATION_MARKER: &str = "[TRUNCATED CONTENT]";
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
if s.len() <= max_bytes {
return (s.to_string(), None);
}
let est_tokens = (s.len() as u64).div_ceil(4);
if max_bytes == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
return (TRUNCATION_MARKER.to_string(), Some(est_tokens));
}
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
@@ -41,50 +44,28 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
idx
}
let mut guess_tokens = est_tokens;
for _ in 0..4 {
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let mut suffix_start = pick_suffix_start(s, right_budget);
if suffix_start < prefix_end {
suffix_start = prefix_end;
}
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
if new_tokens == guess_tokens {
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(&s[suffix_start..]);
return (out, Some(est_tokens));
}
guess_tokens = new_tokens;
let marker_len = TRUNCATION_MARKER.len();
if marker_len >= max_bytes {
return (TRUNCATION_MARKER.to_string(), Some(est_tokens));
}
let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens));
return (TRUNCATION_MARKER.to_string(), Some(est_tokens));
}
let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget);
let suffix_start = pick_suffix_start(s, right_budget);
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
let mut suffix_start = pick_suffix_start(s, right_budget);
if suffix_start < prefix_end {
suffix_start = prefix_end;
}
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start));
out.push_str(&s[..prefix_end]);
out.push_str(&marker);
out.push('\n');
out.push_str(TRUNCATION_MARKER);
out.push_str(&s[suffix_start..]);
(out, Some(est_tokens))
}