1use std::path::Path;
15
16use chrono::{Local, NaiveDate};
17
18use starpod_core::{Result, StarpodError};
19
20pub const MAX_WRITE_SIZE: usize = 1_048_576;
22
23pub fn validate_path(name: &str, data_dir: &Path) -> Result<()> {
25 if name.is_empty() {
27 return Err(StarpodError::Io(std::io::Error::new(
28 std::io::ErrorKind::InvalidInput,
29 "File name cannot be empty",
30 )));
31 }
32
33 if name.len() > 255 {
35 return Err(StarpodError::Io(std::io::Error::new(
36 std::io::ErrorKind::InvalidInput,
37 "File name exceeds 255 characters",
38 )));
39 }
40
41 if name.contains("..") {
43 return Err(StarpodError::Io(std::io::Error::new(
44 std::io::ErrorKind::InvalidInput,
45 "File name must not contain '..'",
46 )));
47 }
48
49 if Path::new(name).is_absolute() {
51 return Err(StarpodError::Io(std::io::Error::new(
52 std::io::ErrorKind::InvalidInput,
53 "File name must be a relative path",
54 )));
55 }
56
57 if !name.ends_with(".md") {
59 return Err(StarpodError::Io(std::io::Error::new(
60 std::io::ErrorKind::InvalidInput,
61 "Only .md files are allowed",
62 )));
63 }
64
65 let resolved = data_dir.join(name);
67 let canonical_data = data_dir
68 .canonicalize()
69 .unwrap_or_else(|_| data_dir.to_path_buf());
70 let canonical_resolved = if resolved.exists() {
72 resolved.canonicalize().unwrap_or(resolved)
73 } else if let Some(parent) = resolved.parent() {
74 let canonical_parent = if parent.exists() {
75 parent
76 .canonicalize()
77 .unwrap_or_else(|_| parent.to_path_buf())
78 } else {
79 parent.to_path_buf()
80 };
81 canonical_parent.join(resolved.file_name().unwrap_or_default())
82 } else {
83 resolved
84 };
85
86 if !canonical_resolved.starts_with(&canonical_data) {
87 return Err(StarpodError::Io(std::io::Error::new(
88 std::io::ErrorKind::InvalidInput,
89 "File path escapes the data directory",
90 )));
91 }
92
93 Ok(())
94}
95
96pub fn validate_content_size(content: &str) -> Result<()> {
98 if content.len() > MAX_WRITE_SIZE {
99 return Err(StarpodError::Io(std::io::Error::new(
100 std::io::ErrorKind::InvalidInput,
101 format!(
102 "Content size ({} bytes) exceeds maximum ({} bytes)",
103 content.len(),
104 MAX_WRITE_SIZE
105 ),
106 )));
107 }
108 Ok(())
109}
110
111pub fn decay_factor(source: &str, half_life_days: f64) -> f64 {
118 let evergreen_prefixes = ["SOUL.md", "USER.md", "MEMORY.md", "HEARTBEAT.md"];
120 if evergreen_prefixes.contains(&source) {
121 return 1.0;
122 }
123
124 if let Some(date_str) = source
126 .strip_prefix("memory/")
127 .and_then(|s| s.strip_suffix(".md"))
128 {
129 if let Ok(file_date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
130 let today = Local::now().date_naive();
131 let age_days = (today - file_date).num_days().max(0) as f64;
132 return 0.5_f64.powf(age_days / half_life_days);
133 }
134 }
135
136 0.8
138}
139
140pub fn apply_decay(rank: f64, source: &str, half_life_days: f64) -> f64 {
148 let factor = decay_factor(source, half_life_days);
149 if factor <= 0.0 {
150 return rank;
151 }
152 rank * factor
153}
154
155pub fn mmr_rerank(
166 query_embedding: &[f32],
167 candidates: &[(Vec<f32>, usize)], limit: usize,
169 lambda: f64,
170) -> Vec<usize> {
171 use crate::embedder::cosine_similarity;
172
173 if candidates.is_empty() || limit == 0 {
174 return Vec::new();
175 }
176
177 let query_sims: Vec<f64> = candidates
179 .iter()
180 .map(|(emb, _)| cosine_similarity(query_embedding, emb) as f64)
181 .collect();
182
183 let mut selected: Vec<usize> = Vec::with_capacity(limit); let mut remaining: Vec<usize> = (0..candidates.len()).collect();
185
186 for _ in 0..limit {
187 if remaining.is_empty() {
188 break;
189 }
190
191 let mut best_idx = 0;
192 let mut best_score = f64::NEG_INFINITY;
193
194 for (pos, &cand_idx) in remaining.iter().enumerate() {
195 let relevance = query_sims[cand_idx];
196
197 let max_selected_sim = if selected.is_empty() {
199 0.0
200 } else {
201 selected
202 .iter()
203 .map(|&sel_idx| {
204 cosine_similarity(&candidates[sel_idx].0, &candidates[cand_idx].0) as f64
205 })
206 .fold(f64::NEG_INFINITY, f64::max)
207 };
208
209 let mmr_score = lambda * relevance - (1.0 - lambda) * max_selected_sim;
210
211 if mmr_score > best_score {
212 best_score = mmr_score;
213 best_idx = pos;
214 }
215 }
216
217 let chosen = remaining.swap_remove(best_idx);
218 selected.push(chosen);
219 }
220
221 selected
223 .iter()
224 .map(|&cand_idx| candidates[cand_idx].1)
225 .collect()
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use tempfile::TempDir;
232
233 #[test]
236 fn validate_path_accepts_simple_name() {
237 let tmp = TempDir::new().unwrap();
238 assert!(validate_path("notes.md", tmp.path()).is_ok());
239 }
240
241 #[test]
242 fn validate_path_accepts_subdirectory() {
243 let tmp = TempDir::new().unwrap();
244 std::fs::create_dir_all(tmp.path().join("subdir")).unwrap();
245 assert!(validate_path("subdir/notes.md", tmp.path()).is_ok());
246 }
247
248 #[test]
249 fn validate_path_rejects_traversal() {
250 let tmp = TempDir::new().unwrap();
251 assert!(validate_path("../etc/passwd.md", tmp.path()).is_err());
252 assert!(validate_path("subdir/../../secret.md", tmp.path()).is_err());
253 }
254
255 #[test]
256 fn validate_path_rejects_absolute() {
257 let tmp = TempDir::new().unwrap();
258 assert!(validate_path("/etc/passwd.md", tmp.path()).is_err());
259 }
260
261 #[test]
262 fn validate_path_rejects_non_md() {
263 let tmp = TempDir::new().unwrap();
264 assert!(validate_path("script.sh", tmp.path()).is_err());
265 assert!(validate_path("data.json", tmp.path()).is_err());
266 }
267
268 #[test]
269 fn validate_path_rejects_empty() {
270 let tmp = TempDir::new().unwrap();
271 assert!(validate_path("", tmp.path()).is_err());
272 }
273
274 #[test]
275 fn validate_path_rejects_long_name() {
276 let tmp = TempDir::new().unwrap();
277 let long_name = format!("{}.md", "a".repeat(254));
278 assert!(validate_path(&long_name, tmp.path()).is_err());
279 }
280
281 #[test]
284 fn validate_content_accepts_normal_size() {
285 assert!(validate_content_size("hello world").is_ok());
286 }
287
288 #[test]
289 fn validate_content_rejects_oversized() {
290 let big = "x".repeat(MAX_WRITE_SIZE + 1);
291 assert!(validate_content_size(&big).is_err());
292 }
293
294 #[test]
295 fn validate_content_accepts_exact_limit() {
296 let exact = "x".repeat(MAX_WRITE_SIZE);
297 assert!(validate_content_size(&exact).is_ok());
298 }
299
300 #[test]
303 fn decay_factor_evergreen_files() {
304 assert_eq!(decay_factor("SOUL.md", 30.0), 1.0);
305 assert_eq!(decay_factor("USER.md", 30.0), 1.0);
306 assert_eq!(decay_factor("MEMORY.md", 30.0), 1.0);
307 assert_eq!(decay_factor("HEARTBEAT.md", 30.0), 1.0);
308 assert_eq!(decay_factor("notes.md", 30.0), 0.8);
310 }
311
312 #[test]
313 fn decay_factor_today() {
314 let today = Local::now().format("memory/%Y-%m-%d.md").to_string();
315 let factor = decay_factor(&today, 30.0);
316 assert!(
317 (factor - 1.0).abs() < 0.01,
318 "Today's factor should be ~1.0, got {}",
319 factor
320 );
321 }
322
323 #[test]
324 fn decay_factor_30_days_ago() {
325 let date = Local::now().date_naive() - chrono::Duration::days(30);
326 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
327 let factor = decay_factor(&source, 30.0);
328 assert!(
329 (factor - 0.5).abs() < 0.01,
330 "30-day-old factor should be ~0.5, got {}",
331 factor
332 );
333 }
334
335 #[test]
336 fn decay_factor_60_days_ago() {
337 let date = Local::now().date_naive() - chrono::Duration::days(60);
338 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
339 let factor = decay_factor(&source, 30.0);
340 assert!(
341 (factor - 0.25).abs() < 0.01,
342 "60-day-old factor should be ~0.25, got {}",
343 factor
344 );
345 }
346
347 #[test]
348 fn decay_factor_non_dated_memory() {
349 assert_eq!(decay_factor("memory/notes.md", 30.0), 0.8);
350 }
351
352 #[test]
353 fn apply_decay_worsens_old_results() {
354 let rank = -10.0;
356 let date = Local::now().date_naive() - chrono::Duration::days(30);
357 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
358
359 let decayed = apply_decay(rank, &source, 30.0);
360 assert!(
362 decayed > rank,
363 "Decayed rank should be less negative (worse): {} > {}",
364 decayed,
365 rank
366 );
367 assert!((decayed - (-5.0)).abs() < 0.1);
368 }
369
370 #[test]
371 fn apply_decay_preserves_evergreen() {
372 let rank = -5.0;
373 let decayed = apply_decay(rank, "SOUL.md", 30.0);
374 assert_eq!(decayed, rank);
375 }
376
377 #[test]
380 fn mmr_empty_candidates() {
381 let query = vec![1.0, 0.0, 0.0];
382 assert!(mmr_rerank(&query, &[], 5, 0.7).is_empty());
383 }
384
385 #[test]
386 fn mmr_single_candidate() {
387 let query = vec![1.0, 0.0, 0.0];
388 let candidates = vec![(vec![1.0, 0.0, 0.0], 0)];
389 let selected = mmr_rerank(&query, &candidates, 5, 0.7);
390 assert_eq!(selected, vec![0]);
391 }
392
393 #[test]
394 fn mmr_selects_most_relevant_first() {
395 let query = vec![1.0, 0.0, 0.0];
396 let candidates = vec![
397 (vec![0.0, 1.0, 0.0], 0), (vec![0.9, 0.1, 0.0], 1), (vec![0.5, 0.5, 0.0], 2), ];
401 let selected = mmr_rerank(&query, &candidates, 3, 1.0); assert_eq!(selected[0], 1, "Most relevant should be first");
403 }
404
405 #[test]
406 fn mmr_promotes_diversity() {
407 let query = vec![1.0, 0.0, 0.0];
408 let candidates = vec![
410 (vec![1.0, 0.0, 0.0], 0), (vec![0.99, 0.01, 0.0], 1), (vec![0.7, 0.7, 0.0], 2), ];
414 let selected = mmr_rerank(&query, &candidates, 3, 0.3); assert_eq!(selected[0], 0);
417 assert_eq!(
418 selected[1], 2,
419 "Diverse candidate should come before near-duplicate"
420 );
421 }
422
423 #[test]
424 fn mmr_respects_limit() {
425 let query = vec![1.0, 0.0];
426 let candidates: Vec<(Vec<f32>, usize)> =
427 (0..10).map(|i| (vec![1.0, i as f32 * 0.1], i)).collect();
428 let selected = mmr_rerank(&query, &candidates, 3, 0.7);
429 assert_eq!(selected.len(), 3);
430 }
431
432 #[test]
433 fn mmr_limit_zero_returns_empty() {
434 let query = vec![1.0, 0.0];
435 let candidates = vec![(vec![1.0, 0.0], 0)];
436 assert!(mmr_rerank(&query, &candidates, 0, 0.7).is_empty());
437 }
438
439 #[test]
440 fn mmr_preserves_original_indices() {
441 let query = vec![1.0, 0.0, 0.0];
442 let candidates = vec![(vec![0.9, 0.1, 0.0], 42), (vec![0.1, 0.9, 0.0], 99)];
444 let selected = mmr_rerank(&query, &candidates, 2, 1.0);
445 assert_eq!(selected[0], 42, "Should return original index 42");
446 assert_eq!(selected[1], 99);
447 }
448
449 #[test]
452 fn validate_path_accepts_exact_255_chars() {
453 let tmp = TempDir::new().unwrap();
454 let name = format!("{}.md", "a".repeat(252));
456 assert_eq!(name.len(), 255);
457 assert!(validate_path(&name, tmp.path()).is_ok());
458 }
459
460 #[test]
461 fn validate_path_rejects_hidden_dotdot_in_component() {
462 let tmp = TempDir::new().unwrap();
463 assert!(validate_path("foo/../bar.md", tmp.path()).is_err());
465 }
466
467 #[test]
468 fn validate_path_accepts_dotfile() {
469 let tmp = TempDir::new().unwrap();
470 assert!(validate_path(".hidden.md", tmp.path()).is_ok());
472 }
473
474 #[test]
475 fn validate_path_rejects_md_in_middle() {
476 let tmp = TempDir::new().unwrap();
477 assert!(validate_path("notes.md.bak", tmp.path()).is_err());
479 }
480
481 #[test]
484 fn decay_factor_future_date_returns_above_one() {
485 let date = Local::now().date_naive() + chrono::Duration::days(5);
487 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
488 let factor = decay_factor(&source, 30.0);
489 assert!(
490 factor >= 1.0,
491 "Future date factor should be >= 1.0, got {}",
492 factor
493 );
494 }
495
496 #[test]
497 fn decay_factor_custom_half_life() {
498 let date = Local::now().date_naive() - chrono::Duration::days(7);
500 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
501 let factor = decay_factor(&source, 7.0);
502 assert!(
503 (factor - 0.5).abs() < 0.01,
504 "7-day half-life, 7 days old should be ~0.5, got {}",
505 factor
506 );
507 }
508
509 #[test]
510 fn decay_factor_very_old_approaches_zero() {
511 let date = Local::now().date_naive() - chrono::Duration::days(365);
512 let source = format!("memory/{}.md", date.format("%Y-%m-%d"));
513 let factor = decay_factor(&source, 30.0);
514 assert!(
515 factor < 0.01,
516 "365-day-old factor should be near 0, got {}",
517 factor
518 );
519 }
520
521 #[test]
522 fn decay_factor_malformed_date_returns_default() {
523 assert_eq!(decay_factor("memory/not-a-date.md", 30.0), 0.8);
524 assert_eq!(decay_factor("memory/2026-13-45.md", 30.0), 0.8);
525 }
526
527 #[test]
528 fn apply_decay_with_factor_one_is_identity() {
529 let rank = -7.5;
531 assert_eq!(apply_decay(rank, "HEARTBEAT.md", 30.0), rank);
532 }
533}