Compare commits

...

12 Commits

Author SHA1 Message Date
pap
5225cec3a1 add a 5s timeout on ollama queries 2025-08-04 19:20:16 +01:00
pap
c2cf4a3cb9 adding tests ; linting 2025-08-04 18:40:58 +01:00
pap
304d01c099 ollama provider broke down in multiple files 2025-08-04 18:10:46 +01:00
pap
5df778471c fixing dep alphabetical 2025-08-04 17:51:27 +01:00
pap
3c97fc1423 fix clippy 2025-08-04 17:48:07 +01:00
pap
d5c48cd049 fmt 2025-08-04 17:48:06 +01:00
pap
42c726be79 adding tests 2025-08-04 17:48:06 +01:00
pap
b2ed15430f code cleaning, toml edit helpers 2025-08-04 17:48:06 +01:00
pap
75eec73fcc refactor into providers 2025-08-04 17:48:06 +01:00
pap
6b6d2c5e00 fixed ctrl+q/d at config stage and -c model_provider=ollama is similar elsewhere. 2025-08-04 17:48:06 +01:00
pap
47e84d5c05 streaming model download 2025-08-04 17:48:06 +01:00
pap
2cfb2a2265 adding --ollama 2025-08-04 17:48:06 +01:00
21 changed files with 1566 additions and 5 deletions

9
codex-rs/Cargo.lock generated
View File

