1use crate::error::MemoryError;
2use crate::tokenizer::TokenCounter;
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Serialize, Deserialize)]
10pub struct MemoryConfig {
11 pub base_dir: PathBuf,
14
15 pub embedding: EmbeddingConfig,
17
18 pub search: SearchConfig,
20
21 pub chunking: ChunkingConfig,
23
24 pub pool: PoolConfig,
26
27 pub limits: MemoryLimits,
29
30 #[serde(skip)]
32 pub token_counter: Option<Arc<dyn TokenCounter>>,
33
34 #[cfg(feature = "hnsw")]
36 #[serde(skip)]
37 pub hnsw: crate::hnsw::HnswConfig,
38}
39
40impl std::fmt::Debug for MemoryConfig {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 let mut s = f.debug_struct("MemoryConfig");
43 s.field("base_dir", &self.base_dir)
44 .field("embedding", &self.embedding)
45 .field("search", &self.search)
46 .field("chunking", &self.chunking)
47 .field("pool", &self.pool)
48 .field("limits", &self.limits)
49 .field(
50 "token_counter",
51 &self.token_counter.as_ref().map(|_| "custom"),
52 );
53 #[cfg(feature = "hnsw")]
54 s.field("hnsw", &self.hnsw);
55 s.finish()
56 }
57}
58
59impl Default for MemoryConfig {
60 fn default() -> Self {
61 Self {
62 base_dir: PathBuf::from("memory"),
63 embedding: EmbeddingConfig::default(),
64 search: SearchConfig::default(),
65 chunking: ChunkingConfig::default(),
66 pool: PoolConfig::default(),
67 limits: MemoryLimits::default(),
68 token_counter: None,
69 #[cfg(feature = "hnsw")]
70 hnsw: crate::hnsw::HnswConfig::default(),
71 }
72 }
73}
74
75impl MemoryConfig {
76 pub fn normalize_and_validate(mut self) -> Result<Self, MemoryError> {
80 self.embedding.normalize_and_validate()?;
81 self.limits = self.limits.normalize_and_validate()?;
82 let timeout_cap_secs = self.limits.embedding_timeout.as_secs().max(1);
83 self.embedding.timeout_secs = self.embedding.timeout_secs.min(timeout_cap_secs);
84 self.search.normalize_and_validate()?;
85 self.chunking.normalize_and_validate()?;
86 self.pool.normalize_and_validate()?;
87 #[cfg(feature = "hnsw")]
88 {
89 self.hnsw.dimensions = self.embedding.dimensions;
90 }
91 Ok(self)
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EmbeddingConfig {
98 pub ollama_url: String,
100
101 pub model: String,
103
104 pub dimensions: usize,
106
107 pub batch_size: usize,
109
110 pub timeout_secs: u64,
112}
113
114impl Default for EmbeddingConfig {
115 fn default() -> Self {
116 Self {
117 ollama_url: "http://localhost:11434".to_string(),
118 model: "nomic-embed-text".to_string(),
119 dimensions: 768,
120 batch_size: 32,
121 timeout_secs: 30,
122 }
123 }
124}
125
126impl EmbeddingConfig {
127 fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
128 if self.dimensions == 0 {
129 return Err(MemoryError::InvalidConfig {
130 field: "embedding.dimensions",
131 reason: "dimensions must be at least 1".to_string(),
132 });
133 }
134 if self.batch_size == 0 {
135 self.batch_size = 1;
136 }
137 if self.timeout_secs == 0 {
138 self.timeout_secs = 1;
139 }
140 let parsed =
141 reqwest::Url::parse(&self.ollama_url).map_err(|_| MemoryError::InvalidConfig {
142 field: "embedding.ollama_url",
143 reason: "must be an absolute http:// or https:// URL".to_string(),
144 })?;
145 match parsed.scheme() {
146 "http" | "https" if parsed.host_str().is_some() => {}
147 _ => {
148 return Err(MemoryError::InvalidConfig {
149 field: "embedding.ollama_url",
150 reason: "must be an absolute http:// or https:// URL".to_string(),
151 })
152 }
153 }
154 Ok(())
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct SearchConfig {
161 pub bm25_weight: f64,
163
164 pub vector_weight: f64,
166
167 pub rrf_k: f64,
169
170 pub candidate_pool_size: usize,
172
173 pub default_top_k: usize,
175
176 pub min_similarity: f64,
178
179 pub recency_half_life_days: Option<f64>,
184
185 pub recency_weight: f64,
189
190 pub rerank_from_f32: bool,
195}
196
197impl Default for SearchConfig {
198 fn default() -> Self {
199 Self {
200 bm25_weight: 1.0,
201 vector_weight: 1.0,
202 rrf_k: 60.0,
203 candidate_pool_size: 50,
204 default_top_k: 5,
205 min_similarity: 0.3,
206 recency_half_life_days: None,
207 recency_weight: 0.5,
208 rerank_from_f32: true,
209 }
210 }
211}
212
213impl SearchConfig {
214 fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
215 if self.candidate_pool_size == 0 {
216 self.candidate_pool_size = 1;
217 }
218 if self.default_top_k == 0 {
219 self.default_top_k = 1;
220 }
221 self.candidate_pool_size = self.candidate_pool_size.max(self.default_top_k);
222 if !self.rrf_k.is_finite() || self.rrf_k <= 0.0 {
223 return Err(MemoryError::InvalidConfig {
224 field: "search.rrf_k",
225 reason: "rrf_k must be finite and > 0".to_string(),
226 });
227 }
228 if !self.bm25_weight.is_finite() || self.bm25_weight < 0.0 {
229 return Err(MemoryError::InvalidConfig {
230 field: "search.bm25_weight",
231 reason: "bm25_weight must be finite and >= 0".to_string(),
232 });
233 }
234 if !self.vector_weight.is_finite() || self.vector_weight < 0.0 {
235 return Err(MemoryError::InvalidConfig {
236 field: "search.vector_weight",
237 reason: "vector_weight must be finite and >= 0".to_string(),
238 });
239 }
240 if !self.recency_weight.is_finite() || self.recency_weight < 0.0 {
241 return Err(MemoryError::InvalidConfig {
242 field: "search.recency_weight",
243 reason: "recency_weight must be finite and >= 0".to_string(),
244 });
245 }
246 if !self.min_similarity.is_finite() || !(-1.0..=1.0).contains(&self.min_similarity) {
247 return Err(MemoryError::InvalidConfig {
248 field: "search.min_similarity",
249 reason: "min_similarity must be finite and within [-1.0, 1.0]".to_string(),
250 });
251 }
252 if matches!(self.recency_half_life_days, Some(v) if !v.is_finite()) {
253 return Err(MemoryError::InvalidConfig {
254 field: "search.recency_half_life_days",
255 reason: "recency_half_life_days must be finite".to_string(),
256 });
257 }
258 if matches!(self.recency_half_life_days, Some(v) if v <= 0.0) {
259 return Err(MemoryError::InvalidConfig {
260 field: "search.recency_half_life_days",
261 reason: "recency_half_life_days must be > 0 when enabled".to_string(),
262 });
263 }
264 Ok(())
265 }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ChunkingConfig {
271 pub target_size: usize,
273
274 pub min_size: usize,
276
277 pub max_size: usize,
279
280 pub overlap: usize,
282}
283
284impl Default for ChunkingConfig {
285 fn default() -> Self {
286 Self {
287 target_size: 1000,
288 min_size: 100,
289 max_size: 2000,
290 overlap: 200,
291 }
292 }
293}
294
295impl ChunkingConfig {
296 fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
297 if self.min_size == 0 {
298 self.min_size = 1;
299 }
300 if self.max_size == 0 {
301 return Err(MemoryError::InvalidConfig {
302 field: "chunking.max_size",
303 reason: "max_size must be at least 1".to_string(),
304 });
305 }
306 if self.max_size < self.min_size {
307 return Err(MemoryError::InvalidConfig {
308 field: "chunking.max_size",
309 reason: "max_size must be >= min_size".to_string(),
310 });
311 }
312 if self.target_size < self.min_size {
313 self.target_size = self.min_size;
314 }
315 if self.target_size > self.max_size {
316 self.target_size = self.max_size;
317 }
318 if self.overlap >= self.min_size {
319 self.overlap = self.min_size.saturating_sub(1);
320 }
321 Ok(())
322 }
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct PoolConfig {
331 pub busy_timeout_ms: u32,
334
335 pub wal_autocheckpoint: u32,
338
339 pub enable_wal: bool,
342
343 pub max_read_connections: usize,
348
349 pub reader_timeout_secs: u64,
352}
353
354impl Default for PoolConfig {
355 fn default() -> Self {
356 Self {
357 busy_timeout_ms: 5000,
358 wal_autocheckpoint: 1000,
359 enable_wal: true,
360 max_read_connections: 4,
361 reader_timeout_secs: 30,
362 }
363 }
364}
365
366impl PoolConfig {
367 fn normalize_and_validate(&mut self) -> Result<(), MemoryError> {
368 if self.busy_timeout_ms == 0 {
369 self.busy_timeout_ms = 1;
370 }
371 if self.wal_autocheckpoint == 0 {
372 self.wal_autocheckpoint = 1;
373 }
374 if self.max_read_connections == 0 {
375 return Err(MemoryError::InvalidConfig {
376 field: "pool.max_read_connections",
377 reason: "set pool.max_read_connections to at least 1".to_string(),
378 });
379 }
380 if self.reader_timeout_secs == 0 {
381 self.reader_timeout_secs = 1;
382 }
383 self.reader_timeout_secs = self.reader_timeout_secs.min(300);
384 Ok(())
385 }
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct MemoryLimits {
394 pub max_facts_per_namespace: usize,
397
398 pub max_chunks_per_document: usize,
401
402 pub max_content_bytes: usize,
405
406 pub max_embedding_concurrency: usize,
410
411 pub max_db_size_bytes: u64,
414
415 #[serde(with = "duration_secs")]
418 pub embedding_timeout: Duration,
419}
420
421impl Default for MemoryLimits {
422 fn default() -> Self {
423 Self {
424 max_facts_per_namespace: 100_000,
425 max_chunks_per_document: 1_000,
426 max_content_bytes: 1_048_576,
427 max_embedding_concurrency: 8,
428 max_db_size_bytes: 0,
429 embedding_timeout: Duration::from_secs(30),
430 }
431 }
432}
433
434impl MemoryLimits {
435 pub fn normalize_and_validate(mut self) -> Result<Self, MemoryError> {
437 if self.max_facts_per_namespace == 0 {
438 return Err(MemoryError::InvalidConfig {
439 field: "limits.max_facts_per_namespace",
440 reason: "must be at least 1".to_string(),
441 });
442 }
443 if self.max_chunks_per_document == 0 {
444 return Err(MemoryError::InvalidConfig {
445 field: "limits.max_chunks_per_document",
446 reason: "must be at least 1".to_string(),
447 });
448 }
449 if self.max_content_bytes == 0 {
450 return Err(MemoryError::InvalidConfig {
451 field: "limits.max_content_bytes",
452 reason: "must be at least 1".to_string(),
453 });
454 }
455 if self.max_embedding_concurrency > 32 {
457 self.max_embedding_concurrency = 32;
458 }
459 if self.max_embedding_concurrency == 0 {
460 self.max_embedding_concurrency = 1;
461 }
462 if self.embedding_timeout.is_zero() {
463 self.embedding_timeout = Duration::from_secs(1);
464 }
465 Ok(self)
466 }
467
468 pub fn validated(self) -> Self {
473 self.normalize_and_validate().unwrap_or_else(|err| {
474 tracing::warn!(
475 error = %err,
476 "invalid MemoryLimits supplied to validated(); using defaults"
477 );
478 let defaults = Self::default();
480 Self {
481 max_facts_per_namespace: defaults.max_facts_per_namespace,
482 max_chunks_per_document: defaults.max_chunks_per_document,
483 max_content_bytes: defaults.max_content_bytes,
484 max_embedding_concurrency: defaults.max_embedding_concurrency.clamp(1, 32),
485 max_db_size_bytes: defaults.max_db_size_bytes,
486 embedding_timeout: if defaults.embedding_timeout.is_zero() {
487 std::time::Duration::from_secs(1)
488 } else {
489 defaults.embedding_timeout
490 },
491 }
492 })
493 }
494}
495
496mod duration_secs {
497 use serde::{Deserialize, Deserializer, Serializer};
498 use std::time::Duration;
499
500 pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
501 s.serialize_u64(d.as_secs())
502 }
503
504 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
505 let secs = u64::deserialize(d)?;
506 Ok(Duration::from_secs(secs))
507 }
508}