fixed ctrl+q/d at config stage and -c model_provider=ollama is similar elsewhere.

This commit is contained in:
pap
2025-08-03 23:12:19 +01:00
parent 47e84d5c05
commit 6b6d2c5e00
3 changed files with 244 additions and 230 deletions

View File

@@ -137,7 +137,7 @@ async fn pull_model_with_progress_cli(host_root: &str, model: &str) -> std::io::
use futures_util::StreamExt;
let url = format!("{host_root}/api/pull");
let client = reqwest::Client::new();
let mut resp = client
let resp = client
.post(&url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
@@ -165,92 +165,98 @@ async fn pull_model_with_progress_cli(host_root: &str, model: &str) -> std::io::
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
buf.extend_from_slice(&chunk);
loop {
if let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() {
continue;
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() {
continue;
}
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
let _ = out.write_all(b"\n");
let _ = out.flush();
if err_msg.contains("file does not exist") {
return Err(std::io::Error::other("model not found"));
} else {
return Err(std::io::Error::other(format!(
"ollama pull error: {err_msg}"
)));
}
}
let status = value.get("status").and_then(|s| s.as_str()).unwrap_or("");
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if let Some(t) = total {
let entry = totals.entry(digest.clone()).or_insert((t, 0));
entry.0 = t;
}
if let Some(c) = completed {
let entry = totals.entry(digest.clone()).or_insert((0, 0));
entry.1 = c;
}
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
let status = value.get("status").and_then(|s| s.as_str()).unwrap_or("");
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if let Some(t) = total {
let entry = totals.entry(digest.clone()).or_insert((t, 0));
entry.0 = t;
}
if let Some(c) = completed {
let entry = totals.entry(digest.clone()).or_insert((0, 0));
entry.1 = c;
}
let (sum_total, sum_completed) = totals
.values()
.fold((0u64, 0u64), |acc, v| (acc.0 + v.0, acc.1 + v.1));
let (sum_total, sum_completed) = totals
.values()
.fold((0u64, 0u64), |acc, v| (acc.0 + v.0, acc.1 + v.1));
if sum_total > 0 && !printed_header {
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let header = format!("Downloading {model}: total {gb:.2} GB\n");
let _ = out.write_all(header.as_bytes());
printed_header = true;
}
if sum_total > 0 && !printed_header {
let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let header = format!("Downloading {model}: total {gb:.2} GB\n");
// Clear any prior inline status text before printing header.
let _ = out.write_all(b"\r\x1b[2K");
let _ = out.write_all(header.as_bytes());
printed_header = true;
}
if sum_total > 0 {
let now = std::time::Instant::now();
let dt = now.duration_since(last_instant).as_secs_f64().max(0.001);
let dbytes = sum_completed.saturating_sub(last_completed) as f64;
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
last_completed = sum_completed;
last_instant = now;
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
let line_text = format!(
"{done:.2}/{total:.2} GB ({pct:.1}%) {speed:.1} MB/s",
done = done_gb,
total = total_gb,
pct = pct,
speed = speed_mb_s
);
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
} else if !status.is_empty() {
let line_text = status.to_string();
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
}
if sum_total > 0 {
let now = std::time::Instant::now();
let dt = now.duration_since(last_instant).as_secs_f64().max(0.001);
let dbytes = sum_completed.saturating_sub(last_completed) as f64;
let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
last_completed = sum_completed;
last_instant = now;
let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
let line_text = format!(
"{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s"
);
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
} else if !status.is_empty() && !status.eq_ignore_ascii_case("pulling manifest")
{
let line_text = status.to_string();
let pad = last_line_len.saturating_sub(line_text.len());
let line = format!(
"\r{text}{spaces}",
text = line_text,
spaces = " ".repeat(pad)
);
last_line_len = line_text.len();
let _ = out.write_all(line.as_bytes());
let _ = out.flush();
}
if status == "success" {
let _ = out.write_all(b"\n");
let _ = out.flush();
saw_success = true;
break;
}
if status == "success" {
let _ = out.write_all(b"\n");
let _ = out.flush();
saw_success = true;
break;
}
}
} else {
break;
}
}
if saw_success {
@@ -287,11 +293,10 @@ async fn ensure_ollama_model_available_cli(
}
// 2) Pull if allowlisted
const ALLOWLIST: &[&str] = &["llama3.2:3b-instruct"];
if !ALLOWLIST.iter().any(|&m| m == model) {
const ALLOWLIST: &[&str] = &["llama3.2:3b-instruct", "llama3.2:3b"];
if !ALLOWLIST.contains(&model) {
return Err(std::io::Error::other(format!(
"Model `{}` not found locally and not in allowlist for automatic download.",
model
"Model `{model}` not found locally and not in allowlist for automatic download."
)));
}
// Pull with progress; if the streaming connection ends before success, keep
@@ -299,7 +304,11 @@ async fn ensure_ollama_model_available_cli(
loop {
match pull_model_with_progress_cli(host_root, model).await {
Ok(()) => break,
Err(_) => {
Err(e) => {
// If the server reported the model manifest does not exist, surface a clear error.
if e.to_string().contains("model not found") {
return Err(std::io::Error::other("model not found"));
}
let available = fetch_ollama_models(host_root).await;
if available.iter().any(|m| m == model) {
break;