Skip to main content

sqlite_graphrag/
daemon.rs

1use crate::constants::SQLITE_GRAPHRAG_VERSION;
2use crate::errors::AppError;
3use crate::{embedder, shutdown_requested};
4use interprocess::local_socket::{
5    prelude::LocalSocketStream,
6    traits::{Listener as _, Stream as _},
7    GenericFilePath, GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToFsName,
8    ToNsName,
9};
10use serde::{Deserialize, Serialize};
11use std::io::{BufRead, BufReader, Write};
12use std::path::Path;
13use std::thread;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Serialize, Deserialize)]
17#[serde(tag = "request", rename_all = "snake_case")]
18pub enum DaemonRequest {
19    Ping,
20    Shutdown,
21    EmbedPassage {
22        text: String,
23    },
24    EmbedQuery {
25        text: String,
26    },
27    EmbedPassages {
28        texts: Vec<String>,
29        token_counts: Vec<usize>,
30    },
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34#[serde(tag = "status", rename_all = "snake_case")]
35pub enum DaemonResponse {
36    Listening {
37        pid: u32,
38        socket: String,
39        idle_shutdown_secs: u64,
40    },
41    Ok {
42        pid: u32,
43        version: String,
44        handled_embed_requests: u64,
45    },
46    PassageEmbedding {
47        embedding: Vec<f32>,
48        handled_embed_requests: u64,
49    },
50    QueryEmbedding {
51        embedding: Vec<f32>,
52        handled_embed_requests: u64,
53    },
54    PassageEmbeddings {
55        embeddings: Vec<Vec<f32>>,
56        handled_embed_requests: u64,
57    },
58    ShuttingDown {
59        handled_embed_requests: u64,
60    },
61    Error {
62        message: String,
63    },
64}
65
66pub fn daemon_label(models_dir: &Path) -> String {
67    let hash = blake3::hash(models_dir.to_string_lossy().as_bytes())
68        .to_hex()
69        .to_string();
70    format!("sqlite-graphrag-daemon-{}", &hash[..16])
71}
72
73pub fn try_ping(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
74    request_if_available(models_dir, &DaemonRequest::Ping)
75}
76
77pub fn try_shutdown(models_dir: &Path) -> Result<Option<DaemonResponse>, AppError> {
78    request_if_available(models_dir, &DaemonRequest::Shutdown)
79}
80
81pub fn embed_passage_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
82    match request_if_available(
83        models_dir,
84        &DaemonRequest::EmbedPassage {
85            text: text.to_string(),
86        },
87    )? {
88        Some(DaemonResponse::PassageEmbedding { embedding, .. }) => Ok(embedding),
89        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
90        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
91            "unexpected daemon response for passage embedding: {other:?}"
92        ))),
93        None => {
94            let embedder = embedder::get_embedder(models_dir)?;
95            embedder::embed_passage(embedder, text)
96        }
97    }
98}
99
100pub fn embed_query_or_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
101    match request_if_available(
102        models_dir,
103        &DaemonRequest::EmbedQuery {
104            text: text.to_string(),
105        },
106    )? {
107        Some(DaemonResponse::QueryEmbedding { embedding, .. }) => Ok(embedding),
108        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
109        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
110            "unexpected daemon response for query embedding: {other:?}"
111        ))),
112        None => {
113            let embedder = embedder::get_embedder(models_dir)?;
114            embedder::embed_query(embedder, text)
115        }
116    }
117}
118
119pub fn embed_passages_controlled_or_local(
120    models_dir: &Path,
121    texts: &[&str],
122    token_counts: &[usize],
123) -> Result<Vec<Vec<f32>>, AppError> {
124    let request = DaemonRequest::EmbedPassages {
125        texts: texts.iter().map(|t| (*t).to_string()).collect(),
126        token_counts: token_counts.to_vec(),
127    };
128
129    match request_if_available(models_dir, &request)? {
130        Some(DaemonResponse::PassageEmbeddings { embeddings, .. }) => Ok(embeddings),
131        Some(DaemonResponse::Error { message }) => Err(AppError::Embedding(message)),
132        Some(other) => Err(AppError::Internal(anyhow::anyhow!(
133            "unexpected daemon response for batch passage embeddings: {other:?}"
134        ))),
135        None => {
136            let embedder = embedder::get_embedder(models_dir)?;
137            embedder::embed_passages_controlled(embedder, texts, token_counts)
138        }
139    }
140}
141
142pub fn run(models_dir: &Path, idle_shutdown_secs: u64) -> Result<(), AppError> {
143    let socket = daemon_label(models_dir);
144    let name = to_local_socket_name(&socket)?;
145    let listener = ListenerOptions::new()
146        .name(name)
147        .nonblocking(ListenerNonblockingMode::Accept)
148        .try_overwrite(true)
149        .create_sync()
150        .map_err(AppError::Io)?;
151
152    // Warm the model once per daemon process.
153    let _ = embedder::get_embedder(models_dir)?;
154
155    crate::output::emit_json(&DaemonResponse::Listening {
156        pid: std::process::id(),
157        socket,
158        idle_shutdown_secs,
159    })?;
160
161    let mut handled_embed_requests = 0_u64;
162    let mut last_activity = Instant::now();
163
164    loop {
165        if shutdown_requested() {
166            break;
167        }
168
169        match listener.accept() {
170            Ok(stream) => {
171                last_activity = Instant::now();
172                let should_exit = handle_client(stream, models_dir, &mut handled_embed_requests)?;
173                if should_exit {
174                    break;
175                }
176            }
177            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
178                if last_activity.elapsed() >= Duration::from_secs(idle_shutdown_secs) {
179                    tracing::info!(
180                        idle_shutdown_secs,
181                        handled_embed_requests,
182                        "daemon idle timeout reached"
183                    );
184                    break;
185                }
186                thread::sleep(Duration::from_millis(50));
187            }
188            Err(err) => return Err(AppError::Io(err)),
189        }
190    }
191
192    Ok(())
193}
194
195fn handle_client(
196    stream: LocalSocketStream,
197    models_dir: &Path,
198    handled_embed_requests: &mut u64,
199) -> Result<bool, AppError> {
200    let mut reader = BufReader::new(stream);
201    let mut line = String::new();
202    reader.read_line(&mut line).map_err(AppError::Io)?;
203
204    if line.trim().is_empty() {
205        write_response(
206            reader.get_mut(),
207            &DaemonResponse::Error {
208                message: "empty daemon request".to_string(),
209            },
210        )?;
211        return Ok(false);
212    }
213
214    let request: DaemonRequest = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
215    let (response, should_exit) = match request {
216        DaemonRequest::Ping => (
217            DaemonResponse::Ok {
218                pid: std::process::id(),
219                version: SQLITE_GRAPHRAG_VERSION.to_string(),
220                handled_embed_requests: *handled_embed_requests,
221            },
222            false,
223        ),
224        DaemonRequest::Shutdown => (
225            DaemonResponse::ShuttingDown {
226                handled_embed_requests: *handled_embed_requests,
227            },
228            true,
229        ),
230        DaemonRequest::EmbedPassage { text } => {
231            let embedder = embedder::get_embedder(models_dir)?;
232            let embedding = embedder::embed_passage(embedder, &text)?;
233            *handled_embed_requests += 1;
234            (
235                DaemonResponse::PassageEmbedding {
236                    embedding,
237                    handled_embed_requests: *handled_embed_requests,
238                },
239                false,
240            )
241        }
242        DaemonRequest::EmbedQuery { text } => {
243            let embedder = embedder::get_embedder(models_dir)?;
244            let embedding = embedder::embed_query(embedder, &text)?;
245            *handled_embed_requests += 1;
246            (
247                DaemonResponse::QueryEmbedding {
248                    embedding,
249                    handled_embed_requests: *handled_embed_requests,
250                },
251                false,
252            )
253        }
254        DaemonRequest::EmbedPassages {
255            texts,
256            token_counts,
257        } => {
258            let embedder = embedder::get_embedder(models_dir)?;
259            let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
260            let embeddings =
261                embedder::embed_passages_controlled(embedder, &text_refs, &token_counts)?;
262            *handled_embed_requests += 1;
263            (
264                DaemonResponse::PassageEmbeddings {
265                    embeddings,
266                    handled_embed_requests: *handled_embed_requests,
267                },
268                false,
269            )
270        }
271    };
272
273    write_response(reader.get_mut(), &response)?;
274    Ok(should_exit)
275}
276
277fn write_response(
278    stream: &mut LocalSocketStream,
279    response: &DaemonResponse,
280) -> Result<(), AppError> {
281    serde_json::to_writer(&mut *stream, response).map_err(AppError::Json)?;
282    stream.write_all(b"\n").map_err(AppError::Io)?;
283    stream.flush().map_err(AppError::Io)?;
284    Ok(())
285}
286
287fn request_if_available(
288    models_dir: &Path,
289    request: &DaemonRequest,
290) -> Result<Option<DaemonResponse>, AppError> {
291    let socket = daemon_label(models_dir);
292    let name = match to_local_socket_name(&socket) {
293        Ok(name) => name,
294        Err(err) => return Err(AppError::Io(err)),
295    };
296
297    let mut stream = match LocalSocketStream::connect(name) {
298        Ok(stream) => stream,
299        Err(err)
300            if matches!(
301                err.kind(),
302                std::io::ErrorKind::NotFound
303                    | std::io::ErrorKind::ConnectionRefused
304                    | std::io::ErrorKind::AddrNotAvailable
305                    | std::io::ErrorKind::TimedOut
306            ) =>
307        {
308            return Ok(None);
309        }
310        Err(err) => return Err(AppError::Io(err)),
311    };
312
313    serde_json::to_writer(&mut stream, request).map_err(AppError::Json)?;
314    stream.write_all(b"\n").map_err(AppError::Io)?;
315    stream.flush().map_err(AppError::Io)?;
316
317    let mut reader = BufReader::new(stream);
318    let mut line = String::new();
319    reader.read_line(&mut line).map_err(AppError::Io)?;
320    if line.trim().is_empty() {
321        return Err(AppError::Embedding("daemon returned empty response".into()));
322    }
323
324    let response = serde_json::from_str(line.trim()).map_err(AppError::Json)?;
325    Ok(Some(response))
326}
327
328fn to_local_socket_name(name: &str) -> std::io::Result<interprocess::local_socket::Name<'static>> {
329    if let Ok(ns_name) = name.to_string().to_ns_name::<GenericNamespaced>() {
330        return Ok(ns_name);
331    }
332
333    let path = if cfg!(unix) {
334        format!("/tmp/{name}.sock")
335    } else {
336        format!(r"\\.\pipe\{name}")
337    };
338    path.to_fs_name::<GenericFilePath>()
339}