1use std::path::Path;
15
16use chrono::{Local, NaiveDate};
17
18use starpod_core::{Result, StarpodError};
19
20pub const MAX_WRITE_SIZE: usize = 1_048_576;
22
23pub fn validate_path(name: &str, data_dir: &Path) -> Result<()> {
25 if name.is_empty() {
27 return Err(StarpodError::Io(std::io::Error::new(
28 std::io::ErrorKind::InvalidInput,
29 "File name cannot be empty",
30 )));
31 }
32
33 if name.len() > 255 {
35 return Err(StarpodError::Io(std::io::Error::new(
36 std::io::ErrorKind::InvalidInput,
37 "File name exceeds 255 characters",
38 )));
39 }
40
41 if name.contains("..") {
43 return Err(StarpodError::Io(std::io::Error::new(
44 std::io::ErrorKind::InvalidInput,
45 "File name must not contain '..'",
46 )));
47 }
48
49 if Path::new(name).is_absolute() {
51 return Err(StarpodError::Io(std::io::Error::new(
52 std::io::ErrorKind::InvalidInput,
53 "File name must be a relative path",
54 )));
55 }
56
57 if !name.ends_with(".md") {
59 return Err(StarpodError::Io(std::io::Error::new(
60 std::io::ErrorKind::InvalidInput,
61 "Only .md files are allowed",
62 )));
63 }
64
65 let resolved = data_dir.join(name);
67 let canonical_data = data_dir
68 .canonicalize()
69 .unwrap_or_else(|_| data_dir.to_path_buf());
70 let canonical_resolved = if resolved.exists() {
72 resolved.canonicalize().unwrap_or(resolved)
73 } else if let Some(parent) = resolved.parent() {
74 let canonical_parent = if parent.exists() {
75 parent.canonicalize().unwrap_or_else(|_| parent.to_path_buf())
76 } else {
77 parent.to_path_buf()
78 };
79 canonical_parent.join(resolved.file_name().unwrap_or_default())
80 } else {
81 resolved
82 };
83
84 if !canonical_resolved.starts_with(&canonical_data) {
85 return Err(StarpodError::Io(std::io::Error::new(
86 std::io::ErrorKind::InvalidInput,
87 "File path escapes the data directory",
88 )));
89 }
90
91 Ok(())
92}
93
94pub fn validate_content_size(content: &str) -> Result<()> {
96 if content.len() > MAX_WRITE_SIZE {
97 return Err(StarpodError::Io(std::io::Error::new(
98 std::io::ErrorKind::InvalidInput,
99 format!(
100 "Content size ({} bytes) exceeds maximum ({} bytes)",
101 content.len(),
102 MAX_WRITE_SIZE
103 ),
104 )));
105 }
106 Ok(())
107}
108
109pub fn decay_factor(source: &str, half_life_days: f64) -> f64 {
116 let evergreen_prefixes = ["SOUL.md", "USER.md", "MEMORY.md", "HEARTBEAT.md"];
118 if evergreen_prefixes.contains(&source) {
119 return 1.0;
120 }
121
122 if let Some(date_str) = source
124 .strip_prefix("memory/")
125 .and_then(|s| s.strip_suffix(".md"))
126 {
127 if let Ok(file_date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
128 let today = Local::now().date_naive();
129 let age_days = (today - file_date).num_days().max(0) as f64;
130 return 0.5_f64.powf(age_days / half_life_days);
131 }
132 }
133
134 0.8
136}
137
138pub fn apply_decay(rank: f64, source: &str, half_life_days: f64) -> f64 {
146 let factor = decay_factor(source, half_life_days);
147 if factor <= 0.0 {
148 return rank;
149 }
150 rank * factor
151}
152
153pub fn mmr_rerank(
164 query_embedding: &[f32],
165 candidates: &[(Vec<f32>, usize)], limit: usize,
167 lambda: f64,
168) -> Vec<usize> {
169 use crate::embedder::cosine_similarity;
170
171 if candidates.is_empty() || limit == 0 {
172 return Vec::new();
173 }
174
175 let query_sims: Vec<f64> = candidates
177 .iter()
178 .map(|(emb, _)| cosine_similarity(query_embedding, emb) as f64)
179 .collect();
180
181 let mut selected: Vec<usize> = Vec::with_capacity(limit); let mut remaining: Vec<usize> = (0..candidates.len()).collect();
183
184 for _ in 0..limit {
185 if remaining.is_empty() {
186 break;
187 }
188
189 let mut best_idx = 0;
190 let mut best_score = f64::NEG_INFINITY;
191
192 for (pos, &cand_idx) in remaining.iter().enumerate() {
193 let relevance = query_sims[cand_idx];
194
195 let max_selected_sim = if selected.is_empty() {
197 0.0
198 } else {
199 selected
200 .iter()
201 .map(|&sel_idx| {
202 cosine_similarity(&candidates[sel_idx].0, &candidates[cand_idx].0) as f64
203 })
204 .fold(f64::NEG_INFINITY, f64::max)
205 };
206
207 let mmr_score = lambda * relevance - (1.0 - lambda) * max_selected_sim;
208
209 if mmr_score > best_score {
210 best_score = mmr_score;
211 best_idx = pos;
212 }
213 }
214
215 let chosen = remaining.swap_remove(best_idx);
216 selected.push(chosen);
217 }
218
219 selected
221 .iter()
222 .map(|&cand_idx| candidates[cand_idx].1)
223 .collect()
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use tempfile::TempDir;
230
231 #[test]
234 fn validate_path_accepts_simple_name() {
235 let tmp = TempDir::new().unwrap();
236 assert!(validate_path("notes.md", tmp.path()).is_ok());
237 }
238
239 #[test]
240 fn validate_path_accepts_subdirectory() {
241 let tmp = TempDir::new().unwrap();
242 std::fs::create_dir_all(tmp.path().join("subdir")).unwrap();
243 assert!(validate_path("subdir/notes.md", tmp.path()).is_ok());
244 }
245
246 #[test]
247 fn validate_path_rejects_traversal() {
248 let tmp = TempDir::new().unwrap();
249 assert!(validate_path("../etc/passwd.md", tmp.path()).is_err());
250 assert!(validate_path("subdir/../../secret.md", tmp.path()).is_err());
251 }
252
253 #[test]
254 fn validate_path_rejects_absolute() {
255 let tmp = TempDir::new().unwrap();
256 assert!(validate_path("/etc/passwd.md", tmp.path()).is_err());
257 }
258
259 #[test]
260 fn validate_path_rejects_non_md() {
261 let tmp = TempDir::new().unwrap();
262 assert!(validate_path("script.sh", tmp.path()).is_err());
263 assert!(validate_path("data.json", tmp.path()).is_err());
264 }
265
266 #[test]
267 fn validate_path_rejects_empty() {
268 let tmp = TempDir::new().unwrap();
269 assert!(validate_path("", tmp.path()).is_err());
270 }
271
272 #[test]
273 fn validate_path_rejects_long_name() {
274 let tmp = TempDir::new().unwrap();
275 let long_name = format!("{}.md", "a".repeat(254));
276 assert!(validate_path(&long_name, tmp.path()).is_err());
277 }
278
279 #[test]
282 fn validate_content_accepts_normal_size() {
283 assert!(validate_content_size("hello world").is_ok());
284 }
285
286 #[test]
287 fn validate_content_rejects_oversized() {
288 let big = "x".repeat(MAX_WRITE_SIZE + 1);
289 assert!(validate_content_size(&big).is_err());
290 }
291
292 #[test]
293 fn validate_content_accepts_exact_limit() {
294 let exact = "x".repeat(MAX_WRITE_SIZE);
295 assert!(validate_content_size(&exact).is_ok());
296 }
297
298 #[test]
301 fn decay_factor_evergreen_files() {
302 assert_eq!(decay_factor("SOUL.md", 30.0), 1.0);
303 assert_eq!(decay_factor("USER.md", 30.0), 1.0);
304 assert_eq!(decay_factor("MEMORY.md", 30.0), 1.0);
305 assert_eq!(decay_factor("HEARTBEAT.md", 30.0), 1.0);
306 assert_eq!(decay_factor("notes.md", 30.0), 0.8);
308 }
309
310 #[test]
311 fn decay_factor_today() {
312 let today = Local::now().format("memory/%Y-%m-%d.md").to_string();
313 let factor = decay_factor(&today, 30.0);
314 assert!((factor - 1.0).abs() < 0.01, "Today's factor should be ~1.0, got {}", factor);
315 }
316
317 #[test]
318 fn decay_factor_30_days_ago() {
319 let date = Local::now().date_naive() - chrono::Duration::days(30);
320 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
321 let factor = decay_factor(&source, 30.0);
322 assert!((factor - 0.5).abs() < 0.01, "30-day-old factor should be ~0.5, got {}", factor);
323 }
324
325 #[test]
326 fn decay_factor_60_days_ago() {
327 let date = Local::now().date_naive() - chrono::Duration::days(60);
328 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
329 let factor = decay_factor(&source, 30.0);
330 assert!((factor - 0.25).abs() < 0.01, "60-day-old factor should be ~0.25, got {}", factor);
331 }
332
333 #[test]
334 fn decay_factor_non_dated_memory() {
335 assert_eq!(decay_factor("memory/notes.md", 30.0), 0.8);
336 }
337
338 #[test]
339 fn apply_decay_worsens_old_results() {
340 let rank = -10.0;
342 let date = Local::now().date_naive() - chrono::Duration::days(30);
343 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
344
345 let decayed = apply_decay(rank, &source, 30.0);
346 assert!(decayed > rank, "Decayed rank should be less negative (worse): {} > {}", decayed, rank);
348 assert!((decayed - (-5.0)).abs() < 0.1);
349 }
350
351 #[test]
352 fn apply_decay_preserves_evergreen() {
353 let rank = -5.0;
354 let decayed = apply_decay(rank, "SOUL.md", 30.0);
355 assert_eq!(decayed, rank);
356 }
357
358 #[test]
361 fn mmr_empty_candidates() {
362 let query = vec![1.0, 0.0, 0.0];
363 assert!(mmr_rerank(&query, &[], 5, 0.7).is_empty());
364 }
365
366 #[test]
367 fn mmr_single_candidate() {
368 let query = vec![1.0, 0.0, 0.0];
369 let candidates = vec![(vec![1.0, 0.0, 0.0], 0)];
370 let selected = mmr_rerank(&query, &candidates, 5, 0.7);
371 assert_eq!(selected, vec![0]);
372 }
373
374 #[test]
375 fn mmr_selects_most_relevant_first() {
376 let query = vec![1.0, 0.0, 0.0];
377 let candidates = vec![
378 (vec![0.0, 1.0, 0.0], 0), (vec![0.9, 0.1, 0.0], 1), (vec![0.5, 0.5, 0.0], 2), ];
382 let selected = mmr_rerank(&query, &candidates, 3, 1.0); assert_eq!(selected[0], 1, "Most relevant should be first");
384 }
385
386 #[test]
387 fn mmr_promotes_diversity() {
388 let query = vec![1.0, 0.0, 0.0];
389 let candidates = vec![
391 (vec![1.0, 0.0, 0.0], 0), (vec![0.99, 0.01, 0.0], 1), (vec![0.7, 0.7, 0.0], 2), ];
395 let selected = mmr_rerank(&query, &candidates, 3, 0.3); assert_eq!(selected[0], 0);
398 assert_eq!(selected[1], 2, "Diverse candidate should come before near-duplicate");
399 }
400
401 #[test]
402 fn mmr_respects_limit() {
403 let query = vec![1.0, 0.0];
404 let candidates: Vec<(Vec<f32>, usize)> = (0..10)
405 .map(|i| (vec![1.0, i as f32 * 0.1], i))
406 .collect();
407 let selected = mmr_rerank(&query, &candidates, 3, 0.7);
408 assert_eq!(selected.len(), 3);
409 }
410
411 #[test]
412 fn mmr_limit_zero_returns_empty() {
413 let query = vec![1.0, 0.0];
414 let candidates = vec![(vec![1.0, 0.0], 0)];
415 assert!(mmr_rerank(&query, &candidates, 0, 0.7).is_empty());
416 }
417
418 #[test]
419 fn mmr_preserves_original_indices() {
420 let query = vec![1.0, 0.0, 0.0];
421 let candidates = vec![
423 (vec![0.9, 0.1, 0.0], 42),
424 (vec![0.1, 0.9, 0.0], 99),
425 ];
426 let selected = mmr_rerank(&query, &candidates, 2, 1.0);
427 assert_eq!(selected[0], 42, "Should return original index 42");
428 assert_eq!(selected[1], 99);
429 }
430
431 #[test]
434 fn validate_path_accepts_exact_255_chars() {
435 let tmp = TempDir::new().unwrap();
436 let name = format!("{}.md", "a".repeat(252));
438 assert_eq!(name.len(), 255);
439 assert!(validate_path(&name, tmp.path()).is_ok());
440 }
441
442 #[test]
443 fn validate_path_rejects_hidden_dotdot_in_component() {
444 let tmp = TempDir::new().unwrap();
445 assert!(validate_path("foo/../bar.md", tmp.path()).is_err());
447 }
448
449 #[test]
450 fn validate_path_accepts_dotfile() {
451 let tmp = TempDir::new().unwrap();
452 assert!(validate_path(".hidden.md", tmp.path()).is_ok());
454 }
455
456 #[test]
457 fn validate_path_rejects_md_in_middle() {
458 let tmp = TempDir::new().unwrap();
459 assert!(validate_path("notes.md.bak", tmp.path()).is_err());
461 }
462
463 #[test]
466 fn decay_factor_future_date_returns_above_one() {
467 let date = Local::now().date_naive() + chrono::Duration::days(5);
469 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
470 let factor = decay_factor(&source, 30.0);
471 assert!(factor >= 1.0, "Future date factor should be >= 1.0, got {}", factor);
472 }
473
474 #[test]
475 fn decay_factor_custom_half_life() {
476 let date = Local::now().date_naive() - chrono::Duration::days(7);
478 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
479 let factor = decay_factor(&source, 7.0);
480 assert!((factor - 0.5).abs() < 0.01, "7-day half-life, 7 days old should be ~0.5, got {}", factor);
481 }
482
483 #[test]
484 fn decay_factor_very_old_approaches_zero() {
485 let date = Local::now().date_naive() - chrono::Duration::days(365);
486 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
487 let factor = decay_factor(&source, 30.0);
488 assert!(factor < 0.01, "365-day-old factor should be near 0, got {}", factor);
489 }
490
491 #[test]
492 fn decay_factor_malformed_date_returns_default() {
493 assert_eq!(decay_factor("memory/not-a-date.md", 30.0), 0.8);
494 assert_eq!(decay_factor("memory/2026-13-45.md", 30.0), 0.8);
495 }
496
497 #[test]
498 fn apply_decay_with_factor_one_is_identity() {
499 let rank = -7.5;
501 assert_eq!(apply_decay(rank, "HEARTBEAT.md", 30.0), rank);
502 }
503}