refactor into providers

This commit is contained in:
pap
2025-08-04 00:18:52 +01:00
parent 6b6d2c5e00
commit 75eec73fcc
12 changed files with 823 additions and 800 deletions

View File

@@ -5,7 +5,6 @@ mod event_processor_with_json_output;
use std::io::IsTerminal;
use std::io::Read;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
@@ -32,299 +31,11 @@ use tracing_subscriber::EnvFilter;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
// ----- Ollama model discovery and pull helpers (CLI) -----
// These helpers are used when the user passes both --ollama and --model=<name>.
// We verify the requested model is recorded in config.toml or present on the
// local Ollama instance; if missing we will pull it (subject to an allowlist)
// and record it in config.toml without prompting.
async fn fetch_ollama_models(host_root: &str) -> Vec<String> {
let tags_url = format!("{host_root}/api/tags");
match reqwest::Client::new().get(&tags_url).send().await {
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
Ok(val) => val
.get("models")
.and_then(|m| m.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default(),
Err(_) => Vec::new(),
},
_ => Vec::new(),
}
}
fn read_ollama_models_list(config_path: &std::path::Path) -> Vec<String> {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
})
.unwrap_or_default(),
_ => Vec::new(),
}
}
fn save_ollama_models(config_path: &std::path::Path, models: &[String]) -> std::io::Result<()> {
use toml::value::Table as TomlTable;
let mut root_value = if let Ok(contents) = std::fs::read_to_string(config_path) {
toml::from_str::<toml::Value>(&contents).unwrap_or(toml::Value::Table(TomlTable::new()))
} else {
toml::Value::Table(TomlTable::new())
};
if !matches!(root_value, toml::Value::Table(_)) {
root_value = toml::Value::Table(TomlTable::new());
}
let root_tbl = match root_value.as_table_mut() {
Some(t) => t,
None => return Err(std::io::Error::other("invalid TOML root value")),
};
let mp_val = root_tbl
.entry("model_providers".to_string())
.or_insert_with(|| toml::Value::Table(TomlTable::new()));
if !mp_val.is_table() {
*mp_val = toml::Value::Table(TomlTable::new());
}
let mp_tbl = match mp_val.as_table_mut() {
Some(t) => t,
None => return Err(std::io::Error::other("invalid model_providers table")),
};
let ollama_val = mp_tbl
.entry("ollama".to_string())
.or_insert_with(|| toml::Value::Table(TomlTable::new()));
if !ollama_val.is_table() {
*ollama_val = toml::Value::Table(TomlTable::new());
}
let ollama_tbl = match ollama_val.as_table_mut() {
Some(t) => t,
None => return Err(std::io::Error::other("invalid ollama table")),
};
let arr = toml::Value::Array(
models
.iter()
.map(|m| toml::Value::String(m.clone()))
.collect(),
);
ollama_tbl.insert("models".to_string(), arr);
let updated =
toml::to_string_pretty(&root_value).map_err(|e| std::io::Error::other(e.to_string()))?;
if let Some(parent) = config_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(config_path, updated)
}
async fn pull_model_with_progress_cli(host_root: &str, model: &str) -> std::io::Result<()> {
use futures_util::StreamExt;
let url = format!("{host_root}/api/pull");
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
.await
.map_err(|e| std::io::Error::other(e.to_string()))?;
if !resp.status().is_success() {
return Err(std::io::Error::other(format!(
"failed to start pull: HTTP {}",
resp.status()
)));
}
let mut out = std::io::stderr();
let _ = out.write_all(format!("Pulling model {model}...\n").as_bytes());
let _ = out.flush();
let mut buf = bytes::BytesMut::new();
let mut totals: std::collections::HashMap<String, (u64, u64)> = Default::default();
let mut last_completed: u64 = 0;
let mut last_instant = std::time::Instant::now();
let mut stream = resp.bytes_stream();
let mut printed_header = false;
let mut saw_success = false;
let mut last_line_len: usize = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
buf.extend_from_slice(&chunk);
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() {
continue;
}
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
let _ = out.write_all(b"\n");
let _ = out.flush();
if err_msg.contains("file does not exist") {
return Err(std::io::Error::other("model not found"));
} else {
return Err(std::io::Error::other(format!(
"ollama pull error: {err_msg}"
)));
}
}
let status = value.get("status").and_then(|s| s.as_str()).unwrap_or("");
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if let Some(t) = total {
let entry = totals.entry(digest.clone()).or_insert((t, 0));
entry.0 = t;
}
if let Some(c) = completed {
let entry = totals.entry(digest.clone()).or_insert((0, 0));
entry.1 = c;
}
let (sum_total, sum_completed) = totals
.values()
.fold((0u64, 0u64), |acc, v| (acc.0 + v.0, acc.1 + v.1));
if sum_total > 0 && !printed_header {
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let header = format!("Downloading {model}: total {gb:.2} GB\n");
// Clear any prior inline status text before printing header.
let _ = out.write_all(b"\r\x1b[2K");
let _ = out.write_all(header.as_bytes());
printed_header = true;
}
if sum_total > 0 {
let now = std::time::Instant::now();
let dt = now.duration_since(last_instant).as_secs_f64().max(0.001);
let dbytes = sum_completed.saturating_sub(last_completed) as f64;
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
last_completed = sum_completed;
last_instant = now;
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
let line_text = format!(
"{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s"
);
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
} else if !status.is_empty() && !status.eq_ignore_ascii_case("pulling manifest")
{
let line_text = status.to_string();
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
}
if status == "success" {
let _ = out.write_all(b"\n");
let _ = out.flush();
saw_success = true;
break;
}
}
}
}
if saw_success {
break;
}
}
if saw_success {
Ok(())
} else {
Err(std::io::Error::other(
"model pull did not complete (no success status)",
))
}
}
async fn ensure_ollama_model_available_cli(
model: &str,
host_root: &str,
config_path: &std::path::Path,
) -> std::io::Result<()> {
// 1) Always check the instance to ensure the model is actually available locally.
// This avoids relying solely on potentially stale entries in config.toml.
let mut listed = read_ollama_models_list(config_path);
let available = fetch_ollama_models(host_root).await;
if available.iter().any(|m| m == model) {
// Ensure the model is recorded in config.toml.
if !listed.iter().any(|m| m == model) {
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = save_ollama_models(config_path, &listed);
}
return Ok(());
}
// 2) Pull if allowlisted
const ALLOWLIST: &[&str] = &["llama3.2:3b-instruct", "llama3.2:3b"];
if !ALLOWLIST.contains(&model) {
return Err(std::io::Error::other(format!(
"Model `{model}` not found locally and not in allowlist for automatic download."
)));
}
// Pull with progress; if the streaming connection ends before success, keep
// waiting/polling and retry the streaming request until the model appears or succeeds.
loop {
match pull_model_with_progress_cli(host_root, model).await {
Ok(()) => break,
Err(e) => {
// If the server reported the model manifest does not exist, surface a clear error.
if e.to_string().contains("model not found") {
return Err(std::io::Error::other("model not found"));
}
let available = fetch_ollama_models(host_root).await;
if available.iter().any(|m| m == model) {
break;
}
eprintln!("waiting for model to finish downloading...");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
continue;
}
}
}
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = save_ollama_models(config_path, &listed);
Ok(())
}
// Shared Ollama helpers are centralized in codex_core::providers::ollama.
use codex_core::providers::ollama::CliProgressReporter;
use codex_core::providers::ollama::OllamaClient;
use codex_core::providers::ollama::ensure_configured_and_running;
use codex_core::providers::ollama::ensure_model_available;
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
let Cli {
@@ -416,7 +127,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
// When the user opts into the Ollama provider via `--ollama`, ensure we
// have a configured provider entry and that a local server is running.
if ollama {
if let Err(e) = codex_core::config::ensure_ollama_provider_configured_and_running().await {
if let Err(e) = ensure_configured_and_running().await {
eprintln!("{e}");
std::process::exit(1);
}
@@ -455,18 +166,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
// is present locally or pull it automatically (subject to allowlist).
if ollama && user_specified_model {
let model_name = config.model.clone();
let base_url = config
.model_provider
.base_url
.clone()
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
let host_root = base_url
.trim_end_matches('/')
.trim_end_matches("/v1")
.to_string();
let client = OllamaClient::from_provider(&config.model_provider);
let config_path = config.codex_home.join("config.toml");
let mut reporter = CliProgressReporter::new();
if let Err(e) =
ensure_ollama_model_available_cli(&model_name, &host_root, &config_path).await
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
{
eprintln!("{e}");
std::process::exit(1);