mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
12 Commits
plan-defau
...
ollama
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5225cec3a1 | ||
|
|
c2cf4a3cb9 | ||
|
|
304d01c099 | ||
|
|
5df778471c | ||
|
|
3c97fc1423 | ||
|
|
d5c48cd049 | ||
|
|
42c726be79 | ||
|
|
b2ed15430f | ||
|
|
75eec73fcc | ||
|
|
6b6d2c5e00 | ||
|
|
47e84d5c05 | ||
|
|
2cfb2a2265 |
9
codex-rs/Cargo.lock
generated
9
codex-rs/Cargo.lock
generated
@@ -671,6 +671,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-channel",
|
||||
"async-stream",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -708,6 +709,7 @@ dependencies = [
|
||||
"tokio-test",
|
||||
"tokio-util",
|
||||
"toml 0.9.2",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
@@ -4814,6 +4816,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_spanned 0.6.9",
|
||||
"toml_datetime 0.6.11",
|
||||
"toml_write",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
@@ -4826,6 +4829,12 @@ dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_write"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
|
||||
|
||||
[[package]]
|
||||
name = "toml_writer"
|
||||
version = "1.0.2"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Codex supports several mechanisms for setting config values:
|
||||
|
||||
- Config-specific command-line flags, such as `--model o3` (highest precedence).
|
||||
- Convenience provider flags, such as `--ollama` (equivalent to `-c model_provider=ollama`).
|
||||
- A generic `-c`/`--config` flag that takes a `key=value` pair, such as `--config model="o3"`.
|
||||
- The key can contain dots to set a value deeper than the root, e.g. `--config model_providers.openai.wire_api="chat"`.
|
||||
- Values can contain objects, such as `--config shell_environment_policy.include_only=["PATH", "HOME", "USER"]`.
|
||||
@@ -56,6 +57,13 @@ name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
```
|
||||
|
||||
Alternatively, you can pass `--ollama` on the CLI, which is equivalent to `-c model_provider=ollama`.
|
||||
When using `--ollama`, Codex will verify that an Ollama server is running locally and
|
||||
will create a `[model_providers.ollama]` entry in your `config.toml` with sensible defaults
|
||||
(`base_url = "http://localhost:11434/v1"`, `wire_api = "chat"`) if one does not already exist.
|
||||
If no running Ollama server is detected, Codex will print instructions to install/start Ollama
|
||||
and exit: https://github.com/ollama/ollama?tab=readme-ov-file#ollama
|
||||
|
||||
Or a third-party provider (using a distinct environment variable for the API key):
|
||||
|
||||
```toml
|
||||
|
||||
@@ -13,6 +13,7 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
async-channel = "2.3.1"
|
||||
async-stream = "0.3"
|
||||
base64 = "0.22"
|
||||
bytes = "1.10.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
@@ -30,8 +31,8 @@ mime_guess = "2.0"
|
||||
rand = "0.9"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_bytes = "0.11"
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.7.0"
|
||||
@@ -47,6 +48,7 @@ tokio = { version = "1", features = [
|
||||
] }
|
||||
tokio-util = "0.7.14"
|
||||
toml = "0.9.2"
|
||||
toml_edit = "0.22"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-bash = "0.25.0"
|
||||
|
||||
@@ -62,10 +62,14 @@ impl ModelClient {
|
||||
summary: ReasoningSummaryConfig,
|
||||
session_id: Uuid,
|
||||
) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
config,
|
||||
auth,
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
provider,
|
||||
session_id,
|
||||
effort,
|
||||
|
||||
@@ -428,6 +428,11 @@ impl Config {
|
||||
.or(config_profile.model_provider)
|
||||
.or(cfg.model_provider)
|
||||
.unwrap_or_else(|| "openai".to_string());
|
||||
// Do not implicitly inject an Ollama provider when selected via
|
||||
// `-c model_provider=ollama`. Only the `--ollama` flag path sets up the
|
||||
// provider entry and performs discovery. This ensures parity with
|
||||
// other providers: if a provider is not defined in config.toml, we
|
||||
// return a clear error below.
|
||||
let model_provider = model_providers
|
||||
.get(&model_provider_id)
|
||||
.ok_or_else(|| {
|
||||
|
||||
@@ -102,6 +102,20 @@ pub enum CodexErr {
|
||||
|
||||
#[error("{0}")]
|
||||
EnvVar(EnvVarError),
|
||||
|
||||
// ------------------------------
|
||||
// Ollama‑specific errors
|
||||
// ------------------------------
|
||||
#[error(
|
||||
"No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama"
|
||||
)]
|
||||
OllamaServerUnreachable,
|
||||
|
||||
#[error("ollama model not found: {0}")]
|
||||
OllamaModelNotFound(String),
|
||||
|
||||
#[error("ollama pull failed: {0}")]
|
||||
OllamaPullFailed(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -37,6 +37,7 @@ mod openai_tools;
|
||||
pub mod plan_tool;
|
||||
mod project_doc;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
mod rollout;
|
||||
pub(crate) mod safety;
|
||||
pub mod seatbelt;
|
||||
|
||||
1
codex-rs/core/src/providers/mod.rs
Normal file
1
codex-rs/core/src/providers/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ollama;
|
||||
259
codex-rs/core/src/providers/ollama/client.rs
Normal file
259
codex-rs/core/src/providers/ollama/client.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use bytes::BytesMut;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::BoxStream;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
|
||||
use super::DEFAULT_BASE_URL;
|
||||
use super::PullEvent;
|
||||
use super::PullProgressReporter;
|
||||
use super::parser::pull_events_from_value;
|
||||
use super::url::base_url_to_host_root;
|
||||
use super::url::is_openai_compatible_base_url;
|
||||
|
||||
/// Client for interacting with a local Ollama instance.
|
||||
pub struct OllamaClient {
|
||||
client: reqwest::Client,
|
||||
host_root: String,
|
||||
uses_openai_compat: bool,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Build a client from a provider definition. Falls back to the default
|
||||
/// local URL if no base_url is configured.
|
||||
pub fn from_provider(provider: &ModelProviderInfo) -> Self {
|
||||
let base_url = provider
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
|
||||
let uses_openai_compat = is_openai_compatible_base_url(&base_url)
|
||||
|| matches!(provider.wire_api, WireApi::Chat)
|
||||
&& is_openai_compatible_base_url(&base_url);
|
||||
let host_root = base_url_to_host_root(&base_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
client,
|
||||
host_root,
|
||||
uses_openai_compat,
|
||||
}
|
||||
}
|
||||
|
||||
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
||||
pub fn from_host_root(host_root: impl Into<String>) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
Self {
|
||||
client,
|
||||
host_root: host_root.into(),
|
||||
uses_openai_compat: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
||||
pub async fn probe_server(&self) -> io::Result<bool> {
|
||||
let url = if self.uses_openai_compat {
|
||||
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
||||
};
|
||||
let resp = self.client.get(url).send().await;
|
||||
Ok(matches!(resp, Ok(r) if r.status().is_success()))
|
||||
}
|
||||
|
||||
/// Return the list of model names known to the local Ollama instance.
|
||||
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
||||
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.get(tags_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
|
||||
let names = val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Ok(names)
|
||||
}
|
||||
|
||||
/// Start a model pull and emit streaming events. The returned stream ends when
|
||||
/// a Success event is observed or the server closes the connection.
|
||||
pub async fn pull_model_stream(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> io::Result<BoxStream<'static, PullEvent>> {
|
||||
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
|
||||
let resp = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(io::Error::other(format!(
|
||||
"failed to start pull: HTTP {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut buf = BytesMut::new();
|
||||
let _pending: VecDeque<PullEvent> = VecDeque::new();
|
||||
|
||||
// Using an async stream adaptor backed by unfold-like manual loop.
|
||||
let s = async_stream::stream! {
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
buf.extend_from_slice(&bytes);
|
||||
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||||
let line = buf.split_to(pos + 1);
|
||||
if let Ok(text) = std::str::from_utf8(&line) {
|
||||
let text = text.trim();
|
||||
if text.is_empty() { continue; }
|
||||
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
||||
for ev in pull_events_from_value(&value) { yield ev; }
|
||||
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
||||
yield PullEvent::Status(format!("error: {err_msg}"));
|
||||
return;
|
||||
}
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
if status == "success" { yield PullEvent::Success; return; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Connection error: end the stream.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(s))
|
||||
}
|
||||
|
||||
/// High-level helper to pull a model and drive a progress reporter.
|
||||
pub async fn pull_with_reporter(
|
||||
&self,
|
||||
model: &str,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> io::Result<()> {
|
||||
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
|
||||
let mut stream = self.pull_model_stream(model).await?;
|
||||
while let Some(event) = stream.next().await {
|
||||
reporter.on_event(&event)?;
|
||||
if matches!(event, PullEvent::Success) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
|
||||
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
||||
#[tokio::test]
|
||||
async fn test_fetch_models_happy_path() {
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} is set; skipping test_fetch_models_happy_path",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(
|
||||
wiremock::ResponseTemplate::new(200).set_body_raw(
|
||||
serde_json::json!({
|
||||
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
|
||||
})
|
||||
.to_string(),
|
||||
"application/json",
|
||||
),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let client = OllamaClient::from_host_root(server.uri());
|
||||
let models = client.fetch_models().await.expect("fetch models");
|
||||
assert!(models.contains(&"llama3.2:3b".to_string()));
|
||||
assert!(models.contains(&"mistral".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_probe_server_happy_path_openai_compat_and_native() {
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = wiremock::MockServer::start().await;
|
||||
|
||||
// Native endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/api/tags"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let native = OllamaClient::from_host_root(server.uri());
|
||||
assert!(native.probe_server().await.expect("probe native"));
|
||||
|
||||
// OpenAI compatibility endpoint
|
||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||
.and(wiremock::matchers::path("/v1/models"))
|
||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
let provider = ModelProviderInfo {
|
||||
name: "Ollama".to_string(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_auth: false,
|
||||
};
|
||||
let compat = OllamaClient::from_provider(&provider);
|
||||
assert!(compat.probe_server().await.expect("probe compat"));
|
||||
}
|
||||
}
|
||||
243
codex-rs/core/src/providers/ollama/config.rs
Normal file
243
codex-rs/core/src/providers/ollama/config.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
use toml_edit::DocumentMut as Document;
|
||||
use toml_edit::Item;
|
||||
use toml_edit::Table;
|
||||
use toml_edit::Value as TomlValueEdit;
|
||||
|
||||
use super::DEFAULT_BASE_URL;
|
||||
|
||||
/// Read the list of models recorded under [model_providers.ollama].models.
|
||||
pub fn read_ollama_models_list(config_path: &Path) -> Vec<String> {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
|
||||
{
|
||||
Some(toml::Value::Table(root)) => root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("models"))
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience wrapper that returns the models list as an io::Result for callers
|
||||
/// that want a uniform Result-based API.
|
||||
pub fn read_config_models(config_path: &Path) -> io::Result<Vec<String>> {
|
||||
Ok(read_ollama_models_list(config_path))
|
||||
}
|
||||
|
||||
/// Overwrite the recorded models list under [model_providers.ollama].models using toml_edit.
|
||||
pub fn write_ollama_models_list(config_path: &Path, models: &[String]) -> io::Result<()> {
|
||||
let mut doc = read_document(config_path)?;
|
||||
{
|
||||
let tbl = upsert_provider_ollama(&mut doc);
|
||||
let mut arr = toml_edit::Array::new();
|
||||
for m in models {
|
||||
arr.push(TomlValueEdit::from(m.clone()));
|
||||
}
|
||||
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
|
||||
}
|
||||
write_document(config_path, &doc)
|
||||
}
|
||||
|
||||
/// Write models list via a uniform name expected by higher layers.
|
||||
pub fn write_config_models(config_path: &Path, models: &[String]) -> io::Result<()> {
|
||||
write_ollama_models_list(config_path, models)
|
||||
}
|
||||
|
||||
/// Ensure `[model_providers.ollama]` exists with sensible defaults on disk.
|
||||
/// Returns true if it created/updated the entry.
|
||||
pub fn ensure_ollama_provider_entry(codex_home: &Path) -> io::Result<bool> {
|
||||
let config_path = codex_home.join("config.toml");
|
||||
let mut doc = read_document(&config_path)?;
|
||||
let before = doc.to_string();
|
||||
let _tbl = upsert_provider_ollama(&mut doc);
|
||||
let after = doc.to_string();
|
||||
if before != after {
|
||||
write_document(&config_path, &doc)?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Alias name mirroring the refactor plan wording.
|
||||
pub fn ensure_provider_entry_and_defaults(codex_home: &Path) -> io::Result<bool> {
|
||||
ensure_ollama_provider_entry(codex_home)
|
||||
}
|
||||
|
||||
/// Read whether the provider exists and how many models are recorded under it.
|
||||
pub fn read_provider_state(config_path: &Path) -> (bool, usize) {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
|
||||
{
|
||||
Some(toml::Value::Table(root)) => {
|
||||
let provider_present = root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.map(|_| true)
|
||||
.unwrap_or(false);
|
||||
let models_count = root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("models"))
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.len())
|
||||
.unwrap_or(0);
|
||||
(provider_present, models_count)
|
||||
}
|
||||
_ => (false, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- toml_edit helpers ----------
|
||||
|
||||
fn read_document(path: &Path) -> io::Result<Document> {
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(s) => Document::from_str(&s).map_err(io::Error::other),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Document::new()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
std::fs::write(path, doc.to_string())
|
||||
}
|
||||
|
||||
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
|
||||
// Ensure "model_providers" exists and is a table.
|
||||
let needs_init = match doc.get("model_providers") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_init {
|
||||
doc.as_table_mut()
|
||||
.insert("model_providers", Item::Table(Table::new()));
|
||||
}
|
||||
|
||||
// Now, get a mutable reference to the "model_providers" table without `expect`/`unwrap`.
|
||||
let providers: &mut Table = {
|
||||
// Insert if missing.
|
||||
if doc.as_table().get("model_providers").is_none() {
|
||||
doc.as_table_mut()
|
||||
.insert("model_providers", Item::Table(Table::new()));
|
||||
}
|
||||
match doc.as_table_mut().get_mut("model_providers") {
|
||||
Some(item) => {
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(Table::new());
|
||||
}
|
||||
match item.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => unreachable!("model_providers was set to a table"),
|
||||
}
|
||||
}
|
||||
None => unreachable!("model_providers should exist after insertion"),
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure "ollama" exists and is a table.
|
||||
let needs_ollama_init = match providers.get("ollama") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_ollama_init {
|
||||
providers.insert("ollama", Item::Table(Table::new()));
|
||||
}
|
||||
|
||||
// Get a mutable reference to the "ollama" table without `expect`/`unwrap`.
|
||||
let tbl: &mut Table = {
|
||||
let needs_set = match providers.get("ollama") {
|
||||
None => true,
|
||||
Some(item) => !item.is_table(),
|
||||
};
|
||||
if needs_set {
|
||||
providers.insert("ollama", Item::Table(Table::new()));
|
||||
}
|
||||
match providers.get_mut("ollama") {
|
||||
Some(item) => {
|
||||
if !item.is_table() {
|
||||
*item = Item::Table(Table::new());
|
||||
}
|
||||
match item.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => unreachable!("ollama was set to a table"),
|
||||
}
|
||||
}
|
||||
None => unreachable!("ollama should exist after insertion"),
|
||||
}
|
||||
};
|
||||
|
||||
if !tbl.contains_key("name") {
|
||||
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
|
||||
}
|
||||
if !tbl.contains_key("base_url") {
|
||||
tbl["base_url"] = Item::Value(TomlValueEdit::from(DEFAULT_BASE_URL));
|
||||
}
|
||||
if !tbl.contains_key("wire_api") {
|
||||
tbl["wire_api"] = Item::Value(TomlValueEdit::from("chat"));
|
||||
}
|
||||
tbl
|
||||
}
|
||||
|
||||
pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
|
||||
let tbl = upsert_provider_ollama(doc);
|
||||
let mut arr = toml_edit::Array::new();
|
||||
for m in models {
|
||||
arr.push(TomlValueEdit::from(m.clone()));
|
||||
}
|
||||
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use toml_edit::DocumentMut as Document;
|
||||
|
||||
#[test]
|
||||
fn test_upsert_provider_and_models() {
|
||||
let mut doc = Document::new();
|
||||
let tbl = upsert_provider_ollama(&mut doc);
|
||||
assert!(tbl.contains_key("name"));
|
||||
assert!(tbl.contains_key("base_url"));
|
||||
assert!(tbl.contains_key("wire_api"));
|
||||
set_ollama_models(&mut doc, &[String::from("llama3.2:3b")]);
|
||||
let root = doc.as_table();
|
||||
let mp = match root.get("model_providers").and_then(|i| i.as_table()) {
|
||||
Some(t) => t,
|
||||
None => panic!("model_providers"),
|
||||
};
|
||||
let ollama = match mp.get("ollama").and_then(|i| i.as_table()) {
|
||||
Some(t) => t,
|
||||
None => panic!("ollama"),
|
||||
};
|
||||
let arr = match ollama.get("models") {
|
||||
Some(v) => v,
|
||||
None => panic!("models array"),
|
||||
};
|
||||
assert!(arr.is_array(), "models should be an array");
|
||||
let s = doc.to_string();
|
||||
assert!(s.contains("model_providers"));
|
||||
assert!(s.contains("ollama"));
|
||||
assert!(s.contains("models"));
|
||||
}
|
||||
}
|
||||
291
codex-rs/core/src/providers/ollama/mod.rs
Normal file
291
codex-rs/core/src/providers/ollama/mod.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CoreResult;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
pub const DEFAULT_BASE_URL: &str = "http://localhost:11434/v1";
|
||||
pub const DEFAULT_WIRE_API: crate::model_provider_info::WireApi =
|
||||
crate::model_provider_info::WireApi::Chat;
|
||||
pub const DEFAULT_PULL_ALLOWLIST: &[&str] = &["llama3.2:3b"];
|
||||
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod parser;
|
||||
pub mod url;
|
||||
|
||||
pub use client::OllamaClient;
|
||||
pub use config::read_config_models;
|
||||
pub use config::read_provider_state;
|
||||
pub use config::write_config_models;
|
||||
pub use url::base_url_to_host_root;
|
||||
pub use url::base_url_to_host_root_with_wire;
|
||||
pub use url::probe_ollama_server;
|
||||
pub use url::probe_url_for_base;
|
||||
/// Coordinator wrapper used by frontends when responding to `--ollama`.
|
||||
///
|
||||
/// - Probes the server using the configured base_url when present, otherwise
|
||||
/// falls back to DEFAULT_BASE_URL.
|
||||
/// - If the server is reachable, ensures an `[model_providers.ollama]` entry
|
||||
/// exists in `config.toml` with sensible defaults.
|
||||
/// - If no server is reachable, returns an error.
|
||||
pub async fn ensure_configured_and_running() -> CoreResult<()> {
|
||||
use crate::config::find_codex_home;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
let codex_home = find_codex_home()?;
|
||||
let config_path = codex_home.join("config.toml");
|
||||
// Try to read a configured base_url if present.
|
||||
let base_url = match std::fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(TomlValue::Table(root)) => root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("base_url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(DEFAULT_BASE_URL)
|
||||
.to_string(),
|
||||
_ => DEFAULT_BASE_URL.to_string(),
|
||||
},
|
||||
Err(_) => DEFAULT_BASE_URL.to_string(),
|
||||
};
|
||||
|
||||
// Probe reachability; map any probe error to a friendly unreachable message.
|
||||
let ok: bool = url::probe_ollama_server(&base_url)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if !ok {
|
||||
return Err(CodexErr::OllamaServerUnreachable);
|
||||
}
|
||||
|
||||
// Ensure provider entry exists with defaults.
|
||||
let _ = config::ensure_ollama_provider_entry(&codex_home)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod ensure_tests {
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_configured_returns_friendly_error_when_unreachable() {
|
||||
// Skip in CI sandbox environments without network to avoid false negatives.
|
||||
if std::env::var(crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
tracing::info!(
|
||||
"{} is set; skipping test_ensure_configured_returns_friendly_error_when_unreachable",
|
||||
crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let tmpdir = tempfile::TempDir::new().expect("tempdir");
|
||||
let config_path = tmpdir.path().join("config.toml");
|
||||
std::fs::create_dir_all(tmpdir.path()).unwrap();
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"[model_providers.ollama]
|
||||
name = "Ollama"
|
||||
base_url = "http://127.0.0.1:1/v1"
|
||||
wire_api = "chat"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", tmpdir.path());
|
||||
}
|
||||
|
||||
let err = ensure_configured_and_running()
|
||||
.await
|
||||
.expect_err("should report unreachable server as friendly error");
|
||||
assert!(matches!(err, CodexErr::OllamaServerUnreachable));
|
||||
}
|
||||
}
|
||||
|
||||
/// Events emitted while pulling a model from Ollama.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PullEvent {
|
||||
/// A human-readable status message (e.g., "verifying", "writing").
|
||||
Status(String),
|
||||
/// Byte-level progress update for a specific layer digest.
|
||||
ChunkProgress {
|
||||
digest: String,
|
||||
total: Option<u64>,
|
||||
completed: Option<u64>,
|
||||
},
|
||||
/// The pull finished successfully.
|
||||
Success,
|
||||
}
|
||||
|
||||
/// A simple observer for pull progress events. Implementations decide how to
|
||||
/// render progress (CLI, TUI, logs, ...).
|
||||
pub trait PullProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A minimal CLI reporter that writes inline progress to stderr.
|
||||
pub struct CliProgressReporter {
|
||||
printed_header: bool,
|
||||
last_line_len: usize,
|
||||
last_completed_sum: u64,
|
||||
last_instant: std::time::Instant,
|
||||
totals_by_digest: HashMap<String, (u64, u64)>,
|
||||
}
|
||||
|
||||
impl Default for CliProgressReporter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CliProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
printed_header: false,
|
||||
last_line_len: 0,
|
||||
last_completed_sum: 0,
|
||||
last_instant: std::time::Instant::now(),
|
||||
totals_by_digest: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PullProgressReporter for CliProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
let mut out = std::io::stderr();
|
||||
match event {
|
||||
PullEvent::Status(status) => {
|
||||
// Avoid noisy manifest messages; otherwise show status inline.
|
||||
if status.eq_ignore_ascii_case("pulling manifest") {
|
||||
return Ok(());
|
||||
}
|
||||
let pad = self.last_line_len.saturating_sub(status.len());
|
||||
let line = format!("\r{status}{}", " ".repeat(pad));
|
||||
self.last_line_len = status.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
}
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
if let Some(t) = *total {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.0 = t;
|
||||
}
|
||||
if let Some(c) = *completed {
|
||||
self.totals_by_digest
|
||||
.entry(digest.clone())
|
||||
.or_insert((0, 0))
|
||||
.1 = c;
|
||||
}
|
||||
|
||||
let (sum_total, sum_completed) = self
|
||||
.totals_by_digest
|
||||
.values()
|
||||
.fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
|
||||
if sum_total > 0 {
|
||||
if !self.printed_header {
|
||||
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let header = format!("Downloading model: total {gb:.2} GB\n");
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(header.as_bytes())?;
|
||||
self.printed_header = true;
|
||||
}
|
||||
let now = std::time::Instant::now();
|
||||
let dt = now
|
||||
.duration_since(self.last_instant)
|
||||
.as_secs_f64()
|
||||
.max(0.001);
|
||||
let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
|
||||
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
|
||||
self.last_completed_sum = sum_completed;
|
||||
self.last_instant = now;
|
||||
|
||||
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
|
||||
let text =
|
||||
format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
|
||||
let pad = self.last_line_len.saturating_sub(text.len());
|
||||
let line = format!("\r{text}{}", " ".repeat(pad));
|
||||
self.last_line_len = text.len();
|
||||
out.write_all(line.as_bytes())?;
|
||||
out.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
PullEvent::Success => {
|
||||
out.write_all(b"\n")?;
|
||||
out.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
|
||||
/// CLI behavior aligned until a dedicated TUI integration is implemented.
|
||||
#[derive(Default)]
|
||||
pub struct TuiProgressReporter(CliProgressReporter);
|
||||
impl TuiProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
impl PullProgressReporter for TuiProgressReporter {
|
||||
fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
|
||||
self.0.on_event(event)
|
||||
}
|
||||
}
|
||||
/// Ensure a model is available locally.
|
||||
///
|
||||
/// - If the model is already present, ensure it is recorded in config.toml.
|
||||
/// - If missing and in the default allowlist, pull it with streaming progress
|
||||
/// and record it in config.toml after success.
|
||||
/// - If missing and not allowlisted, return an error.
|
||||
pub async fn ensure_model_available(
|
||||
model: &str,
|
||||
client: &OllamaClient,
|
||||
config_path: &Path,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> CoreResult<()> {
|
||||
let mut listed = config::read_ollama_models_list(config_path);
|
||||
let available = client.fetch_models().await.unwrap_or_default();
|
||||
if available.iter().any(|m| m == model) {
|
||||
if !listed.iter().any(|m| m == model) {
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = config::write_ollama_models_list(config_path, &listed);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !DEFAULT_PULL_ALLOWLIST.contains(&model) {
|
||||
return Err(CodexErr::OllamaModelNotFound(model.to_string()));
|
||||
}
|
||||
|
||||
loop {
|
||||
let _ = client.pull_with_reporter(model, reporter).await;
|
||||
// After the stream completes (success or early EOF), check again.
|
||||
let available = client.fetch_models().await.unwrap_or_default();
|
||||
if available.iter().any(|m| m == model) {
|
||||
break;
|
||||
}
|
||||
// Keep waiting for the model to finish downloading.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = config::write_ollama_models_list(config_path, &listed);
|
||||
Ok(())
|
||||
}
|
||||
82
codex-rs/core/src/providers/ollama/parser.rs
Normal file
82
codex-rs/core/src/providers/ollama/parser.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use super::PullEvent;
|
||||
|
||||
// Convert a single JSON object representing a pull update into one or more events.
|
||||
pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
events.push(PullEvent::Status(status.to_string()));
|
||||
if status == "success" {
|
||||
events.push(PullEvent::Success);
|
||||
}
|
||||
}
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if total.is_some() || completed.is_some() {
|
||||
events.push(PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
});
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_status_and_success() {
|
||||
let v: JsonValue = serde_json::json!({"status":"verifying"});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"status":"success"});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 2);
|
||||
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
|
||||
assert!(matches!(events2[1], PullEvent::Success));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_progress() {
|
||||
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:abc");
|
||||
assert_eq!(*total, Some(100));
|
||||
assert_eq!(*completed, None);
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 1);
|
||||
match &events2[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:def");
|
||||
assert_eq!(*total, None);
|
||||
assert_eq!(*completed, Some(42));
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
}
|
||||
}
|
||||
83
codex-rs/core/src/providers/ollama/url.rs
Normal file
83
codex-rs/core/src/providers/ollama/url.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use crate::error::Result as CoreResult;
|
||||
|
||||
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
|
||||
pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool {
|
||||
base_url.trim_end_matches('/').ends_with("/v1")
|
||||
}
|
||||
|
||||
/// Convert a provider base_url into the native Ollama host root.
|
||||
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
|
||||
pub fn base_url_to_host_root(base_url: &str) -> String {
|
||||
let trimmed = base_url.trim_end_matches('/');
|
||||
if trimmed.ends_with("/v1") {
|
||||
trimmed
|
||||
.trim_end_matches("/v1")
|
||||
.trim_end_matches('/')
|
||||
.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Variant that considers an explicit WireApi value; provided to centralize
|
||||
/// host root computation in one place for future extension.
|
||||
pub fn base_url_to_host_root_with_wire(
|
||||
base_url: &str,
|
||||
_wire_api: crate::model_provider_info::WireApi,
|
||||
) -> String {
|
||||
base_url_to_host_root(base_url)
|
||||
}
|
||||
|
||||
/// Compute the probe URL to verify if an Ollama server is reachable.
|
||||
/// If the configured base is OpenAI-compatible (/v1), probe "models", otherwise
|
||||
/// fall back to the native "/api/tags" endpoint.
|
||||
pub fn probe_url_for_base(base_url: &str) -> String {
|
||||
if is_openai_compatible_base_url(base_url) {
|
||||
format!("{}/models", base_url.trim_end_matches('/'))
|
||||
} else {
|
||||
format!("{}/api/tags", base_url.trim_end_matches('/'))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience helper to probe an Ollama server given a provider style base URL.
|
||||
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
|
||||
let url = probe_url_for_base(base_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.build()?;
|
||||
let resp = client.get(url).send().await?;
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_base_url_to_host_root() {
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/v1"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_probe_url_for_base() {
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434/v1"),
|
||||
"http://localhost:11434/v1/models"
|
||||
);
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434"),
|
||||
"http://localhost:11434/api/tags"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,12 @@ pub struct Cli {
|
||||
#[arg(long, short = 'm')]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Convenience flag to select the local Ollama provider.
|
||||
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
|
||||
/// creates a model_providers.ollama entry in config.toml if missing.
|
||||
#[arg(long = "ollama", default_value_t = false)]
|
||||
pub ollama: bool,
|
||||
|
||||
/// Select the sandbox policy to use when executing model-generated shell
|
||||
/// commands.
|
||||
#[arg(long = "sandbox", short = 's')]
|
||||
|
||||
@@ -31,10 +31,17 @@ use tracing_subscriber::EnvFilter;
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
// Shared Ollama helpers are centralized in codex_core::providers::ollama.
|
||||
use codex_core::providers::ollama::CliProgressReporter;
|
||||
use codex_core::providers::ollama::OllamaClient;
|
||||
use codex_core::providers::ollama::ensure_configured_and_running;
|
||||
use codex_core::providers::ollama::ensure_model_available;
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
model,
|
||||
ollama,
|
||||
config_profile,
|
||||
full_auto,
|
||||
dangerously_bypass_approvals_and_sandbox,
|
||||
@@ -48,6 +55,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
config_overrides,
|
||||
} = cli;
|
||||
|
||||
// Track whether the user explicitly provided a model via --model.
|
||||
let user_specified_model = model.is_some();
|
||||
|
||||
// Determine the prompt based on CLI arg and/or stdin.
|
||||
let prompt = match prompt {
|
||||
Some(p) if p != "-" => p,
|
||||
@@ -114,6 +124,16 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
sandbox_mode_cli_arg.map(Into::<SandboxMode>::into)
|
||||
};
|
||||
|
||||
// When the user opts into the Ollama provider via `--ollama`, ensure we
|
||||
// have a configured provider entry and that a local server is running.
|
||||
if ollama {
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Load configuration and determine approval policy
|
||||
let overrides = ConfigOverrides {
|
||||
model,
|
||||
@@ -123,7 +143,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
approval_policy: Some(AskForApproval::Never),
|
||||
sandbox_mode,
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
model_provider: if ollama {
|
||||
Some("ollama".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
include_plan_tool: None,
|
||||
@@ -138,6 +162,22 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
|
||||
// If the user passed both --ollama and --model, ensure the requested model
|
||||
// is present locally or pull it automatically (subject to allowlist).
|
||||
if ollama && user_specified_model {
|
||||
let model_name = config.model.clone();
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
let mut reporter = CliProgressReporter::new();
|
||||
if let Err(e) =
|
||||
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
|
||||
} else {
|
||||
|
||||
@@ -67,7 +67,6 @@ unicode-width = "0.1"
|
||||
uuid = "1"
|
||||
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "1.43.1"
|
||||
pretty_assertions = "1"
|
||||
|
||||
@@ -43,6 +43,7 @@ enum AppState<'a> {
|
||||
},
|
||||
/// The start-up warning that recommends running codex inside a Git repo.
|
||||
GitWarning { screen: GitWarningScreen },
|
||||
// (no additional states)
|
||||
}
|
||||
|
||||
pub(crate) struct App<'a> {
|
||||
|
||||
@@ -17,6 +17,12 @@ pub struct Cli {
|
||||
#[arg(long, short = 'm')]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Convenience flag to select the local Ollama provider.
|
||||
/// Equivalent to -c model_provider=ollama; verifies a local Ollama server is running and
|
||||
/// creates a model_providers.ollama entry in config.toml if missing.
|
||||
#[arg(long = "ollama", default_value_t = false)]
|
||||
pub ollama: bool,
|
||||
|
||||
/// Configuration profile from config.toml to specify default options.
|
||||
#[arg(long = "profile", short = 'p')]
|
||||
pub config_profile: Option<String>,
|
||||
|
||||
@@ -9,9 +9,17 @@ use codex_core::config_types::SandboxMode;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_login::load_auth;
|
||||
use crossterm::event::Event as CEvent;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::KeyModifiers;
|
||||
use crossterm::event::{self};
|
||||
use crossterm::terminal::disable_raw_mode;
|
||||
use crossterm::terminal::enable_raw_mode;
|
||||
use log_layer::TuiLogLayer;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::path::PathBuf;
|
||||
use tracing::error;
|
||||
use tracing_appender::non_blocking;
|
||||
@@ -47,6 +55,253 @@ mod updates;
|
||||
use color_eyre::owo_colors::OwoColorize;
|
||||
|
||||
pub use cli::Cli;
|
||||
// Centralized Ollama helpers from core
|
||||
use codex_core::providers::ollama::OllamaClient;
|
||||
use codex_core::providers::ollama::TuiProgressReporter;
|
||||
use codex_core::providers::ollama::ensure_configured_and_running;
|
||||
use codex_core::providers::ollama::ensure_model_available;
|
||||
use codex_core::providers::ollama::read_config_models;
|
||||
use codex_core::providers::ollama::read_provider_state;
|
||||
use codex_core::providers::ollama::write_config_models;
|
||||
|
||||
fn print_inline_message_no_models(
|
||||
host_root: &str,
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
let path = config_path.display().to_string();
|
||||
// green bold helper
|
||||
let b = |s: &str| format!("\x1b[1m{s}\x1b[0m");
|
||||
// Ensure we start clean at column 0.
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(
|
||||
format!(
|
||||
"{}\n\n",
|
||||
b("we've discovered no models on your local Ollama instance.")
|
||||
)
|
||||
.as_bytes(),
|
||||
)?;
|
||||
out.write_all(format!("\rendpoint: {host_root}\n").as_bytes())?;
|
||||
if provider_was_present_before {
|
||||
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
|
||||
)?;
|
||||
}
|
||||
out.write_all(
|
||||
b"\rmodels: none recorded in config (pull models with `ollama pull <model>`).\n\n",
|
||||
)?;
|
||||
out.flush()
|
||||
}
|
||||
|
||||
fn run_inline_models_picker(
|
||||
host_root: &str,
|
||||
available: &[String],
|
||||
preselected: &[String],
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
let mut selected: Vec<bool> = available
|
||||
.iter()
|
||||
.map(|m| preselected.iter().any(|x| x == m))
|
||||
.collect();
|
||||
let mut cursor: usize = 0;
|
||||
|
||||
let mut first = true;
|
||||
let mut lines_printed: usize = 0;
|
||||
|
||||
enable_raw_mode()?;
|
||||
|
||||
loop {
|
||||
// Render block
|
||||
render_inline_picker(
|
||||
&mut out,
|
||||
host_root,
|
||||
available,
|
||||
&selected,
|
||||
cursor,
|
||||
&mut first,
|
||||
&mut lines_printed,
|
||||
)?;
|
||||
|
||||
// Wait for key
|
||||
match event::read()? {
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Up, ..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('k'),
|
||||
..
|
||||
}) => {
|
||||
cursor = cursor.saturating_sub(1);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Down,
|
||||
..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('j'),
|
||||
..
|
||||
}) => {
|
||||
if cursor + 1 < available.len() {
|
||||
cursor += 1;
|
||||
}
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char(' '),
|
||||
..
|
||||
}) => {
|
||||
if let Some(s) = selected.get_mut(cursor) {
|
||||
*s = !*s;
|
||||
}
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('a'),
|
||||
..
|
||||
}) => {
|
||||
let all_sel = selected.iter().all(|s| *s);
|
||||
selected.fill(!all_sel);
|
||||
}
|
||||
// Allow quitting the entire app from the inline picker with Ctrl+C or Ctrl+D.
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('c'),
|
||||
modifiers,
|
||||
..
|
||||
}) if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Restore terminal state and exit with SIGINT-like code.
|
||||
disable_raw_mode()?;
|
||||
// Start on a clean line before exiting.
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
std::process::exit(130);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('d'),
|
||||
modifiers,
|
||||
..
|
||||
}) if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Restore terminal state and exit cleanly.
|
||||
disable_raw_mode()?;
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
std::process::exit(0);
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
..
|
||||
}) => {
|
||||
break;
|
||||
}
|
||||
CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Char('q'),
|
||||
..
|
||||
})
|
||||
| CEvent::Key(KeyEvent {
|
||||
code: KeyCode::Esc, ..
|
||||
}) => {
|
||||
// Skip saving – print summary and continue.
|
||||
disable_raw_mode()?;
|
||||
print_config_summary_after_save(config_path, provider_was_present_before, None)?;
|
||||
return Ok(());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
disable_raw_mode()?;
|
||||
// Ensure the summary starts on a clean, left‑aligned new line.
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
|
||||
// Compute chosen
|
||||
let chosen: Vec<String> = available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(selected.iter())
|
||||
.filter_map(|(name, sel)| if *sel { Some(name) } else { None })
|
||||
.collect();
|
||||
|
||||
let _ = write_config_models(config_path, &chosen);
|
||||
print_config_summary_after_save(config_path, provider_was_present_before, Some(chosen.len()))
|
||||
}
|
||||
|
||||
fn render_inline_picker(
|
||||
out: &mut std::io::Stdout,
|
||||
host_root: &str,
|
||||
items: &[String],
|
||||
selected: &[bool],
|
||||
cursor: usize,
|
||||
first: &mut bool,
|
||||
lines_printed: &mut usize,
|
||||
) -> io::Result<()> {
|
||||
// If not first render, move to the start of the block. We will clear each line as we redraw.
|
||||
if !*first {
|
||||
out.write_all(format!("\x1b[{}A", *lines_printed).as_bytes())?; // up N lines
|
||||
// Ensure we start at column 1 for a clean redraw.
|
||||
out.write_all(b"\r")?;
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
let bold = |s: &str| format!("\x1b[1m{s}\x1b[0m");
|
||||
lines.push(bold(&format!("discovered models on ollama ({host_root}):")));
|
||||
lines
|
||||
.push("↑/↓ move, space to toggle, 'a' (un)select all, enter confirm, 'q' skip".to_string());
|
||||
lines.push(String::new());
|
||||
for (i, name) in items.iter().enumerate() {
|
||||
let mark = if selected.get(i).copied().unwrap_or(false) {
|
||||
"\x1b[32m[x]\x1b[0m" // green
|
||||
} else {
|
||||
"[ ]"
|
||||
};
|
||||
let mut line = format!("{mark} {name}");
|
||||
if i == cursor {
|
||||
line = format!("\x1b[7m{line}\x1b[0m"); // reverse video for current row
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
|
||||
for l in &lines {
|
||||
// Move to column 0 and clear the entire line before writing.
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(l.as_bytes())?;
|
||||
out.write_all(b"\n")?;
|
||||
}
|
||||
out.flush()?;
|
||||
*first = false;
|
||||
*lines_printed = lines.len();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_config_summary_after_save(
|
||||
config_path: &std::path::Path,
|
||||
provider_was_present_before: bool,
|
||||
models_count_after: Option<usize>,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
// Start clean and at column 0
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
let path = config_path.display().to_string();
|
||||
if provider_was_present_before {
|
||||
out.write_all(format!("\rconfig: ollama provider already present in {path}\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rconfig: added ollama as a model provider in {path}\n").as_bytes(),
|
||||
)?;
|
||||
}
|
||||
if let Some(after) = models_count_after {
|
||||
let names = read_config_models(config_path).unwrap_or_default();
|
||||
if names.is_empty() {
|
||||
out.write_all(format!("\rmodels: recorded {after}\n\n").as_bytes())?;
|
||||
} else {
|
||||
out.write_all(
|
||||
format!("\rmodels: recorded {} ({})\n\n", after, names.join(", ")).as_bytes(),
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
out.write_all(b"\rmodels: no changes recorded\n\n")?;
|
||||
}
|
||||
out.flush()
|
||||
}
|
||||
|
||||
pub async fn run_main(
|
||||
cli: Cli,
|
||||
@@ -69,14 +324,42 @@ pub async fn run_main(
|
||||
)
|
||||
};
|
||||
|
||||
// Track config.toml state for messaging before launching TUI.
|
||||
let provider_was_present_before = if cli.ollama {
|
||||
let codex_home = codex_core::config::find_codex_home()?;
|
||||
let config_path = codex_home.join("config.toml");
|
||||
let (p, _m) = read_provider_state(&config_path);
|
||||
p
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let config = {
|
||||
// If the user selected the Ollama provider via `--ollama`, verify a
|
||||
// local server is reachable and ensure a provider entry exists in
|
||||
// config.toml. Exit early with a helpful message otherwise.
|
||||
if cli.ollama {
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
#[allow(clippy::print_stderr)]
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
}
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Load configuration and support CLI overrides.
|
||||
let overrides = ConfigOverrides {
|
||||
model: cli.model.clone(),
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
||||
model_provider: None,
|
||||
model_provider: if cli.ollama {
|
||||
Some("ollama".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
config_profile: cli.config_profile.clone(),
|
||||
codex_linux_sandbox_exe,
|
||||
base_instructions: None,
|
||||
@@ -101,6 +384,73 @@ pub async fn run_main(
|
||||
}
|
||||
}
|
||||
};
|
||||
// If the user passed --ollama, either ensure an explicitly requested model is
|
||||
// available (automatic pull if allowlisted) or offer an inline picker when no
|
||||
// specific model was provided.
|
||||
if cli.ollama {
|
||||
// Determine host root for the Ollama native API (e.g. http://localhost:11434).
|
||||
let base_url = config
|
||||
.model_provider
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
|
||||
let host_root = base_url
|
||||
.trim_end_matches('/')
|
||||
.trim_end_matches("/v1")
|
||||
.to_string();
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
|
||||
if let Some(ref model_name) = cli.model {
|
||||
// Explicit model requested: ensure it is available locally without prompting.
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let mut reporter = TuiProgressReporter::new();
|
||||
if let Err(e) =
|
||||
ensure_model_available(model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
let mut out = std::io::stderr();
|
||||
tracing::error!("{e}");
|
||||
let _ = out.write_all(format!("{e}\n").as_bytes());
|
||||
let _ = out.flush();
|
||||
std::process::exit(1);
|
||||
}
|
||||
} else {
|
||||
// No specific model was requested: fetch available models from the local instance
|
||||
// and, if they differ from what is listed in config.toml, display a minimal
|
||||
// inline selection UI before launching the TUI.
|
||||
let client = OllamaClient::from_provider(&config.model_provider);
|
||||
let available_models: Vec<String> = client.fetch_models().await.unwrap_or_default();
|
||||
|
||||
// Read existing models in config.
|
||||
let existing_models: Vec<String> = read_config_models(&config_path).unwrap_or_default();
|
||||
|
||||
if available_models.is_empty() {
|
||||
// Inform the user and continue launching the TUI.
|
||||
print_inline_message_no_models(
|
||||
&host_root,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
)?;
|
||||
} else {
|
||||
// Compare sets to decide whether to show the prompt.
|
||||
let set_eq = {
|
||||
use std::collections::HashSet;
|
||||
let a: HashSet<_> = available_models.iter().collect();
|
||||
let b: HashSet<_> = existing_models.iter().collect();
|
||||
a == b
|
||||
};
|
||||
|
||||
if !set_eq {
|
||||
run_inline_models_picker(
|
||||
&host_root,
|
||||
&available_models,
|
||||
&existing_models,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let log_dir = codex_core::config::log_dir(&config)?;
|
||||
std::fs::create_dir_all(&log_dir)?;
|
||||
|
||||
2
codex-rs/tui/src/screens/mod.rs
Normal file
2
codex-rs/tui/src/screens/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod ollama_model_picker;
|
||||
|
||||
155
codex-rs/tui/src/screens/ollama_model_picker.rs
Normal file
155
codex-rs/tui/src/screens/ollama_model_picker.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::buffer::Buffer;
|
||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
|
||||
use ratatui::style::{Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::widgets::{Block, BorderType, Borders, Paragraph, WidgetRef};
|
||||
use ratatui::prelude::Widget;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub enum PickerOutcome {
|
||||
Submit(Vec<String>),
|
||||
Cancel,
|
||||
None,
|
||||
}
|
||||
|
||||
pub struct OllamaModelPickerScreen {
|
||||
pub host_root: String,
|
||||
pub config_path: PathBuf,
|
||||
available: Vec<String>,
|
||||
selected: Vec<bool>,
|
||||
cursor: usize,
|
||||
pub loading: bool,
|
||||
}
|
||||
|
||||
impl OllamaModelPickerScreen {
|
||||
pub fn new(host_root: String, config_path: PathBuf, preselected: Vec<String>) -> Self {
|
||||
Self {
|
||||
host_root,
|
||||
config_path,
|
||||
available: Vec::new(),
|
||||
selected: preselected.into_iter().map(|_| false).collect(),
|
||||
cursor: 0,
|
||||
loading: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn desired_height(&self, _width: u16) -> u16 {
|
||||
18u16
|
||||
}
|
||||
|
||||
pub fn update_available(&mut self, available: Vec<String>) {
|
||||
// Build selection state using existing selected names where possible.
|
||||
let prev_selected_names: Vec<String> = self
|
||||
.available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.selected.iter().cloned())
|
||||
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
|
||||
.collect();
|
||||
|
||||
self.available = available.clone();
|
||||
self.selected = available
|
||||
.iter()
|
||||
.map(|n| prev_selected_names.iter().any(|p| p == n))
|
||||
.collect();
|
||||
if self.cursor >= self.available.len() {
|
||||
self.cursor = self.available.len().saturating_sub(1);
|
||||
}
|
||||
self.loading = false;
|
||||
}
|
||||
|
||||
pub fn handle_key_event(&mut self, key: KeyEvent) -> PickerOutcome {
|
||||
match key.code {
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
if self.cursor > 0 { self.cursor -= 1; }
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
if self.cursor + 1 < self.available.len() { self.cursor += 1; }
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Char(' ') => {
|
||||
if let Some(s) = self.selected.get_mut(self.cursor) {
|
||||
*s = !*s;
|
||||
}
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Char('a') => {
|
||||
let all = self.selected.iter().all(|s| *s);
|
||||
self.selected.fill(!all);
|
||||
PickerOutcome::None
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let chosen: Vec<String> = self
|
||||
.available
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.selected.iter().cloned())
|
||||
.filter_map(|(n, sel)| if sel { Some(n) } else { None })
|
||||
.collect();
|
||||
PickerOutcome::Submit(chosen)
|
||||
}
|
||||
KeyCode::Esc | KeyCode::Char('q') => PickerOutcome::Cancel,
|
||||
_ => PickerOutcome::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WidgetRef for &OllamaModelPickerScreen {
|
||||
fn render_ref(&self, area: Rect, buf: &mut Buffer) {
|
||||
const MIN_WIDTH: u16 = 40;
|
||||
const MIN_HEIGHT: u16 = 15;
|
||||
let popup_width = std::cmp::max(MIN_WIDTH, (area.width as f32 * 0.7) as u16);
|
||||
let popup_height = std::cmp::max(MIN_HEIGHT, (area.height as f32 * 0.6) as u16);
|
||||
let popup_x = area.x + (area.width.saturating_sub(popup_width)) / 2;
|
||||
let popup_y = area.y + (area.height.saturating_sub(popup_height)) / 2;
|
||||
let popup_area = Rect::new(popup_x, popup_y, popup_width, popup_height);
|
||||
|
||||
let popup_block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_type(BorderType::Plain)
|
||||
.title(Span::styled(
|
||||
"Select Ollama models",
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
let inner = popup_block.inner(popup_area);
|
||||
popup_block.render(popup_area, buf);
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([Constraint::Length(3), Constraint::Min(3), Constraint::Length(3)])
|
||||
.split(inner);
|
||||
|
||||
// Header
|
||||
let header = format!("endpoint: {}\n↑/↓ move, space toggle, 'a' (un)select all, enter confirm, 'q' skip", self.host_root);
|
||||
Paragraph::new(header).alignment(Alignment::Left).render(chunks[0], buf);
|
||||
|
||||
// Body: list of models or a loading message
|
||||
if self.loading {
|
||||
Paragraph::new("discovering models...").alignment(Alignment::Center).render(chunks[1], buf);
|
||||
} else if self.available.is_empty() {
|
||||
Paragraph::new("No models discovered on the local Ollama instance.")
|
||||
.alignment(Alignment::Center)
|
||||
.render(chunks[1], buf);
|
||||
} else {
|
||||
// Render each line manually with highlight for cursor.
|
||||
let mut lines: Vec<Line> = Vec::with_capacity(self.available.len());
|
||||
for (i, name) in self.available.iter().enumerate() {
|
||||
let mark = if self.selected.get(i).copied().unwrap_or(false) { "[x]" } else { "[ ]" };
|
||||
let content = format!("{mark} {name}");
|
||||
if i == self.cursor {
|
||||
lines.push(Line::from(content).style(Style::default().add_modifier(Modifier::REVERSED)));
|
||||
} else {
|
||||
lines.push(Line::from(content));
|
||||
}
|
||||
}
|
||||
Paragraph::new(lines).render(chunks[1], buf);
|
||||
}
|
||||
|
||||
// Footer/help
|
||||
Paragraph::new("press Enter to save, 'q' to continue without changes")
|
||||
.alignment(Alignment::Center)
|
||||
.render(chunks[2], buf);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user