Skip to main content

starpod_memory/
scoring.rs

1//! Scoring, validation, and re-ranking utilities for the memory search pipeline.
2//!
3//! This module provides:
4//!
5//! - **Path validation** ([`validate_path`], [`validate_content_size`]) — security
6//!   checks preventing directory traversal, non-`.md` writes, and oversized content.
7//! - **Temporal decay** ([`decay_factor`], [`apply_decay`]) — penalizes older daily
8//!   logs in search results using exponential half-life decay while leaving evergreen
9//!   files (SOUL.md, HEARTBEAT.md, etc.) unaffected.
10//! - **MMR re-ranking** ([`mmr_rerank`]) — Maximal Marginal Relevance diversifies
11//!   search results by balancing query relevance against redundancy with
12//!   already-selected results.
13
14use std::path::Path;
15
16use chrono::{Local, NaiveDate};
17
18use starpod_core::{Result, StarpodError};
19
20/// Maximum file size for memory writes (1 MB).
21pub const MAX_WRITE_SIZE: usize = 1_048_576;
22
23/// Validate a memory file path, rejecting traversal attacks and unsafe names.
24pub fn validate_path(name: &str, data_dir: &Path) -> Result<()> {
25    // Reject empty names
26    if name.is_empty() {
27        return Err(StarpodError::Io(std::io::Error::new(
28            std::io::ErrorKind::InvalidInput,
29            "File name cannot be empty",
30        )));
31    }
32
33    // Reject names that are too long
34    if name.len() > 255 {
35        return Err(StarpodError::Io(std::io::Error::new(
36            std::io::ErrorKind::InvalidInput,
37            "File name exceeds 255 characters",
38        )));
39    }
40
41    // Reject path traversal
42    if name.contains("..") {
43        return Err(StarpodError::Io(std::io::Error::new(
44            std::io::ErrorKind::InvalidInput,
45            "File name must not contain '..'",
46        )));
47    }
48
49    // Reject absolute paths
50    if Path::new(name).is_absolute() {
51        return Err(StarpodError::Io(std::io::Error::new(
52            std::io::ErrorKind::InvalidInput,
53            "File name must be a relative path",
54        )));
55    }
56
57    // Require .md extension
58    if !name.ends_with(".md") {
59        return Err(StarpodError::Io(std::io::Error::new(
60            std::io::ErrorKind::InvalidInput,
61            "Only .md files are allowed",
62        )));
63    }
64
65    // Ensure the resolved path stays under data_dir
66    let resolved = data_dir.join(name);
67    let canonical_data = data_dir
68        .canonicalize()
69        .unwrap_or_else(|_| data_dir.to_path_buf());
70    // Use the parent's canonical path if the file doesn't exist yet
71    let canonical_resolved = if resolved.exists() {
72        resolved.canonicalize().unwrap_or(resolved)
73    } else if let Some(parent) = resolved.parent() {
74        let canonical_parent = if parent.exists() {
75            parent
76                .canonicalize()
77                .unwrap_or_else(|_| parent.to_path_buf())
78        } else {
79            parent.to_path_buf()
80        };
81        canonical_parent.join(resolved.file_name().unwrap_or_default())
82    } else {
83        resolved
84    };
85
86    if !canonical_resolved.starts_with(&canonical_data) {
87        return Err(StarpodError::Io(std::io::Error::new(
88            std::io::ErrorKind::InvalidInput,
89            "File path escapes the data directory",
90        )));
91    }
92
93    Ok(())
94}
95
96/// Validate content size for writes.
97pub fn validate_content_size(content: &str) -> Result<()> {
98    if content.len() > MAX_WRITE_SIZE {
99        return Err(StarpodError::Io(std::io::Error::new(
100            std::io::ErrorKind::InvalidInput,
101            format!(
102                "Content size ({} bytes) exceeds maximum ({} bytes)",
103                content.len(),
104                MAX_WRITE_SIZE
105            ),
106        )));
107    }
108    Ok(())
109}
110
111/// Compute a temporal decay factor for a memory source.
112///
113/// Daily log files (`memory/YYYY-MM-DD.md`) decay with a half-life:
114///   `decay = 0.5^(age_days / half_life_days)`
115///
116/// Evergreen files (SOUL.md, USER.md, MEMORY.md, HEARTBEAT.md, etc.) return 1.0.
117pub fn decay_factor(source: &str, half_life_days: f64) -> f64 {
118    // Evergreen files don't decay
119    let evergreen_prefixes = ["SOUL.md", "USER.md", "MEMORY.md", "HEARTBEAT.md"];
120    if evergreen_prefixes.contains(&source) {
121        return 1.0;
122    }
123
124    // Try to parse date from memory/YYYY-MM-DD.md pattern
125    if let Some(date_str) = source
126        .strip_prefix("memory/")
127        .and_then(|s| s.strip_suffix(".md"))
128    {
129        if let Ok(file_date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
130            let today = Local::now().date_naive();
131            let age_days = (today - file_date).num_days().max(0) as f64;
132            return 0.5_f64.powf(age_days / half_life_days);
133        }
134    }
135
136    // Non-daily, non-evergreen files (e.g. memory/notes.md) — slight decay
137    0.8
138}
139
140/// Apply temporal decay to an FTS5 rank score.
141///
142/// FTS5 ranks are negative (more negative = better match). Multiplying by a
143/// decay factor < 1.0 makes the score less negative (closer to zero = worse),
144/// effectively penalizing older content.
145///
146/// Example: rank = -10.0, decay = 0.5 → adjusted = -5.0 (worse match).
147pub fn apply_decay(rank: f64, source: &str, half_life_days: f64) -> f64 {
148    let factor = decay_factor(source, half_life_days);
149    if factor <= 0.0 {
150        return rank;
151    }
152    rank * factor
153}
154
155/// Maximal Marginal Relevance (MMR) re-ranking for result diversity.
156///
157/// Given a query embedding and candidate results with their embeddings,
158/// iteratively selects candidates that maximize:
159///   `lambda * sim(candidate, query) - (1 - lambda) * max_sim(candidate, selected)`
160///
161/// This balances relevance (similarity to query) with diversity (dissimilarity
162/// from already-selected results).
163///
164/// `lambda`: 0.0 = maximum diversity, 1.0 = pure relevance. Default: 0.7.
165pub fn mmr_rerank(
166    query_embedding: &[f32],
167    candidates: &[(Vec<f32>, usize)], // (embedding, index into results)
168    limit: usize,
169    lambda: f64,
170) -> Vec<usize> {
171    use crate::embedder::cosine_similarity;
172
173    if candidates.is_empty() || limit == 0 {
174        return Vec::new();
175    }
176
177    // Pre-compute similarities to query
178    let query_sims: Vec<f64> = candidates
179        .iter()
180        .map(|(emb, _)| cosine_similarity(query_embedding, emb) as f64)
181        .collect();
182
183    let mut selected: Vec<usize> = Vec::with_capacity(limit); // indices into candidates
184    let mut remaining: Vec<usize> = (0..candidates.len()).collect();
185
186    for _ in 0..limit {
187        if remaining.is_empty() {
188            break;
189        }
190
191        let mut best_idx = 0;
192        let mut best_score = f64::NEG_INFINITY;
193
194        for (pos, &cand_idx) in remaining.iter().enumerate() {
195            let relevance = query_sims[cand_idx];
196
197            // Max similarity to already-selected results
198            let max_selected_sim = if selected.is_empty() {
199                0.0
200            } else {
201                selected
202                    .iter()
203                    .map(|&sel_idx| {
204                        cosine_similarity(&candidates[sel_idx].0, &candidates[cand_idx].0) as f64
205                    })
206                    .fold(f64::NEG_INFINITY, f64::max)
207            };
208
209            let mmr_score = lambda * relevance - (1.0 - lambda) * max_selected_sim;
210
211            if mmr_score > best_score {
212                best_score = mmr_score;
213                best_idx = pos;
214            }
215        }
216
217        let chosen = remaining.swap_remove(best_idx);
218        selected.push(chosen);
219    }
220
221    // Return the original result indices
222    selected
223        .iter()
224        .map(|&cand_idx| candidates[cand_idx].1)
225        .collect()
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use tempfile::TempDir;
232
233    // ── Path validation tests ───────────────────────────────────────────
234
235    #[test]
236    fn validate_path_accepts_simple_name() {
237        let tmp = TempDir::new().unwrap();
238        assert!(validate_path("notes.md", tmp.path()).is_ok());
239    }
240
241    #[test]
242    fn validate_path_accepts_subdirectory() {
243        let tmp = TempDir::new().unwrap();
244        std::fs::create_dir_all(tmp.path().join("subdir")).unwrap();
245        assert!(validate_path("subdir/notes.md", tmp.path()).is_ok());
246    }
247
248    #[test]
249    fn validate_path_rejects_traversal() {
250        let tmp = TempDir::new().unwrap();
251        assert!(validate_path("../etc/passwd.md", tmp.path()).is_err());
252        assert!(validate_path("subdir/../../secret.md", tmp.path()).is_err());
253    }
254
255    #[test]
256    fn validate_path_rejects_absolute() {
257        let tmp = TempDir::new().unwrap();
258        assert!(validate_path("/etc/passwd.md", tmp.path()).is_err());
259    }
260
261    #[test]
262    fn validate_path_rejects_non_md() {
263        let tmp = TempDir::new().unwrap();
264        assert!(validate_path("script.sh", tmp.path()).is_err());
265        assert!(validate_path("data.json", tmp.path()).is_err());
266    }
267
268    #[test]
269    fn validate_path_rejects_empty() {
270        let tmp = TempDir::new().unwrap();
271        assert!(validate_path("", tmp.path()).is_err());
272    }
273
274    #[test]
275    fn validate_path_rejects_long_name() {
276        let tmp = TempDir::new().unwrap();
277        let long_name = format!("{}.md", "a".repeat(254));
278        assert!(validate_path(&long_name, tmp.path()).is_err());
279    }
280
281    // ── Content size tests ──────────────────────────────────────────────
282
283    #[test]
284    fn validate_content_accepts_normal_size() {
285        assert!(validate_content_size("hello world").is_ok());
286    }
287
288    #[test]
289    fn validate_content_rejects_oversized() {
290        let big = "x".repeat(MAX_WRITE_SIZE + 1);
291        assert!(validate_content_size(&big).is_err());
292    }
293
294    #[test]
295    fn validate_content_accepts_exact_limit() {
296        let exact = "x".repeat(MAX_WRITE_SIZE);
297        assert!(validate_content_size(&exact).is_ok());
298    }
299
300    // ── Temporal decay tests ────────────────────────────────────────────
301
302    #[test]
303    fn decay_factor_evergreen_files() {
304        assert_eq!(decay_factor("SOUL.md", 30.0), 1.0);
305        assert_eq!(decay_factor("USER.md", 30.0), 1.0);
306        assert_eq!(decay_factor("MEMORY.md", 30.0), 1.0);
307        assert_eq!(decay_factor("HEARTBEAT.md", 30.0), 1.0);
308        // Non-evergreen, non-daily files get a slight decay
309        assert_eq!(decay_factor("notes.md", 30.0), 0.8);
310    }
311
312    #[test]
313    fn decay_factor_today() {
314        let today = Local::now().format("memory/%Y-%m-%d.md").to_string();
315        let factor = decay_factor(&today, 30.0);
316        assert!(
317            (factor - 1.0).abs() < 0.01,
318            "Today's factor should be ~1.0, got {}",
319            factor
320        );
321    }
322
323    #[test]
324    fn decay_factor_30_days_ago() {
325        let date = Local::now().date_naive() - chrono::Duration::days(30);
326        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
327        let factor = decay_factor(&source, 30.0);
328        assert!(
329            (factor - 0.5).abs() < 0.01,
330            "30-day-old factor should be ~0.5, got {}",
331            factor
332        );
333    }
334
335    #[test]
336    fn decay_factor_60_days_ago() {
337        let date = Local::now().date_naive() - chrono::Duration::days(60);
338        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
339        let factor = decay_factor(&source, 30.0);
340        assert!(
341            (factor - 0.25).abs() < 0.01,
342            "60-day-old factor should be ~0.25, got {}",
343            factor
344        );
345    }
346
347    #[test]
348    fn decay_factor_non_dated_memory() {
349        assert_eq!(decay_factor("memory/notes.md", 30.0), 0.8);
350    }
351
352    #[test]
353    fn apply_decay_worsens_old_results() {
354        // FTS5 rank of -10.0 (a decent match)
355        let rank = -10.0;
356        let date = Local::now().date_naive() - chrono::Duration::days(30);
357        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
358
359        let decayed = apply_decay(rank, &source, 30.0);
360        // Multiplying -10.0 by 0.5 = -5.0 (less negative = worse rank)
361        assert!(
362            decayed > rank,
363            "Decayed rank should be less negative (worse): {} > {}",
364            decayed,
365            rank
366        );
367        assert!((decayed - (-5.0)).abs() < 0.1);
368    }
369
370    #[test]
371    fn apply_decay_preserves_evergreen() {
372        let rank = -5.0;
373        let decayed = apply_decay(rank, "SOUL.md", 30.0);
374        assert_eq!(decayed, rank);
375    }
376
377    // ── MMR re-ranking tests ────────────────────────────────────────────
378
379    #[test]
380    fn mmr_empty_candidates() {
381        let query = vec![1.0, 0.0, 0.0];
382        assert!(mmr_rerank(&query, &[], 5, 0.7).is_empty());
383    }
384
385    #[test]
386    fn mmr_single_candidate() {
387        let query = vec![1.0, 0.0, 0.0];
388        let candidates = vec![(vec![1.0, 0.0, 0.0], 0)];
389        let selected = mmr_rerank(&query, &candidates, 5, 0.7);
390        assert_eq!(selected, vec![0]);
391    }
392
393    #[test]
394    fn mmr_selects_most_relevant_first() {
395        let query = vec![1.0, 0.0, 0.0];
396        let candidates = vec![
397            (vec![0.0, 1.0, 0.0], 0), // orthogonal to query
398            (vec![0.9, 0.1, 0.0], 1), // very similar to query
399            (vec![0.5, 0.5, 0.0], 2), // moderate similarity
400        ];
401        let selected = mmr_rerank(&query, &candidates, 3, 1.0); // lambda=1.0 = pure relevance
402        assert_eq!(selected[0], 1, "Most relevant should be first");
403    }
404
405    #[test]
406    fn mmr_promotes_diversity() {
407        let query = vec![1.0, 0.0, 0.0];
408        // Two near-identical candidates and one diverse one with moderate relevance
409        let candidates = vec![
410            (vec![1.0, 0.0, 0.0], 0),   // identical to query
411            (vec![0.99, 0.01, 0.0], 1), // near-duplicate of candidate 0
412            (vec![0.7, 0.7, 0.0], 2),   // different direction but still relevant
413        ];
414        let selected = mmr_rerank(&query, &candidates, 3, 0.3); // lambda=0.3 = diversity-heavy
415                                                                // First should be most relevant, second should be the diverse one
416        assert_eq!(selected[0], 0);
417        assert_eq!(
418            selected[1], 2,
419            "Diverse candidate should come before near-duplicate"
420        );
421    }
422
423    #[test]
424    fn mmr_respects_limit() {
425        let query = vec![1.0, 0.0];
426        let candidates: Vec<(Vec<f32>, usize)> =
427            (0..10).map(|i| (vec![1.0, i as f32 * 0.1], i)).collect();
428        let selected = mmr_rerank(&query, &candidates, 3, 0.7);
429        assert_eq!(selected.len(), 3);
430    }
431
432    #[test]
433    fn mmr_limit_zero_returns_empty() {
434        let query = vec![1.0, 0.0];
435        let candidates = vec![(vec![1.0, 0.0], 0)];
436        assert!(mmr_rerank(&query, &candidates, 0, 0.7).is_empty());
437    }
438
439    #[test]
440    fn mmr_preserves_original_indices() {
441        let query = vec![1.0, 0.0, 0.0];
442        // Indices 42 and 99 are the original result positions
443        let candidates = vec![(vec![0.9, 0.1, 0.0], 42), (vec![0.1, 0.9, 0.0], 99)];
444        let selected = mmr_rerank(&query, &candidates, 2, 1.0);
445        assert_eq!(selected[0], 42, "Should return original index 42");
446        assert_eq!(selected[1], 99);
447    }
448
449    // ── Path validation edge cases ──────────────────────────────────────
450
451    #[test]
452    fn validate_path_accepts_exact_255_chars() {
453        let tmp = TempDir::new().unwrap();
454        // 255 total: 251 'a's + ".md" = 254 — exactly at the limit
455        let name = format!("{}.md", "a".repeat(252));
456        assert_eq!(name.len(), 255);
457        assert!(validate_path(&name, tmp.path()).is_ok());
458    }
459
460    #[test]
461    fn validate_path_rejects_hidden_dotdot_in_component() {
462        let tmp = TempDir::new().unwrap();
463        // "foo/../bar.md" contains ".." even though it looks like a subdirectory
464        assert!(validate_path("foo/../bar.md", tmp.path()).is_err());
465    }
466
467    #[test]
468    fn validate_path_accepts_dotfile() {
469        let tmp = TempDir::new().unwrap();
470        // A single dot in a filename is fine (not traversal)
471        assert!(validate_path(".hidden.md", tmp.path()).is_ok());
472    }
473
474    #[test]
475    fn validate_path_rejects_md_in_middle() {
476        let tmp = TempDir::new().unwrap();
477        // Must end with .md, not just contain it
478        assert!(validate_path("notes.md.bak", tmp.path()).is_err());
479    }
480
481    // ── Temporal decay edge cases ───────────────────────────────────────
482
483    #[test]
484    fn decay_factor_future_date_returns_above_one() {
485        // A future date should not be penalized (factor >= 1.0)
486        let date = Local::now().date_naive() + chrono::Duration::days(5);
487        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
488        let factor = decay_factor(&source, 30.0);
489        assert!(
490            factor >= 1.0,
491            "Future date factor should be >= 1.0, got {}",
492            factor
493        );
494    }
495
496    #[test]
497    fn decay_factor_custom_half_life() {
498        // With half-life of 7 days, 7 days ago should give ~0.5
499        let date = Local::now().date_naive() - chrono::Duration::days(7);
500        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
501        let factor = decay_factor(&source, 7.0);
502        assert!(
503            (factor - 0.5).abs() < 0.01,
504            "7-day half-life, 7 days old should be ~0.5, got {}",
505            factor
506        );
507    }
508
509    #[test]
510    fn decay_factor_very_old_approaches_zero() {
511        let date = Local::now().date_naive() - chrono::Duration::days(365);
512        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
513        let factor = decay_factor(&source, 30.0);
514        assert!(
515            factor < 0.01,
516            "365-day-old factor should be near 0, got {}",
517            factor
518        );
519    }
520
521    #[test]
522    fn decay_factor_malformed_date_returns_default() {
523        assert_eq!(decay_factor("memory/not-a-date.md", 30.0), 0.8);
524        assert_eq!(decay_factor("memory/2026-13-45.md", 30.0), 0.8);
525    }
526
527    #[test]
528    fn apply_decay_with_factor_one_is_identity() {
529        // Evergreen files have factor 1.0 — rank should be unchanged
530        let rank = -7.5;
531        assert_eq!(apply_decay(rank, "HEARTBEAT.md", 30.0), rank);
532    }
533}