Skip to main content

sqlite_graphrag/
daemon.rs

1use crate::constants::{
2    DAEMON_AUTO_START_INITIAL_BACKOFF_MS, DAEMON_AUTO_START_MAX_BACKOFF_MS,
3    DAEMON_AUTO_START_MAX_WAIT_MS, DAEMON_IDLE_SHUTDOWN_SECS, DAEMON_PING_TIMEOUT_MS,
4    DAEMON_SPAWN_BACKOFF_BASE_MS, DAEMON_SPAWN_LOCK_WAIT_MS, SQLITE_GRAPHRAG_VERSION,
5};
6use crate::errors::AppError;
7use crate::{embedder, shutdown_requested};
8use fs4::fs_std::FileExt;
9use interprocess::local_socket::{
10    prelude::LocalSocketStream,
11    traits::{Listener as _, Stream as _},
12    GenericFilePath, GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToFsName,
13    ToNsName,
14};
15use serde::{Deserialize, Serialize};
16use std::fs::{File, OpenOptions};
17use std::io::{BufRead, BufReader, Write};
18use std::path::{Path, PathBuf};
19use std::process::Stdio;
20use std::thread;
21use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
22
23#[derive(Debug, Serialize, Deserialize)]
24#[serde(tag = "request", rename_all = "snake_case")]
25pub enum DaemonRequest {
26    Ping,
27    Shutdown,
28    EmbedPassage {
29        text: String,
30    },
31    EmbedQuery {
32        text: String,
33    },
34    EmbedPassages {
35        texts: Vec<String>,
36        token_counts: Vec<usize>,
37    },
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41#[serde(tag = "status", rename_all = "snake_case")]
42pub enum DaemonResponse {
43    Listening {
44        pid: u32,
45        socket: String,
46        idle_shutdown_secs: u64,
47    },
48    Ok {
49        pid: u32,
50        version: String,
51        handled_embed_requests: u64,
52    },
53    PassageEmbedding {
54        embedding: Vec<f32>,
55        handled_embed_requests: u64,
56    },
57    QueryEmbedding {
58        embedding: Vec<f32>,
59        handled_embed_requests: u64,
60    },
61    PassageEmbeddings {
62        embeddings: Vec<Vec<f32>>,
63        handled_embed_requests: u64,
64    },
65    ShuttingDown {
66        handled_embed_requests: u64,
67    },
68    Error {
69        message: String,
70    },
71}
72
73#[derive(Debug, Default, Serialize, Deserialize)]
74struct DaemonSpawnState {
75    consecutive_failures: u32,
76    not_before_epoch_ms: u64,
77    last_error: Option<String>,
78}
79
80pub fn daemon_label(models_dir: &Path) -> String {
81    let hash = blake3::hash(models_dir.to_string_lossy().as_bytes())
82        .to_hex()
83        .to_string();
84    format!("sqlite-graphrag-daemon-{}", &hash[..16])
85}
86
87pub fn try_ping(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
88    request_if_available(models_dir, &DaemonRequest::Ping)
89}
90
91pub fn try_shutdown(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
92    request_if_available(models_dir, &DaemonRequest::Shutdown)
93}
94
95pub fn embed_passage_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
96    match request_or_autostart(
97        models_dir,
98        &DaemonRequest::EmbedPassage {
99            text: text.to_string(),
100        },
101    )? {
102        Some(DaemonResponse::PassageEmbedding { embedding, .. }) => Ok(embedding),
103        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
104        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
105            "unexpected daemon response for passage embedding: {other:?}"
106        ))),
107        None => {
108            let embedder = embedder::get_embedder(models_dir)?;
109            embedder::embed_passage(embedder, text)
110        }
111    }
112}
113
114pub fn embed_query_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
115    match request_or_autostart(
116        models_dir,
117        &DaemonRequest::EmbedQuery {
118            text: text.to_string(),
119        },
120    )? {
121        Some(DaemonResponse::QueryEmbedding { embedding, .. }) => Ok(embedding),
122        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
123        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
124            "unexpected daemon response for query embedding: {other:?}"
125        ))),
126        None => {
127            let embedder = embedder::get_embedder(models_dir)?;
128            embedder::embed_query(embedder, text)
129        }
130    }
131}
132
133pub fn embed_passages_controlled_or_local(
134    models_dir: &Path,
135    texts: &[&str],
136    token_counts: &[usize],
137) -> Result<Vec<Vec<f32>>, AppError> {
138    let request = DaemonRequest::EmbedPassages {
139        texts: texts.iter().map(|t| (*t).to_string()).collect(),
140        token_counts: token_counts.to_vec(),
141    };
142
143    match request_or_autostart(models_dir, &request)? {
144        Some(DaemonResponse::PassageEmbeddings { embeddings, .. }) => Ok(embeddings),
145        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
146        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
147            "unexpected daemon response for batch passage embeddings: {other:?}"
148        ))),
149        None => {
150            let embedder = embedder::get_embedder(models_dir)?;
151            embedder::embed_passages_controlled(embedder, texts, token_counts)
152        }
153    }
154}
155
156pub fn run(models_dir: &Path, idle_shutdown_secs: u64) -> Result<(), AppError> {
157    let socket = daemon_label(models_dir);
158    let name = to_local_socket_name(&socket)?;
159    let listener = ListenerOptions::new()
160        .name(name)
161        .nonblocking(ListenerNonblockingMode::Accept)
162        .try_overwrite(true)
163        .create_sync()
164        .map_err(AppError::Io)?;
165
166    // Warm the model once per daemon process.
167    let _ = embedder::get_embedder(models_dir)?;
168
169    crate::output::emit_json(&DaemonResponse::Listening {
170        pid: std::process::id(),
171        socket,
172        idle_shutdown_secs,
173    })?;
174
175    let mut handled_embed_requests = 0_u64;
176    let mut last_activity = Instant::now();
177
178    loop {
179        if shutdown_requested() {
180            break;
181        }
182
183        match listener.accept() {
184            Ok(stream) => {
185                last_activity = Instant::now();
186                let should_exit = handle_client(stream, models_dir, &mut handled_embed_requests)?;
187                if should_exit {
188                    break;
189                }
190            }
191            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
192                if last_activity.elapsed() >= Duration::from_secs(idle_shutdown_secs) {
193                    tracing::info!(
194                        idle_shutdown_secs,
195                        handled_embed_requests,
196                        "daemon idle timeout reached"
197                    );
198                    break;
199                }
200                thread::sleep(Duration::from_millis(50));
201            }
202            Err(err) => return Err(AppError::Io(err)),
203        }
204    }
205
206    Ok(())
207}
208
209fn handle_client(
210    stream: LocalSocketStream,
211    models_dir: &Path,
212    handled_embed_requests: &mut u64,
213) -> Result<bool, AppError> {
214    let mut reader = BufReader::new(stream);
215    let mut line = String::new();
216    reader.read_line(&mut line).map_err(AppError::Io)?;
217
218    if line.trim().is_empty() {
219        write_response(
220            reader.get_mut(),
221            &DaemonResponse::Error {
222                message: "empty daemon request".to_string(),
223            },
224        )?;
225        return Ok(false);
226    }
227
228    let request: DaemonRequest = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
229    let (response, should_exit) = match request {
230        DaemonRequest::Ping => (
231            DaemonResponse::Ok {
232                pid: std::process::id(),
233                version: SQLITE_GRAPHRAG_VERSION.to_string(),
234                handled_embed_requests: *handled_embed_requests,
235            },
236            false,
237        ),
238        DaemonRequest::Shutdown => (
239            DaemonResponse::ShuttingDown {
240                handled_embed_requests: *handled_embed_requests,
241            },
242            true,
243        ),
244        DaemonRequest::EmbedPassage { text } => {
245            let embedder = embedder::get_embedder(models_dir)?;
246            let embedding = embedder::embed_passage(embedder, &text)?;
247            *handled_embed_requests += 1;
248            (
249                DaemonResponse::PassageEmbedding {
250                    embedding,
251                    handled_embed_requests: *handled_embed_requests,
252                },
253                false,
254            )
255        }
256        DaemonRequest::EmbedQuery { text } => {
257            let embedder = embedder::get_embedder(models_dir)?;
258            let embedding = embedder::embed_query(embedder, &text)?;
259            *handled_embed_requests += 1;
260            (
261                DaemonResponse::QueryEmbedding {
262                    embedding,
263                    handled_embed_requests: *handled_embed_requests,
264                },
265                false,
266            )
267        }
268        DaemonRequest::EmbedPassages {
269            texts,
270            token_counts,
271        } => {
272            let embedder = embedder::get_embedder(models_dir)?;
273            let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
274            let embeddings =
275                embedder::embed_passages_controlled(embedder, &text_refs, &token_counts)?;
276            *handled_embed_requests += 1;
277            (
278                DaemonResponse::PassageEmbeddings {
279                    embeddings,
280                    handled_embed_requests: *handled_embed_requests,
281                },
282                false,
283            )
284        }
285    };
286
287    write_response(reader.get_mut(), &response)?;
288    Ok(should_exit)
289}
290
291fn write_response(
292    stream: &mut LocalSocketStream,
293    response: &DaemonResponse,
294) -> Result<(), AppError> {
295    serde_json::to_writer(&mut *stream, response).map_err(AppError::Json)?;
296    stream.write_all(b"\n").map_err(AppError::Io)?;
297    stream.flush().map_err(AppError::Io)?;
298    Ok(())
299}
300
301fn request_if_available(
302    models_dir: &Path,
303    request: &DaemonRequest,
304) -> Result<Option<DaemonResponse>, AppError> {
305    let socket = daemon_label(models_dir);
306    let name = match to_local_socket_name(&socket) {
307        Ok(name) => name,
308        Err(err) => return Err(AppError::Io(err)),
309    };
310
311    let mut stream = match LocalSocketStream::connect(name) {
312        Ok(stream) => stream,
313        Err(err)
314            if matches!(
315                err.kind(),
316                std::io::ErrorKind::NotFound
317                    | std::io::ErrorKind::ConnectionRefused
318                    | std::io::ErrorKind::AddrNotAvailable
319                    | std::io::ErrorKind::TimedOut
320            ) =>
321        {
322            return Ok(None);
323        }
324        Err(err) => return Err(AppError::Io(err)),
325    };
326
327    serde_json::to_writer(&mut stream, request).map_err(AppError::Json)?;
328    stream.write_all(b"\n").map_err(AppError::Io)?;
329    stream.flush().map_err(AppError::Io)?;
330
331    let mut reader = BufReader::new(stream);
332    let mut line = String::new();
333    reader.read_line(&mut line).map_err(AppError::Io)?;
334    if line.trim().is_empty() {
335        return Err(AppError::Embedding("daemon returned empty response".into()));
336    }
337
338    let response = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
339    Ok(Some(response))
340}
341
342fn request_or_autostart(
343    models_dir: &Path,
344    request: &DaemonRequest,
345) -> Result<Option<DaemonResponse>, AppError> {
346    if let Some(response) = request_if_available(models_dir, request)? {
347        clear_spawn_backoff_state(models_dir).ok();
348        return Ok(Some(response));
349    }
350
351    if autostart_disabled() {
352        return Ok(None);
353    }
354
355    if !ensure_daemon_running(models_dir)? {
356        return Ok(None);
357    }
358
359    request_if_available(models_dir, request)
360}
361
362fn ensure_daemon_running(models_dir: &Path) -> Result<bool, AppError> {
363    if let Some(_) = try_ping(models_dir)? {
364        clear_spawn_backoff_state(models_dir).ok();
365        return Ok(true);
366    }
367
368    if spawn_backoff_active(models_dir)? {
369        tracing::warn!("daemon autostart suppressed by backoff window");
370        return Ok(false);
371    }
372
373    let spawn_lock = match try_acquire_spawn_lock(models_dir)? {
374        Some(lock) => lock,
375        None => return wait_for_daemon_ready(models_dir),
376    };
377
378    if let Some(_) = try_ping(models_dir)? {
379        clear_spawn_backoff_state(models_dir).ok();
380        drop(spawn_lock);
381        return Ok(true);
382    }
383
384    let exe = match std::env::current_exe() {
385        Ok(path) => path,
386        Err(err) => {
387            record_spawn_failure(models_dir, format!("current_exe failed: {err}"))?;
388            drop(spawn_lock);
389            return Ok(false);
390        }
391    };
392
393    let mut child = std::process::Command::new(exe);
394    child
395        .arg("daemon")
396        .arg("--idle-shutdown-secs")
397        .arg(DAEMON_IDLE_SHUTDOWN_SECS.to_string())
398        .env("SQLITE_GRAPHRAG_DAEMON_CHILD", "1")
399        .stdin(Stdio::null())
400        .stdout(Stdio::null())
401        .stderr(Stdio::null());
402
403    match child.spawn() {
404        Ok(_) => {
405            let ready = wait_for_daemon_ready(models_dir)?;
406            if ready {
407                clear_spawn_backoff_state(models_dir).ok();
408            } else {
409                record_spawn_failure(
410                    models_dir,
411                    "daemon did not become healthy after autostart".to_string(),
412                )?;
413            }
414            drop(spawn_lock);
415            Ok(ready)
416        }
417        Err(err) => {
418            record_spawn_failure(models_dir, format!("daemon spawn failed: {err}"))?;
419            drop(spawn_lock);
420            Ok(false)
421        }
422    }
423}
424
425fn wait_for_daemon_ready(models_dir: &Path) -> Result<bool, AppError> {
426    let deadline = Instant::now() + Duration::from_millis(DAEMON_AUTO_START_MAX_WAIT_MS);
427    let mut sleep_ms = DAEMON_AUTO_START_INITIAL_BACKOFF_MS.max(DAEMON_PING_TIMEOUT_MS);
428
429    while Instant::now() < deadline {
430        if let Some(_) = try_ping(models_dir)? {
431            return Ok(true);
432        }
433        thread::sleep(Duration::from_millis(sleep_ms));
434        sleep_ms = (sleep_ms * 2).min(DAEMON_AUTO_START_MAX_BACKOFF_MS);
435    }
436
437    Ok(false)
438}
439
440fn autostart_disabled() -> bool {
441    std::env::var("SQLITE_GRAPHRAG_DAEMON_CHILD").as_deref() == Ok("1")
442        || std::env::var("SQLITE_GRAPHRAG_DAEMON_DISABLE_AUTOSTART").as_deref() == Ok("1")
443}
444
445fn daemon_control_dir(models_dir: &Path) -> PathBuf {
446    models_dir
447        .parent()
448        .map(Path::to_path_buf)
449        .unwrap_or_else(|| models_dir.to_path_buf())
450}
451
452fn spawn_lock_path(models_dir: &Path) -> PathBuf {
453    daemon_control_dir(models_dir).join("daemon-spawn.lock")
454}
455
456fn spawn_state_path(models_dir: &Path) -> PathBuf {
457    daemon_control_dir(models_dir).join("daemon-spawn-state.json")
458}
459
460fn try_acquire_spawn_lock(models_dir: &Path) -> Result<Option<File>, AppError> {
461    let path = spawn_lock_path(models_dir);
462    std::fs::create_dir_all(path.parent().unwrap()).map_err(AppError::Io)?;
463    let file = OpenOptions::new()
464        .read(true)
465        .write(true)
466        .create(true)
467        .truncate(false)
468        .open(path)
469        .map_err(AppError::Io)?;
470
471    let deadline = Instant::now() + Duration::from_millis(DAEMON_SPAWN_LOCK_WAIT_MS);
472    loop {
473        match file.try_lock_exclusive() {
474            Ok(()) => return Ok(Some(file)),
475            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
476                if Instant::now() >= deadline {
477                    return Ok(None);
478                }
479                thread::sleep(Duration::from_millis(50));
480            }
481            Err(err) => return Err(AppError::Io(err)),
482        }
483    }
484}
485
486fn spawn_backoff_active(models_dir: &Path) -> Result<bool, AppError> {
487    let state = load_spawn_state(models_dir)?;
488    Ok(now_epoch_ms() < state.not_before_epoch_ms)
489}
490
491fn record_spawn_failure(models_dir: &Path, message: String) -> Result<(), AppError> {
492    let mut state = load_spawn_state(models_dir)?;
493    state.consecutive_failures = state.consecutive_failures.saturating_add(1);
494    let exponent = state.consecutive_failures.saturating_sub(1).min(6);
495    let backoff_ms =
496        (DAEMON_SPAWN_BACKOFF_BASE_MS * (1_u64 << exponent)).min(DAEMON_AUTO_START_MAX_BACKOFF_MS);
497    state.not_before_epoch_ms = now_epoch_ms() + backoff_ms;
498    state.last_error = Some(message);
499    save_spawn_state(models_dir, &state)
500}
501
502fn clear_spawn_backoff_state(models_dir: &Path) -> Result<(), AppError> {
503    let path = spawn_state_path(models_dir);
504    if path.exists() {
505        std::fs::remove_file(path).map_err(AppError::Io)?;
506    }
507    Ok(())
508}
509
510fn load_spawn_state(models_dir: &Path) -> Result<DaemonSpawnState, AppError> {
511    let path = spawn_state_path(models_dir);
512    if !path.exists() {
513        return Ok(DaemonSpawnState::default());
514    }
515
516    let bytes = std::fs::read(path).map_err(AppError::Io)?;
517    serde_json::from_slice(&bytes).map_err(AppError::Json)
518}
519
520fn save_spawn_state(models_dir: &Path, state: &DaemonSpawnState) -> Result<(), AppError> {
521    let path = spawn_state_path(models_dir);
522    std::fs::create_dir_all(path.parent().unwrap()).map_err(AppError::Io)?;
523    let bytes = serde_json::to_vec(state).map_err(AppError::Json)?;
524    std::fs::write(path, bytes).map_err(AppError::Io)
525}
526
527fn now_epoch_ms() -> u64 {
528    SystemTime::now()
529        .duration_since(UNIX_EPOCH)
530        .unwrap_or_else(|_| Duration::from_secs(0))
531        .as_millis() as u64
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn record_and_clear_spawn_backoff_state() {
540        let tmp = tempfile::tempdir().unwrap();
541        let models_dir = tmp.path().join("cache").join("models");
542        std::fs::create_dir_all(&models_dir).unwrap();
543
544        assert!(!spawn_backoff_active(&models_dir).unwrap());
545
546        record_spawn_failure(&models_dir, "spawn failed".to_string()).unwrap();
547        assert!(spawn_backoff_active(&models_dir).unwrap());
548
549        let state = load_spawn_state(&models_dir).unwrap();
550        assert_eq!(state.consecutive_failures, 1);
551        assert_eq!(state.last_error.as_deref(), Some("spawn failed"));
552
553        clear_spawn_backoff_state(&models_dir).unwrap();
554        assert!(!spawn_backoff_active(&models_dir).unwrap());
555    }
556
557    #[test]
558    fn daemon_control_dir_usa_pai_de_models() {
559        let base = PathBuf::from("/tmp/sqlite-graphrag-cache-test");
560        let models_dir = base.join("models");
561        assert_eq!(daemon_control_dir(&models_dir), base);
562    }
563}
564
565fn to_local_socket_name(name: &str) -> std::io::Result<interprocess::local_socket::Name<'static>> {
566    if let Ok(ns_name) = name.to_string().to_ns_name::<GenericNamespaced>() {
567        return Ok(ns_name);
568    }
569
570    let path = if cfg!(unix) {
571        format!("/tmp/{name}.sock")
572    } else {
573        format!(r"\\.\pipe\{name}")
574    };
575    path.to_fs_name::<GenericFilePath>()
576}