@@ -671,6 +671,7 @@ dependencies = [
"anyhow",
"assert_cmd",
"async-channel",
"async-stream",
"base64 0.22.1",
"bytes",
"chrono",
@@ -708,6 +709,7 @@ dependencies = [
"tokio-test",
"tokio-util",
"toml 0.9.2",
"toml_edit",
"tracing",
"tree-sitter",
"tree-sitter-bash",
@@ -4814,6 +4816,7 @@ dependencies = [
"serde",
"serde_spanned 0.6.9",
"toml_datetime 0.6.11",
"toml_write",
"winnow",
]
@@ -4826,6 +4829,12 @@ dependencies = [
"winnow",
]
[[package]]
name = "toml_write"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
[[package]]
name = "toml_writer"
version = "1.0.2"

View File

@@ -3,6 +3,7 @@
Codex supports several mechanisms for setting config values:
- Config-specific command-line flags, such as `--model o3` (highest precedence).
- Convenience provider flags, such as `--ollama` (equivalent to `-c model_provider=ollama`).
- A generic `-c`/`--config` flag that takes a `key=value` pair, such as `--config model="o3"`.
- The key can contain dots to set a value deeper than the root, e.g. `--config model_providers.openai.wire_api="chat"`.
- Values can contain objects, such as `--config shell_environment_policy.include_only=["PATH", "HOME", "USER"]`.
@@ -56,6 +57,13 @@ name = "Ollama"
base_url = "http://localhost:11434/v1"
```
Alternatively, you can pass `--ollama` on the CLI, which is equivalent to `-c model_provider=ollama`.
When using `--ollama`, Codex will verify that an Ollama server is running locally and
will create a `[model_providers.ollama]` entry in your `config.toml` with sensible defaults
(`base_url = "http://localhost:11434/v1"`, `wire_api = "chat"`) if one does not already exist.
If no running Ollama server is detected, Codex will print instructions to install/start Ollama
and exit: https://github.com/ollama/ollama?tab=readme-ov-file#ollama
Or a third-party provider (using a distinct environment variable for the API key):
```toml

View File

@@ -13,6 +13,7 @@ workspace = true
[dependencies]
anyhow = "1"
async-channel = "2.3.1"
async-stream = "0.3"
base64 = "0.22"
bytes = "1.10.1"
chrono = { version = "0.4", features = ["serde"] }
@@ -30,8 +31,8 @@ mime_guess = "2.0"
rand = "0.9"
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_bytes = "0.11"
serde_json = "1"
sha1 = "0.10.6"
shlex = "1.3.0"
similar = "2.7.0"
@@ -47,6 +48,7 @@ tokio = { version = "1", features = [
] }
tokio-util = "0.7.14"
toml = "0.9.2"
toml_edit = "0.22"
tracing = { version = "0.1.41", features = ["log"] }
tree-sitter = "0.25.8"
tree-sitter-bash = "0.25.0"

View File

@@ -62,10 +62,14 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
session_id: Uuid,
) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
config,
auth,
client: reqwest::Client::new(),
client,
provider,
session_id,
effort,

View File

@@ -428,6 +428,11 @@ impl Config {
.or(config_profile.model_provider)
.or(cfg.model_provider)
.unwrap_or_else(|| "openai".to_string());
// Do not implicitly inject an Ollama provider when selected via
// `-c model_provider=ollama`. Only the `--ollama` flag path sets up the
// provider entry and performs discovery. This ensures parity with
// other providers: if a provider is not defined in config.toml, we
// return a clear error below.
let model_provider = model_providers
.get(&model_provider_id)
.ok_or_else(|| {

View File

@@ -102,6 +102,20 @@ pub enum CodexErr {
#[error("{0}")]
EnvVar(EnvVarError),
// ------------------------------
// Ollamaspecific errors
// ------------------------------
#[error(
"No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama"
)]
OllamaServerUnreachable,
#[error("ollama model not found: {0}")]
OllamaModelNotFound(String),
#[error("ollama pull failed: {0}")]
OllamaPullFailed(String),
}
#[derive(Debug)]

View File

@@ -37,6 +37,7 @@ mod openai_tools;
pub mod plan_tool;
mod project_doc;
pub mod protocol;
pub mod providers;
mod rollout;
pub(crate) mod safety;
pub mod seatbelt;

View File

@@ -0,0 +1 @@
pub mod ollama;

View File

@@ -0,0 +1,259 @@
use bytes::BytesMut;
use futures::StreamExt;
use futures::stream::BoxStream;
use serde_json::Value as JsonValue;
use std::collections::VecDeque;
use std::io;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use super::DEFAULT_BASE_URL;
use super::PullEvent;
use super::PullProgressReporter;
use super::parser::pull_events_from_value;
use super::url::base_url_to_host_root;
use super::url::is_openai_compatible_base_url;
/// Client for interacting with a local Ollama instance.
pub struct OllamaClient {
client: reqwest::Client,
host_root: String,
uses_openai_compat: bool,
}
impl OllamaClient {
/// Build a client from a provider definition. Falls back to the default
/// local URL if no base_url is configured.
pub fn from_provider(provider: &ModelProviderInfo) -> Self {
let base_url = provider
.base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
let uses_openai_compat = is_openai_compatible_base_url(&base_url)
|| matches!(provider.wire_api, WireApi::Chat)
&& is_openai_compatible_base_url(&base_url);
let host_root = base_url_to_host_root(&base_url);
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
host_root,
uses_openai_compat,
}
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
pub fn from_host_root(host_root: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
host_root: host_root.into(),
uses_openai_compat: false,
}
}
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
pub async fn probe_server(&self) -> io::Result<bool> {
let url = if self.uses_openai_compat {
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
} else {
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
};
let resp = self.client.get(url).send().await;
Ok(matches!(resp, Ok(r) if r.status().is_success()))
}
/// Return the list of model names known to the local Ollama instance.
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
let resp = self
.client
.get(tags_url)
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Ok(Vec::new());
}
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
let names = val
.get("models")
.and_then(|m| m.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
Ok(names)
}
/// Start a model pull and emit streaming events. The returned stream ends when
/// a Success event is observed or the server closes the connection.
pub async fn pull_model_stream(
&self,
model: &str,
) -> io::Result<BoxStream<'static, PullEvent>> {
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
let resp = self
.client
.post(url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Err(io::Error::other(format!(
"failed to start pull: HTTP {}",
resp.status()
)));
}
let mut stream = resp.bytes_stream();
let mut buf = BytesMut::new();
let _pending: VecDeque<PullEvent> = VecDeque::new();
// Using an async stream adaptor backed by unfold-like manual loop.
let s = async_stream::stream! {
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
buf.extend_from_slice(&bytes);
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() { continue; }
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
for ev in pull_events_from_value(&value) { yield ev; }
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
yield PullEvent::Status(format!("error: {err_msg}"));
return;
}
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
if status == "success" { yield PullEvent::Success; return; }
}
}
}
}
}
Err(_) => {
// Connection error: end the stream.
return;
}
}
}
};
Ok(Box::pin(s))
}
/// High-level helper to pull a model and drive a progress reporter.
pub async fn pull_with_reporter(
&self,
model: &str,
reporter: &mut dyn PullProgressReporter,
) -> io::Result<()> {
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
let mut stream = self.pull_model_stream(model).await?;
while let Some(event) = stream.next().await {
reporter.on_event(&event)?;
if matches!(event, PullEvent::Success) {
break;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
#[tokio::test]
async fn test_fetch_models_happy_path() {
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_fetch_models_happy_path",
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/api/tags"))
.respond_with(
wiremock::ResponseTemplate::new(200).set_body_raw(
serde_json::json!({
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
})
.to_string(),
"application/json",
),
)
.mount(&server)
.await;
let client = OllamaClient::from_host_root(server.uri());
let models = client.fetch_models().await.expect("fetch models");
assert!(models.contains(&"llama3.2:3b".to_string()));
assert!(models.contains(&"mistral".to_string()));
}
#[tokio::test]
async fn test_probe_server_happy_path_openai_compat_and_native() {
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
// Native endpoint
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/api/tags"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
let native = OllamaClient::from_host_root(server.uri());
assert!(native.probe_server().await.expect("probe native"));
// OpenAI compatibility endpoint
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/v1/models"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = ModelProviderInfo {
name: "Ollama".to_string(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_auth: false,
};
let compat = OllamaClient::from_provider(&provider);
assert!(compat.probe_server().await.expect("probe compat"));
}
}

View File

@@ -0,0 +1,243 @@
use std::io;
use std::path::Path;
use std::str::FromStr;
use toml_edit::DocumentMut as Document;
use toml_edit::Item;
use toml_edit::Table;
use toml_edit::Value as TomlValueEdit;
use super::DEFAULT_BASE_URL;
/// Read the list of models recorded under [model_providers.ollama].models.
pub fn read_ollama_models_list(config_path: &Path) -> Vec<String> {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
})
.unwrap_or_default(),
_ => Vec::new(),
}
}
/// Convenience wrapper that returns the models list as an io::Result for callers
/// that want a uniform Result-based API.
pub fn read_config_models(config_path: &Path) -> io::Result<Vec<String>> {
Ok(read_ollama_models_list(config_path))
}
/// Overwrite the recorded models list under [model_providers.ollama].models using toml_edit.
pub fn write_ollama_models_list(config_path: &Path, models: &[String]) -> io::Result<()> {
let mut doc = read_document(config_path)?;
{
let tbl = upsert_provider_ollama(&mut doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
write_document(config_path, &doc)
}
/// Write models list via a uniform name expected by higher layers.
pub fn write_config_models(config_path: &Path, models: &[String]) -> io::Result<()> {
write_ollama_models_list(config_path, models)
}
/// Ensure `[model_providers.ollama]` exists with sensible defaults on disk.
/// Returns true if it created/updated the entry.
pub fn ensure_ollama_provider_entry(codex_home: &Path) -> io::Result<bool> {
let config_path = codex_home.join("config.toml");
let mut doc = read_document(&config_path)?;
let before = doc.to_string();
let _tbl = upsert_provider_ollama(&mut doc);
let after = doc.to_string();
if before != after {
write_document(&config_path, &doc)?;
Ok(true)
} else {
Ok(false)
}
}
/// Alias name mirroring the refactor plan wording.
pub fn ensure_provider_entry_and_defaults(codex_home: &Path) -> io::Result<bool> {
ensure_ollama_provider_entry(codex_home)
}
/// Read whether the provider exists and how many models are recorded under it.
pub fn read_provider_state(config_path: &Path) -> (bool, usize) {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => {
let provider_present = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.map(|_| true)
.unwrap_or(false);
let models_count = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| arr.len())
.unwrap_or(0);
(provider_present, models_count)
}
_ => (false, 0),
}
}
// ---------- toml_edit helpers ----------
fn read_document(path: &Path) -> io::Result<Document> {
match std::fs::read_to_string(path) {
Ok(s) => Document::from_str(&s).map_err(io::Error::other),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Document::new()),
Err(e) => Err(e),
}
}
fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, doc.to_string())
}
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
// Ensure "model_providers" exists and is a table.
let needs_init = match doc.get("model_providers") {
None => true,
Some(item) => !item.is_table(),
};
if needs_init {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
// Now, get a mutable reference to the "model_providers" table without `expect`/`unwrap`.
let providers: &mut Table = {
// Insert if missing.
if doc.as_table().get("model_providers").is_none() {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
match doc.as_table_mut().get_mut("model_providers") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("model_providers was set to a table"),
}
}
None => unreachable!("model_providers should exist after insertion"),
}
};
// Ensure "ollama" exists and is a table.
let needs_ollama_init = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_ollama_init {
providers.insert("ollama", Item::Table(Table::new()));
}
// Get a mutable reference to the "ollama" table without `expect`/`unwrap`.
let tbl: &mut Table = {
let needs_set = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_set {
providers.insert("ollama", Item::Table(Table::new()));
}
match providers.get_mut("ollama") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("ollama was set to a table"),
}
}
None => unreachable!("ollama should exist after insertion"),
}
};
if !tbl.contains_key("name") {
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
}
if !tbl.contains_key("base_url") {
tbl["base_url"] = Item::Value(TomlValueEdit::from(DEFAULT_BASE_URL));
}
if !tbl.contains_key("wire_api") {
tbl["wire_api"] = Item::Value(TomlValueEdit::from("chat"));
}
tbl
}
pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
let tbl = upsert_provider_ollama(doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
#[cfg(test)]
mod tests {
use super::*;
use toml_edit::DocumentMut as Document;
#[test]
fn test_upsert_provider_and_models() {
let mut doc = Document::new();
let tbl = upsert_provider_ollama(&mut doc);
assert!(tbl.contains_key("name"));
assert!(tbl.contains_key("base_url"));
assert!(tbl.contains_key("wire_api"));
set_ollama_models(&mut doc, &[String::from("llama3.2:3b")]);
let root = doc.as_table();
let mp = match root.get("model_providers").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("model_providers"),
};
let ollama = match mp.get("ollama").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("ollama"),
};
let arr = match ollama.get("models") {
Some(v) => v,
None => panic!("models array"),
};
assert!(arr.is_array(), "models should be an array");
let s = doc.to_string();
assert!(s.contains("model_providers"));
assert!(s.contains("ollama"));
assert!(s.contains("models"));
}
}

