mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
streaming model download
This commit is contained in:
6
codex-rs/Cargo.lock
generated
6
codex-rs/Cargo.lock
generated
@@ -724,17 +724,21 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
"futures-util",
|
||||
"owo-colors",
|
||||
"predicates",
|
||||
"reqwest",
|
||||
"serde_json",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"toml 0.8.23",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
@@ -844,6 +848,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-ansi-escape",
|
||||
@@ -854,6 +859,7 @@ dependencies = [
|
||||
"codex-login",
|
||||
"color-eyre",
|
||||
"crossterm",
|
||||
"futures-util",
|
||||
"image",
|
||||
"insta",
|
||||
"lazy_static",
|
||||
|
||||
@@ -27,6 +27,10 @@ codex-common = { path = "../common", features = [
|
||||
codex-core = { path = "../core" }
|
||||
owo-colors = "4.2.0"
|
||||
serde_json = "1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
toml = "0.8"
|
||||
bytes = "1"
|
||||
futures-util = "0.3"
|
||||
shlex = "1.3.0"
|
||||
tokio = { version = "1", features = [
|
||||
"io-std",
|
||||
|
||||
@@ -5,6 +5,7 @@ mod event_processor_with_json_output;
|
||||
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -31,6 +32,291 @@ use tracing_subscriber::EnvFilter;
|
||||
use crate::event_processor::CodexStatus;
|
||||
use crate::event_processor::EventProcessor;
|
||||
|
||||
// ----- Ollama model discovery and pull helpers (CLI) -----
|
||||
// These helpers are used when the user passes both --ollama and --model=<name>.
|
||||
// We verify the requested model is recorded in config.toml or present on the
|
||||
// local Ollama instance; if missing we will pull it (subject to an allowlist)
|
||||
// and record it in config.toml without prompting.
|
||||
|
||||
async fn fetch_ollama_models(host_root: &str) -> Vec<String> {
|
||||
let tags_url = format!("{host_root}/api/tags");
|
||||
match reqwest::Client::new().get(&tags_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
|
||||
Ok(val) => val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
Err(_) => Vec::new(),
|
||||
},
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_ollama_models_list(config_path: &std::path::Path) -> Vec<String> {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
|
||||
{
|
||||
Some(toml::Value::Table(root)) => root
|
||||
.get("model_providers")
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("ollama"))
|
||||
.and_then(|v| v.as_table())
|
||||
.and_then(|t| t.get("models"))
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_ollama_models(config_path: &std::path::Path, models: &[String]) -> std::io::Result<()> {
|
||||
use toml::value::Table as TomlTable;
|
||||
let mut root_value = if let Ok(contents) = std::fs::read_to_string(config_path) {
|
||||
toml::from_str::<toml::Value>(&contents).unwrap_or(toml::Value::Table(TomlTable::new()))
|
||||
} else {
|
||||
toml::Value::Table(TomlTable::new())
|
||||
};
|
||||
|
||||
if !matches!(root_value, toml::Value::Table(_)) {
|
||||
root_value = toml::Value::Table(TomlTable::new());
|
||||
}
|
||||
let root_tbl = match root_value.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => return Err(std::io::Error::other("invalid TOML root value")),
|
||||
};
|
||||
|
||||
let mp_val = root_tbl
|
||||
.entry("model_providers".to_string())
|
||||
.or_insert_with(|| toml::Value::Table(TomlTable::new()));
|
||||
if !mp_val.is_table() {
|
||||
*mp_val = toml::Value::Table(TomlTable::new());
|
||||
}
|
||||
let mp_tbl = match mp_val.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => return Err(std::io::Error::other("invalid model_providers table")),
|
||||
};
|
||||
|
||||
let ollama_val = mp_tbl
|
||||
.entry("ollama".to_string())
|
||||
.or_insert_with(|| toml::Value::Table(TomlTable::new()));
|
||||
if !ollama_val.is_table() {
|
||||
*ollama_val = toml::Value::Table(TomlTable::new());
|
||||
}
|
||||
let ollama_tbl = match ollama_val.as_table_mut() {
|
||||
Some(t) => t,
|
||||
None => return Err(std::io::Error::other("invalid ollama table")),
|
||||
};
|
||||
let arr = toml::Value::Array(
|
||||
models
|
||||
.iter()
|
||||
.map(|m| toml::Value::String(m.clone()))
|
||||
.collect(),
|
||||
);
|
||||
ollama_tbl.insert("models".to_string(), arr);
|
||||
|
||||
let updated =
|
||||
toml::to_string_pretty(&root_value).map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
if let Some(parent) = config_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
std::fs::write(config_path, updated)
|
||||
}
|
||||
|
||||
async fn pull_model_with_progress_cli(host_root: &str, model: &str) -> std::io::Result<()> {
|
||||
use futures_util::StreamExt;
|
||||
let url = format!("{host_root}/api/pull");
|
||||
let client = reqwest::Client::new();
|
||||
let mut resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"failed to start pull: HTTP {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut out = std::io::stderr();
|
||||
let _ = out.write_all(format!("Pulling model {model}...\n").as_bytes());
|
||||
let _ = out.flush();
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
let mut totals: std::collections::HashMap<String, (u64, u64)> = Default::default();
|
||||
let mut last_completed: u64 = 0;
|
||||
let mut last_instant = std::time::Instant::now();
|
||||
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut printed_header = false;
|
||||
let mut saw_success = false;
|
||||
let mut last_line_len: usize = 0;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
buf.extend_from_slice(&chunk);
|
||||
loop {
|
||||
if let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||||
let line = buf.split_to(pos + 1);
|
||||
if let Ok(text) = std::str::from_utf8(&line) {
|
||||
let text = text.trim();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
|
||||
let status = value.get("status").and_then(|s| s.as_str()).unwrap_or("");
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if let Some(t) = total {
|
||||
let entry = totals.entry(digest.clone()).or_insert((t, 0));
|
||||
entry.0 = t;
|
||||
}
|
||||
if let Some(c) = completed {
|
||||
let entry = totals.entry(digest.clone()).or_insert((0, 0));
|
||||
entry.1 = c;
|
||||
}
|
||||
|
||||
let (sum_total, sum_completed) = totals
|
||||
.values()
|
||||
.fold((0u64, 0u64), |acc, v| (acc.0 + v.0, acc.1 + v.1));
|
||||
|
||||
if sum_total > 0 && !printed_header {
|
||||
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let header = format!("Downloading {model}: total {gb:.2} GB\n");
|
||||
let _ = out.write_all(header.as_bytes());
|
||||
printed_header = true;
|
||||
}
|
||||
|
||||
if sum_total > 0 {
|
||||
let now = std::time::Instant::now();
|
||||
let dt = now.duration_since(last_instant).as_secs_f64().max(0.001);
|
||||
let dbytes = sum_completed.saturating_sub(last_completed) as f64;
|
||||
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
|
||||
last_completed = sum_completed;
|
||||
last_instant = now;
|
||||
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
|
||||
let line_text = format!(
|
||||
"{done:.2}/{total:.2} GB ({pct:.1}%) {speed:.1} MB/s",
|
||||
done = done_gb,
|
||||
total = total_gb,
|
||||
pct = pct,
|
||||
speed = speed_mb_s
|
||||
);
|
||||
let pad = last_line_len.saturating_sub(line_text.len());
|
||||
let line = format!(
|
||||
"\r{text}{spaces}",
|
||||
text = line_text,
|
||||
spaces = " ".repeat(pad)
|
||||
);
|
||||
last_line_len = line_text.len();
|
||||
let _ = out.write_all(line.as_bytes());
|
||||
let _ = out.flush();
|
||||
} else if !status.is_empty() {
|
||||
let line_text = status.to_string();
|
||||
let pad = last_line_len.saturating_sub(line_text.len());
|
||||
let line = format!(
|
||||
"\r{text}{spaces}",
|
||||
text = line_text,
|
||||
spaces = " ".repeat(pad)
|
||||
);
|
||||
last_line_len = line_text.len();
|
||||
let _ = out.write_all(line.as_bytes());
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
if status == "success" {
|
||||
let _ = out.write_all(b"\n");
|
||||
let _ = out.flush();
|
||||
saw_success = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if saw_success {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if saw_success {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::other(
|
||||
"model pull did not complete (no success status)",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_ollama_model_available_cli(
|
||||
model: &str,
|
||||
host_root: &str,
|
||||
config_path: &std::path::Path,
|
||||
) -> std::io::Result<()> {
|
||||
// 1) Always check the instance to ensure the model is actually available locally.
|
||||
// This avoids relying solely on potentially stale entries in config.toml.
|
||||
let mut listed = read_ollama_models_list(config_path);
|
||||
let available = fetch_ollama_models(host_root).await;
|
||||
if available.iter().any(|m| m == model) {
|
||||
// Ensure the model is recorded in config.toml.
|
||||
if !listed.iter().any(|m| m == model) {
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = save_ollama_models(config_path, &listed);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 2) Pull if allowlisted
|
||||
const ALLOWLIST: &[&str] = &["llama3.2:3b-instruct"];
|
||||
if !ALLOWLIST.iter().any(|&m| m == model) {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"Model `{}` not found locally and not in allowlist for automatic download.",
|
||||
model
|
||||
)));
|
||||
}
|
||||
// Pull with progress; if the streaming connection ends before success, keep
|
||||
// waiting/polling and retry the streaming request until the model appears or succeeds.
|
||||
loop {
|
||||
match pull_model_with_progress_cli(host_root, model).await {
|
||||
Ok(()) => break,
|
||||
Err(_) => {
|
||||
let available = fetch_ollama_models(host_root).await;
|
||||
if available.iter().any(|m| m == model) {
|
||||
break;
|
||||
}
|
||||
eprintln!("waiting for model to finish downloading...");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = save_ollama_models(config_path, &listed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||
let Cli {
|
||||
images,
|
||||
@@ -49,6 +335,9 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
config_overrides,
|
||||
} = cli;
|
||||
|
||||
// Track whether the user explicitly provided a model via --model.
|
||||
let user_specified_model = model.is_some();
|
||||
|
||||
// Determine the prompt based on CLI arg and/or stdin.
|
||||
let prompt = match prompt {
|
||||
Some(p) if p != "-" => p,
|
||||
@@ -117,7 +406,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
|
||||
// When the user opts into the Ollama provider via `--ollama`, ensure we
|
||||
// have a configured provider entry and that a local server is running.
|
||||
if ollama {
|
||||
if ollama {
|
||||
if let Err(e) = codex_core::config::ensure_ollama_provider_configured_and_running().await {
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
@@ -152,6 +441,28 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
};
|
||||
|
||||
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
|
||||
|
||||
// If the user passed both --ollama and --model, ensure the requested model
|
||||
// is present locally or pull it automatically (subject to allowlist).
|
||||
if ollama && user_specified_model {
|
||||
let model_name = config.model.clone();
|
||||
let base_url = config
|
||||
.model_provider
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
|
||||
let host_root = base_url
|
||||
.trim_end_matches('/')
|
||||
.trim_end_matches("/v1")
|
||||
.to_string();
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
if let Err(e) =
|
||||
ensure_ollama_model_available_cli(&model_name, &host_root, &config_path).await
|
||||
{
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
|
||||
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
|
||||
} else {
|
||||
|
||||
@@ -17,6 +17,7 @@ workspace = true
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
base64 = "0.22.1"
|
||||
bytes = "1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
codex-ansi-escape = { path = "../ansi-escape" }
|
||||
@@ -31,6 +32,7 @@ codex-file-search = { path = "../file-search" }
|
||||
codex-login = { path = "../login" }
|
||||
color-eyre = "0.6.3"
|
||||
crossterm = { version = "0.28.1", features = ["bracketed-paste"] }
|
||||
futures-util = "0.3"
|
||||
image = { version = "^0.25.6", default-features = false, features = ["jpeg"] }
|
||||
lazy_static = "1"
|
||||
mcp-types = { path = "../mcp-types" }
|
||||
@@ -45,7 +47,6 @@ regex-lite = "0.1"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1", features = ["preserve_order"] }
|
||||
toml = "0.8"
|
||||
shlex = "1.3.0"
|
||||
strum = "0.27.2"
|
||||
strum_macros = "0.27.2"
|
||||
@@ -58,6 +59,7 @@ tokio = { version = "1", features = [
|
||||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
toml = "0.8"
|
||||
tracing = { version = "0.1.41", features = ["log"] }
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
|
||||
@@ -68,7 +70,6 @@ unicode-width = "0.1"
|
||||
uuid = "1"
|
||||
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
insta = "1.43.1"
|
||||
pretty_assertions = "1"
|
||||
|
||||
@@ -9,11 +9,16 @@ use codex_core::config_types::SandboxMode;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::util::is_inside_git_repo;
|
||||
use codex_login::load_auth;
|
||||
use crossterm::event::{self, Event as CEvent, KeyCode, KeyEvent};
|
||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
|
||||
use crossterm::event::Event as CEvent;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
use crossterm::event::{self};
|
||||
use crossterm::terminal::disable_raw_mode;
|
||||
use crossterm::terminal::enable_raw_mode;
|
||||
use log_layer::TuiLogLayer;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::{self, Write};
|
||||
use std::io::Write;
|
||||
use std::io::{self};
|
||||
use std::path::PathBuf;
|
||||
use toml as _;
|
||||
use tracing::error;
|
||||
@@ -51,6 +56,216 @@ use color_eyre::owo_colors::OwoColorize;
|
||||
|
||||
pub use cli::Cli;
|
||||
|
||||
async fn fetch_ollama_models(host_root: &str) -> Vec<String> {
|
||||
let tags_url = format!("{host_root}/api/tags");
|
||||
match reqwest::Client::new().get(&tags_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
|
||||
Ok(val) => val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
Err(_) => Vec::new(),
|
||||
},
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn pull_model_with_progress_tui(host_root: &str, model: &str) -> std::io::Result<()> {
|
||||
use futures_util::StreamExt;
|
||||
let url = format!("{host_root}/api/pull");
|
||||
let client = reqwest::Client::new();
|
||||
let mut resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"failed to start pull: HTTP {}",
|
||||
resp.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut out = std::io::stderr();
|
||||
// Print an immediate status line so the user sees activity even before totals are known.
|
||||
let _ = out.write_all(format!("Pulling model {model}...\n").as_bytes());
|
||||
let _ = out.flush();
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
let mut totals: std::collections::HashMap<String, (u64, u64)> = Default::default();
|
||||
let mut last_completed: u64 = 0;
|
||||
let mut last_instant = std::time::Instant::now();
|
||||
|
||||
let mut stream = resp.bytes_stream();
|
||||
// Print a header once we know total size.
|
||||
let mut printed_header = false;
|
||||
let mut saw_success = false;
|
||||
let mut last_line_len: usize = 0;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
buf.extend_from_slice(&chunk);
|
||||
// split by newlines
|
||||
loop {
|
||||
if let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||||
let line = buf.split_to(pos + 1);
|
||||
if let Ok(text) = std::str::from_utf8(&line) {
|
||||
let text = text.trim();
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
|
||||
let status = value.get("status").and_then(|s| s.as_str()).unwrap_or("");
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if let Some(t) = total {
|
||||
let entry = totals.entry(digest.clone()).or_insert((t, 0));
|
||||
entry.0 = t;
|
||||
}
|
||||
if let Some(c) = completed {
|
||||
let entry = totals.entry(digest.clone()).or_insert((0, 0));
|
||||
entry.1 = c;
|
||||
}
|
||||
|
||||
let (sum_total, sum_completed) = totals
|
||||
.values()
|
||||
.fold((0u64, 0u64), |acc, v| (acc.0 + v.0, acc.1 + v.1));
|
||||
|
||||
if sum_total > 0 && !printed_header {
|
||||
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let header = format!("Downloading {model}: total {gb:.2} GB\n");
|
||||
let _ = out.write_all(header.as_bytes());
|
||||
printed_header = true;
|
||||
}
|
||||
|
||||
if sum_total > 0 {
|
||||
let now = std::time::Instant::now();
|
||||
let dt = now.duration_since(last_instant).as_secs_f64().max(0.001);
|
||||
let dbytes = sum_completed.saturating_sub(last_completed) as f64;
|
||||
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
|
||||
last_completed = sum_completed;
|
||||
last_instant = now;
|
||||
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
|
||||
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
|
||||
let line_text = format!(
|
||||
"{done:.2}/{total:.2} GB ({pct:.1}%) {speed:.1} MB/s",
|
||||
done = done_gb,
|
||||
total = total_gb,
|
||||
pct = pct,
|
||||
speed = speed_mb_s
|
||||
);
|
||||
let pad = last_line_len.saturating_sub(line_text.len());
|
||||
let line = format!(
|
||||
"\r{text}{spaces}",
|
||||
text = line_text,
|
||||
spaces = " ".repeat(pad)
|
||||
);
|
||||
last_line_len = line_text.len();
|
||||
let _ = out.write_all(line.as_bytes());
|
||||
let _ = out.flush();
|
||||
} else if !status.is_empty() {
|
||||
// Print status lines like verifying/writing only once.
|
||||
let line_text = status.to_string();
|
||||
let pad = last_line_len.saturating_sub(line_text.len());
|
||||
let line = format!(
|
||||
"\r{text}{spaces}",
|
||||
text = line_text,
|
||||
spaces = " ".repeat(pad)
|
||||
);
|
||||
last_line_len = line_text.len();
|
||||
let _ = out.write_all(line.as_bytes());
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
if status == "success" {
|
||||
let _ = out.write_all(b"\n");
|
||||
let _ = out.flush();
|
||||
saw_success = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if saw_success {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if saw_success {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(std::io::Error::other(
|
||||
"model pull did not complete (no success status)",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_ollama_model_available_tui(
|
||||
model: &str,
|
||||
host_root: &str,
|
||||
config_path: &std::path::Path,
|
||||
) -> std::io::Result<()> {
|
||||
// 1) Always check the instance to ensure the model is actually available locally.
|
||||
// This avoids relying solely on potentially stale entries in config.toml.
|
||||
let mut listed = read_ollama_models_list(config_path);
|
||||
let available = fetch_ollama_models(host_root).await;
|
||||
if available.iter().any(|m| m == model) {
|
||||
// Ensure the model is recorded in config.toml.
|
||||
if !listed.iter().any(|m| m == model) {
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = save_ollama_models(config_path, &listed);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 2) Pull if allowlisted.
|
||||
const ALLOWLIST: &[&str] = &["llama3.2:3b"];
|
||||
if !ALLOWLIST.iter().any(|&m| m == model) {
|
||||
return Err(std::io::Error::other(format!(
|
||||
"Model `{}` not found locally and not in allowlist for automatic download.",
|
||||
model
|
||||
)));
|
||||
}
|
||||
// Pull with progress; if the streaming connection ends before success, keep
|
||||
// waiting/polling and retry the streaming request until the model appears or succeeds.
|
||||
loop {
|
||||
match pull_model_with_progress_tui(host_root, model).await {
|
||||
Ok(()) => break,
|
||||
Err(_) => {
|
||||
let available = fetch_ollama_models(host_root).await;
|
||||
if available.iter().any(|m| m == model) {
|
||||
break;
|
||||
}
|
||||
let mut out = std::io::stderr();
|
||||
let _ = out.write_all(b"\nwaiting for model to finish downloading...\n");
|
||||
let _ = out.flush();
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
listed.push(model.to_string());
|
||||
listed.sort();
|
||||
listed.dedup();
|
||||
let _ = save_ollama_models(config_path, &listed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_ollama_models_list(config_path: &std::path::Path) -> Vec<String> {
|
||||
match std::fs::read_to_string(config_path)
|
||||
.ok()
|
||||
@@ -145,8 +360,8 @@ fn save_ollama_models(config_path: &std::path::Path, models: &[String]) -> std::
|
||||
);
|
||||
ollama_tbl.insert("models".to_string(), arr);
|
||||
|
||||
let updated = toml::to_string_pretty(&root_value)
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
let updated =
|
||||
toml::to_string_pretty(&root_value).map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||
if let Some(parent) = config_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
@@ -279,6 +494,8 @@ fn run_inline_models_picker(
|
||||
}
|
||||
|
||||
disable_raw_mode()?;
|
||||
// Ensure the summary starts on a clean, left‑aligned new line.
|
||||
out.write_all(b"\r\x1b[2K\n")?;
|
||||
|
||||
// Compute chosen
|
||||
let chosen: Vec<String> = available
|
||||
@@ -308,16 +525,18 @@ fn render_inline_picker(
|
||||
) -> io::Result<()> {
|
||||
// If not first render, move to the start of the block. We will clear each line as we redraw.
|
||||
if !*first {
|
||||
out.write_all(format!("\x1b[{}A", *lines_printed).as_bytes())?; // up N
|
||||
out.write_all(format!("\x1b[{}A", *lines_printed).as_bytes())?; // up N lines
|
||||
// Ensure we start at column 1 for a clean redraw.
|
||||
out.write_all(b"\r")?;
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
let bold = |s: &str| format!("\x1b[1m{s}\x1b[0m");
|
||||
lines.push(bold("we've discovered some models on ollama:").to_string());
|
||||
lines.push(format!("endpoint: {host_root}"));
|
||||
lines.push(bold(&format!(
|
||||
"we've discovered some models on ollama {host_root}:"
|
||||
)));
|
||||
lines.push(
|
||||
"controls: ↑/↓ move, space toggle, 'a' select/unselect all, enter confirm, 'q' skip"
|
||||
.to_string(),
|
||||
"↑/↓ move, space to toggle, 'a' select/unselect all, enter confirm, 'q' skip".to_string(),
|
||||
);
|
||||
lines.push(String::new());
|
||||
for (i, name) in items.iter().enumerate() {
|
||||
@@ -334,7 +553,8 @@ fn render_inline_picker(
|
||||
}
|
||||
|
||||
for l in &lines {
|
||||
out.write_all(b"\x1b[2K")?; // clear current line
|
||||
// Move to column 0 and clear the entire line before writing.
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
out.write_all(l.as_bytes())?;
|
||||
out.write_all(b"\n")?;
|
||||
}
|
||||
@@ -351,6 +571,8 @@ fn print_config_summary_after_save(
|
||||
models_count_after: Option<usize>,
|
||||
) -> io::Result<()> {
|
||||
let mut out = std::io::stdout();
|
||||
// Start clean and at column 0
|
||||
out.write_all(b"\r\x1b[2K")?;
|
||||
let path = config_path.display().to_string();
|
||||
if provider_was_present_before {
|
||||
out.write_all(format!("config: ollama provider already present in {path}\n").as_bytes())?;
|
||||
@@ -454,9 +676,9 @@ pub async fn run_main(
|
||||
}
|
||||
}
|
||||
};
|
||||
// If the user passed --ollama, fetch available models from the local
|
||||
// Ollama instance and, if they differ from what is listed in
|
||||
// config.toml, display a minimal inline selection UI before launching the TUI.
|
||||
// If the user passed --ollama, either ensure an explicitly requested model is
|
||||
// available (automatic pull if allowlisted) or offer an inline picker when no
|
||||
// specific model was provided.
|
||||
if cli.ollama {
|
||||
// Determine host root for the Ollama native API (e.g. http://localhost:11434).
|
||||
let base_url = config
|
||||
@@ -468,54 +690,71 @@ pub async fn run_main(
|
||||
.trim_end_matches('/')
|
||||
.trim_end_matches("/v1")
|
||||
.to_string();
|
||||
|
||||
// Query the list of local models via GET /api/tags.
|
||||
let tags_url = format!("{host_root}/api/tags");
|
||||
let available_models: Vec<String> = match reqwest::Client::new().get(&tags_url).send().await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<serde_json::Value>().await {
|
||||
Ok(val) => val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
Err(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
_ => Vec::new(),
|
||||
};
|
||||
|
||||
let config_path = config.codex_home.join("config.toml");
|
||||
// Read existing models in config.
|
||||
let existing_models: Vec<String> = read_ollama_models_list(&config_path);
|
||||
|
||||
if available_models.is_empty() {
|
||||
// Inform the user and continue launching the TUI.
|
||||
print_inline_message_no_models(&host_root, &config_path, provider_was_present_before)?;
|
||||
if let Some(ref model_name) = cli.model {
|
||||
// Explicit model requested: ensure it is available locally without prompting.
|
||||
if let Err(e) =
|
||||
ensure_ollama_model_available_tui(model_name, &host_root, &config_path).await
|
||||
{
|
||||
#[allow(clippy::print_stderr)]
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
} else {
|
||||
// Compare sets to decide whether to show the prompt.
|
||||
let set_eq = {
|
||||
use std::collections::HashSet;
|
||||
let a: HashSet<_> = available_models.iter().collect();
|
||||
let b: HashSet<_> = existing_models.iter().collect();
|
||||
a == b
|
||||
};
|
||||
// No specific model was requested: fetch available models from the local instance
|
||||
// and, if they differ from what is listed in config.toml, display a minimal
|
||||
// inline selection UI before launching the TUI.
|
||||
let tags_url = format!("{host_root}/api/tags");
|
||||
let available_models: Vec<String> =
|
||||
match reqwest::Client::new().get(&tags_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<serde_json::Value>().await {
|
||||
Ok(val) => val
|
||||
.get("models")
|
||||
.and_then(|m| m.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
Err(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
_ => Vec::new(),
|
||||
};
|
||||
|
||||
if !set_eq {
|
||||
run_inline_models_picker(
|
||||
// Read existing models in config.
|
||||
let existing_models: Vec<String> = read_ollama_models_list(&config_path);
|
||||
|
||||
if available_models.is_empty() {
|
||||
// Inform the user and continue launching the TUI.
|
||||
print_inline_message_no_models(
|
||||
&host_root,
|
||||
&available_models,
|
||||
&existing_models,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
models_count_before,
|
||||
)?;
|
||||
} else {
|
||||
// Compare sets to decide whether to show the prompt.
|
||||
let set_eq = {
|
||||
use std::collections::HashSet;
|
||||
let a: HashSet<_> = available_models.iter().collect();
|
||||
let b: HashSet<_> = existing_models.iter().collect();
|
||||
a == b
|
||||
};
|
||||
|
||||
if !set_eq {
|
||||
run_inline_models_picker(
|
||||
&host_root,
|
||||
&available_models,
|
||||
&existing_models,
|
||||
&config_path,
|
||||
provider_was_present_before,
|
||||
models_count_before,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user