Skip to main content

starpod_memory/
scoring.rs

1//! Scoring, validation, and re-ranking utilities for the memory search pipeline.
2//!
3//! This module provides:
4//!
5//! - **Path validation** ([`validate_path`], [`validate_content_size`]) — security
6//!   checks preventing directory traversal, non-`.md` writes, and oversized content.
7//! - **Temporal decay** ([`decay_factor`], [`apply_decay`]) — penalizes older daily
8//!   logs in search results using exponential half-life decay while leaving evergreen
9//!   files (SOUL.md, HEARTBEAT.md, etc.) unaffected.
10//! - **MMR re-ranking** ([`mmr_rerank`]) — Maximal Marginal Relevance diversifies
11//!   search results by balancing query relevance against redundancy with
12//!   already-selected results.
13
14use std::path::Path;
15
16use chrono::{Local, NaiveDate};
17
18use starpod_core::{Result, StarpodError};
19
20/// Maximum file size for memory writes (1 MB).
21pub const MAX_WRITE_SIZE: usize = 1_048_576;
22
23/// Validate a memory file path, rejecting traversal attacks and unsafe names.
24pub fn validate_path(name: &str, data_dir: &Path) -> Result<()> {
25    // Reject empty names
26    if name.is_empty() {
27        return Err(StarpodError::Io(std::io::Error::new(
28            std::io::ErrorKind::InvalidInput,
29            "File name cannot be empty",
30        )));
31    }
32
33    // Reject names that are too long
34    if name.len() > 255 {
35        return Err(StarpodError::Io(std::io::Error::new(
36            std::io::ErrorKind::InvalidInput,
37            "File name exceeds 255 characters",
38        )));
39    }
40
41    // Reject path traversal
42    if name.contains("..") {
43        return Err(StarpodError::Io(std::io::Error::new(
44            std::io::ErrorKind::InvalidInput,
45            "File name must not contain '..'",
46        )));
47    }
48
49    // Reject absolute paths
50    if Path::new(name).is_absolute() {
51        return Err(StarpodError::Io(std::io::Error::new(
52            std::io::ErrorKind::InvalidInput,
53            "File name must be a relative path",
54        )));
55    }
56
57    // Require .md extension
58    if !name.ends_with(".md") {
59        return Err(StarpodError::Io(std::io::Error::new(
60            std::io::ErrorKind::InvalidInput,
61            "Only .md files are allowed",
62        )));
63    }
64
65    // Ensure the resolved path stays under data_dir
66    let resolved = data_dir.join(name);
67    let canonical_data = data_dir
68        .canonicalize()
69        .unwrap_or_else(|_| data_dir.to_path_buf());
70    // Use the parent's canonical path if the file doesn't exist yet
71    let canonical_resolved = if resolved.exists() {
72        resolved.canonicalize().unwrap_or(resolved)
73    } else if let Some(parent) = resolved.parent() {
74        let canonical_parent = if parent.exists() {
75            parent.canonicalize().unwrap_or_else(|_| parent.to_path_buf())
76        } else {
77            parent.to_path_buf()
78        };
79        canonical_parent.join(resolved.file_name().unwrap_or_default())
80    } else {
81        resolved
82    };
83
84    if !canonical_resolved.starts_with(&canonical_data) {
85        return Err(StarpodError::Io(std::io::Error::new(
86            std::io::ErrorKind::InvalidInput,
87            "File path escapes the data directory",
88        )));
89    }
90
91    Ok(())
92}
93
94/// Validate content size for writes.
95pub fn validate_content_size(content: &str) -> Result<()> {
96    if content.len() > MAX_WRITE_SIZE {
97        return Err(StarpodError::Io(std::io::Error::new(
98            std::io::ErrorKind::InvalidInput,
99            format!(
100                "Content size ({} bytes) exceeds maximum ({} bytes)",
101                content.len(),
102                MAX_WRITE_SIZE
103            ),
104        )));
105    }
106    Ok(())
107}
108
109/// Compute a temporal decay factor for a memory source.
110///
111/// Daily log files (`memory/YYYY-MM-DD.md`) decay with a half-life:
112///   `decay = 0.5^(age_days / half_life_days)`
113///
114/// Evergreen files (SOUL.md, USER.md, MEMORY.md, HEARTBEAT.md, etc.) return 1.0.
115pub fn decay_factor(source: &str, half_life_days: f64) -> f64 {
116    // Evergreen files don't decay
117    let evergreen_prefixes = ["SOUL.md", "USER.md", "MEMORY.md", "HEARTBEAT.md"];
118    if evergreen_prefixes.contains(&source) {
119        return 1.0;
120    }
121
122    // Try to parse date from memory/YYYY-MM-DD.md pattern
123    if let Some(date_str) = source
124        .strip_prefix("memory/")
125        .and_then(|s| s.strip_suffix(".md"))
126    {
127        if let Ok(file_date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
128            let today = Local::now().date_naive();
129            let age_days = (today - file_date).num_days().max(0) as f64;
130            return 0.5_f64.powf(age_days / half_life_days);
131        }
132    }
133
134    // Non-daily, non-evergreen files (e.g. memory/notes.md) — slight decay
135    0.8
136}
137
138/// Apply temporal decay to an FTS5 rank score.
139///
140/// FTS5 ranks are negative (more negative = better match). Multiplying by a
141/// decay factor < 1.0 makes the score less negative (closer to zero = worse),
142/// effectively penalizing older content.
143///
144/// Example: rank = -10.0, decay = 0.5 → adjusted = -5.0 (worse match).
145pub fn apply_decay(rank: f64, source: &str, half_life_days: f64) -> f64 {
146    let factor = decay_factor(source, half_life_days);
147    if factor <= 0.0 {
148        return rank;
149    }
150    rank * factor
151}
152
153/// Maximal Marginal Relevance (MMR) re-ranking for result diversity.
154///
155/// Given a query embedding and candidate results with their embeddings,
156/// iteratively selects candidates that maximize:
157///   `lambda * sim(candidate, query) - (1 - lambda) * max_sim(candidate, selected)`
158///
159/// This balances relevance (similarity to query) with diversity (dissimilarity
160/// from already-selected results).
161///
162/// `lambda`: 0.0 = maximum diversity, 1.0 = pure relevance. Default: 0.7.
163pub fn mmr_rerank(
164    query_embedding: &[f32],
165    candidates: &[(Vec<f32>, usize)], // (embedding, index into results)
166    limit: usize,
167    lambda: f64,
168) -> Vec<usize> {
169    use crate::embedder::cosine_similarity;
170
171    if candidates.is_empty() || limit == 0 {
172        return Vec::new();
173    }
174
175    // Pre-compute similarities to query
176    let query_sims: Vec<f64> = candidates
177        .iter()
178        .map(|(emb, _)| cosine_similarity(query_embedding, emb) as f64)
179        .collect();
180
181    let mut selected: Vec<usize> = Vec::with_capacity(limit); // indices into candidates
182    let mut remaining: Vec<usize> = (0..candidates.len()).collect();
183
184    for _ in 0..limit {
185        if remaining.is_empty() {
186            break;
187        }
188
189        let mut best_idx = 0;
190        let mut best_score = f64::NEG_INFINITY;
191
192        for (pos, &cand_idx) in remaining.iter().enumerate() {
193            let relevance = query_sims[cand_idx];
194
195            // Max similarity to already-selected results
196            let max_selected_sim = if selected.is_empty() {
197                0.0
198            } else {
199                selected
200                    .iter()
201                    .map(|&sel_idx| {
202                        cosine_similarity(&candidates[sel_idx].0, &candidates[cand_idx].0) as f64
203                    })
204                    .fold(f64::NEG_INFINITY, f64::max)
205            };
206
207            let mmr_score = lambda * relevance - (1.0 - lambda) * max_selected_sim;
208
209            if mmr_score > best_score {
210                best_score = mmr_score;
211                best_idx = pos;
212            }
213        }
214
215        let chosen = remaining.swap_remove(best_idx);
216        selected.push(chosen);
217    }
218
219    // Return the original result indices
220    selected
221        .iter()
222        .map(|&cand_idx| candidates[cand_idx].1)
223        .collect()
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use tempfile::TempDir;
230
231    // ── Path validation tests ───────────────────────────────────────────
232
233    #[test]
234    fn validate_path_accepts_simple_name() {
235        let tmp = TempDir::new().unwrap();
236        assert!(validate_path("notes.md", tmp.path()).is_ok());
237    }
238
239    #[test]
240    fn validate_path_accepts_subdirectory() {
241        let tmp = TempDir::new().unwrap();
242        std::fs::create_dir_all(tmp.path().join("subdir")).unwrap();
243        assert!(validate_path("subdir/notes.md", tmp.path()).is_ok());
244    }
245
246    #[test]
247    fn validate_path_rejects_traversal() {
248        let tmp = TempDir::new().unwrap();
249        assert!(validate_path("../etc/passwd.md", tmp.path()).is_err());
250        assert!(validate_path("subdir/../../secret.md", tmp.path()).is_err());
251    }
252
253    #[test]
254    fn validate_path_rejects_absolute() {
255        let tmp = TempDir::new().unwrap();
256        assert!(validate_path("/etc/passwd.md", tmp.path()).is_err());
257    }
258
259    #[test]
260    fn validate_path_rejects_non_md() {
261        let tmp = TempDir::new().unwrap();
262        assert!(validate_path("script.sh", tmp.path()).is_err());
263        assert!(validate_path("data.json", tmp.path()).is_err());
264    }
265
266    #[test]
267    fn validate_path_rejects_empty() {
268        let tmp = TempDir::new().unwrap();
269        assert!(validate_path("", tmp.path()).is_err());
270    }
271
272    #[test]
273    fn validate_path_rejects_long_name() {
274        let tmp = TempDir::new().unwrap();
275        let long_name = format!("{}.md", "a".repeat(254));
276        assert!(validate_path(&long_name, tmp.path()).is_err());
277    }
278
279    // ── Content size tests ──────────────────────────────────────────────
280
281    #[test]
282    fn validate_content_accepts_normal_size() {
283        assert!(validate_content_size("hello world").is_ok());
284    }
285
286    #[test]
287    fn validate_content_rejects_oversized() {
288        let big = "x".repeat(MAX_WRITE_SIZE + 1);
289        assert!(validate_content_size(&big).is_err());
290    }
291
292    #[test]
293    fn validate_content_accepts_exact_limit() {
294        let exact = "x".repeat(MAX_WRITE_SIZE);
295        assert!(validate_content_size(&exact).is_ok());
296    }
297
298    // ── Temporal decay tests ────────────────────────────────────────────
299
300    #[test]
301    fn decay_factor_evergreen_files() {
302        assert_eq!(decay_factor("SOUL.md", 30.0), 1.0);
303        assert_eq!(decay_factor("USER.md", 30.0), 1.0);
304        assert_eq!(decay_factor("MEMORY.md", 30.0), 1.0);
305        assert_eq!(decay_factor("HEARTBEAT.md", 30.0), 1.0);
306        // Non-evergreen, non-daily files get a slight decay
307        assert_eq!(decay_factor("notes.md", 30.0), 0.8);
308    }
309
310    #[test]
311    fn decay_factor_today() {
312        let today = Local::now().format("memory/%Y-%m-%d.md").to_string();
313        let factor = decay_factor(&today, 30.0);
314        assert!((factor - 1.0).abs() < 0.01, "Today's factor should be ~1.0, got {}", factor);
315    }
316
317    #[test]
318    fn decay_factor_30_days_ago() {
319        let date = Local::now().date_naive() - chrono::Duration::days(30);
320        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
321        let factor = decay_factor(&source, 30.0);
322        assert!((factor - 0.5).abs() < 0.01, "30-day-old factor should be ~0.5, got {}", factor);
323    }
324
325    #[test]
326    fn decay_factor_60_days_ago() {
327        let date = Local::now().date_naive() - chrono::Duration::days(60);
328        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
329        let factor = decay_factor(&source, 30.0);
330        assert!((factor - 0.25).abs() < 0.01, "60-day-old factor should be ~0.25, got {}", factor);
331    }
332
333    #[test]
334    fn decay_factor_non_dated_memory() {
335        assert_eq!(decay_factor("memory/notes.md", 30.0), 0.8);
336    }
337
338    #[test]
339    fn apply_decay_worsens_old_results() {
340        // FTS5 rank of -10.0 (a decent match)
341        let rank = -10.0;
342        let date = Local::now().date_naive() - chrono::Duration::days(30);
343        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
344
345        let decayed = apply_decay(rank, &source, 30.0);
346        // Multiplying -10.0 by 0.5 = -5.0 (less negative = worse rank)
347        assert!(decayed > rank, "Decayed rank should be less negative (worse): {} > {}", decayed, rank);
348        assert!((decayed - (-5.0)).abs() < 0.1);
349    }
350
351    #[test]
352    fn apply_decay_preserves_evergreen() {
353        let rank = -5.0;
354        let decayed = apply_decay(rank, "SOUL.md", 30.0);
355        assert_eq!(decayed, rank);
356    }
357
358    // ── MMR re-ranking tests ────────────────────────────────────────────
359
360    #[test]
361    fn mmr_empty_candidates() {
362        let query = vec![1.0, 0.0, 0.0];
363        assert!(mmr_rerank(&query, &[], 5, 0.7).is_empty());
364    }
365
366    #[test]
367    fn mmr_single_candidate() {
368        let query = vec![1.0, 0.0, 0.0];
369        let candidates = vec![(vec![1.0, 0.0, 0.0], 0)];
370        let selected = mmr_rerank(&query, &candidates, 5, 0.7);
371        assert_eq!(selected, vec![0]);
372    }
373
374    #[test]
375    fn mmr_selects_most_relevant_first() {
376        let query = vec![1.0, 0.0, 0.0];
377        let candidates = vec![
378            (vec![0.0, 1.0, 0.0], 0), // orthogonal to query
379            (vec![0.9, 0.1, 0.0], 1), // very similar to query
380            (vec![0.5, 0.5, 0.0], 2), // moderate similarity
381        ];
382        let selected = mmr_rerank(&query, &candidates, 3, 1.0); // lambda=1.0 = pure relevance
383        assert_eq!(selected[0], 1, "Most relevant should be first");
384    }
385
386    #[test]
387    fn mmr_promotes_diversity() {
388        let query = vec![1.0, 0.0, 0.0];
389        // Two near-identical candidates and one diverse one with moderate relevance
390        let candidates = vec![
391            (vec![1.0, 0.0, 0.0], 0),  // identical to query
392            (vec![0.99, 0.01, 0.0], 1), // near-duplicate of candidate 0
393            (vec![0.7, 0.7, 0.0], 2),   // different direction but still relevant
394        ];
395        let selected = mmr_rerank(&query, &candidates, 3, 0.3); // lambda=0.3 = diversity-heavy
396        // First should be most relevant, second should be the diverse one
397        assert_eq!(selected[0], 0);
398        assert_eq!(selected[1], 2, "Diverse candidate should come before near-duplicate");
399    }
400
401    #[test]
402    fn mmr_respects_limit() {
403        let query = vec![1.0, 0.0];
404        let candidates: Vec<(Vec<f32>, usize)> = (0..10)
405            .map(|i| (vec![1.0, i as f32 * 0.1], i))
406            .collect();
407        let selected = mmr_rerank(&query, &candidates, 3, 0.7);
408        assert_eq!(selected.len(), 3);
409    }
410
411    #[test]
412    fn mmr_limit_zero_returns_empty() {
413        let query = vec![1.0, 0.0];
414        let candidates = vec![(vec![1.0, 0.0], 0)];
415        assert!(mmr_rerank(&query, &candidates, 0, 0.7).is_empty());
416    }
417
418    #[test]
419    fn mmr_preserves_original_indices() {
420        let query = vec![1.0, 0.0, 0.0];
421        // Indices 42 and 99 are the original result positions
422        let candidates = vec![
423            (vec![0.9, 0.1, 0.0], 42),
424            (vec![0.1, 0.9, 0.0], 99),
425        ];
426        let selected = mmr_rerank(&query, &candidates, 2, 1.0);
427        assert_eq!(selected[0], 42, "Should return original index 42");
428        assert_eq!(selected[1], 99);
429    }
430
431    // ── Path validation edge cases ──────────────────────────────────────
432
433    #[test]
434    fn validate_path_accepts_exact_255_chars() {
435        let tmp = TempDir::new().unwrap();
436        // 255 total: 251 'a's + ".md" = 254 — exactly at the limit
437        let name = format!("{}.md", "a".repeat(252));
438        assert_eq!(name.len(), 255);
439        assert!(validate_path(&name, tmp.path()).is_ok());
440    }
441
442    #[test]
443    fn validate_path_rejects_hidden_dotdot_in_component() {
444        let tmp = TempDir::new().unwrap();
445        // "foo/../bar.md" contains ".." even though it looks like a subdirectory
446        assert!(validate_path("foo/../bar.md", tmp.path()).is_err());
447    }
448
449    #[test]
450    fn validate_path_accepts_dotfile() {
451        let tmp = TempDir::new().unwrap();
452        // A single dot in a filename is fine (not traversal)
453        assert!(validate_path(".hidden.md", tmp.path()).is_ok());
454    }
455
456    #[test]
457    fn validate_path_rejects_md_in_middle() {
458        let tmp = TempDir::new().unwrap();
459        // Must end with .md, not just contain it
460        assert!(validate_path("notes.md.bak", tmp.path()).is_err());
461    }
462
463    // ── Temporal decay edge cases ───────────────────────────────────────
464
465    #[test]
466    fn decay_factor_future_date_returns_above_one() {
467        // A future date should not be penalized (factor >= 1.0)
468        let date = Local::now().date_naive() + chrono::Duration::days(5);
469        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
470        let factor = decay_factor(&source, 30.0);
471        assert!(factor >= 1.0, "Future date factor should be >= 1.0, got {}", factor);
472    }
473
474    #[test]
475    fn decay_factor_custom_half_life() {
476        // With half-life of 7 days, 7 days ago should give ~0.5
477        let date = Local::now().date_naive() - chrono::Duration::days(7);
478        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
479        let factor = decay_factor(&source, 7.0);
480        assert!((factor - 0.5).abs() < 0.01, "7-day half-life, 7 days old should be ~0.5, got {}", factor);
481    }
482
483    #[test]
484    fn decay_factor_very_old_approaches_zero() {
485        let date = Local::now().date_naive() - chrono::Duration::days(365);
486        let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
487        let factor = decay_factor(&source, 30.0);
488        assert!(factor < 0.01, "365-day-old factor should be near 0, got {}", factor);
489    }
490
491    #[test]
492    fn decay_factor_malformed_date_returns_default() {
493        assert_eq!(decay_factor("memory/not-a-date.md", 30.0), 0.8);
494        assert_eq!(decay_factor("memory/2026-13-45.md", 30.0), 0.8);
495    }
496
497    #[test]
498    fn apply_decay_with_factor_one_is_identity() {
499        // Evergreen files have factor 1.0 — rank should be unchanged
500        let rank = -7.5;
501        assert_eq!(apply_decay(rank, "HEARTBEAT.md", 30.0), rank);
502    }
503}