Skip to main content

semantic_memory/
config.rs

1use crate::error::MemoryError;
2use crate::tokenizer::TokenCounter;
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Duration;
7
8/// Configuration for the memory system.
9#[derive(Clone, Serialize, Deserialize)]
10pub struct MemoryConfig {
11    /// Base directory for all storage files (SQLite + HNSW sidecar files).
12    /// Replaces the v0.1.0 `database_path` field.
13    pub base_dir: PathBuf,
14
15    /// Embedding provider configuration.
16    pub embedding: EmbeddingConfig,
17
18    /// Search tuning parameters.
19    pub search: SearchConfig,
20
21    /// Chunking parameters.
22    pub chunking: ChunkingConfig,
23
24    /// Connection pool configuration.
25    pub pool: PoolConfig,
26
27    /// Resource limits.
28    pub limits: MemoryLimits,
29
30    /// Custom token counter. None = use EstimateTokenCounter (chars / 4).
31    #[serde(skip)]
32    pub token_counter: Option<Arc<dyn TokenCounter>>,
33
34    /// HNSW index configuration.
35    #[cfg(feature = "hnsw")]
36    #[serde(skip)]
37    pub hnsw: crate::hnsw::HnswConfig,
38}
39
40impl std::fmt::Debug for MemoryConfig {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        let mut s = f.debug_struct("MemoryConfig");
43        s.field("base_dir", &self.base_dir)
44            .field("embedding", &self.embedding)
45            .field("search", &self.search)
46            .field("chunking", &self.chunking)
47            .field("pool", &self.pool)
48            .field("limits", &self.limits)
49            .field(
50                "token_counter",
51                &self.token_counter.as_ref().map(|_| "custom"),
52            );
53        #[cfg(feature = "hnsw")]
54        s.field("hnsw", &self.hnsw);
55        s.finish()
56    }
57}
58
59impl Default for MemoryConfig {
60    fn default() -> Self {
61        Self {
62            base_dir: PathBuf::from("memory"),
63            embedding: EmbeddingConfig::default(),
64            search: SearchConfig::default(),
65            chunking: ChunkingConfig::default(),
66            pool: PoolConfig::default(),
67            limits: MemoryLimits::default(),
68            token_counter: None,
69            #[cfg(feature = "hnsw")]
70            hnsw: crate::hnsw::HnswConfig::default(),
71        }
72    }
73}
74
75impl MemoryConfig {
76    /// Normalize and validate configuration into a concrete runtime shape.
77    ///
78    /// This is the single canonical config entry point used by store creation.
79    pub fn normalize_and_validate(mut self) -> Result<Self, MemoryError> {
80        self.embedding.normalize_and_validate()?;
81        self.limits = self.limits.normalize_and_validate()?;
82        let timeout_cap_secs = self.limits.embedding_timeout.as_secs().max(1);
83        self.embedding.timeout_secs = self.embedding.timeout_secs.min(timeout_cap_secs);
84        self.search
85            .normalize_and_validate(self.embedding.dimensions)?;
86        self.chunking.normalize_and_validate()?;
87        self.pool.normalize_and_validate()?;
88        #[cfg(feature = "hnsw")]
89        {
90            self.hnsw.dimensions = self.embedding.dimensions;
91        }
92        Ok(self)
93    }
94}
95
96/// Embedding provider configuration.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct EmbeddingConfig {
99    /// Ollama base URL.
100    pub ollama_url: String,
101
102    /// Embedding model name.
103    pub model: String,
104
105    /// Expected embedding dimensions.
106    pub dimensions: usize,
107
108    /// Maximum texts to embed in a single API call.
109    pub batch_size: usize,
110
111    /// Timeout for embedding requests in seconds.
112    pub timeout_secs: u64,
113}
114
115impl Default for EmbeddingConfig {
116    fn default() -> Self {
117        Self {
118            ollama_url: "http://localhost:11434".to_string(),
119            model: "nomic-embed-text".to_string(),
120            dimensions: 768,
121            batch_size: 32,
122            timeout_secs: 30,
123        }
124    }
125}
126
127impl EmbeddingConfig {
128    fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
129        if self.dimensions == 0 {
130            return Err(MemoryError::InvalidConfig {
131                field: "embedding.dimensions",
132                reason: "dimensions must be at least 1".to_string(),
133            });
134        }
135        if self.batch_size == 0 {
136            self.batch_size = 1;
137        }
138        if self.timeout_secs == 0 {
139            self.timeout_secs = 1;
140        }
141        let parsed =
142            reqwest::Url::parse(&self.ollama_url).map_err(|_| MemoryError::InvalidConfig {
143                field: "embedding.ollama_url",
144                reason: "must be an absolute http:// or https:// URL".to_string(),
145            })?;
146        match parsed.scheme() {
147            "http" | "https" if parsed.host_str().is_some() => {}
148            _ => {
149                return Err(MemoryError::InvalidConfig {
150                    field: "embedding.ollama_url",
151                    reason: "must be an absolute http:// or https:// URL".to_string(),
152                })
153            }
154        }
155        Ok(())
156    }
157}
158
159/// Search tuning parameters.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct SearchConfig {
162    /// Weight for BM25 score in RRF fusion.
163    pub bm25_weight: f64,
164
165    /// Weight for vector similarity in RRF fusion.
166    pub vector_weight: f64,
167
168    /// RRF constant (k). Controls rank importance decay.
169    pub rrf_k: f64,
170
171    /// Number of candidates from each search method before fusion.
172    pub candidate_pool_size: usize,
173
174    /// Default number of results to return.
175    pub default_top_k: usize,
176
177    /// Minimum cosine similarity threshold for vector candidates.
178    pub min_similarity: f64,
179
180    /// Optional recency boost. If enabled, results are boosted based on how
181    /// recently they were created/updated. The value is the half-life in days —
182    /// a fact that is `recency_half_life_days` old gets 50% of the recency boost.
183    /// None = no recency weighting (current behavior, default).
184    pub recency_half_life_days: Option<f64>,
185
186    /// Weight of the recency boost relative to BM25 and vector scores in RRF.
187    /// Only used when recency_half_life_days is Some.
188    /// Default: 0.5
189    pub recency_weight: f64,
190
191    /// When true, rerank top HNSW candidates using exact f32 cosine similarity
192    /// from SQLite. Improves recall at the cost of one batched SQL query.
193    /// Only applies when HNSW feature is enabled.
194    /// Default: true
195    pub rerank_from_f32: bool,
196
197    /// Optional derived-vector candidate backend. Disabled by default because
198    /// raw f32 embeddings remain authoritative.
199    #[serde(default)]
200    pub derived_vector_backend: DerivedVectorBackendPolicy,
201
202    /// TurboQuant polar angle bits when the TurboQuant candidate backend is enabled.
203    #[serde(default = "default_turbo_quant_bits")]
204    pub turbo_quant_bits: u8,
205
206    /// TurboQuant QJL projection count when the TurboQuant candidate backend is enabled.
207    #[serde(default = "default_turbo_quant_projections")]
208    pub turbo_quant_projections: usize,
209
210    /// TurboQuant profile seed when the TurboQuant candidate backend is enabled.
211    #[serde(default)]
212    pub turbo_quant_seed: u64,
213
214    /// Require exact f32 rerank for TurboQuant candidates. Defaults to true.
215    #[serde(default = "default_true")]
216    pub turbo_quant_require_exact_rerank: bool,
217}
218
219/// Candidate backend policy for rebuildable derived vector artifacts.
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
221#[serde(rename_all = "snake_case")]
222pub enum DerivedVectorBackendPolicy {
223    /// Use authoritative raw f32 embeddings for vector candidate generation.
224    #[default]
225    Disabled,
226    /// Use TurboQuant only to generate candidates, then exact rerank by default.
227    TurboQuantCandidateOnly,
228}
229
230const fn default_turbo_quant_bits() -> u8 {
231    8
232}
233
234const fn default_turbo_quant_projections() -> usize {
235    64
236}
237
238const fn default_true() -> bool {
239    true
240}
241
242impl Default for SearchConfig {
243    fn default() -> Self {
244        Self {
245            bm25_weight: 1.0,
246            vector_weight: 1.0,
247            rrf_k: 60.0,
248            candidate_pool_size: 50,
249            default_top_k: 5,
250            min_similarity: 0.3,
251            recency_half_life_days: None,
252            recency_weight: 0.5,
253            rerank_from_f32: true,
254            derived_vector_backend: DerivedVectorBackendPolicy::Disabled,
255            turbo_quant_bits: default_turbo_quant_bits(),
256            turbo_quant_projections: default_turbo_quant_projections(),
257            turbo_quant_seed: 0,
258            turbo_quant_require_exact_rerank: true,
259        }
260    }
261}
262
263impl SearchConfig {
264    pub(crate) fn uses_turbo_quant_backend(&self) -> bool {
265        self.derived_vector_backend == DerivedVectorBackendPolicy::TurboQuantCandidateOnly
266    }
267
268    fn normalize_and_validate(&mut self, embedding_dimensions: usize) -> Result<(), MemoryError> {
269        #[cfg(not(feature = "turbo-quant-codec"))]
270        let _ = embedding_dimensions;
271        if self.candidate_pool_size == 0 {
272            self.candidate_pool_size = 1;
273        }
274        if self.default_top_k == 0 {
275            self.default_top_k = 1;
276        }
277        self.candidate_pool_size = self.candidate_pool_size.max(self.default_top_k);
278        if !self.rrf_k.is_finite() || self.rrf_k <= 0.0 {
279            return Err(MemoryError::InvalidConfig {
280                field: "search.rrf_k",
281                reason: "rrf_k must be finite and > 0".to_string(),
282            });
283        }
284        if !self.bm25_weight.is_finite() || self.bm25_weight < 0.0 {
285            return Err(MemoryError::InvalidConfig {
286                field: "search.bm25_weight",
287                reason: "bm25_weight must be finite and >= 0".to_string(),
288            });
289        }
290        if !self.vector_weight.is_finite() || self.vector_weight < 0.0 {
291            return Err(MemoryError::InvalidConfig {
292                field: "search.vector_weight",
293                reason: "vector_weight must be finite and >= 0".to_string(),
294            });
295        }
296        if !self.recency_weight.is_finite() || self.recency_weight < 0.0 {
297            return Err(MemoryError::InvalidConfig {
298                field: "search.recency_weight",
299                reason: "recency_weight must be finite and >= 0".to_string(),
300            });
301        }
302        if !self.min_similarity.is_finite() || !(-1.0..=1.0).contains(&self.min_similarity) {
303            return Err(MemoryError::InvalidConfig {
304                field: "search.min_similarity",
305                reason: "min_similarity must be finite and within [-1.0, 1.0]".to_string(),
306            });
307        }
308        if matches!(self.recency_half_life_days, Some(v) if !v.is_finite()) {
309            return Err(MemoryError::InvalidConfig {
310                field: "search.recency_half_life_days",
311                reason: "recency_half_life_days must be finite".to_string(),
312            });
313        }
314        if matches!(self.recency_half_life_days, Some(v) if v <= 0.0) {
315            return Err(MemoryError::InvalidConfig {
316                field: "search.recency_half_life_days",
317                reason: "recency_half_life_days must be > 0 when enabled".to_string(),
318            });
319        }
320        if self.uses_turbo_quant_backend() {
321            #[cfg(not(feature = "turbo-quant-codec"))]
322            {
323                return Err(MemoryError::InvalidConfig {
324                    field: "search.derived_vector_backend",
325                    reason: "turbo_quant_candidate_only requires the turbo-quant-codec feature"
326                        .to_string(),
327                });
328            }
329            #[cfg(feature = "turbo-quant-codec")]
330            {
331                if embedding_dimensions % 2 != 0 {
332                    return Err(MemoryError::InvalidConfig {
333                        field: "embedding.dimensions",
334                        reason: "TurboQuant requires even embedding dimensions".to_string(),
335                    });
336                }
337                if self.turbo_quant_projections == 0 {
338                    return Err(MemoryError::InvalidConfig {
339                        field: "search.turbo_quant_projections",
340                        reason: "TurboQuant projections must be at least 1".to_string(),
341                    });
342                }
343                if !(2..=16).contains(&self.turbo_quant_bits) {
344                    return Err(MemoryError::InvalidConfig {
345                        field: "search.turbo_quant_bits",
346                        reason: "TurboQuant bits must be within 2..=16".to_string(),
347                    });
348                }
349                if !self.turbo_quant_require_exact_rerank {
350                    return Err(MemoryError::InvalidConfig {
351                        field: "search.turbo_quant_require_exact_rerank",
352                        reason: "TurboQuant candidate backend requires exact f32 rerank"
353                            .to_string(),
354                    });
355                }
356            }
357        }
358        Ok(())
359    }
360}
361
362/// Text chunking parameters.
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct ChunkingConfig {
365    /// Target chunk size in characters.
366    pub target_size: usize,
367
368    /// Minimum chunk size. Chunks smaller than this are merged with neighbors.
369    pub min_size: usize,
370
371    /// Maximum chunk size. Chunks larger than this are force-split.
372    pub max_size: usize,
373
374    /// Overlap between adjacent chunks in characters.
375    pub overlap: usize,
376}
377
378impl Default for ChunkingConfig {
379    fn default() -> Self {
380        Self {
381            target_size: 1000,
382            min_size: 100,
383            max_size: 2000,
384            overlap: 200,
385        }
386    }
387}
388
389impl ChunkingConfig {
390    fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
391        if self.min_size == 0 {
392            self.min_size = 1;
393        }
394        if self.max_size == 0 {
395            return Err(MemoryError::InvalidConfig {
396                field: "chunking.max_size",
397                reason: "max_size must be at least 1".to_string(),
398            });
399        }
400        if self.max_size < self.min_size {
401            return Err(MemoryError::InvalidConfig {
402                field: "chunking.max_size",
403                reason: "max_size must be >= min_size".to_string(),
404            });
405        }
406        if self.target_size < self.min_size {
407            self.target_size = self.min_size;
408        }
409        if self.target_size > self.max_size {
410            self.target_size = self.max_size;
411        }
412        if self.overlap >= self.min_size {
413            self.overlap = self.min_size.saturating_sub(1);
414        }
415        Ok(())
416    }
417}
418
419/// Connection pool configuration for SQLite.
420///
421/// Controls busy timeout and WAL checkpoint behavior. These defaults
422/// are tuned for a single-process server on local SSD storage.
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct PoolConfig {
425    /// SQLite busy timeout in milliseconds.
426    /// Default: 5000 (5 seconds).
427    pub busy_timeout_ms: u32,
428
429    /// WAL auto-checkpoint threshold in pages.
430    /// Default: 1000 (~4 MB with 4KB pages).
431    pub wal_autocheckpoint: u32,
432
433    /// Enable WAL mode. Should almost always be true.
434    /// Default: true.
435    pub enable_wal: bool,
436
437    /// Number of reader connections kept in the pool.
438    /// Writes still flow through a single writer connection because SQLite
439    /// allows only one concurrent writer, but readers can proceed concurrently
440    /// under WAL semantics.
441    pub max_read_connections: usize,
442
443    /// Timeout in seconds for acquiring a reader connection from the pool.
444    /// Default: 30 seconds.
445    pub reader_timeout_secs: u64,
446}
447
448impl Default for PoolConfig {
449    fn default() -> Self {
450        Self {
451            busy_timeout_ms: 5000,
452            wal_autocheckpoint: 1000,
453            enable_wal: true,
454            max_read_connections: 4,
455            reader_timeout_secs: 30,
456        }
457    }
458}
459
460impl PoolConfig {
461    fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
462        if self.busy_timeout_ms == 0 {
463            self.busy_timeout_ms = 1;
464        }
465        if self.wal_autocheckpoint == 0 {
466            self.wal_autocheckpoint = 1;
467        }
468        if self.max_read_connections == 0 {
469            return Err(MemoryError::InvalidConfig {
470                field: "pool.max_read_connections",
471                reason: "set pool.max_read_connections to at least 1".to_string(),
472            });
473        }
474        if self.reader_timeout_secs == 0 {
475            self.reader_timeout_secs = 1;
476        }
477        self.reader_timeout_secs = self.reader_timeout_secs.min(300);
478        Ok(())
479    }
480}
481
482/// Resource limits for the memory system.
483///
484/// Prevents runaway resource usage. All limits have defaults tuned for
485/// a laptop-class server (8GB RAM, SSD storage).
486#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct MemoryLimits {
488    /// Maximum number of facts per namespace.
489    /// Default: 100_000.
490    pub max_facts_per_namespace: usize,
491
492    /// Maximum number of chunks per document.
493    /// Default: 1_000.
494    pub max_chunks_per_document: usize,
495
496    /// Maximum content size in bytes for a single fact or message.
497    /// Default: 1 MB (1_048_576 bytes).
498    pub max_content_bytes: usize,
499
500    /// Maximum number of concurrent embedding requests.
501    /// Hard-capped at 32 regardless of config.
502    /// Default: 8.
503    pub max_embedding_concurrency: usize,
504
505    /// Maximum total database size in bytes. 0 = unlimited.
506    /// Default: 0 (unlimited).
507    pub max_db_size_bytes: u64,
508
509    /// Embedding request timeout.
510    /// Default: 30 seconds.
511    #[serde(with = "duration_secs")]
512    pub embedding_timeout: Duration,
513}
514
515impl Default for MemoryLimits {
516    fn default() -> Self {
517        Self {
518            max_facts_per_namespace: 100_000,
519            max_chunks_per_document: 1_000,
520            max_content_bytes: 1_048_576,
521            max_embedding_concurrency: 8,
522            max_db_size_bytes: 0,
523            embedding_timeout: Duration::from_secs(30),
524        }
525    }
526}
527
528impl MemoryLimits {
529    /// Normalize and validate limits to hard caps.
530    pub fn normalize_and_validate(mut self) -> Result<Self, MemoryError> {
531        if self.max_facts_per_namespace == 0 {
532            return Err(MemoryError::InvalidConfig {
533                field: "limits.max_facts_per_namespace",
534                reason: "must be at least 1".to_string(),
535            });
536        }
537        if self.max_chunks_per_document == 0 {
538            return Err(MemoryError::InvalidConfig {
539                field: "limits.max_chunks_per_document",
540                reason: "must be at least 1".to_string(),
541            });
542        }
543        if self.max_content_bytes == 0 {
544            return Err(MemoryError::InvalidConfig {
545                field: "limits.max_content_bytes",
546                reason: "must be at least 1".to_string(),
547            });
548        }
549        // Hard cap: concurrency at 32
550        if self.max_embedding_concurrency > 32 {
551            self.max_embedding_concurrency = 32;
552        }
553        if self.max_embedding_concurrency == 0 {
554            self.max_embedding_concurrency = 1;
555        }
556        if self.embedding_timeout.is_zero() {
557            self.embedding_timeout = Duration::from_secs(1);
558        }
559        Ok(self)
560    }
561
562    /// Backward-compatible alias for callers that only need clamped limits.
563    ///
564    /// Falls back to defaults if the caller-provided limits are invalid.
565    /// Default limits are infallible so the fallback path cannot fail.
566    pub fn validated(self) -> Self {
567        self.normalize_and_validate().unwrap_or_else(|err| {
568            tracing::warn!(
569                error = %err,
570                "invalid MemoryLimits supplied to validated(); using defaults"
571            );
572            // Default limits are always valid — this path is infallible.
573            let defaults = Self::default();
574            Self {
575                max_facts_per_namespace: defaults.max_facts_per_namespace,
576                max_chunks_per_document: defaults.max_chunks_per_document,
577                max_content_bytes: defaults.max_content_bytes,
578                max_embedding_concurrency: defaults.max_embedding_concurrency.clamp(1, 32),
579                max_db_size_bytes: defaults.max_db_size_bytes,
580                embedding_timeout: if defaults.embedding_timeout.is_zero() {
581                    std::time::Duration::from_secs(1)
582                } else {
583                    defaults.embedding_timeout
584                },
585            }
586        })
587    }
588}
589
590mod duration_secs {
591    use serde::{Deserialize, Deserializer, Serializer};
592    use std::time::Duration;
593
594    pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
595        s.serialize_u64(d.as_secs())
596    }
597
598    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
599        let secs = u64::deserialize(d)?;
600        Ok(Duration::from_secs(secs))
601    }
602}