diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index cca5ca0b43..4209b8f39f 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1061,6 +1061,8 @@ dependencies = [ "clap", "codex-app-server-protocol", "codex-core", + "codex-lmstudio", + "codex-ollama", "codex-protocol", "once_cell", "serde", @@ -1159,7 +1161,6 @@ dependencies = [ "codex-arg0", "codex-common", "codex-core", - "codex-ollama", "codex-protocol", "core_test_support", "libc", @@ -1278,6 +1279,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "codex-lmstudio" +version = "0.0.0" +dependencies = [ + "codex-core", + "reqwest", + "serde_json", + "tokio", + "tracing", + "which", + "wiremock", +] + [[package]] name = "codex-login" version = "0.0.0" @@ -1475,7 +1489,6 @@ dependencies = [ "codex-feedback", "codex-file-search", "codex-login", - "codex-ollama", "codex-protocol", "codex-windows-sandbox", "color-eyre", @@ -1498,6 +1511,7 @@ dependencies = [ "ratatui", "ratatui-macros", "regex-lite", + "reqwest", "serde", "serde_json", "serial_test", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index a8d16dd8be..5e2bd05e47 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -21,6 +21,7 @@ members = [ "keyring-store", "file-search", "linux-sandbox", + "lmstudio", "login", "mcp-server", "mcp-types", @@ -70,6 +71,7 @@ codex-file-search = { path = "file-search" } codex-git = { path = "utils/git" } codex-keyring-store = { path = "keyring-store" } codex-linux-sandbox = { path = "linux-sandbox" } +codex-lmstudio = { path = "lmstudio" } codex-login = { path = "login" } codex-mcp-server = { path = "mcp-server" } codex-ollama = { path = "ollama" } diff --git a/codex-rs/common/Cargo.toml b/codex-rs/common/Cargo.toml index 285d56b99b..cff9c4b307 100644 --- a/codex-rs/common/Cargo.toml +++ b/codex-rs/common/Cargo.toml @@ -10,6 +10,8 @@ workspace = true clap = { workspace = true, features = ["derive", "wrap_help"], optional = true } codex-app-server-protocol = { workspace = true } codex-core = { workspace = true } +codex-lmstudio = { workspace = true } +codex-ollama = { workspace = true } codex-protocol = { workspace = true } once_cell = { workspace = true } serde = { workspace = true, optional = true } diff --git a/codex-rs/common/src/lib.rs b/codex-rs/common/src/lib.rs index 276bfca069..5092b3be24 100644 --- a/codex-rs/common/src/lib.rs +++ b/codex-rs/common/src/lib.rs @@ -37,3 +37,5 @@ pub mod model_presets; // Shared approval presets (AskForApproval + Sandbox) used by TUI and MCP server // Not to be confused with AskForApproval, which we should probably rename to EscalationPolicy. pub mod approval_presets; +// Shared OSS provider utilities used by TUI and exec +pub mod oss; diff --git a/codex-rs/common/src/oss.rs b/codex-rs/common/src/oss.rs new file mode 100644 index 0000000000..b2f511e478 --- /dev/null +++ b/codex-rs/common/src/oss.rs @@ -0,0 +1,60 @@ +//! OSS provider utilities shared between TUI and exec. + +use codex_core::LMSTUDIO_OSS_PROVIDER_ID; +use codex_core::OLLAMA_OSS_PROVIDER_ID; +use codex_core::config::Config; + +/// Returns the default model for a given OSS provider. +pub fn get_default_model_for_oss_provider(provider_id: &str) -> Option<&'static str> { + match provider_id { + LMSTUDIO_OSS_PROVIDER_ID => Some(codex_lmstudio::DEFAULT_OSS_MODEL), + OLLAMA_OSS_PROVIDER_ID => Some(codex_ollama::DEFAULT_OSS_MODEL), + _ => None, + } +} + +/// Ensures the specified OSS provider is ready (models downloaded, service reachable). +pub async fn ensure_oss_provider_ready( + provider_id: &str, + config: &Config, +) -> Result<(), std::io::Error> { + match provider_id { + LMSTUDIO_OSS_PROVIDER_ID => { + codex_lmstudio::ensure_oss_ready(config) + .await + .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?; + } + OLLAMA_OSS_PROVIDER_ID => { + codex_ollama::ensure_oss_ready(config) + .await + .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?; + } + _ => { + // Unknown provider, skip setup + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_default_model_for_provider_lmstudio() { + let result = get_default_model_for_oss_provider(LMSTUDIO_OSS_PROVIDER_ID); + assert_eq!(result, Some(codex_lmstudio::DEFAULT_OSS_MODEL)); + } + + #[test] + fn test_get_default_model_for_provider_ollama() { + let result = get_default_model_for_oss_provider(OLLAMA_OSS_PROVIDER_ID); + assert_eq!(result, Some(codex_ollama::DEFAULT_OSS_MODEL)); + } + + #[test] + fn test_get_default_model_for_provider_unknown() { + let result = get_default_model_for_oss_provider("unknown-provider"); + assert_eq!(result, None); + } +} diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index b1e5b7f98a..1bf180dd36 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -25,7 +25,9 @@ use crate::git_info::resolve_root_git_project_for_trust; use crate::model_family::ModelFamily; use crate::model_family::derive_default_model_family; use crate::model_family::find_family_for_model; +use crate::model_provider_info::LMSTUDIO_OSS_PROVIDER_ID; use crate::model_provider_info::ModelProviderInfo; +use crate::model_provider_info::OLLAMA_OSS_PROVIDER_ID; use crate::model_provider_info::built_in_model_providers; use crate::openai_model_info::get_model_info; use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME; @@ -466,6 +468,48 @@ pub fn set_project_trust_level( .apply_blocking() } +/// Save the default OSS provider preference to config.toml +pub fn set_default_oss_provider(codex_home: &Path, provider: &str) -> std::io::Result<()> { + // Validate that the provider is one of the known OSS providers + match provider { + LMSTUDIO_OSS_PROVIDER_ID | OLLAMA_OSS_PROVIDER_ID => { + // Valid provider, continue + } + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "Invalid OSS provider '{provider}'. Must be one of: {LMSTUDIO_OSS_PROVIDER_ID}, {OLLAMA_OSS_PROVIDER_ID}" + ), + )); + } + } + let config_path = codex_home.join(CONFIG_TOML_FILE); + + // Read existing config or create empty string if file doesn't exist + let content = match std::fs::read_to_string(&config_path) { + Ok(content) => content, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => String::new(), + Err(e) => return Err(e), + }; + + // Parse as DocumentMut for editing while preserving structure + let mut doc = content.parse::().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("failed to parse config.toml: {e}"), + ) + })?; + + // Set the default_oss_provider at root level + use toml_edit::value; + doc["oss_provider"] = value(provider); + + // Write the modified document back + std::fs::write(&config_path, doc.to_string())?; + Ok(()) +} + /// Apply a single dotted-path override onto a TOML value. fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) { use toml::value::Table; @@ -663,6 +707,8 @@ pub struct ConfigToml { pub experimental_use_rmcp_client: Option, pub experimental_use_freeform_apply_patch: Option, pub experimental_sandbox_command_assessment: Option, + /// Preferred OSS provider for local models, e.g. "lmstudio" or "ollama". + pub oss_provider: Option, } impl From for UserSavedConfig { @@ -851,6 +897,34 @@ pub struct ConfigOverrides { pub additional_writable_roots: Vec, } +/// Resolves the OSS provider from CLI override, profile config, or global config. +/// Returns `None` if no provider is configured at any level. +pub fn resolve_oss_provider( + explicit_provider: Option<&str>, + config_toml: &ConfigToml, + config_profile: Option, +) -> Option { + if let Some(provider) = explicit_provider { + // Explicit provider specified (e.g., via --local-provider) + Some(provider.to_string()) + } else { + // Check profile config first, then global config + let profile = config_toml.get_config_profile(config_profile).ok(); + if let Some(profile) = &profile { + // Check if profile has an oss provider + if let Some(profile_oss_provider) = &profile.oss_provider { + Some(profile_oss_provider.clone()) + } + // If not then check if the toml has an oss provider + else { + config_toml.oss_provider.clone() + } + } else { + config_toml.oss_provider.clone() + } + } +} + impl Config { /// Meant to be used exclusively for tests: `load_with_overrides()` should /// be used in all other cases. @@ -3265,6 +3339,41 @@ trust_level = "trusted" Ok(()) } + #[test] + fn test_set_default_oss_provider() -> std::io::Result<()> { + let temp_dir = TempDir::new()?; + let codex_home = temp_dir.path(); + let config_path = codex_home.join(CONFIG_TOML_FILE); + + // Test setting valid provider on empty config + set_default_oss_provider(codex_home, OLLAMA_OSS_PROVIDER_ID)?; + let content = std::fs::read_to_string(&config_path)?; + assert!(content.contains("oss_provider = \"ollama\"")); + + // Test updating existing config + std::fs::write(&config_path, "model = \"gpt-4\"\n")?; + set_default_oss_provider(codex_home, LMSTUDIO_OSS_PROVIDER_ID)?; + let content = std::fs::read_to_string(&config_path)?; + assert!(content.contains("oss_provider = \"lmstudio\"")); + assert!(content.contains("model = \"gpt-4\"")); + + // Test overwriting existing oss_provider + set_default_oss_provider(codex_home, OLLAMA_OSS_PROVIDER_ID)?; + let content = std::fs::read_to_string(&config_path)?; + assert!(content.contains("oss_provider = \"ollama\"")); + assert!(!content.contains("oss_provider = \"lmstudio\"")); + + // Test invalid provider + let result = set_default_oss_provider(codex_home, "invalid_provider"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::InvalidInput); + assert!(error.to_string().contains("Invalid OSS provider")); + assert!(error.to_string().contains("invalid_provider")); + + Ok(()) + } + #[test] fn test_untrusted_project_gets_workspace_write_sandbox() -> anyhow::Result<()> { let config_with_untrusted = r#" @@ -3295,6 +3404,85 @@ trust_level = "untrusted" Ok(()) } + #[test] + fn test_resolve_oss_provider_explicit_override() { + let config_toml = ConfigToml::default(); + let result = resolve_oss_provider(Some("custom-provider"), &config_toml, None); + assert_eq!(result, Some("custom-provider".to_string())); + } + + #[test] + fn test_resolve_oss_provider_from_profile() { + let mut profiles = std::collections::HashMap::new(); + let profile = ConfigProfile { + oss_provider: Some("profile-provider".to_string()), + ..Default::default() + }; + profiles.insert("test-profile".to_string(), profile); + let config_toml = ConfigToml { + profiles, + ..Default::default() + }; + + let result = resolve_oss_provider(None, &config_toml, Some("test-profile".to_string())); + assert_eq!(result, Some("profile-provider".to_string())); + } + + #[test] + fn test_resolve_oss_provider_from_global_config() { + let config_toml = ConfigToml { + oss_provider: Some("global-provider".to_string()), + ..Default::default() + }; + + let result = resolve_oss_provider(None, &config_toml, None); + assert_eq!(result, Some("global-provider".to_string())); + } + + #[test] + fn test_resolve_oss_provider_profile_fallback_to_global() { + let mut profiles = std::collections::HashMap::new(); + let profile = ConfigProfile::default(); // No oss_provider set + profiles.insert("test-profile".to_string(), profile); + let config_toml = ConfigToml { + oss_provider: Some("global-provider".to_string()), + profiles, + ..Default::default() + }; + + let result = resolve_oss_provider(None, &config_toml, Some("test-profile".to_string())); + assert_eq!(result, Some("global-provider".to_string())); + } + + #[test] + fn test_resolve_oss_provider_none_when_not_configured() { + let config_toml = ConfigToml::default(); + let result = resolve_oss_provider(None, &config_toml, None); + assert_eq!(result, None); + } + + #[test] + fn test_resolve_oss_provider_explicit_overrides_all() { + let mut profiles = std::collections::HashMap::new(); + let profile = ConfigProfile { + oss_provider: Some("profile-provider".to_string()), + ..Default::default() + }; + profiles.insert("test-profile".to_string(), profile); + let config_toml = ConfigToml { + oss_provider: Some("global-provider".to_string()), + profiles, + ..Default::default() + }; + + let result = resolve_oss_provider( + Some("explicit-provider"), + &config_toml, + Some("test-profile".to_string()), + ); + assert_eq!(result, Some("explicit-provider".to_string())); + } + #[test] fn test_untrusted_project_gets_unless_trusted_approval_policy() -> std::io::Result<()> { let codex_home = TempDir::new()?; diff --git a/codex-rs/core/src/config/profile.rs b/codex-rs/core/src/config/profile.rs index 6d872546af..3d9e60b8e5 100644 --- a/codex-rs/core/src/config/profile.rs +++ b/codex-rs/core/src/config/profile.rs @@ -33,6 +33,7 @@ pub struct ConfigProfile { /// Optional feature toggles scoped to this profile. #[serde(default)] pub features: Option, + pub oss_provider: Option, } impl From for codex_app_server_protocol::Profile { diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 5229d00606..e684781194 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -40,8 +40,11 @@ pub mod token_data; mod truncate; mod unified_exec; mod user_instructions; -pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; +pub use model_provider_info::DEFAULT_LMSTUDIO_PORT; +pub use model_provider_info::DEFAULT_OLLAMA_PORT; +pub use model_provider_info::LMSTUDIO_OSS_PROVIDER_ID; pub use model_provider_info::ModelProviderInfo; +pub use model_provider_info::OLLAMA_OSS_PROVIDER_ID; pub use model_provider_info::WireApi; pub use model_provider_info::built_in_model_providers; pub use model_provider_info::create_oss_provider_with_base_url; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 8dc252aa7c..6aa7a31dc5 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -258,9 +258,11 @@ impl ModelProviderInfo { } } -const DEFAULT_OLLAMA_PORT: u32 = 11434; +pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234; +pub const DEFAULT_OLLAMA_PORT: u16 = 11434; -pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; +pub const LMSTUDIO_OSS_PROVIDER_ID: &str = "lmstudio"; +pub const OLLAMA_OSS_PROVIDER_ID: &str = "ollama"; /// Built-in default provider list. pub fn built_in_model_providers() -> HashMap { @@ -311,14 +313,21 @@ pub fn built_in_model_providers() -> HashMap { requires_openai_auth: true, }, ), - (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()), + ( + OLLAMA_OSS_PROVIDER_ID, + create_oss_provider(DEFAULT_OLLAMA_PORT, WireApi::Chat), + ), + ( + LMSTUDIO_OSS_PROVIDER_ID, + create_oss_provider(DEFAULT_LMSTUDIO_PORT, WireApi::Responses), + ), ] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() } -pub fn create_oss_provider() -> ModelProviderInfo { +pub fn create_oss_provider(default_provider_port: u16, wire_api: WireApi) -> ModelProviderInfo { // These CODEX_OSS_ environment variables are experimental: we may // switch to reading values from config.toml instead. let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") @@ -331,22 +340,21 @@ pub fn create_oss_provider() -> ModelProviderInfo { port = std::env::var("CODEX_OSS_PORT") .ok() .filter(|v| !v.trim().is_empty()) - .and_then(|v| v.parse::().ok()) - .unwrap_or(DEFAULT_OLLAMA_PORT) + .and_then(|v| v.parse::().ok()) + .unwrap_or(default_provider_port) ), }; - - create_oss_provider_with_base_url(&codex_oss_base_url) + create_oss_provider_with_base_url(&codex_oss_base_url, wire_api) } -pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo { +pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> ModelProviderInfo { ModelProviderInfo { name: "gpt-oss".into(), base_url: Some(base_url.into()), env_key: None, env_key_instructions: None, experimental_bearer_token: None, - wire_api: WireApi::Chat, + wire_api, query_params: None, http_headers: None, env_http_headers: None, diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index 8fc1e38875..4a0ddf7dde 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -24,7 +24,6 @@ codex-common = { workspace = true, features = [ "sandbox_summary", ] } codex-core = { workspace = true } -codex-ollama = { workspace = true } codex-protocol = { workspace = true } mcp-types = { workspace = true } opentelemetry-appender-tracing = { workspace = true } diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index 0b5b0b4b6d..ef20bf6fc7 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -18,9 +18,15 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, + /// Use open-source provider. #[arg(long = "oss", default_value_t = false)] pub oss: bool, + /// Specify which local provider to use (lmstudio or ollama). + /// If not specified with --oss, will use config default or show selection. + #[arg(long = "local-provider")] + pub oss_provider: Option, + /// Select the sandbox policy to use when executing model-generated shell /// commands. #[arg(long = "sandbox", short = 's', value_enum)] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index a9cf6b2c6d..a003b4ff21 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -11,20 +11,25 @@ pub mod event_processor_with_jsonl_output; pub mod exec_events; pub use cli::Cli; +use codex_common::oss::ensure_oss_provider_ready; +use codex_common::oss::get_default_model_for_oss_provider; use codex_core::AuthManager; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::ConversationManager; +use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::NewConversation; +use codex_core::OLLAMA_OSS_PROVIDER_ID; use codex_core::auth::enforce_login_restrictions; use codex_core::config::Config; use codex_core::config::ConfigOverrides; +use codex_core::config::find_codex_home; +use codex_core::config::load_config_as_toml_with_cli_overrides; +use codex_core::config::resolve_oss_provider; use codex_core::git_info::get_git_repo_root; use codex_core::protocol::AskForApproval; use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_core::protocol::SessionSource; -use codex_ollama::DEFAULT_OSS_MODEL; use codex_protocol::config_types::SandboxMode; use codex_protocol::user_input::UserInput; use event_processor_with_human_output::EventProcessorWithHumanOutput; @@ -57,6 +62,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any images, model: model_cli_arg, oss, + oss_provider, config_profile, full_auto, dangerously_bypass_approvals_and_sandbox, @@ -146,21 +152,64 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any sandbox_mode_cli_arg.map(Into::::into) }; - // When using `--oss`, let the bootstrapper pick the model (defaulting to - // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in - // `oss` model provider. - let model = if let Some(model) = model_cli_arg { - Some(model) - } else if oss { - Some(DEFAULT_OSS_MODEL.to_owned()) - } else { - None // No model specified, will use the default. + // Parse `-c` overrides from the CLI. + let cli_kv_overrides = match config_overrides.parse_overrides() { + Ok(v) => v, + #[allow(clippy::print_stderr)] + Err(e) => { + eprintln!("Error parsing -c overrides: {e}"); + std::process::exit(1); + } + }; + + // we load config.toml here to determine project state. + #[allow(clippy::print_stderr)] + let config_toml = { + let codex_home = match find_codex_home() { + Ok(codex_home) => codex_home, + Err(err) => { + eprintln!("Error finding codex home: {err}"); + std::process::exit(1); + } + }; + + match load_config_as_toml_with_cli_overrides(&codex_home, cli_kv_overrides.clone()).await { + Ok(config_toml) => config_toml, + Err(err) => { + eprintln!("Error loading config.toml: {err}"); + std::process::exit(1); + } + } }; let model_provider = if oss { - Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string()) + let resolved = resolve_oss_provider( + oss_provider.as_deref(), + &config_toml, + config_profile.clone(), + ); + + if let Some(provider) = resolved { + Some(provider) + } else { + return Err(anyhow::anyhow!( + "No default OSS provider configured. Use --local-provider=provider or set oss_provider to either {LMSTUDIO_OSS_PROVIDER_ID} or {OLLAMA_OSS_PROVIDER_ID} in config.toml" + )); + } } else { - None // No specific model provider override. + None // No OSS mode enabled + }; + + // When using `--oss`, let the bootstrapper pick the model based on selected provider + let model = if let Some(model) = model_cli_arg { + Some(model) + } else if oss { + model_provider + .as_ref() + .and_then(|provider_id| get_default_model_for_oss_provider(provider_id)) + .map(std::borrow::ToOwned::to_owned) + } else { + None // No model specified, will use the default. }; // Load configuration and determine approval policy @@ -172,7 +221,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any approval_policy: Some(AskForApproval::Never), sandbox_mode, cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), - model_provider, + model_provider: model_provider.clone(), codex_linux_sandbox_exe, base_instructions: None, developer_instructions: None, @@ -183,14 +232,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any experimental_sandbox_command_assessment: None, additional_writable_roots: add_dir, }; - // Parse `-c` overrides. - let cli_kv_overrides = match config_overrides.parse_overrides() { - Ok(v) => v, - Err(e) => { - eprintln!("Error parsing -c overrides: {e}"); - std::process::exit(1); - } - }; let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides).await?; @@ -233,7 +274,18 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any }; if oss { - codex_ollama::ensure_oss_ready(&config) + // We're in the oss section, so provider_id should be Some + // Let's handle None case gracefully though just in case + let provider_id = match model_provider.as_ref() { + Some(id) => id, + None => { + error!("OSS provider unexpectedly not set when oss flag is used"); + return Err(anyhow::anyhow!( + "OSS provider not set but oss flag was used" + )); + } + }; + ensure_oss_provider_ready(provider_id, &config) .await .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?; } diff --git a/codex-rs/lmstudio/Cargo.toml b/codex-rs/lmstudio/Cargo.toml new file mode 100644 index 0000000000..4035d5529b --- /dev/null +++ b/codex-rs/lmstudio/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "codex-lmstudio" +version.workspace = true +edition.workspace = true + +[lib] +name = "codex_lmstudio" +path = "src/lib.rs" + + +[dependencies] +codex-core = { path = "../core" } +reqwest = { version = "0.12", features = ["json", "stream"] } +serde_json = "1" +tokio = { version = "1", features = ["rt"] } +tracing = { version = "0.1.41", features = ["log"] } +which = "6.0" + +[dev-dependencies] +wiremock = "0.6" +tokio = { version = "1", features = ["full"] } + +[lints] +workspace = true diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs new file mode 100644 index 0000000000..a2a8ee03bf --- /dev/null +++ b/codex-rs/lmstudio/src/client.rs @@ -0,0 +1,397 @@ +use codex_core::LMSTUDIO_OSS_PROVIDER_ID; +use codex_core::config::Config; +use std::io; +use std::path::Path; + +#[derive(Clone)] +pub struct LMStudioClient { + client: reqwest::Client, + base_url: String, +} + +const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; + +impl LMStudioClient { + pub async fn try_from_provider(config: &Config) -> std::io::Result { + let provider = config + .model_providers + .get(LMSTUDIO_OSS_PROVIDER_ID) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + format!("Built-in provider {LMSTUDIO_OSS_PROVIDER_ID} not found",), + ) + })?; + let base_url = provider.base_url.as_ref().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "oss provider must have a base_url", + ) + })?; + + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + + let client = LMStudioClient { + client, + base_url: base_url.to_string(), + }; + client.check_server().await?; + + Ok(client) + } + + async fn check_server(&self) -> io::Result<()> { + let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let response = self.client.get(&url).send().await; + + if let Ok(resp) = response { + if resp.status().is_success() { + Ok(()) + } else { + Err(io::Error::other(format!( + "Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}", + resp.status() + ))) + } + } else { + Err(io::Error::other(LMSTUDIO_CONNECTION_ERROR)) + } + } + + // Load a model by sending an empty request with max_tokens 1 + pub async fn load_model(&self, model: &str) -> io::Result<()> { + let url = format!("{}/responses", self.base_url.trim_end_matches('/')); + + let request_body = serde_json::json!({ + "model": model, + "input": "", + "max_output_tokens": 1 + }); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + + if response.status().is_success() { + tracing::info!("Successfully loaded model '{model}'"); + Ok(()) + } else { + Err(io::Error::other(format!( + "Failed to load model: {}", + response.status() + ))) + } + } + + // Return the list of models available on the LM Studio server. + pub async fn fetch_models(&self) -> io::Result> { + let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let response = self + .client + .get(&url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + + if response.status().is_success() { + let json: serde_json::Value = response.json().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; + let models = json["data"] + .as_array() + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response") + })? + .iter() + .filter_map(|model| model["id"].as_str()) + .map(std::string::ToString::to_string) + .collect(); + Ok(models) + } else { + Err(io::Error::other(format!( + "Failed to fetch models: {}", + response.status() + ))) + } + } + + // Find lms, checking fallback paths if not in PATH + fn find_lms() -> std::io::Result { + Self::find_lms_with_home_dir(None) + } + + fn find_lms_with_home_dir(home_dir: Option<&str>) -> std::io::Result { + // First try 'lms' in PATH + if which::which("lms").is_ok() { + return Ok("lms".to_string()); + } + + // Platform-specific fallback paths + let home = match home_dir { + Some(dir) => dir.to_string(), + None => { + #[cfg(unix)] + { + std::env::var("HOME").unwrap_or_default() + } + #[cfg(windows)] + { + std::env::var("USERPROFILE").unwrap_or_default() + } + } + }; + + #[cfg(unix)] + let fallback_path = format!("{home}/.lmstudio/bin/lms"); + + #[cfg(windows)] + let fallback_path = format!("{home}/.lmstudio/bin/lms.exe"); + + if Path::new(&fallback_path).exists() { + Ok(fallback_path) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "LM Studio not found. Please install LM Studio from https://lmstudio.ai/", + )) + } + } + + pub async fn download_model(&self, model: &str) -> std::io::Result<()> { + let lms = Self::find_lms()?; + eprintln!("Downloading model: {model}"); + + let status = std::process::Command::new(&lms) + .args(["get", "--yes", model]) + .stdout(std::process::Stdio::inherit()) + .stderr(std::process::Stdio::null()) + .status() + .map_err(|e| { + std::io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}")) + })?; + + if !status.success() { + return Err(std::io::Error::other(format!( + "Model download failed with exit code: {}", + status.code().unwrap_or(-1) + ))); + } + + tracing::info!("Successfully downloaded model '{model}'"); + Ok(()) + } + + /// Low-level constructor given a raw host root, e.g. "http://localhost:1234". + #[cfg(test)] + fn from_host_root(host_root: impl Into) -> Self { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + base_url: host_root.into(), + } + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used, clippy::unwrap_used)] + use super::*; + + #[tokio::test] + async fn test_fetch_models_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_models_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/models")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "data": [ + {"id": "openai/gpt-oss-20b"}, + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let models = client.fetch_models().await.expect("fetch models"); + assert!(models.contains(&"openai/gpt-oss-20b".to_string())); + } + + #[tokio::test] + async fn test_fetch_models_no_data_array() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_models_no_data_array", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/models")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_raw(serde_json::json!({}).to_string(), "application/json"), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.fetch_models().await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("No 'data' array in response") + ); + } + + #[tokio::test] + async fn test_fetch_models_server_error() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_models_server_error", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/models")) + .respond_with(wiremock::ResponseTemplate::new(500)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.fetch_models().await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to fetch models: 500") + ); + } + + #[tokio::test] + async fn test_check_server_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_check_server_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/models")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + client + .check_server() + .await + .expect("server check should pass"); + } + + #[tokio::test] + async fn test_check_server_error() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_check_server_error", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/models")) + .respond_with(wiremock::ResponseTemplate::new(404)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.check_server().await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Server returned error: 404") + ); + } + + #[test] + fn test_find_lms() { + let result = LMStudioClient::find_lms(); + + match result { + Ok(_) => { + // lms was found in PATH - that's fine + } + Err(e) => { + // Expected error when LM Studio not installed + assert!(e.to_string().contains("LM Studio not found")); + } + } + } + + #[test] + fn test_find_lms_with_mock_home() { + // Test fallback path construction without touching env vars + #[cfg(unix)] + { + let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home")); + if let Err(e) = result { + assert!(e.to_string().contains("LM Studio not found")); + } + } + + #[cfg(windows)] + { + let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home")); + if let Err(e) = result { + assert!(e.to_string().contains("LM Studio not found")); + } + } + } + + #[test] + fn test_from_host_root() { + let client = LMStudioClient::from_host_root("http://localhost:1234"); + assert_eq!(client.base_url, "http://localhost:1234"); + + let client = LMStudioClient::from_host_root("https://example.com:8080/api"); + assert_eq!(client.base_url, "https://example.com:8080/api"); + } +} diff --git a/codex-rs/lmstudio/src/lib.rs b/codex-rs/lmstudio/src/lib.rs new file mode 100644 index 0000000000..bb8c8cef6a --- /dev/null +++ b/codex-rs/lmstudio/src/lib.rs @@ -0,0 +1,43 @@ +mod client; + +pub use client::LMStudioClient; +use codex_core::config::Config; + +/// Default OSS model to use when `--oss` is passed without an explicit `-m`. +pub const DEFAULT_OSS_MODEL: &str = "openai/gpt-oss-20b"; + +/// Prepare the local OSS environment when `--oss` is selected. +/// +/// - Ensures a local LM Studio server is reachable. +/// - Checks if the model exists locally and downloads it if missing. +pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { + let model: &str = config.model.as_ref(); + + // Verify local LM Studio is reachable. + let lmstudio_client = LMStudioClient::try_from_provider(config).await?; + + match lmstudio_client.fetch_models().await { + Ok(models) => { + if !models.iter().any(|m| m == model) { + lmstudio_client.download_model(model).await?; + } + } + Err(err) => { + // Not fatal; higher layers may still proceed and surface errors later. + tracing::warn!("Failed to query local models from LM Studio: {}.", err); + } + } + + // Load the model in the background + tokio::spawn({ + let client = lmstudio_client.clone(); + let model = model.to_string(); + async move { + if let Err(e) = client.load_model(&model).await { + tracing::warn!("Failed to load model {}: {}", model, e); + } + } + }); + + Ok(()) +} diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs index 04b7e9dea2..93244cc2e5 100644 --- a/codex-rs/ollama/src/client.rs +++ b/codex-rs/ollama/src/client.rs @@ -10,8 +10,8 @@ use crate::pull::PullEvent; use crate::pull::PullProgressReporter; use crate::url::base_url_to_host_root; use crate::url::is_openai_compatible_base_url; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::ModelProviderInfo; +use codex_core::OLLAMA_OSS_PROVIDER_ID; use codex_core::WireApi; use codex_core::config::Config; @@ -34,11 +34,11 @@ impl OllamaClient { // account. let provider = config .model_providers - .get(BUILT_IN_OSS_MODEL_PROVIDER_ID) + .get(OLLAMA_OSS_PROVIDER_ID) .ok_or_else(|| { io::Error::new( io::ErrorKind::NotFound, - format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",), + format!("Built-in provider {OLLAMA_OSS_PROVIDER_ID} not found",), ) })?; @@ -47,7 +47,8 @@ impl OllamaClient { #[cfg(test)] async fn try_from_provider_with_base_url(base_url: &str) -> io::Result { - let provider = codex_core::create_oss_provider_with_base_url(base_url); + let provider = + codex_core::create_oss_provider_with_base_url(base_url, codex_core::WireApi::Chat); Self::try_from_provider(&provider).await } diff --git a/codex-rs/tui/Cargo.toml b/codex-rs/tui/Cargo.toml index b524d8bfd4..1a1c5f4538 100644 --- a/codex-rs/tui/Cargo.toml +++ b/codex-rs/tui/Cargo.toml @@ -38,7 +38,6 @@ codex-core = { workspace = true } codex-feedback = { workspace = true } codex-file-search = { workspace = true } codex-login = { workspace = true } -codex-ollama = { workspace = true } codex-protocol = { workspace = true } color-eyre = { workspace = true } crossterm = { workspace = true, features = ["bracketed-paste", "event-stream"] } @@ -62,6 +61,7 @@ ratatui = { workspace = true, features = [ ] } ratatui-macros = { workspace = true } regex-lite = { workspace = true } +reqwest = { version = "0.12", features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["preserve_order"] } shlex = { workspace = true } diff --git a/codex-rs/tui/src/cli.rs b/codex-rs/tui/src/cli.rs index 5a7e409932..e7a0c945bf 100644 --- a/codex-rs/tui/src/cli.rs +++ b/codex-rs/tui/src/cli.rs @@ -32,12 +32,16 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, - /// Convenience flag to select the local open source model provider. - /// Equivalent to -c model_provider=oss; verifies a local Ollama server is - /// running. + /// Convenience flag to select the local open source model provider. Equivalent to -c + /// model_provider=oss; verifies a local LM Studio or Ollama server is running. #[arg(long = "oss", default_value_t = false)] pub oss: bool, + /// Specify which local provider to use (lmstudio or ollama). + /// If not specified with --oss, will use config default or show selection. + #[arg(long = "local-provider")] + pub oss_provider: Option, + /// Configuration profile from config.toml to specify default options. #[arg(long = "profile", short = 'p')] pub config_profile: Option, diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 3953f8b50f..ca7cba9c24 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -7,18 +7,21 @@ use additional_dirs::add_dir_warning_message; use app::App; pub use app::AppExitInfo; use codex_app_server_protocol::AuthMode; +use codex_common::oss::ensure_oss_provider_ready; +use codex_common::oss::get_default_model_for_oss_provider; use codex_core::AuthManager; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::CodexAuth; use codex_core::INTERACTIVE_SESSION_SOURCES; use codex_core::RolloutRecorder; use codex_core::auth::enforce_login_restrictions; use codex_core::config::Config; use codex_core::config::ConfigOverrides; +use codex_core::config::find_codex_home; +use codex_core::config::load_config_as_toml_with_cli_overrides; +use codex_core::config::resolve_oss_provider; use codex_core::find_conversation_path_by_id_str; use codex_core::get_platform_sandbox; use codex_core::protocol::AskForApproval; -use codex_ollama::DEFAULT_OSS_MODEL; use codex_protocol::config_types::SandboxMode; use opentelemetry_appender_tracing::layer::OpenTelemetryTracingBridge; use std::fs::OpenOptions; @@ -56,6 +59,7 @@ mod markdown_render; mod markdown_stream; mod model_migration; pub mod onboarding; +mod oss_selection; mod pager_overlay; pub mod public_widgets; mod render; @@ -124,21 +128,75 @@ pub async fn run_main( // When using `--oss`, let the bootstrapper pick the model (defaulting to // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in + let raw_overrides = cli.config_overrides.raw_overrides.clone(); // `oss` model provider. + let overrides_cli = codex_common::CliConfigOverrides { raw_overrides }; + let cli_kv_overrides = match overrides_cli.parse_overrides() { + // Parse `-c` overrides from the CLI. + Ok(v) => v, + #[allow(clippy::print_stderr)] + Err(e) => { + eprintln!("Error parsing -c overrides: {e}"); + std::process::exit(1); + } + }; + + // we load config.toml here to determine project state. + #[allow(clippy::print_stderr)] + let codex_home = match find_codex_home() { + Ok(codex_home) => codex_home.to_path_buf(), + Err(err) => { + eprintln!("Error finding codex home: {err}"); + std::process::exit(1); + } + }; + + #[allow(clippy::print_stderr)] + let config_toml = + match load_config_as_toml_with_cli_overrides(&codex_home, cli_kv_overrides.clone()).await { + Ok(config_toml) => config_toml, + Err(err) => { + eprintln!("Error loading config.toml: {err}"); + std::process::exit(1); + } + }; + + let model_provider_override = if cli.oss { + let resolved = resolve_oss_provider( + cli.oss_provider.as_deref(), + &config_toml, + cli.config_profile.clone(), + ); + + if let Some(provider) = resolved { + Some(provider) + } else { + // No provider configured, prompt the user + let provider = oss_selection::select_oss_provider(&codex_home).await?; + if provider == "__CANCELLED__" { + return Err(std::io::Error::other( + "OSS provider selection was cancelled by user", + )); + } + Some(provider) + } + } else { + None + }; + + // When using `--oss`, let the bootstrapper pick the model based on selected provider let model = if let Some(model) = &cli.model { Some(model.clone()) } else if cli.oss { - Some(DEFAULT_OSS_MODEL.to_owned()) + // Use the provider from model_provider_override + model_provider_override + .as_ref() + .and_then(|provider_id| get_default_model_for_oss_provider(provider_id)) + .map(std::borrow::ToOwned::to_owned) } else { None // No model specified, will use the default. }; - let model_provider_override = if cli.oss { - Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) - } else { - None - }; - // canonicalize the cwd let cwd = cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)); let additional_dirs = cli.add_dir.clone(); @@ -149,7 +207,7 @@ pub async fn run_main( approval_policy, sandbox_mode, cwd, - model_provider: model_provider_override, + model_provider: model_provider_override.clone(), config_profile: cli.config_profile.clone(), codex_linux_sandbox_exe, base_instructions: None, @@ -161,16 +219,6 @@ pub async fn run_main( experimental_sandbox_command_assessment: None, additional_writable_roots: additional_dirs, }; - let raw_overrides = cli.config_overrides.raw_overrides.clone(); - let overrides_cli = codex_common::CliConfigOverrides { raw_overrides }; - let cli_kv_overrides = match overrides_cli.parse_overrides() { - Ok(v) => v, - #[allow(clippy::print_stderr)] - Err(e) => { - eprintln!("Error parsing -c overrides: {e}"); - std::process::exit(1); - } - }; let config = load_config_or_exit(cli_kv_overrides.clone(), overrides.clone()).await; @@ -232,10 +280,19 @@ pub async fn run_main( .with_target(false) .with_filter(targets); - if cli.oss { - codex_ollama::ensure_oss_ready(&config) - .await - .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?; + if cli.oss && model_provider_override.is_some() { + // We're in the oss section, so provider_id should be Some + // Let's handle None case gracefully though just in case + let provider_id = match model_provider_override.as_ref() { + Some(id) => id, + None => { + error!("OSS provider unexpectedly not set when oss flag is used"); + return Err(std::io::Error::other( + "OSS provider not set but oss flag was used", + )); + } + }; + ensure_oss_provider_ready(provider_id, &config).await?; } let otel = codex_core::otel_init::build_provider(&config, env!("CARGO_PKG_VERSION")); diff --git a/codex-rs/tui/src/oss_selection.rs b/codex-rs/tui/src/oss_selection.rs new file mode 100644 index 0000000000..eb1ca18231 --- /dev/null +++ b/codex-rs/tui/src/oss_selection.rs @@ -0,0 +1,369 @@ +use std::io; +use std::sync::LazyLock; + +use codex_core::DEFAULT_LMSTUDIO_PORT; +use codex_core::DEFAULT_OLLAMA_PORT; +use codex_core::LMSTUDIO_OSS_PROVIDER_ID; +use codex_core::OLLAMA_OSS_PROVIDER_ID; +use codex_core::config::set_default_oss_provider; +use crossterm::event::Event; +use crossterm::event::KeyCode; +use crossterm::event::KeyEvent; +use crossterm::event::KeyEventKind; +use crossterm::event::{self}; +use crossterm::execute; +use crossterm::terminal::EnterAlternateScreen; +use crossterm::terminal::LeaveAlternateScreen; +use crossterm::terminal::disable_raw_mode; +use crossterm::terminal::enable_raw_mode; +use ratatui::Terminal; +use ratatui::backend::CrosstermBackend; +use ratatui::buffer::Buffer; +use ratatui::layout::Alignment; +use ratatui::layout::Constraint; +use ratatui::layout::Direction; +use ratatui::layout::Layout; +use ratatui::layout::Margin; +use ratatui::layout::Rect; +use ratatui::prelude::*; +use ratatui::style::Color; +use ratatui::style::Modifier; +use ratatui::style::Style; +use ratatui::text::Line; +use ratatui::text::Span; +use ratatui::widgets::Paragraph; +use ratatui::widgets::Widget; +use ratatui::widgets::WidgetRef; +use ratatui::widgets::Wrap; +use std::time::Duration; + +#[derive(Clone)] +struct ProviderOption { + name: String, + status: ProviderStatus, +} + +#[derive(Clone)] +enum ProviderStatus { + Running, + NotRunning, + Unknown, +} + +/// Options displayed in the *select* mode. +/// +/// The `key` is matched case-insensitively. +struct SelectOption { + label: Line<'static>, + description: &'static str, + key: KeyCode, + provider_id: &'static str, +} + +static OSS_SELECT_OPTIONS: LazyLock> = LazyLock::new(|| { + vec![ + SelectOption { + label: Line::from(vec!["L".underlined(), "M Studio".into()]), + description: "Local LM Studio server (default port 1234)", + key: KeyCode::Char('l'), + provider_id: LMSTUDIO_OSS_PROVIDER_ID, + }, + SelectOption { + label: Line::from(vec!["O".underlined(), "llama".into()]), + description: "Local Ollama server (default port 11434)", + key: KeyCode::Char('o'), + provider_id: OLLAMA_OSS_PROVIDER_ID, + }, + ] +}); + +pub struct OssSelectionWidget<'a> { + select_options: &'a Vec, + confirmation_prompt: Paragraph<'a>, + + /// Currently selected index in *select* mode. + selected_option: usize, + + /// Set to `true` once a decision has been sent – the parent view can then + /// remove this widget from its queue. + done: bool, + + selection: Option, +} + +impl OssSelectionWidget<'_> { + fn new(lmstudio_status: ProviderStatus, ollama_status: ProviderStatus) -> io::Result { + let providers = vec![ + ProviderOption { + name: "LM Studio".to_string(), + status: lmstudio_status, + }, + ProviderOption { + name: "Ollama".to_string(), + status: ollama_status, + }, + ]; + + let mut contents: Vec = vec![ + Line::from(vec![ + "? ".fg(Color::Blue), + "Select an open-source provider".bold(), + ]), + Line::from(""), + Line::from(" Choose which local AI server to use for your session."), + Line::from(""), + ]; + + // Add status indicators for each provider + for provider in &providers { + let (status_symbol, status_color) = get_status_symbol_and_color(&provider.status); + contents.push(Line::from(vec![ + Span::raw(" "), + Span::styled(status_symbol, Style::default().fg(status_color)), + Span::raw(format!(" {} ", provider.name)), + ])); + } + contents.push(Line::from("")); + contents.push(Line::from(" ● Running ○ Not Running").add_modifier(Modifier::DIM)); + + contents.push(Line::from("")); + contents.push( + Line::from(" Press Enter to select • Ctrl+C to exit").add_modifier(Modifier::DIM), + ); + + let confirmation_prompt = Paragraph::new(contents).wrap(Wrap { trim: false }); + + Ok(Self { + select_options: &OSS_SELECT_OPTIONS, + confirmation_prompt, + selected_option: 0, + done: false, + selection: None, + }) + } + + fn get_confirmation_prompt_height(&self, width: u16) -> u16 { + // Should cache this for last value of width. + self.confirmation_prompt.line_count(width) as u16 + } + + /// Process a `KeyEvent` coming from crossterm. Always consumes the event + /// while the modal is visible. + /// Process a key event originating from crossterm. As the modal fully + /// captures input while visible, we don't need to report whether the event + /// was consumed—callers can assume it always is. + pub fn handle_key_event(&mut self, key: KeyEvent) -> Option { + if key.kind == KeyEventKind::Press { + self.handle_select_key(key); + } + if self.done { + self.selection.clone() + } else { + None + } + } + + /// Normalize a key for comparison. + /// - For `KeyCode::Char`, converts to lowercase for case-insensitive matching. + /// - Other key codes are returned unchanged. + fn normalize_keycode(code: KeyCode) -> KeyCode { + match code { + KeyCode::Char(c) => KeyCode::Char(c.to_ascii_lowercase()), + other => other, + } + } + + fn handle_select_key(&mut self, key_event: KeyEvent) { + match key_event.code { + KeyCode::Char('c') + if key_event + .modifiers + .contains(crossterm::event::KeyModifiers::CONTROL) => + { + self.send_decision("__CANCELLED__".to_string()); + } + KeyCode::Left => { + self.selected_option = (self.selected_option + self.select_options.len() - 1) + % self.select_options.len(); + } + KeyCode::Right => { + self.selected_option = (self.selected_option + 1) % self.select_options.len(); + } + KeyCode::Enter => { + let opt = &self.select_options[self.selected_option]; + self.send_decision(opt.provider_id.to_string()); + } + KeyCode::Esc => { + self.send_decision(LMSTUDIO_OSS_PROVIDER_ID.to_string()); + } + other => { + let normalized = Self::normalize_keycode(other); + if let Some(opt) = self + .select_options + .iter() + .find(|opt| Self::normalize_keycode(opt.key) == normalized) + { + self.send_decision(opt.provider_id.to_string()); + } + } + } + } + + fn send_decision(&mut self, selection: String) { + self.selection = Some(selection); + self.done = true; + } + + /// Returns `true` once the user has made a decision and the widget no + /// longer needs to be displayed. + pub fn is_complete(&self) -> bool { + self.done + } + + pub fn desired_height(&self, width: u16) -> u16 { + self.get_confirmation_prompt_height(width) + self.select_options.len() as u16 + } +} + +impl WidgetRef for &OssSelectionWidget<'_> { + fn render_ref(&self, area: Rect, buf: &mut Buffer) { + let prompt_height = self.get_confirmation_prompt_height(area.width); + let [prompt_chunk, response_chunk] = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Length(prompt_height), Constraint::Min(0)]) + .areas(area); + + let lines: Vec = self + .select_options + .iter() + .enumerate() + .map(|(idx, opt)| { + let style = if idx == self.selected_option { + Style::new().bg(Color::Cyan).fg(Color::Black) + } else { + Style::new().bg(Color::DarkGray) + }; + opt.label.clone().alignment(Alignment::Center).style(style) + }) + .collect(); + + let [title_area, button_area, description_area] = Layout::vertical([ + Constraint::Length(1), + Constraint::Length(1), + Constraint::Min(0), + ]) + .areas(response_chunk.inner(Margin::new(1, 0))); + + Line::from("Select provider?").render(title_area, buf); + + self.confirmation_prompt.clone().render(prompt_chunk, buf); + let areas = Layout::horizontal( + lines + .iter() + .map(|l| Constraint::Length(l.width() as u16 + 2)), + ) + .spacing(1) + .split(button_area); + for (idx, area) in areas.iter().enumerate() { + let line = &lines[idx]; + line.render(*area, buf); + } + + Line::from(self.select_options[self.selected_option].description) + .style(Style::new().italic().fg(Color::DarkGray)) + .render(description_area.inner(Margin::new(1, 0)), buf); + } +} + +fn get_status_symbol_and_color(status: &ProviderStatus) -> (&'static str, Color) { + match status { + ProviderStatus::Running => ("●", Color::Green), + ProviderStatus::NotRunning => ("○", Color::Red), + ProviderStatus::Unknown => ("?", Color::Yellow), + } +} + +pub async fn select_oss_provider(codex_home: &std::path::Path) -> io::Result { + // Check provider statuses first + let lmstudio_status = check_lmstudio_status().await; + let ollama_status = check_ollama_status().await; + + // Autoselect if only one is running + match (&lmstudio_status, &ollama_status) { + (ProviderStatus::Running, ProviderStatus::NotRunning) => { + let provider = LMSTUDIO_OSS_PROVIDER_ID.to_string(); + return Ok(provider); + } + (ProviderStatus::NotRunning, ProviderStatus::Running) => { + let provider = OLLAMA_OSS_PROVIDER_ID.to_string(); + return Ok(provider); + } + _ => { + // Both running or both not running - show UI + } + } + + let mut widget = OssSelectionWidget::new(lmstudio_status, ollama_status)?; + + enable_raw_mode()?; + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen)?; + + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::new(backend)?; + + let result = loop { + terminal.draw(|f| { + (&widget).render_ref(f.area(), f.buffer_mut()); + })?; + + if let Event::Key(key_event) = event::read()? + && let Some(selection) = widget.handle_key_event(key_event) + { + break Ok(selection); + } + }; + + disable_raw_mode()?; + execute!(terminal.backend_mut(), LeaveAlternateScreen)?; + + // If the user manually selected an OSS provider, we save it as the + // default one to use later. + if let Ok(ref provider) = result + && let Err(e) = set_default_oss_provider(codex_home, provider) + { + tracing::warn!("Failed to save OSS provider preference: {e}"); + } + + result +} + +async fn check_lmstudio_status() -> ProviderStatus { + match check_port_status(DEFAULT_LMSTUDIO_PORT).await { + Ok(true) => ProviderStatus::Running, + Ok(false) => ProviderStatus::NotRunning, + Err(_) => ProviderStatus::Unknown, + } +} + +async fn check_ollama_status() -> ProviderStatus { + match check_port_status(DEFAULT_OLLAMA_PORT).await { + Ok(true) => ProviderStatus::Running, + Ok(false) => ProviderStatus::NotRunning, + Err(_) => ProviderStatus::Unknown, + } +} + +async fn check_port_status(port: u16) -> io::Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .build() + .map_err(io::Error::other)?; + + let url = format!("http://localhost:{port}"); + + match client.get(&url).send().await { + Ok(response) => Ok(response.status().is_success()), + Err(_) => Ok(false), // Connection failed = not running + } +} diff --git a/docs/config.md b/docs/config.md index 3e7b7e165e..ed5aed6fc7 100644 --- a/docs/config.md +++ b/docs/config.md @@ -253,6 +253,20 @@ This is analogous to `model_context_window`, but for the maximum number of outpu > See also [`codex exec`](./exec.md) to see how these model settings influence non-interactive runs. +### oss_provider + +Specifies the default OSS provider to use when running Codex. This is used when the `--oss` flag is provided without a specific provider. + +Valid values are: + +- `"lmstudio"` - Use LM Studio as the local model provider +- `"ollama"` - Use Ollama as the local model provider + +```toml +# Example: Set default OSS provider to LM Studio +oss_provider = "lmstudio" +``` + ## Execution environment ### approval_policy