View File

@@ -0,0 +1,291 @@
use crate::error::CodexErr;
use crate::error::Result as CoreResult;
use std::collections::HashMap;
use std::io;
use std::io::Write;
use std::path::Path;
pub const DEFAULT_BASE_URL: &str = "http://localhost:11434/v1";
pub const DEFAULT_WIRE_API: crate::model_provider_info::WireApi =
crate::model_provider_info::WireApi::Chat;
pub const DEFAULT_PULL_ALLOWLIST: &[&str] = &["llama3.2:3b"];
pub mod client;
pub mod config;
pub mod parser;
pub mod url;
pub use client::OllamaClient;
pub use config::read_config_models;
pub use config::read_provider_state;
pub use config::write_config_models;
pub use url::base_url_to_host_root;
pub use url::base_url_to_host_root_with_wire;
pub use url::probe_ollama_server;
pub use url::probe_url_for_base;
/// Coordinator wrapper used by frontends when responding to `--ollama`.
///
/// - Probes the server using the configured base_url when present, otherwise
/// falls back to DEFAULT_BASE_URL.
/// - If the server is reachable, ensures an `[model_providers.ollama]` entry
/// exists in `config.toml` with sensible defaults.
/// - If no server is reachable, returns an error.
pub async fn ensure_configured_and_running() -> CoreResult<()> {
use crate::config::find_codex_home;
use toml::Value as TomlValue;
let codex_home = find_codex_home()?;
let config_path = codex_home.join("config.toml");
// Try to read a configured base_url if present.
let base_url = match std::fs::read_to_string(&config_path) {
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
Ok(TomlValue::Table(root)) => root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("base_url"))
.and_then(|v| v.as_str())
.unwrap_or(DEFAULT_BASE_URL)
.to_string(),
_ => DEFAULT_BASE_URL.to_string(),
},
Err(_) => DEFAULT_BASE_URL.to_string(),
};
// Probe reachability; map any probe error to a friendly unreachable message.
let ok: bool = url::probe_ollama_server(&base_url)
.await
.unwrap_or_default();
if !ok {
return Err(CodexErr::OllamaServerUnreachable);
}
// Ensure provider entry exists with defaults.
let _ = config::ensure_ollama_provider_entry(&codex_home)?;
Ok(())
}
#[cfg(test)]
mod ensure_tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
#[tokio::test]
async fn test_ensure_configured_returns_friendly_error_when_unreachable() {
// Skip in CI sandbox environments without network to avoid false negatives.
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_ensure_configured_returns_friendly_error_when_unreachable",
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let tmpdir = tempfile::TempDir::new().expect("tempdir");
let config_path = tmpdir.path().join("config.toml");
std::fs::create_dir_all(tmpdir.path()).unwrap();
std::fs::write(
&config_path,
r#"[model_providers.ollama]
name = "Ollama"
base_url = "http://127.0.0.1:1/v1"
wire_api = "chat"
"#,
)
.unwrap();
unsafe {
std::env::set_var("CODEX_HOME", tmpdir.path());
}
let err = ensure_configured_and_running()
.await
.expect_err("should report unreachable server as friendly error");
assert!(matches!(err, CodexErr::OllamaServerUnreachable));
}
}
/// Events emitted while pulling a model from Ollama.
#[derive(Debug, Clone)]
pub enum PullEvent {
/// A human-readable status message (e.g., "verifying", "writing").
Status(String),
/// Byte-level progress update for a specific layer digest.
ChunkProgress {
digest: String,
total: Option<u64>,
completed: Option<u64>,
},
/// The pull finished successfully.
Success,
}
/// A simple observer for pull progress events. Implementations decide how to
/// render progress (CLI, TUI, logs, ...).
pub trait PullProgressReporter {
fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
}
/// A minimal CLI reporter that writes inline progress to stderr.
pub struct CliProgressReporter {
printed_header: bool,
last_line_len: usize,
last_completed_sum: u64,
last_instant: std::time::Instant,
totals_by_digest: HashMap<String, (u64, u64)>,
}
impl Default for CliProgressReporter {
fn default() -> Self {
Self::new()
}
}
impl CliProgressReporter {
pub fn new() -> Self {
Self {
printed_header: false,
last_line_len: 0,
last_completed_sum: 0,
last_instant: std::time::Instant::now(),
totals_by_digest: HashMap::new(),
}
}
}
impl PullProgressReporter for CliProgressReporter {
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
let mut out = std::io::stderr();
match event {
PullEvent::Status(status) => {
// Avoid noisy manifest messages; otherwise show status inline.
if status.eq_ignore_ascii_case("pulling manifest") {
return Ok(());
}
let pad = self.last_line_len.saturating_sub(status.len());
let line = format!("\r{status}{}", " ".repeat(pad));
self.last_line_len = status.len();
out.write_all(line.as_bytes())?;
out.flush()
}
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
if let Some(t) = *total {
self.totals_by_digest
.entry(digest.clone())
.or_insert((0, 0))
.0 = t;
}
if let Some(c) = *completed {
self.totals_by_digest
.entry(digest.clone())
.or_insert((0, 0))
.1 = c;
}
let (sum_total, sum_completed) = self
.totals_by_digest
.values()
.fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
if sum_total > 0 {
if !self.printed_header {
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let header = format!("Downloading model: total {gb:.2} GB\n");
out.write_all(b"\r\x1b[2K")?;
out.write_all(header.as_bytes())?;
self.printed_header = true;
}
let now = std::time::Instant::now();
let dt = now
.duration_since(self.last_instant)
.as_secs_f64()
.max(0.001);
let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
self.last_completed_sum = sum_completed;
self.last_instant = now;
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
let text =
format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
let pad = self.last_line_len.saturating_sub(text.len());
let line = format!("\r{text}{}", " ".repeat(pad));
self.last_line_len = text.len();
out.write_all(line.as_bytes())?;
out.flush()
} else {
Ok(())
}
}
PullEvent::Success => {
out.write_all(b"\n")?;
out.flush()
}
}
}
}
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
/// CLI behavior aligned until a dedicated TUI integration is implemented.
#[derive(Default)]
pub struct TuiProgressReporter(CliProgressReporter);
impl TuiProgressReporter {
pub fn new() -> Self {
Default::default()
}
}
impl PullProgressReporter for TuiProgressReporter {
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
self.0.on_event(event)
}
}
/// Ensure a model is available locally.
///
/// - If the model is already present, ensure it is recorded in config.toml.
/// - If missing and in the default allowlist, pull it with streaming progress
/// and record it in config.toml after success.
/// - If missing and not allowlisted, return an error.
pub async fn ensure_model_available(
model: &str,
client: &OllamaClient,
config_path: &Path,
reporter: &mut dyn PullProgressReporter,
) -> CoreResult<()> {
let mut listed = config::read_ollama_models_list(config_path);
let available = client.fetch_models().await.unwrap_or_default();
if available.iter().any(|m| m == model) {
if !listed.iter().any(|m| m == model) {
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = config::write_ollama_models_list(config_path, &listed);
}
return Ok(());
}
if !DEFAULT_PULL_ALLOWLIST.contains(&model) {
return Err(CodexErr::OllamaModelNotFound(model.to_string()));
}
loop {
let _ = client.pull_with_reporter(model, reporter).await;
// After the stream completes (success or early EOF), check again.
let available = client.fetch_models().await.unwrap_or_default();
if available.iter().any(|m| m == model) {
break;
}
// Keep waiting for the model to finish downloading.
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = config::write_ollama_models_list(config_path, &listed);
Ok(())
}

View File

@@ -0,0 +1,82 @@
use serde_json::Value as JsonValue;
use super::PullEvent;
// Convert a single JSON object representing a pull update into one or more events.
pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
let mut events = Vec::new();
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
events.push(PullEvent::Status(status.to_string()));
if status == "success" {
events.push(PullEvent::Success);
}
}
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if total.is_some() || completed.is_some() {
events.push(PullEvent::ChunkProgress {
digest,
total,
completed,
});
}
events
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pull_events_decoder_status_and_success() {
let v: JsonValue = serde_json::json!({"status":"verifying"});
let events = pull_events_from_value(&v);
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
let v2: JsonValue = serde_json::json!({"status":"success"});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 2);
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
assert!(matches!(events2[1], PullEvent::Success));
}
#[test]
fn test_pull_events_decoder_progress() {
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
let events = pull_events_from_value(&v);
assert_eq!(events.len(), 1);
match &events[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:abc");
assert_eq!(*total, Some(100));
assert_eq!(*completed, None);
}
_ => panic!("expected ChunkProgress"),
}
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 1);
match &events2[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:def");
assert_eq!(*total, None);
assert_eq!(*completed, Some(42));
}
_ => panic!("expected ChunkProgress"),
}
}
}

