Skip to main content

sifs/daemon/
protocol.rs

1use crate::index::CacheConfig;
2use crate::model2vec::{EncoderSpec, ModelLoadPolicy, ModelOptions};
3use crate::types::{Chunk, IndexStats, IndexWarning, SearchMode, SearchOptions, SearchResult};
4use crate::utils::is_git_url;
5use anyhow::{Context, Result, bail};
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10pub const DAEMON_PROTOCOL_VERSION: u32 = 1;
11
12pub fn daemon_version() -> &'static str {
13    env!("CARGO_PKG_VERSION")
14}
15
16#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum SourceKind {
19    LocalPath,
20    GitUrl,
21}
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct SourceSpec {
25    pub kind: SourceKind,
26    pub source: String,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub ref_name: Option<String>,
29}
30
31impl SourceSpec {
32    pub fn resolve(
33        source: impl AsRef<str>,
34        ref_name: Option<String>,
35        offline: bool,
36    ) -> Result<Self> {
37        let source = source.as_ref();
38        if is_git_url(source) {
39            if offline {
40                bail!("--offline does not allow remote Git sources");
41            }
42            return Ok(Self {
43                kind: SourceKind::GitUrl,
44                source: source.to_owned(),
45                ref_name,
46            });
47        }
48
49        let path = PathBuf::from(source);
50        if !path.exists() {
51            bail!("local source does not exist: {}", path.display());
52        }
53        if !path.is_dir() {
54            bail!("local source is not a directory: {}", path.display());
55        }
56        Ok(Self {
57            kind: SourceKind::LocalPath,
58            source: path
59                .canonicalize()
60                .with_context(|| format!("canonicalize source {}", path.display()))?
61                .to_string_lossy()
62                .into_owned(),
63            ref_name: None,
64        })
65    }
66
67    pub fn current_dir(offline: bool) -> Result<Self> {
68        let cwd = std::env::current_dir().context("resolve current directory")?;
69        Self::resolve(cwd.to_string_lossy(), None, offline)
70    }
71
72    pub fn cache_key(&self) -> String {
73        match (&self.kind, &self.ref_name) {
74            (SourceKind::LocalPath, _) => format!("path:{}", self.source),
75            (SourceKind::GitUrl, Some(ref_name)) => format!("git:{}@{}", self.source, ref_name),
76            (SourceKind::GitUrl, None) => format!("git:{}", self.source),
77        }
78    }
79
80    pub fn display(&self) -> String {
81        match &self.ref_name {
82            Some(ref_name) => format!("{}@{}", self.source, ref_name),
83            None => self.source.clone(),
84        }
85    }
86
87    pub fn as_path(&self) -> Option<&Path> {
88        matches!(self.kind, SourceKind::LocalPath).then(|| Path::new(&self.source))
89    }
90}
91
92#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum CacheConfigSpec {
95    Platform,
96    Project,
97    Custom { path: PathBuf },
98    Disabled,
99}
100
101impl From<&CacheConfig> for CacheConfigSpec {
102    fn from(value: &CacheConfig) -> Self {
103        match value {
104            CacheConfig::Platform => Self::Platform,
105            CacheConfig::Project => Self::Project,
106            CacheConfig::Custom(path) => Self::Custom { path: path.clone() },
107            CacheConfig::Disabled => Self::Disabled,
108        }
109    }
110}
111
112impl From<CacheConfigSpec> for CacheConfig {
113    fn from(value: CacheConfigSpec) -> Self {
114        match value {
115            CacheConfigSpec::Platform => Self::Platform,
116            CacheConfigSpec::Project => Self::Project,
117            CacheConfigSpec::Custom { path } => Self::Custom(path),
118            CacheConfigSpec::Disabled => Self::Disabled,
119        }
120    }
121}
122
123#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
124#[serde(rename_all = "snake_case")]
125pub enum EncoderSpecWire {
126    Model2Vec {
127        model: String,
128        policy: ModelLoadPolicyWire,
129    },
130    Hashing {
131        dim: usize,
132    },
133    Sparse,
134}
135
136impl EncoderSpecWire {
137    pub fn from_encoder_spec(spec: Option<&EncoderSpec>) -> Self {
138        match spec {
139            Some(EncoderSpec::Model2Vec(options)) => Self::Model2Vec {
140                model: options.model.clone(),
141                policy: ModelLoadPolicyWire::from(options.policy),
142            },
143            Some(EncoderSpec::Hashing { dim }) => Self::Hashing { dim: *dim },
144            None => Self::Sparse,
145        }
146    }
147
148    pub fn into_encoder_spec(self) -> Option<EncoderSpec> {
149        match self {
150            Self::Model2Vec { model, policy } => Some(EncoderSpec::Model2Vec(ModelOptions {
151                model,
152                policy: policy.into(),
153            })),
154            Self::Hashing { dim } => Some(EncoderSpec::Hashing { dim }),
155            Self::Sparse => None,
156        }
157    }
158}
159
160#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case")]
162pub enum ModelLoadPolicyWire {
163    AllowDownload,
164    NoDownload,
165    Offline,
166}
167
168impl From<ModelLoadPolicy> for ModelLoadPolicyWire {
169    fn from(value: ModelLoadPolicy) -> Self {
170        match value {
171            ModelLoadPolicy::AllowDownload => Self::AllowDownload,
172            ModelLoadPolicy::NoDownload => Self::NoDownload,
173            ModelLoadPolicy::Offline => Self::Offline,
174        }
175    }
176}
177
178impl From<ModelLoadPolicyWire> for ModelLoadPolicy {
179    fn from(value: ModelLoadPolicyWire) -> Self {
180        match value {
181            ModelLoadPolicyWire::AllowDownload => Self::AllowDownload,
182            ModelLoadPolicyWire::NoDownload => Self::NoDownload,
183            ModelLoadPolicyWire::Offline => Self::Offline,
184        }
185    }
186}
187
188#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
189pub struct IndexRuntimeOptions {
190    pub encoder: EncoderSpecWire,
191    pub cache: CacheConfigSpec,
192    pub extensions: Option<Vec<String>>,
193    pub ignore: Option<Vec<String>>,
194    pub include_text_files: bool,
195}
196
197impl Default for IndexRuntimeOptions {
198    fn default() -> Self {
199        Self {
200            encoder: EncoderSpecWire::Model2Vec {
201                model: ModelOptions::default().model,
202                policy: ModelLoadPolicyWire::AllowDownload,
203            },
204            cache: CacheConfigSpec::Platform,
205            extensions: None,
206            ignore: None,
207            include_text_files: false,
208        }
209    }
210}
211
212impl IndexRuntimeOptions {
213    pub fn sparse(cache: CacheConfig) -> Self {
214        Self {
215            encoder: EncoderSpecWire::Sparse,
216            cache: CacheConfigSpec::from(&cache),
217            extensions: None,
218            ignore: None,
219            include_text_files: false,
220        }
221    }
222
223    pub fn with_encoder(encoder: EncoderSpec, cache: CacheConfig) -> Self {
224        Self {
225            encoder: EncoderSpecWire::from_encoder_spec(Some(&encoder)),
226            cache: CacheConfigSpec::from(&cache),
227            extensions: None,
228            ignore: None,
229            include_text_files: false,
230        }
231    }
232
233    pub fn extensions_set(&self) -> Option<HashSet<String>> {
234        self.extensions
235            .as_ref()
236            .map(|items| items.iter().cloned().collect())
237    }
238
239    pub fn ignore_set(&self) -> Option<HashSet<String>> {
240        self.ignore
241            .as_ref()
242            .map(|items| items.iter().cloned().collect())
243    }
244}
245
246#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
247pub struct IndexIdentity {
248    pub source: SourceSpec,
249    pub encoder_key: String,
250    pub cache_key: String,
251    pub extensions: Option<Vec<String>>,
252    pub ignore: Option<Vec<String>>,
253    pub include_text_files: bool,
254}
255
256impl IndexIdentity {
257    pub fn new(source: SourceSpec, options: &IndexRuntimeOptions) -> Self {
258        Self {
259            source,
260            encoder_key: encoder_key(&options.encoder),
261            cache_key: cache_key(&options.cache),
262            extensions: normalized_vec(options.extensions.clone()),
263            ignore: normalized_vec(options.ignore.clone()),
264            include_text_files: options.include_text_files,
265        }
266    }
267
268    pub fn key(&self) -> String {
269        serde_json::to_string(self).expect("index identity is serializable")
270    }
271}
272
273fn encoder_key(encoder: &EncoderSpecWire) -> String {
274    match encoder {
275        EncoderSpecWire::Model2Vec { model, policy } => {
276            format!("model2vec:{model}:{}", policy_key(*policy))
277        }
278        EncoderSpecWire::Hashing { dim } => format!("hashing:{dim}"),
279        EncoderSpecWire::Sparse => "sparse".to_owned(),
280    }
281}
282
283fn policy_key(policy: ModelLoadPolicyWire) -> &'static str {
284    match policy {
285        ModelLoadPolicyWire::AllowDownload => "allow-download",
286        ModelLoadPolicyWire::NoDownload => "no-download",
287        ModelLoadPolicyWire::Offline => "offline",
288    }
289}
290
291fn cache_key(cache: &CacheConfigSpec) -> String {
292    match cache {
293        CacheConfigSpec::Platform => "platform".to_owned(),
294        CacheConfigSpec::Project => "project".to_owned(),
295        CacheConfigSpec::Custom { path } => format!("custom:{}", path.display()),
296        CacheConfigSpec::Disabled => "disabled".to_owned(),
297    }
298}
299
300fn normalized_vec(items: Option<Vec<String>>) -> Option<Vec<String>> {
301    let mut items = items?;
302    items.sort();
303    items.dedup();
304    Some(items)
305}
306
307#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
308pub struct DaemonRequestEnvelope {
309    pub protocol_version: u32,
310    pub request_id: String,
311    pub request: DaemonRequest,
312}
313
314impl DaemonRequestEnvelope {
315    pub fn new(request_id: impl Into<String>, request: DaemonRequest) -> Self {
316        Self {
317            protocol_version: DAEMON_PROTOCOL_VERSION,
318            request_id: request_id.into(),
319            request,
320        }
321    }
322}
323
324#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
325#[serde(tag = "type", rename_all = "snake_case")]
326pub enum DaemonRequest {
327    Ping,
328    Status,
329    IndexStatus {
330        source: SourceSpec,
331        options: IndexRuntimeOptions,
332    },
333    Search {
334        source: SourceSpec,
335        options: IndexRuntimeOptions,
336        query: String,
337        search: SearchOptionsWire,
338    },
339    FindRelated {
340        source: SourceSpec,
341        options: IndexRuntimeOptions,
342        file_path: String,
343        line: usize,
344        top_k: usize,
345    },
346    ListFiles {
347        source: SourceSpec,
348        options: IndexRuntimeOptions,
349        limit: usize,
350    },
351    GetChunk {
352        source: SourceSpec,
353        options: IndexRuntimeOptions,
354        file_path: String,
355        line: usize,
356    },
357    Refresh {
358        source: SourceSpec,
359        options: IndexRuntimeOptions,
360    },
361    Clear {
362        source: SourceSpec,
363        options: IndexRuntimeOptions,
364    },
365}
366
367#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
368pub struct SearchOptionsWire {
369    pub top_k: usize,
370    pub mode: SearchMode,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub alpha: Option<f32>,
373    pub filter_languages: Vec<String>,
374    pub filter_paths: Vec<String>,
375    pub use_query_cache: bool,
376    #[serde(default)]
377    pub explain: bool,
378}
379
380impl From<SearchOptions> for SearchOptionsWire {
381    fn from(value: SearchOptions) -> Self {
382        Self {
383            top_k: value.top_k,
384            mode: value.mode,
385            alpha: value.alpha,
386            filter_languages: value.filter_languages,
387            filter_paths: value.filter_paths,
388            use_query_cache: value.use_query_cache,
389            explain: value.explain,
390        }
391    }
392}
393
394impl From<SearchOptionsWire> for SearchOptions {
395    fn from(value: SearchOptionsWire) -> Self {
396        Self {
397            top_k: value.top_k,
398            mode: value.mode,
399            alpha: value.alpha,
400            filter_languages: value.filter_languages,
401            filter_paths: value.filter_paths,
402            use_query_cache: value.use_query_cache,
403            explain: value.explain,
404        }
405    }
406}
407
408#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
409pub struct DaemonResponseEnvelope {
410    pub protocol_version: u32,
411    pub request_id: String,
412    #[serde(flatten)]
413    pub result: ResultEnvelope,
414}
415
416impl DaemonResponseEnvelope {
417    pub fn ok(request_id: impl Into<String>, result: DaemonResult) -> Self {
418        Self {
419            protocol_version: DAEMON_PROTOCOL_VERSION,
420            request_id: request_id.into(),
421            result: ResultEnvelope::Ok { result },
422        }
423    }
424
425    pub fn error(request_id: impl Into<String>, error: DaemonError) -> Self {
426        Self {
427            protocol_version: DAEMON_PROTOCOL_VERSION,
428            request_id: request_id.into(),
429            result: ResultEnvelope::Error { error },
430        }
431    }
432}
433
434#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
435#[serde(rename_all = "snake_case")]
436pub enum ResultEnvelope {
437    Ok { result: DaemonResult },
438    Error { error: DaemonError },
439}
440
441#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
442#[serde(tag = "type", rename_all = "snake_case")]
443pub enum DaemonResult {
444    Pong {
445        version: String,
446    },
447    Status(DaemonStatus),
448    IndexStatus(IndexStatusResult),
449    Search(SearchResultSet),
450    FindRelated(SearchResultSet),
451    ListFiles {
452        source: SourceSpec,
453        total: usize,
454        files: Vec<String>,
455    },
456    GetChunk {
457        source: SourceSpec,
458        chunk: Chunk,
459    },
460    Refresh(IndexStatusResult),
461    Clear {
462        source: SourceSpec,
463        removed: bool,
464    },
465}
466
467#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
468pub struct DaemonStatus {
469    pub version: String,
470    pub protocol_version: u32,
471    pub pid: u32,
472    pub indexes: Vec<CachedIndexStatus>,
473}
474
475#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
476pub struct CachedIndexStatus {
477    pub source: SourceSpec,
478    pub stats: IndexStats,
479    pub semantic_loaded: bool,
480}
481
482#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
483pub struct IndexStatusResult {
484    pub source: SourceSpec,
485    pub stats: IndexStats,
486    pub semantic_loaded: bool,
487    pub warnings: Vec<IndexWarning>,
488}
489
490#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
491pub struct SearchResultSet {
492    pub source: SourceSpec,
493    pub query: String,
494    pub mode: SearchMode,
495    pub stats: IndexStats,
496    pub elapsed_ms: u64,
497    pub results: Vec<SearchResult>,
498    pub warnings: Vec<IndexWarning>,
499}
500
501#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
502pub struct DaemonError {
503    pub code: String,
504    pub message: String,
505}
506
507impl DaemonError {
508    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
509        Self {
510            code: code.into(),
511            message: message.into(),
512        }
513    }
514}