vtcode_core/llm/providers/ollama/
client.rs1use std::io;
4use std::time::Duration;
5
6use futures::StreamExt;
7use futures::stream::BoxStream;
8use semver::Version;
9use serde_json::Value as JsonValue;
10
11use super::OLLAMA_CONNECTION_ERROR;
12use super::ollama_model_name_from_fields;
13use super::pull::{OllamaPullEvent, OllamaPullProgressReporter};
14use super::url::base_url_to_host_root;
15
16pub struct OllamaClient {
18 client: reqwest::Client,
19 host_root: String,
20}
21
22impl OllamaClient {
23 pub async fn try_from_base_url(base_url: &str) -> io::Result<Self> {
25 let host_root = base_url_to_host_root(base_url);
26 let client = reqwest::Client::builder()
27 .connect_timeout(Duration::from_secs(5))
28 .build()
29 .unwrap_or_else(|_| reqwest::Client::new());
30
31 let instance = Self { client, host_root };
32 instance.probe_server().await?;
33 Ok(instance)
34 }
35
36 async fn probe_server(&self) -> io::Result<()> {
38 let url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
39 let resp = self.client.get(url).send().await.map_err(|err| {
40 tracing::warn!("Failed to connect to Ollama server: {err:?}");
41 io::Error::other(OLLAMA_CONNECTION_ERROR)
42 })?;
43
44 if resp.status().is_success() {
45 Ok(())
46 } else {
47 tracing::warn!(
48 "Failed to probe server at {}: HTTP {}",
49 self.host_root,
50 resp.status()
51 );
52 Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
53 }
54 }
55
56 pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
58 let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
59 let resp = self
60 .client
61 .get(tags_url)
62 .send()
63 .await
64 .map_err(io::Error::other)?;
65
66 if !resp.status().is_success() {
67 return Ok(Vec::new());
68 }
69
70 let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
71 let names = val
72 .get("models")
73 .and_then(|m| m.as_array())
74 .map(|arr| {
75 arr.iter()
76 .filter_map(|v| {
77 ollama_model_name_from_fields(
78 v.get("name").and_then(|n| n.as_str()),
79 v.get("model").and_then(|n| n.as_str()),
80 )
81 })
82 .map(str::to_string)
83 .collect::<Vec<_>>()
84 })
85 .unwrap_or_default();
86
87 Ok(names)
88 }
89
90 pub async fn fetch_version(&self) -> io::Result<Option<Version>> {
93 let version_url = format!("{}/api/version", self.host_root.trim_end_matches('/'));
94 let resp = self
95 .client
96 .get(version_url)
97 .send()
98 .await
99 .map_err(io::Error::other)?;
100
101 if !resp.status().is_success() {
102 return Ok(None);
103 }
104
105 let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
106 let Some(version_str) = val.get("version").and_then(|v| v.as_str()).map(str::trim) else {
107 return Ok(None);
108 };
109
110 let normalized = version_str.trim_start_matches('v');
111 match Version::parse(normalized) {
112 Ok(version) => Ok(Some(version)),
113 Err(err) => {
114 tracing::warn!("Failed to parse Ollama version `{version_str}`: {err}");
115 Ok(None)
116 }
117 }
118 }
119
120 pub async fn pull_model_stream(
122 &self,
123 model: &str,
124 ) -> io::Result<BoxStream<'static, OllamaPullEvent>> {
125 let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
126 let resp = self
127 .client
128 .post(url)
129 .json(&serde_json::json!({"model": model, "stream": true}))
130 .send()
131 .await
132 .map_err(io::Error::other)?;
133
134 if !resp.status().is_success() {
135 return Err(io::Error::other(format!(
136 "failed to start pull: HTTP {}",
137 resp.status()
138 )));
139 }
140
141 let mut stream = resp.bytes_stream();
142 let mut buf = String::new();
143
144 let s = async_stream::stream! {
145 while let Some(chunk) = stream.next().await {
146 match chunk {
147 Ok(bytes) => {
148 if let Ok(text) = std::str::from_utf8(&bytes) {
149 buf.push_str(text);
150 while let Some(pos) = buf.find('\n') {
151 let line = buf.drain(..=pos).collect::<String>();
152 let text = line.trim();
153 if text.is_empty() { continue; }
154 if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
155 for ev in super::parser::pull_events_from_value(&value) {
156 yield ev;
157 }
158 if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
159 yield OllamaPullEvent::Error(err_msg.to_string());
160 return;
161 }
162 }
163 }
164 }
165 }
166 Err(_) => {
167 return;
169 }
170 }
171 }
172 };
173
174 Ok(Box::pin(s))
175 }
176
177 pub async fn pull_with_reporter(
180 &self,
181 model: &str,
182 reporter: &mut dyn OllamaPullProgressReporter,
183 ) -> io::Result<()> {
184 reporter.on_event(&OllamaPullEvent::Status(format!(
185 "Pulling model {model}..."
186 )))?;
187 let mut stream = self.pull_model_stream(model).await?;
188
189 while let Some(event) = stream.next().await {
190 reporter.on_event(&event)?;
191 match event {
192 OllamaPullEvent::Success => {
193 return Ok(());
194 }
195 OllamaPullEvent::Error(err) => {
196 return Err(io::Error::other(err));
204 }
205 _ => {}
206 }
207 }
208
209 Err(io::Error::other("Pull stream ended unexpectedly"))
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use semver::Version;
217
218 #[test]
219 fn test_client_creation_requires_valid_base_url() {
220 }
223
224 #[test]
225 fn test_version_parsing() {
226 let v = Version::parse("0.14.1").expect("parse version");
228 assert_eq!(v.major, 0);
229 assert_eq!(v.minor, 14);
230 assert_eq!(v.patch, 1);
231 }
232
233 #[test]
234 fn test_version_parsing_strips_v_prefix() {
235 let version_str = "v0.13.4";
237 let normalized = version_str.trim_start_matches('v');
238 let v = Version::parse(normalized).expect("parse version");
239 assert_eq!(v, Version::new(0, 13, 4));
240 }
241}