View File

@@ -0,0 +1,83 @@
use crate::error::Result as CoreResult;
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool {
base_url.trim_end_matches('/').ends_with("/v1")
}
/// Convert a provider base_url into the native Ollama host root.
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
pub fn base_url_to_host_root(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
if trimmed.ends_with("/v1") {
trimmed
.trim_end_matches("/v1")
.trim_end_matches('/')
.to_string()
} else {
trimmed.to_string()
}
}
/// Variant that considers an explicit WireApi value; provided to centralize
/// host root computation in one place for future extension.
pub fn base_url_to_host_root_with_wire(
base_url: &str,
_wire_api: crate::model_provider_info::WireApi,
) -> String {
base_url_to_host_root(base_url)
}
/// Compute the probe URL to verify if an Ollama server is reachable.
/// If the configured base is OpenAI-compatible (/v1), probe "models", otherwise
/// fall back to the native "/api/tags" endpoint.
pub fn probe_url_for_base(base_url: &str) -> String {
if is_openai_compatible_base_url(base_url) {
format!("{}/models", base_url.trim_end_matches('/'))
} else {
format!("{}/api/tags", base_url.trim_end_matches('/'))
}
}
/// Convenience helper to probe an Ollama server given a provider style base URL.
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
let url = probe_url_for_base(base_url);
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()?;
let resp = client.get(url).send().await?;
Ok(resp.status().is_success())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base_url_to_host_root() {
assert_eq!(
base_url_to_host_root("http://localhost:11434/v1"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434/"),
"http://localhost:11434"
);
}
#[test]
fn test_probe_url_for_base() {
assert_eq!(
probe_url_for_base("http://localhost:11434/v1"),
"http://localhost:11434/v1/models"
);
assert_eq!(
probe_url_for_base("http://localhost:11434"),
"http://localhost:11434/api/tags"
);
}
}

