diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 09c5786a9c..6dfc44d844 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2952,12 +2952,19 @@ dependencies = [ name = "codex-goal-extension" version = "0.0.0" dependencies = [ + "anyhow", "async-trait", + "chrono", "codex-extension-api", "codex-protocol", + "codex-state", "codex-tools", + "pretty_assertions", "serde", "serde_json", + "tempfile", + "tokio", + "tracing", ] [[package]] diff --git a/codex-rs/ext/goal/Cargo.toml b/codex-rs/ext/goal/Cargo.toml index 7f8b9bd308..b6590f68e6 100644 --- a/codex-rs/ext/goal/Cargo.toml +++ b/codex-rs/ext/goal/Cargo.toml @@ -17,6 +17,15 @@ workspace = true async-trait = { workspace = true } codex-extension-api = { workspace = true } codex-protocol = { workspace = true } +codex-state = { workspace = true } codex-tools = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +chrono = { workspace = true } +pretty_assertions = { workspace = true } +tempfile = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index a8d4f5c289..657d285f74 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -5,7 +5,6 @@ use codex_extension_api::ConfigContributor; use codex_extension_api::ExtensionData; use codex_extension_api::ExtensionEventSink; use codex_extension_api::ExtensionRegistryBuilder; -use codex_extension_api::NoopExtensionEventSink; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; use codex_extension_api::TokenUsageContributor; @@ -19,14 +18,12 @@ use codex_extension_api::TurnLifecycleContributor; use codex_extension_api::TurnStartInput; use codex_extension_api::TurnStopInput; use codex_protocol::ThreadId; -use codex_protocol::protocol::ThreadGoal; use codex_protocol::protocol::TokenUsageInfo; use codex_protocol::protocol::TurnAbortReason; use crate::accounting::GoalAccountingState; use crate::events::GoalEventEmitter; use crate::spec::UPDATE_GOAL_TOOL_NAME; -use crate::tool::CreateGoalRequest; use crate::tool::GoalToolExecutor; #[derive(Clone, Debug)] @@ -42,7 +39,7 @@ impl GoalExtensionConfig { #[derive(Clone)] pub struct GoalExtension { - backend: Arc, + state_dbs: Arc, event_emitter: GoalEventEmitter, goals_enabled: Arc bool + Send + Sync>, } @@ -54,70 +51,17 @@ impl std::fmt::Debug for GoalExtension { } impl GoalExtension { - pub fn new( - backend: Arc, - goals_enabled: impl Fn(&C) -> bool + Send + Sync + 'static, - ) -> Self { - Self::new_with_event_sink(backend, Arc::new(NoopExtensionEventSink), goals_enabled) - } - - pub fn new_with_event_sink( - backend: Arc, + pub(crate) fn new_with_event_sink( + state_dbs: Arc, event_sink: Arc, goals_enabled: impl Fn(&C) -> bool + Send + Sync + 'static, ) -> Self { Self { - backend, + state_dbs, event_emitter: GoalEventEmitter::new(event_sink), goals_enabled: Arc::new(goals_enabled), } } - - pub fn without_backend(goals_enabled: impl Fn(&C) -> bool + Send + Sync + 'static) -> Self { - Self::new(Arc::new(NoGoalToolBackend), goals_enabled) - } -} - -#[async_trait] -pub trait GoalToolBackend: Send + Sync { - async fn get_goal(&self, thread_id: ThreadId) -> Result, String>; - - async fn create_goal( - &self, - thread_id: ThreadId, - request: CreateGoalRequest, - ) -> Result; - - async fn complete_goal(&self, thread_id: ThreadId) -> Result; -} - -#[derive(Clone, Copy, Debug, Default)] -pub struct NoGoalToolBackend; - -#[async_trait] -impl GoalToolBackend for NoGoalToolBackend { - async fn get_goal(&self, _thread_id: ThreadId) -> Result, String> { - Err(missing_backend_message()) - } - - async fn create_goal( - &self, - _thread_id: ThreadId, - _request: CreateGoalRequest, - ) -> Result { - Err(missing_backend_message()) - } - - async fn complete_goal(&self, _thread_id: ThreadId) -> Result { - Err(missing_backend_message()) - } -} - -fn missing_backend_message() -> String { - // TODO: replace this fallback with a host-provided goal backend once - // ToolContributor invocations can reach thread-scoped goal persistence and - // the current turn context. - "goal tools are not connected to host goal persistence yet".to_string() } #[async_trait] @@ -270,41 +214,32 @@ where vec![ Arc::new(GoalToolExecutor::get( thread_id, - Arc::clone(&self.backend), + Arc::clone(&self.state_dbs), self.event_emitter.clone(), )), Arc::new(GoalToolExecutor::create( thread_id, - Arc::clone(&self.backend), + Arc::clone(&self.state_dbs), self.event_emitter.clone(), )), Arc::new(GoalToolExecutor::update( thread_id, - Arc::clone(&self.backend), + Arc::clone(&self.state_dbs), self.event_emitter.clone(), )), ] } } -pub fn install( - registry: &mut ExtensionRegistryBuilder, - goals_enabled: impl Fn(&C) -> bool + Send + Sync + 'static, -) where - C: Send + Sync + 'static, -{ - install_with_backend(registry, Arc::new(NoGoalToolBackend), goals_enabled); -} - pub fn install_with_backend( registry: &mut ExtensionRegistryBuilder, - backend: Arc, + state_dbs: Arc, goals_enabled: impl Fn(&C) -> bool + Send + Sync + 'static, ) where C: Send + Sync + 'static, { let extension = Arc::new(GoalExtension::new_with_event_sink( - backend, + state_dbs, registry.event_sink(), goals_enabled, )); diff --git a/codex-rs/ext/goal/src/lib.rs b/codex-rs/ext/goal/src/lib.rs index 38332aee41..1625aeae14 100644 --- a/codex-rs/ext/goal/src/lib.rs +++ b/codex-rs/ext/goal/src/lib.rs @@ -12,9 +12,6 @@ mod tool; pub use extension::GoalExtension; pub use extension::GoalExtensionConfig; -pub use extension::GoalToolBackend; -pub use extension::NoGoalToolBackend; -pub use extension::install; pub use extension::install_with_backend; pub use spec::CREATE_GOAL_TOOL_NAME; pub use spec::GET_GOAL_TOOL_NAME; diff --git a/codex-rs/ext/goal/src/tool.rs b/codex-rs/ext/goal/src/tool.rs index 8ad4541a3c..96b160f938 100644 --- a/codex-rs/ext/goal/src/tool.rs +++ b/codex-rs/ext/goal/src/tool.rs @@ -16,7 +16,6 @@ use serde::Deserialize; use serde::Serialize; use crate::events::GoalEventEmitter; -use crate::extension::GoalToolBackend; use crate::spec::CREATE_GOAL_TOOL_NAME; use crate::spec::GET_GOAL_TOOL_NAME; use crate::spec::UPDATE_GOAL_TOOL_NAME; @@ -28,7 +27,7 @@ use crate::spec::create_update_goal_tool; pub(crate) struct GoalToolExecutor { kind: GoalToolKind, thread_id: ThreadId, - backend: Arc, + state_db: Arc, event_emitter: GoalEventEmitter, } @@ -69,39 +68,39 @@ enum CompletionBudgetReport { impl GoalToolExecutor { pub(crate) fn get( thread_id: ThreadId, - backend: Arc, + state_db: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Get, thread_id, - backend, + state_db, event_emitter, } } pub(crate) fn create( thread_id: ThreadId, - backend: Arc, + state_db: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Create, thread_id, - backend, + state_db, event_emitter, } } pub(crate) fn update( thread_id: ThreadId, - backend: Arc, + state_db: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Update, thread_id, - backend, + state_db, event_emitter, } } @@ -141,10 +140,14 @@ impl GoalToolExecutor { ) -> Result, FunctionCallError> { let _ = invocation.function_arguments()?; let goal = self - .backend - .get_goal(self.thread_id) + .state_db + .thread_goals() + .get_thread_goal(self.thread_id) .await - .map_err(FunctionCallError::RespondToModel)?; + .map(|goal| goal.map(protocol_goal_from_state)) + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to read goal: {err}")) + })?; goal_response(goal, CompletionBudgetReport::Omit) } @@ -159,10 +162,24 @@ impl GoalToolExecutor { validate_goal_budget(request.token_budget).map_err(FunctionCallError::RespondToModel)?; let goal = self - .backend - .create_goal(self.thread_id, request) + .state_db + .thread_goals() + .insert_thread_goal( + self.thread_id, + request.objective.as_str(), + codex_state::ThreadGoalStatus::Active, + request.token_budget, + ) .await - .map_err(FunctionCallError::RespondToModel)?; + .map_err(|err| FunctionCallError::RespondToModel(format!("failed to create goal: {err}")))? + .ok_or_else(|| { + FunctionCallError::RespondToModel( + "cannot create a new goal because this thread already has a goal; use update_goal only when the existing goal is complete" + .to_string(), + ) + })?; + fill_empty_thread_preview_if_possible(self.state_db.as_ref(), self.thread_id, &goal).await; + let goal = protocol_goal_from_state(goal); self.emit_goal_updated_from_tool_call(&invocation, goal.clone()); goal_response(Some(goal), CompletionBudgetReport::Omit) } @@ -182,10 +199,27 @@ impl GoalToolExecutor { // TODO: update_goal needs a host callback before completion to flush // final active-turn accounting with budget steering suppressed. let goal = self - .backend - .complete_goal(self.thread_id) + .state_db + .thread_goals() + .update_thread_goal( + self.thread_id, + codex_state::GoalUpdate { + objective: None, + status: Some(codex_state::ThreadGoalStatus::Complete), + token_budget: None, + expected_goal_id: None, + }, + ) .await - .map_err(FunctionCallError::RespondToModel)?; + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to complete goal: {err}")) + })? + .map(protocol_goal_from_state) + .ok_or_else(|| { + FunctionCallError::RespondToModel( + "cannot update goal because this thread has no goal".to_string(), + ) + })?; self.emit_goal_updated_from_tool_call(&invocation, goal.clone()); goal_response(Some(goal), CompletionBudgetReport::Include) } @@ -249,6 +283,45 @@ impl GoalToolResponse { } } +async fn fill_empty_thread_preview_if_possible( + state_db: &codex_state::StateRuntime, + thread_id: ThreadId, + goal: &codex_state::ThreadGoal, +) { + if let Err(err) = state_db + .set_thread_preview_if_empty(thread_id, goal.objective.as_str()) + .await + { + tracing::warn!( + "failed to set empty thread preview from goal objective for {thread_id}: {err}" + ); + } +} + +fn protocol_goal_from_state(goal: codex_state::ThreadGoal) -> ThreadGoal { + ThreadGoal { + thread_id: goal.thread_id, + objective: goal.objective, + status: protocol_status_from_state(goal.status), + token_budget: goal.token_budget, + tokens_used: goal.tokens_used, + time_used_seconds: goal.time_used_seconds, + created_at: goal.created_at.timestamp(), + updated_at: goal.updated_at.timestamp(), + } +} + +fn protocol_status_from_state(status: codex_state::ThreadGoalStatus) -> ThreadGoalStatus { + match status { + codex_state::ThreadGoalStatus::Active => ThreadGoalStatus::Active, + codex_state::ThreadGoalStatus::Paused => ThreadGoalStatus::Paused, + codex_state::ThreadGoalStatus::Blocked => ThreadGoalStatus::Blocked, + codex_state::ThreadGoalStatus::UsageLimited => ThreadGoalStatus::UsageLimited, + codex_state::ThreadGoalStatus::BudgetLimited => ThreadGoalStatus::BudgetLimited, + codex_state::ThreadGoalStatus::Complete => ThreadGoalStatus::Complete, + } +} + fn completion_budget_report(goal: &ThreadGoal) -> Option { if goal.token_budget.is_none() && goal.time_used_seconds <= 0 { None diff --git a/codex-rs/ext/goal/tests/goal_extension_backend.rs b/codex-rs/ext/goal/tests/goal_extension_backend.rs new file mode 100644 index 0000000000..a79c69c633 --- /dev/null +++ b/codex-rs/ext/goal/tests/goal_extension_backend.rs @@ -0,0 +1,171 @@ +use std::sync::Arc; + +use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionRegistryBuilder; +use codex_extension_api::FunctionCallError; +use codex_extension_api::ThreadStartInput; +use codex_extension_api::ToolCall; +use codex_extension_api::ToolExecutor; +use codex_extension_api::ToolPayload; +use codex_goal_extension::install_with_backend; +use codex_protocol::ThreadId; +use codex_protocol::ToolName; +use codex_protocol::protocol::SessionSource; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::TempDir; + +#[tokio::test] +async fn installed_goal_tools_create_goal_and_fill_empty_preview() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let tools = installed_tools(runtime.clone(), thread_id).await; + + let create_tool = tool_by_name(&tools, "create_goal"); + let invocation = ToolCall { + call_id: "call-create-goal".to_string(), + tool_name: ToolName::plain("create_goal"), + payload: ToolPayload::Function { + arguments: json!({ + "objective": "ship goal extension backend", + "token_budget": 123, + }) + .to_string(), + }, + }; + let output = create_tool.handle(invocation.clone()).await?; + let result = output.code_mode_result(&invocation.payload); + assert_eq!( + result, + json!({ + "goal": { + "threadId": thread_id, + "objective": "ship goal extension backend", + "status": "active", + "tokenBudget": 123, + "tokensUsed": 0, + "timeUsedSeconds": 0, + "createdAt": result["goal"]["createdAt"], + "updatedAt": result["goal"]["updatedAt"], + }, + "remainingTokens": 123, + "completionBudgetReport": serde_json::Value::Null, + }) + ); + + let metadata = runtime + .get_thread(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("seeded thread metadata should exist"))?; + assert_eq!( + metadata.preview.as_deref(), + Some("ship goal extension backend") + ); + Ok(()) +} + +#[tokio::test] +async fn installed_goal_tools_reject_duplicate_goal_creation() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let tools = installed_tools(runtime, thread_id).await; + + let create_tool = tool_by_name(&tools, "create_goal"); + let first = tool_call( + "create_goal", + "call-create-goal-1", + json!({ "objective": "first goal" }), + ); + create_tool.handle(first).await?; + + let second = tool_call( + "create_goal", + "call-create-goal-2", + json!({ "objective": "second goal" }), + ); + let err = match create_tool.handle(second).await { + Ok(_) => panic!("duplicate create should fail"), + Err(err) => err, + }; + + assert_eq!( + err, + FunctionCallError::RespondToModel( + "cannot create a new goal because this thread already has a goal; use update_goal only when the existing goal is complete" + .to_string() + ) + ); + Ok(()) +} + +async fn installed_tools( + runtime: Arc, + thread_id: ThreadId, +) -> Vec>> { + let mut builder = ExtensionRegistryBuilder::<()>::new(); + install_with_backend(&mut builder, runtime, |_| true); + let registry = builder.build(); + let session_store = ExtensionData::new("session-1"); + let thread_store = ExtensionData::new(thread_id.to_string()); + for contributor in registry.thread_lifecycle_contributors() { + contributor + .on_thread_start(ThreadStartInput { + config: &(), + session_store: &session_store, + thread_store: &thread_store, + }) + .await; + } + + registry + .tool_contributors() + .iter() + .flat_map(|contributor| contributor.tools(&session_store, &thread_store)) + .collect() +} + +fn tool_by_name<'a>( + tools: &'a [Arc>], + name: &str, +) -> &'a Arc> { + tools + .iter() + .find(|tool| tool.tool_name().namespace.is_none() && tool.tool_name().name == name) + .unwrap_or_else(|| panic!("missing tool {name}")) +} + +fn tool_call(tool_name: &str, call_id: &str, arguments: serde_json::Value) -> ToolCall { + ToolCall { + call_id: call_id.to_string(), + tool_name: ToolName::plain(tool_name), + payload: ToolPayload::Function { + arguments: arguments.to_string(), + }, + } +} + +async fn test_runtime() -> anyhow::Result> { + let tempdir = TempDir::new()?; + codex_state::StateRuntime::init(tempdir.keep(), "test-provider".to_string()).await +} + +fn test_thread_id() -> anyhow::Result { + ThreadId::from_string("11111111-1111-4111-8111-111111111111").map_err(anyhow::Error::msg) +} + +async fn seed_thread_metadata( + runtime: &codex_state::StateRuntime, + thread_id: ThreadId, +) -> anyhow::Result<()> { + let builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + runtime + .codex_home() + .join(format!("rollout-{thread_id}.jsonl")), + chrono::Utc::now(), + SessionSource::Cli, + ); + runtime.upsert_thread(&builder.build("test-provider")).await +}