1use crate::embeddings::{EmbeddingConfig, ProviderConfig, probe_provider_dimension};
9use anyhow::{Result, anyhow};
10use reqwest::Client;
11use std::path::PathBuf;
12use std::time::Duration;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum CheckStatus {
17 Pass,
19 Fail(String),
21 Running,
23 Pending,
25}
26
27impl CheckStatus {
28 pub fn icon(&self) -> &'static str {
29 match self {
30 CheckStatus::Pass => "[OK]",
31 CheckStatus::Fail(_) => "[ERR]",
32 CheckStatus::Running => "[...]",
33 CheckStatus::Pending => "[ ]",
34 }
35 }
36
37 pub fn is_pass(&self) -> bool {
38 matches!(self, CheckStatus::Pass)
39 }
40
41 pub fn is_fail(&self) -> bool {
42 matches!(self, CheckStatus::Fail(_))
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct HealthCheckItem {
49 pub name: String,
50 pub description: String,
51 pub status: CheckStatus,
52}
53
54impl HealthCheckItem {
55 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
56 Self {
57 name: name.into(),
58 description: description.into(),
59 status: CheckStatus::Pending,
60 }
61 }
62
63 pub fn pass(mut self) -> Self {
64 self.status = CheckStatus::Pass;
65 self
66 }
67
68 pub fn fail(mut self, msg: impl Into<String>) -> Self {
69 self.status = CheckStatus::Fail(msg.into());
70 self
71 }
72
73 pub fn running(mut self) -> Self {
74 self.status = CheckStatus::Running;
75 self
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct HealthCheckResult {
82 pub items: Vec<HealthCheckItem>,
83 pub connected_provider: Option<String>,
84 pub verified_dimension: Option<usize>,
85}
86
87impl HealthCheckResult {
88 pub fn new() -> Self {
89 Self {
90 items: Vec::new(),
91 connected_provider: None,
92 verified_dimension: None,
93 }
94 }
95
96 pub fn all_passed(&self) -> bool {
97 self.items.iter().all(|i| i.status.is_pass())
98 }
99
100 pub fn any_failed(&self) -> bool {
101 self.items.iter().any(|i| i.status.is_fail())
102 }
103
104 pub fn is_finished(&self) -> bool {
105 self.items
106 .iter()
107 .all(|i| !matches!(i.status, CheckStatus::Pending | CheckStatus::Running))
108 }
109}
110
111impl Default for HealthCheckResult {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117pub struct HealthChecker {
119 client: Client,
120}
121
122impl HealthChecker {
123 pub fn new() -> Self {
124 let client = Client::builder()
125 .timeout(Duration::from_secs(30))
126 .connect_timeout(Duration::from_secs(10))
127 .build()
128 .unwrap_or_default();
129
130 Self { client }
131 }
132
133 pub async fn run_all(
135 &self,
136 embedding_config: &EmbeddingConfig,
137 db_path: &str,
138 ) -> HealthCheckResult {
139 let mut result = HealthCheckResult::new();
140
141 let db_check = self.check_db_path(db_path);
143 result.items.push(db_check);
144
145 let (embedder_check, provider_name) =
147 self.check_embedder_connectivity(embedding_config).await;
148 result.items.push(embedder_check);
149 result.connected_provider = provider_name.clone();
150
151 if provider_name.is_some() {
153 let (embed_check, dimension) = self.check_embedding_generation(embedding_config).await;
154 result.items.push(embed_check);
155 result.verified_dimension = dimension;
156
157 if let Some(dim) = dimension {
159 let dim_check =
160 self.check_dimension_match(dim, embedding_config.required_dimension);
161 result.items.push(dim_check);
162 }
163 } else {
164 result.items.push(
166 HealthCheckItem::new("Test Embedding", "Send test text and verify response")
167 .fail("Skipped: No embedder available"),
168 );
169 result.items.push(
170 HealthCheckItem::new(
171 "Dimension Match",
172 format!("Verify dimension = {}", embedding_config.required_dimension),
173 )
174 .fail("Skipped: No embedding to verify"),
175 );
176 }
177
178 result
179 }
180
181 fn check_db_path(&self, db_path: &str) -> HealthCheckItem {
183 let mut item = HealthCheckItem::new("DB Path Writable", format!("Check {}", db_path));
184
185 let expanded = shellexpand::tilde(db_path).to_string();
186 let path = PathBuf::from(&expanded);
187
188 if path.exists() {
190 if path.is_dir() {
191 let test_file = path.join(".rust_memex_write_test");
193 match std::fs::write(&test_file, "test") {
194 Ok(_) => {
195 let _ = std::fs::remove_file(&test_file);
196 item = item.pass();
197 item.description = format!("Writable: {}", expanded);
198 }
199 Err(e) => {
200 item = item.fail(format!("Not writable: {}", e));
201 }
202 }
203 } else {
204 item = item.fail("Path exists but is not a directory");
205 }
206 } else {
207 if let Some(parent) = path.parent() {
209 if parent.exists() || std::fs::create_dir_all(parent).is_ok() {
210 item = item.pass();
211 item.description = format!("Will create: {}", expanded);
212 } else {
213 item = item.fail("Cannot create parent directories");
214 }
215 } else {
216 item = item.fail("Invalid path");
217 }
218 }
219
220 item
221 }
222
223 async fn check_embedder_connectivity(
225 &self,
226 config: &EmbeddingConfig,
227 ) -> (HealthCheckItem, Option<String>) {
228 let mut item = HealthCheckItem::new("Embedder Connection", "Connect to embedding provider");
229
230 if config.providers.is_empty() {
231 return (item.fail("No embedding providers configured"), None);
232 }
233
234 let mut providers = config.providers.clone();
236 providers.sort_by_key(|p| p.priority);
237
238 let mut tried = Vec::new();
239
240 for provider in &providers {
241 match self.try_provider_health(provider).await {
242 Ok(()) => {
243 item = item.pass();
244 item.description =
245 format!("Connected to {} ({})", provider.name, provider.base_url);
246 return (item, Some(provider.name.clone()));
247 }
248 Err(e) => {
249 tried.push(format!("{}: {}", provider.name, e));
250 }
251 }
252 }
253
254 (
255 item.fail(format!("All providers failed:\n {}", tried.join("\n "))),
256 None,
257 )
258 }
259
260 async fn try_provider_health(&self, provider: &ProviderConfig) -> Result<()> {
262 let base_url = provider.base_url.trim_end_matches('/');
263
264 let url = format!("{}/v1/models", base_url);
266 let response = self.client.get(&url).send().await;
267
268 match response {
269 Ok(resp) if resp.status().is_success() => Ok(()),
270 Ok(resp) if resp.status().as_u16() == 404 => {
271 let ollama_url = format!("{}/api/tags", base_url);
273 let ollama_resp = self.client.get(&ollama_url).send().await?;
274 if ollama_resp.status().is_success() {
275 return Ok(());
276 }
277 Err(anyhow!("No compatible endpoint found"))
278 }
279 Ok(resp) => Err(anyhow!("HTTP {}", resp.status())),
280 Err(e) => Err(anyhow!("Connection failed: {}", e)),
281 }
282 }
283
284 async fn check_embedding_generation(
286 &self,
287 config: &EmbeddingConfig,
288 ) -> (HealthCheckItem, Option<usize>) {
289 let mut item =
290 HealthCheckItem::new("Test Embedding", "Generate embedding for 'hello world'");
291
292 let mut providers = config.providers.clone();
294 providers.sort_by_key(|p| p.priority);
295 let mut failures = Vec::new();
296
297 for provider in &providers {
298 match probe_provider_dimension(&self.client, provider).await {
299 Ok(dim) => {
300 item = item.pass();
301 item.description = format!("Got {}-dim vector from {}", dim, provider.name);
302 return (item, Some(dim));
303 }
304 Err(e) => {
305 failures.push(format!("{}: {}", provider.name, e));
306 }
307 }
308 }
309
310 let message = if failures.is_empty() {
311 "No provider returned a valid embedding".to_string()
312 } else {
313 format!(
314 "No provider returned a valid embedding:\n {}",
315 failures.join("\n ")
316 )
317 };
318
319 (item.fail(message), None)
320 }
321
322 fn check_dimension_match(&self, actual: usize, required: usize) -> HealthCheckItem {
324 let mut item = HealthCheckItem::new(
325 "Dimension Match",
326 format!("Verify {} = {}", actual, required),
327 );
328
329 if actual == required {
330 item = item.pass();
331 item.description = format!("Dimension matches: {}", required);
332 } else {
333 item = item.fail(format!(
334 "Dimension mismatch! Got {} but config requires {}. \
335 This would corrupt the database!",
336 actual, required
337 ));
338 }
339
340 item
341 }
342}
343
344impl Default for HealthChecker {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_check_status_icon() {
356 assert_eq!(CheckStatus::Pass.icon(), "[OK]");
357 assert_eq!(CheckStatus::Fail("test".into()).icon(), "[ERR]");
358 assert_eq!(CheckStatus::Running.icon(), "[...]");
359 assert_eq!(CheckStatus::Pending.icon(), "[ ]");
360 }
361
362 #[test]
363 fn test_health_check_result() {
364 let mut result = HealthCheckResult::new();
365 assert!(result.items.is_empty());
366 assert!(!result.any_failed());
367 assert!(result.is_finished());
368
369 result
370 .items
371 .push(HealthCheckItem::new("Test", "Desc").pass());
372 assert!(result.all_passed());
373 assert!(!result.any_failed());
374
375 result
376 .items
377 .push(HealthCheckItem::new("Test2", "Desc2").fail("error"));
378 assert!(!result.all_passed());
379 assert!(result.any_failed());
380 }
381
382 #[test]
383 fn test_db_path_check() {
384 let checker = HealthChecker::new();
385
386 let tmp = tempfile::tempdir().unwrap();
388 let temp_path = tmp.path().join("rust_memex_test");
389 let item = checker.check_db_path(temp_path.to_str().unwrap());
390 assert!(item.status.is_pass() || matches!(item.status, CheckStatus::Fail(_)));
392 }
393}