View File

@@ -14,6 +14,12 @@ pub struct Cli {
#[arg(long, short = 'm')]
pub model: Option<String>,
/// Convenience flag to select the local Ollama provider.
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
/// creates a model_providers.ollama entry in config.toml if missing.
#[arg(long = "ollama", default_value_t = false)]
pub ollama: bool,
/// Select the sandbox policy to use when executing model-generated shell
/// commands.
#[arg(long = "sandbox", short = 's')]

View File

@@ -31,10 +31,17 @@ use tracing_subscriber::EnvFilter;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
// Shared Ollama helpers are centralized in codex_core::providers::ollama.
use codex_core::providers::ollama::CliProgressReporter;
use codex_core::providers::ollama::OllamaClient;
use codex_core::providers::ollama::ensure_configured_and_running;
use codex_core::providers::ollama::ensure_model_available;
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
let Cli {
images,
model,
ollama,
config_profile,
full_auto,
dangerously_bypass_approvals_and_sandbox,
@@ -48,6 +55,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
config_overrides,
} = cli;
// Track whether the user explicitly provided a model via --model.
let user_specified_model = model.is_some();
// Determine the prompt based on CLI arg and/or stdin.
let prompt = match prompt {
Some(p) if p != "-" => p,
@@ -114,6 +124,16 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
sandbox_mode_cli_arg.map(Into::<SandboxMode>::into)
};
// When the user opts into the Ollama provider via `--ollama`, ensure we
// have a configured provider entry and that a local server is running.
if ollama {
if let Err(e) = ensure_configured_and_running().await {
tracing::error!("{e}");
eprintln!("{e}");
std::process::exit(1);
}
}
// Load configuration and determine approval policy
let overrides = ConfigOverrides {
model,
@@ -123,7 +143,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
approval_policy: Some(AskForApproval::Never),
sandbox_mode,
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
model_provider: None,
model_provider: if ollama {
Some("ollama".to_string())
} else {
None
},
codex_linux_sandbox_exe,
base_instructions: None,
include_plan_tool: None,
@@ -138,6 +162,22 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
};
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
// If the user passed both --ollama and --model, ensure the requested model
// is present locally or pull it automatically (subject to allowlist).
if ollama && user_specified_model {
let model_name = config.model.clone();
let client = OllamaClient::from_provider(&config.model_provider);
let config_path = config.codex_home.join("config.toml");
let mut reporter = CliProgressReporter::new();
if let Err(e) =
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
{
tracing::error!("{e}");
eprintln!("{e}");
std::process::exit(1);
}
}
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
} else {

View File

@@ -67,7 +67,6 @@ unicode-width = "0.1"
uuid = "1"
[dev-dependencies]
insta = "1.43.1"
pretty_assertions = "1"

View File

@@ -43,6 +43,7 @@ enum AppState<'a> {
},
/// The start-up warning that recommends running codex inside a Git repo.
GitWarning { screen: GitWarningScreen },
// (no additional states)
}
pub(crate) struct App<'a> {

