1use crate::index::CacheConfig;
2use crate::model2vec::{EncoderSpec, ModelLoadPolicy, ModelOptions};
3use crate::types::{Chunk, IndexStats, IndexWarning, SearchMode, SearchOptions, SearchResult};
4use crate::utils::is_git_url;
5use anyhow::{Context, Result, bail};
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::path::{Path, PathBuf};
9
10pub const DAEMON_PROTOCOL_VERSION: u32 = 1;
11
12pub fn daemon_version() -> &'static str {
13 env!("CARGO_PKG_VERSION")
14}
15
16#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "snake_case")]
18pub enum SourceKind {
19 LocalPath,
20 GitUrl,
21}
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct SourceSpec {
25 pub kind: SourceKind,
26 pub source: String,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub ref_name: Option<String>,
29}
30
31impl SourceSpec {
32 pub fn resolve(
33 source: impl AsRef<str>,
34 ref_name: Option<String>,
35 offline: bool,
36 ) -> Result<Self> {
37 let source = source.as_ref();
38 if is_git_url(source) {
39 if offline {
40 bail!("--offline does not allow remote Git sources");
41 }
42 return Ok(Self {
43 kind: SourceKind::GitUrl,
44 source: source.to_owned(),
45 ref_name,
46 });
47 }
48
49 let path = PathBuf::from(source);
50 if !path.exists() {
51 bail!("local source does not exist: {}", path.display());
52 }
53 if !path.is_dir() {
54 bail!("local source is not a directory: {}", path.display());
55 }
56 Ok(Self {
57 kind: SourceKind::LocalPath,
58 source: path
59 .canonicalize()
60 .with_context(|| format!("canonicalize source {}", path.display()))?
61 .to_string_lossy()
62 .into_owned(),
63 ref_name: None,
64 })
65 }
66
67 pub fn current_dir(offline: bool) -> Result<Self> {
68 let cwd = std::env::current_dir().context("resolve current directory")?;
69 Self::resolve(cwd.to_string_lossy(), None, offline)
70 }
71
72 pub fn cache_key(&self) -> String {
73 match (&self.kind, &self.ref_name) {
74 (SourceKind::LocalPath, _) => format!("path:{}", self.source),
75 (SourceKind::GitUrl, Some(ref_name)) => format!("git:{}@{}", self.source, ref_name),
76 (SourceKind::GitUrl, None) => format!("git:{}", self.source),
77 }
78 }
79
80 pub fn display(&self) -> String {
81 match &self.ref_name {
82 Some(ref_name) => format!("{}@{}", self.source, ref_name),
83 None => self.source.clone(),
84 }
85 }
86
87 pub fn as_path(&self) -> Option<&Path> {
88 matches!(self.kind, SourceKind::LocalPath).then(|| Path::new(&self.source))
89 }
90}
91
92#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum CacheConfigSpec {
95 Platform,
96 Project,
97 Custom { path: PathBuf },
98 Disabled,
99}
100
101impl From<&CacheConfig> for CacheConfigSpec {
102 fn from(value: &CacheConfig) -> Self {
103 match value {
104 CacheConfig::Platform => Self::Platform,
105 CacheConfig::Project => Self::Project,
106 CacheConfig::Custom(path) => Self::Custom { path: path.clone() },
107 CacheConfig::Disabled => Self::Disabled,
108 }
109 }
110}
111
112impl From<CacheConfigSpec> for CacheConfig {
113 fn from(value: CacheConfigSpec) -> Self {
114 match value {
115 CacheConfigSpec::Platform => Self::Platform,
116 CacheConfigSpec::Project => Self::Project,
117 CacheConfigSpec::Custom { path } => Self::Custom(path),
118 CacheConfigSpec::Disabled => Self::Disabled,
119 }
120 }
121}
122
123#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
124#[serde(rename_all = "snake_case")]
125pub enum EncoderSpecWire {
126 Model2Vec {
127 model: String,
128 policy: ModelLoadPolicyWire,
129 },
130 Hashing {
131 dim: usize,
132 },
133 Sparse,
134}
135
136impl EncoderSpecWire {
137 pub fn from_encoder_spec(spec: Option<&EncoderSpec>) -> Self {
138 match spec {
139 Some(EncoderSpec::Model2Vec(options)) => Self::Model2Vec {
140 model: options.model.clone(),
141 policy: ModelLoadPolicyWire::from(options.policy),
142 },
143 Some(EncoderSpec::Hashing { dim }) => Self::Hashing { dim: *dim },
144 None => Self::Sparse,
145 }
146 }
147
148 pub fn into_encoder_spec(self) -> Option<EncoderSpec> {
149 match self {
150 Self::Model2Vec { model, policy } => Some(EncoderSpec::Model2Vec(ModelOptions {
151 model,
152 policy: policy.into(),
153 })),
154 Self::Hashing { dim } => Some(EncoderSpec::Hashing { dim }),
155 Self::Sparse => None,
156 }
157 }
158}
159
160#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case")]
162pub enum ModelLoadPolicyWire {
163 AllowDownload,
164 NoDownload,
165 Offline,
166}
167
168impl From<ModelLoadPolicy> for ModelLoadPolicyWire {
169 fn from(value: ModelLoadPolicy) -> Self {
170 match value {
171 ModelLoadPolicy::AllowDownload => Self::AllowDownload,
172 ModelLoadPolicy::NoDownload => Self::NoDownload,
173 ModelLoadPolicy::Offline => Self::Offline,
174 }
175 }
176}
177
178impl From<ModelLoadPolicyWire> for ModelLoadPolicy {
179 fn from(value: ModelLoadPolicyWire) -> Self {
180 match value {
181 ModelLoadPolicyWire::AllowDownload => Self::AllowDownload,
182 ModelLoadPolicyWire::NoDownload => Self::NoDownload,
183 ModelLoadPolicyWire::Offline => Self::Offline,
184 }
185 }
186}
187
188#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
189pub struct IndexRuntimeOptions {
190 pub encoder: EncoderSpecWire,
191 pub cache: CacheConfigSpec,
192 pub extensions: Option<Vec<String>>,
193 pub ignore: Option<Vec<String>>,
194 pub include_text_files: bool,
195}
196
197impl Default for IndexRuntimeOptions {
198 fn default() -> Self {
199 Self {
200 encoder: EncoderSpecWire::Model2Vec {
201 model: ModelOptions::default().model,
202 policy: ModelLoadPolicyWire::AllowDownload,
203 },
204 cache: CacheConfigSpec::Platform,
205 extensions: None,
206 ignore: None,
207 include_text_files: false,
208 }
209 }
210}
211
212impl IndexRuntimeOptions {
213 pub fn sparse(cache: CacheConfig) -> Self {
214 Self {
215 encoder: EncoderSpecWire::Sparse,
216 cache: CacheConfigSpec::from(&cache),
217 extensions: None,
218 ignore: None,
219 include_text_files: false,
220 }
221 }
222
223 pub fn with_encoder(encoder: EncoderSpec, cache: CacheConfig) -> Self {
224 Self {
225 encoder: EncoderSpecWire::from_encoder_spec(Some(&encoder)),
226 cache: CacheConfigSpec::from(&cache),
227 extensions: None,
228 ignore: None,
229 include_text_files: false,
230 }
231 }
232
233 pub fn extensions_set(&self) -> Option<HashSet<String>> {
234 self.extensions
235 .as_ref()
236 .map(|items| items.iter().cloned().collect())
237 }
238
239 pub fn ignore_set(&self) -> Option<HashSet<String>> {
240 self.ignore
241 .as_ref()
242 .map(|items| items.iter().cloned().collect())
243 }
244}
245
246#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
247pub struct IndexIdentity {
248 pub source: SourceSpec,
249 pub encoder_key: String,
250 pub cache_key: String,
251 pub extensions: Option<Vec<String>>,
252 pub ignore: Option<Vec<String>>,
253 pub include_text_files: bool,
254}
255
256impl IndexIdentity {
257 pub fn new(source: SourceSpec, options: &IndexRuntimeOptions) -> Self {
258 Self {
259 source,
260 encoder_key: encoder_key(&options.encoder),
261 cache_key: cache_key(&options.cache),
262 extensions: normalized_vec(options.extensions.clone()),
263 ignore: normalized_vec(options.ignore.clone()),
264 include_text_files: options.include_text_files,
265 }
266 }
267
268 pub fn key(&self) -> String {
269 serde_json::to_string(self).expect("index identity is serializable")
270 }
271}
272
273fn encoder_key(encoder: &EncoderSpecWire) -> String {
274 match encoder {
275 EncoderSpecWire::Model2Vec { model, policy } => {
276 format!("model2vec:{model}:{}", policy_key(*policy))
277 }
278 EncoderSpecWire::Hashing { dim } => format!("hashing:{dim}"),
279 EncoderSpecWire::Sparse => "sparse".to_owned(),
280 }
281}
282
283fn policy_key(policy: ModelLoadPolicyWire) -> &'static str {
284 match policy {
285 ModelLoadPolicyWire::AllowDownload => "allow-download",
286 ModelLoadPolicyWire::NoDownload => "no-download",
287 ModelLoadPolicyWire::Offline => "offline",
288 }
289}
290
291fn cache_key(cache: &CacheConfigSpec) -> String {
292 match cache {
293 CacheConfigSpec::Platform => "platform".to_owned(),
294 CacheConfigSpec::Project => "project".to_owned(),
295 CacheConfigSpec::Custom { path } => format!("custom:{}", path.display()),
296 CacheConfigSpec::Disabled => "disabled".to_owned(),
297 }
298}
299
300fn normalized_vec(items: Option<Vec<String>>) -> Option<Vec<String>> {
301 let mut items = items?;
302 items.sort();
303 items.dedup();
304 Some(items)
305}
306
307#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
308pub struct DaemonRequestEnvelope {
309 pub protocol_version: u32,
310 pub request_id: String,
311 pub request: DaemonRequest,
312}
313
314impl DaemonRequestEnvelope {
315 pub fn new(request_id: impl Into<String>, request: DaemonRequest) -> Self {
316 Self {
317 protocol_version: DAEMON_PROTOCOL_VERSION,
318 request_id: request_id.into(),
319 request,
320 }
321 }
322}
323
324#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
325#[serde(tag = "type", rename_all = "snake_case")]
326pub enum DaemonRequest {
327 Ping,
328 Status,
329 IndexStatus {
330 source: SourceSpec,
331 options: IndexRuntimeOptions,
332 },
333 Search {
334 source: SourceSpec,
335 options: IndexRuntimeOptions,
336 query: String,
337 search: SearchOptionsWire,
338 },
339 FindRelated {
340 source: SourceSpec,
341 options: IndexRuntimeOptions,
342 file_path: String,
343 line: usize,
344 top_k: usize,
345 },
346 ListFiles {
347 source: SourceSpec,
348 options: IndexRuntimeOptions,
349 limit: usize,
350 },
351 GetChunk {
352 source: SourceSpec,
353 options: IndexRuntimeOptions,
354 file_path: String,
355 line: usize,
356 },
357 Refresh {
358 source: SourceSpec,
359 options: IndexRuntimeOptions,
360 },
361 Clear {
362 source: SourceSpec,
363 options: IndexRuntimeOptions,
364 },
365}
366
367#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
368pub struct SearchOptionsWire {
369 pub top_k: usize,
370 pub mode: SearchMode,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub alpha: Option<f32>,
373 pub filter_languages: Vec<String>,
374 pub filter_paths: Vec<String>,
375 pub use_query_cache: bool,
376 #[serde(default)]
377 pub explain: bool,
378}
379
380impl From<SearchOptions> for SearchOptionsWire {
381 fn from(value: SearchOptions) -> Self {
382 Self {
383 top_k: value.top_k,
384 mode: value.mode,
385 alpha: value.alpha,
386 filter_languages: value.filter_languages,
387 filter_paths: value.filter_paths,
388 use_query_cache: value.use_query_cache,
389 explain: value.explain,
390 }
391 }
392}
393
394impl From<SearchOptionsWire> for SearchOptions {
395 fn from(value: SearchOptionsWire) -> Self {
396 Self {
397 top_k: value.top_k,
398 mode: value.mode,
399 alpha: value.alpha,
400 filter_languages: value.filter_languages,
401 filter_paths: value.filter_paths,
402 use_query_cache: value.use_query_cache,
403 explain: value.explain,
404 }
405 }
406}
407
408#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
409pub struct DaemonResponseEnvelope {
410 pub protocol_version: u32,
411 pub request_id: String,
412 #[serde(flatten)]
413 pub result: ResultEnvelope,
414}
415
416impl DaemonResponseEnvelope {
417 pub fn ok(request_id: impl Into<String>, result: DaemonResult) -> Self {
418 Self {
419 protocol_version: DAEMON_PROTOCOL_VERSION,
420 request_id: request_id.into(),
421 result: ResultEnvelope::Ok { result },
422 }
423 }
424
425 pub fn error(request_id: impl Into<String>, error: DaemonError) -> Self {
426 Self {
427 protocol_version: DAEMON_PROTOCOL_VERSION,
428 request_id: request_id.into(),
429 result: ResultEnvelope::Error { error },
430 }
431 }
432}
433
434#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
435#[serde(rename_all = "snake_case")]
436pub enum ResultEnvelope {
437 Ok { result: DaemonResult },
438 Error { error: DaemonError },
439}
440
441#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
442#[serde(tag = "type", rename_all = "snake_case")]
443pub enum DaemonResult {
444 Pong {
445 version: String,
446 },
447 Status(DaemonStatus),
448 IndexStatus(IndexStatusResult),
449 Search(SearchResultSet),
450 FindRelated(SearchResultSet),
451 ListFiles {
452 source: SourceSpec,
453 total: usize,
454 files: Vec<String>,
455 },
456 GetChunk {
457 source: SourceSpec,
458 chunk: Chunk,
459 },
460 Refresh(IndexStatusResult),
461 Clear {
462 source: SourceSpec,
463 removed: bool,
464 },
465}
466
467#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
468pub struct DaemonStatus {
469 pub version: String,
470 pub protocol_version: u32,
471 pub pid: u32,
472 pub indexes: Vec<CachedIndexStatus>,
473}
474
475#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
476pub struct CachedIndexStatus {
477 pub source: SourceSpec,
478 pub stats: IndexStats,
479 pub semantic_loaded: bool,
480}
481
482#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
483pub struct IndexStatusResult {
484 pub source: SourceSpec,
485 pub stats: IndexStats,
486 pub semantic_loaded: bool,
487 pub warnings: Vec<IndexWarning>,
488}
489
490#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
491pub struct SearchResultSet {
492 pub source: SourceSpec,
493 pub query: String,
494 pub mode: SearchMode,
495 pub stats: IndexStats,
496 pub elapsed_ms: u64,
497 pub results: Vec<SearchResult>,
498 pub warnings: Vec<IndexWarning>,
499}
500
501#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
502pub struct DaemonError {
503 pub code: String,
504 pub message: String,
505}
506
507impl DaemonError {
508 pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
509 Self {
510 code: code.into(),
511 message: message.into(),
512 }
513 }
514}