View File

@@ -17,6 +17,12 @@ pub struct Cli {
#[arg(long, short = 'm')]
pub model: Option<String>,
/// Convenience flag to select the local Ollama provider.
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
/// creates a model_providers.ollama entry in config.toml if missing.
#[arg(long = "ollama", default_value_t = false)]
pub ollama: bool,
/// Configuration profile from config.toml to specify default options.
#[arg(long = "profile", short = 'p')]
pub config_profile: Option<String>,

View File

@@ -9,9 +9,17 @@ use codex_core::config_types::SandboxMode;
use codex_core::protocol::AskForApproval;
use codex_core::util::is_inside_git_repo;
use codex_login::load_auth;
use crossterm::event::Event as CEvent;
use crossterm::event::KeyCode;
use crossterm::event::KeyEvent;
use crossterm::event::KeyModifiers;
use crossterm::event::{self};
use crossterm::terminal::disable_raw_mode;
use crossterm::terminal::enable_raw_mode;
use log_layer::TuiLogLayer;
use std::fs::OpenOptions;
use std::io::Write;
use std::io::{self};
use std::path::PathBuf;
use tracing::error;
use tracing_appender::non_blocking;
@@ -47,6 +55,253 @@ mod updates;
use color_eyre::owo_colors::OwoColorize;
pub use cli::Cli;
// Centralized Ollama helpers from core
use codex_core::providers::ollama::OllamaClient;
use codex_core::providers::ollama::TuiProgressReporter;
use codex_core::providers::ollama::ensure_configured_and_running;
use codex_core::providers::ollama::ensure_model_available;
use codex_core::providers::ollama::read_config_models;
use codex_core::providers::ollama::read_provider_state;
use codex_core::providers::ollama::write_config_models;
fn print_inline_message_no_models(
host_root: &str,
config_path: &std::path::Path,
provider_was_present_before: bool,
) -> io::Result<()> {
let mut out = std::io::stdout();
let path = config_path.display().to_string();
// green bold helper
let b = |s: &str| format!("\x1b[1m{s}\x1b[0m");
// Ensure we start clean at column 0.
out.write_all(b"\r\x1b[2K")?;
out.write_all(
format!(
"{}\n\n",
b("we've discovered no models on your local Ollama instance.")
)
.as_bytes(),
)?;
out.write_all(format!("\rendpoint: {host_root}\n").as_bytes())?;
if provider_was_present_before {
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
} else {
out.write_all(
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
)?;
}
out.write_all(
b"\rmodels: none recorded in config (pull models with `ollama pull <model>`).\n\n",
)?;
out.flush()
}
fn run_inline_models_picker(
host_root: &str,
available: &[String],
preselected: &[String],
config_path: &std::path::Path,
provider_was_present_before: bool,
) -> io::Result<()> {
let mut out = std::io::stdout();
let mut selected: Vec<bool> = available
.iter()
.map(|m| preselected.iter().any(|x| x == m))
.collect();
let mut cursor: usize = 0;
let mut first = true;
let mut lines_printed: usize = 0;
enable_raw_mode()?;
loop {
// Render block
render_inline_picker(
&mut out,
host_root,
available,
&selected,
cursor,
&mut first,
&mut lines_printed,
)?;
// Wait for key
match event::read()? {
CEvent::Key(KeyEvent {
code: KeyCode::Up, ..
})
| CEvent::Key(KeyEvent {
code: KeyCode::Char('k'),
..
}) => {
cursor = cursor.saturating_sub(1);
}
CEvent::Key(KeyEvent {
code: KeyCode::Down,
..
})
| CEvent::Key(KeyEvent {
code: KeyCode::Char('j'),
..
}) => {
if cursor + 1 < available.len() {
cursor += 1;
}
}
CEvent::Key(KeyEvent {
code: KeyCode::Char(' '),
..
}) => {
if let Some(s) = selected.get_mut(cursor) {
*s = !*s;
}
}
CEvent::Key(KeyEvent {
code: KeyCode::Char('a'),
..
}) => {
let all_sel = selected.iter().all(|s| *s);
selected.fill(!all_sel);
}
// Allow quitting the entire app from the inline picker with Ctrl+C or Ctrl+D.
CEvent::Key(KeyEvent {
code: KeyCode::Char('c'),
modifiers,
..
}) if modifiers.contains(KeyModifiers::CONTROL) => {
// Restore terminal state and exit with SIGINT-like code.
disable_raw_mode()?;
// Start on a clean line before exiting.
out.write_all(b"\r\x1b[2K\n")?;
std::process::exit(130);
}
CEvent::Key(KeyEvent {
code: KeyCode::Char('d'),
modifiers,
..
}) if modifiers.contains(KeyModifiers::CONTROL) => {
// Restore terminal state and exit cleanly.
disable_raw_mode()?;
out.write_all(b"\r\x1b[2K\n")?;
std::process::exit(0);
}
CEvent::Key(KeyEvent {
code: KeyCode::Enter,
..
}) => {
break;
}
CEvent::Key(KeyEvent {
code: KeyCode::Char('q'),
..
})
| CEvent::Key(KeyEvent {
code: KeyCode::Esc, ..
}) => {
// Skip saving print summary and continue.
disable_raw_mode()?;
print_config_summary_after_save(config_path, provider_was_present_before, None)?;
return Ok(());
}
_ => {}
}
}
disable_raw_mode()?;
// Ensure the summary starts on a clean, leftaligned new line.
out.write_all(b"\r\x1b[2K\n")?;
// Compute chosen
let chosen: Vec<String> = available
.iter()
.cloned()
.zip(selected.iter())
.filter_map(|(name, sel)| if *sel { Some(name) } else { None })
.collect();
let _ = write_config_models(config_path, &chosen);
print_config_summary_after_save(config_path, provider_was_present_before, Some(chosen.len()))
}
fn render_inline_picker(
out: &mut std::io::Stdout,
host_root: &str,
items: &[String],
selected: &[bool],
cursor: usize,
first: &mut bool,
lines_printed: &mut usize,
) -> io::Result<()> {
// If not first render, move to the start of the block. We will clear each line as we redraw.
if !*first {
out.write_all(format!("\x1b[{}A", *lines_printed).as_bytes())?; // up N lines
// Ensure we start at column 1 for a clean redraw.
out.write_all(b"\r")?;
}
let mut lines = Vec::new();
let bold = |s: &str| format!("\x1b[1m{s}\x1b[0m");
lines.push(bold(&format!("discovered models on ollama ({host_root}):")));
lines
.push("↑/↓ move, space to toggle, 'a' (un)select all, enter confirm, 'q' skip".to_string());
lines.push(String::new());
for (i, name) in items.iter().enumerate() {
let mark = if selected.get(i).copied().unwrap_or(false) {
"\x1b[32m[x]\x1b[0m" // green
} else {
"[ ]"
};
let mut line = format!("{mark} {name}");
if i == cursor {
line = format!("\x1b[7m{line}\x1b[0m"); // reverse video for current row
}
lines.push(line);
}
for l in &lines {
// Move to column 0 and clear the entire line before writing.
out.write_all(b"\r\x1b[2K")?;
out.write_all(l.as_bytes())?;
out.write_all(b"\n")?;
}
out.flush()?;
*first = false;
*lines_printed = lines.len();
Ok(())
}
fn print_config_summary_after_save(
config_path: &std::path::Path,
provider_was_present_before: bool,
models_count_after: Option<usize>,
) -> io::Result<()> {
let mut out = std::io::stdout();
// Start clean and at column 0
out.write_all(b"\r\x1b[2K")?;
let path = config_path.display().to_string();
if provider_was_present_before {
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
} else {
out.write_all(
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
)?;
}
if let Some(after) = models_count_after {
let names = read_config_models(config_path).unwrap_or_default();
if names.is_empty() {
out.write_all(format!("\rmodels: recorded {after}\n\n").as_bytes())?;
} else {
out.write_all(
format!("\rmodels: recorded {} ({})\n\n", after, names.join(", ")).as_bytes(),
)?;
}
} else {
out.write_all(b"\rmodels: no changes recorded\n\n")?;
}
out.flush()
}
pub async fn run_main(
cli: Cli,
@@ -69,14 +324,42 @@ pub async fn run_main(
)
};
// Track config.toml state for messaging before launching TUI.
let provider_was_present_before = if cli.ollama {
let codex_home = codex_core::config::find_codex_home()?;
let config_path = codex_home.join("config.toml");
let (p, _m) = read_provider_state(&config_path);
p
} else {
false
};
let config = {
// If the user selected the Ollama provider via `--ollama`, verify a
// local server is reachable and ensure a provider entry exists in
// config.toml. Exit early with a helpful message otherwise.
if cli.ollama {
if let Err(e) = ensure_configured_and_running().await {
#[allow(clippy::print_stderr)]
{
tracing::error!("{e}");
eprintln!("{e}");
}
std::process::exit(1);
}
}
// Load configuration and support CLI overrides.
let overrides = ConfigOverrides {
model: cli.model.clone(),
approval_policy,
sandbox_mode,
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
model_provider: None,
model_provider: if cli.ollama {
Some("ollama".to_string())
} else {
None
},
config_profile: cli.config_profile.clone(),
codex_linux_sandbox_exe,
base_instructions: None,
@@ -101,6 +384,73 @@ pub async fn run_main(
}
}
};
// If the user passed --ollama, either ensure an explicitly requested model is
// available (automatic pull if allowlisted) or offer an inline picker when no
// specific model was provided.
if cli.ollama {
// Determine host root for the Ollama native API (e.g. http://localhost:11434).
let base_url = config
.model_provider
.base_url
.clone()
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
let host_root = base_url
.trim_end_matches('/')
.trim_end_matches("/v1")
.to_string();
let config_path = config.codex_home.join("config.toml");
if let Some(ref model_name) = cli.model {
// Explicit model requested: ensure it is available locally without prompting.
let client = OllamaClient::from_provider(&config.model_provider);
let mut reporter = TuiProgressReporter::new();
if let Err(e) =
ensure_model_available(model_name, &client, &config_path, &mut reporter).await
{
let mut out = std::io::stderr();
tracing::error!("{e}");
let _ = out.write_all(format!("{e}\n").as_bytes());
let _ = out.flush();
std::process::exit(1);
}
} else {
// No specific model was requested: fetch available models from the local instance
// and, if they differ from what is listed in config.toml, display a minimal
// inline selection UI before launching the TUI.
let client = OllamaClient::from_provider(&config.model_provider);
let available_models: Vec<String> = client.fetch_models().await.unwrap_or_default();
// Read existing models in config.
let existing_models: Vec<String> = read_config_models(&config_path).unwrap_or_default();
if available_models.is_empty() {
// Inform the user and continue launching the TUI.
print_inline_message_no_models(
&host_root,
&config_path,
provider_was_present_before,
)?;
} else {
// Compare sets to decide whether to show the prompt.
let set_eq = {
use std::collections::HashSet;
let a: HashSet<_> = available_models.iter().collect();
let b: HashSet<_> = existing_models.iter().collect();
a == b
};
if !set_eq {
run_inline_models_picker(
&host_root,
&available_models,
&existing_models,
&config_path,
provider_was_present_before,
)?;
}
}
}
}
let log_dir = codex_core::config::log_dir(&config)?;
std::fs::create_dir_all(&log_dir)?;

View File

@@ -0,0 +1,2 @@
pub mod ollama_model_picker;

View File

@@ -0,0 +1,155 @@
use crossterm::event::{KeyCode, KeyEvent};
use ratatui::buffer::Buffer;
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
use ratatui::style::{Modifier, Style};
use ratatui::text::{Line, Span};
use ratatui::widgets::{Block, BorderType, Borders, Paragraph, WidgetRef};
use ratatui::prelude::Widget;
use std::path::PathBuf;
pub enum PickerOutcome {
Submit(Vec<String>),
Cancel,
None,
}
pub struct OllamaModelPickerScreen {
pub host_root: String,
pub config_path: PathBuf,
available: Vec<String>,
selected: Vec<bool>,
cursor: usize,
pub loading: bool,
}
impl OllamaModelPickerScreen {
pub fn new(host_root: String, config_path: PathBuf, preselected: Vec<String>) -> Self {
Self {
host_root,
config_path,
available: Vec::new(),
selected: preselected.into_iter().map(|_| false).collect(),
cursor: 0,
loading: true,
}
}
pub fn desired_height(&self, _width: u16) -> u16 {
18u16
}
pub fn update_available(&mut self, available: Vec<String>) {
// Build selection state using existing selected names where possible.
let prev_selected_names: Vec<String> = self
.available
.iter()
.cloned()
.zip(self.selected.iter().cloned())
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
.collect();
self.available = available.clone();
self.selected = available
.iter()
.map(|n| prev_selected_names.iter().any(|p| p == n))
.collect();
if self.cursor >= self.available.len() {
self.cursor = self.available.len().saturating_sub(1);
}
self.loading = false;
}
pub fn handle_key_event(&mut self, key: KeyEvent) -> PickerOutcome {
match key.code {
KeyCode::Up | KeyCode::Char('k') => {
if self.cursor > 0 { self.cursor -= 1; }
PickerOutcome::None
}
KeyCode::Down | KeyCode::Char('j') => {
if self.cursor + 1 < self.available.len() { self.cursor += 1; }
PickerOutcome::None
}
KeyCode::Char(' ') => {
if let Some(s) = self.selected.get_mut(self.cursor) {
*s = !*s;
}
PickerOutcome::None
}
KeyCode::Char('a') => {
let all = self.selected.iter().all(|s| *s);
self.selected.fill(!all);
PickerOutcome::None
}
KeyCode::Enter => {
let chosen: Vec<String> = self
.available
.iter()
.cloned()
.zip(self.selected.iter().cloned())
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
.collect();
PickerOutcome::Submit(chosen)
}
KeyCode::Esc | KeyCode::Char('q') => PickerOutcome::Cancel,
_ => PickerOutcome::None,
}
}
}
impl WidgetRef for &OllamaModelPickerScreen {
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
const MIN_WIDTH: u16 = 40;
const MIN_HEIGHT: u16 = 15;
let popup_width = std::cmp::max(MIN_WIDTH, (area.width as f32 * 0.7) as u16);
let popup_height = std::cmp::max(MIN_HEIGHT, (area.height as f32 * 0.6) as u16);
let popup_x = area.x + (area.width.saturating_sub(popup_width)) / 2;
let popup_y = area.y + (area.height.saturating_sub(popup_height)) / 2;
let popup_area = Rect::new(popup_x, popup_y, popup_width, popup_height);
let popup_block = Block::default()
.borders(Borders::ALL)
.border_type(BorderType::Plain)
.title(Span::styled(
"Select Ollama models",
Style::default().add_modifier(Modifier::BOLD),
));
let inner = popup_block.inner(popup_area);
popup_block.render(popup_area, buf);
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Length(3), Constraint::Min(3), Constraint::Length(3)])
.split(inner);
// Header
let header = format!("endpoint: {}\n↑/↓ move, space toggle, 'a' (un)select all, enter confirm, 'q' skip", self.host_root);
Paragraph::new(header).alignment(Alignment::Left).render(chunks[0], buf);
// Body: list of models or a loading message
if self.loading {
Paragraph::new("discovering models...").alignment(Alignment::Center).render(chunks[1], buf);
} else if self.available.is_empty() {
Paragraph::new("No models discovered on the local Ollama instance.")
.alignment(Alignment::Center)
.render(chunks[1], buf);
} else {
// Render each line manually with highlight for cursor.
let mut lines: Vec<Line> = Vec::with_capacity(self.available.len());
for (i, name) in self.available.iter().enumerate() {
let mark = if self.selected.get(i).copied().unwrap_or(false) { "[x]" } else { "[ ]" };
let content = format!("{mark} {name}");
if i == self.cursor {
lines.push(Line::from(content).style(Style::default().add_modifier(Modifier::REVERSED)));
} else {
lines.push(Line::from(content));
}
}
Paragraph::new(lines).render(chunks[1], buf);
}
// Footer/help
Paragraph::new("press Enter to save, 'q' to continue without changes")
.alignment(Alignment::Center)
.render(chunks[2], buf);
}
}