Skip to main content

shodh_memory/
validation.rs

1//! Input validation for enterprise security
2//! Prevents injection attacks, ensures data integrity, protects against ReDoS
3
4use anyhow::{anyhow, Result};
5use regex::Regex;
6
7/// Maximum lengths for security
8pub const MAX_USER_ID_LENGTH: usize = 128;
9pub const MAX_CONTENT_LENGTH: usize = 50_000; // 50KB
10pub const MAX_PATTERN_LENGTH: usize = 256; // Max regex pattern length
11pub const MAX_ENTITY_LENGTH: usize = 256; // Max entity name length
12#[allow(unused)] // Public API - available for validation
13pub const MAX_METADATA_SIZE: usize = 10_000; // Max metadata JSON size (10KB)
14#[allow(unused)] // Public API - available for validation
15pub const MAX_ENTITIES_PER_MEMORY: usize = 50; // Max entities per memory
16
17/// Validate user_id
18pub fn validate_user_id(user_id: &str) -> Result<()> {
19    if user_id.is_empty() {
20        return Err(anyhow!("user_id cannot be empty"));
21    }
22
23    if user_id.len() > MAX_USER_ID_LENGTH {
24        return Err(anyhow!(
25            "user_id too long: {} chars (max: {})",
26            user_id.len(),
27            MAX_USER_ID_LENGTH
28        ));
29    }
30
31    // Only allow alphanumeric, dash, underscore
32    if !user_id
33        .chars()
34        .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '@' || c == '.')
35    {
36        return Err(anyhow!(
37            "user_id contains invalid characters (allowed: alphanumeric, -, _, @, .)"
38        ));
39    }
40
41    // Prevent path traversal attacks (.. sequences)
42    if user_id.contains("..") {
43        return Err(anyhow!(
44            "user_id contains invalid path traversal sequence (..)"
45        ));
46    }
47
48    // Reject leading/trailing dots which could be problematic on some filesystems
49    if user_id.starts_with('.') || user_id.ends_with('.') {
50        return Err(anyhow!("user_id cannot start or end with a dot"));
51    }
52
53    // Reject absolute paths — PathBuf::join with an absolute path ignores the base
54    if std::path::Path::new(user_id).is_absolute() {
55        return Err(anyhow!("user_id cannot be an absolute path"));
56    }
57
58    // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
59    // These cause issues when used as directory names on Windows
60    {
61        let upper = user_id.to_uppercase();
62        // Strip any extension (e.g., "CON.txt" is still reserved on Windows)
63        let stem = upper.split('.').next().unwrap_or(&upper);
64        const DEVICE_NAMES: &[&str] = &[
65            "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7",
66            "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
67        ];
68        if DEVICE_NAMES.contains(&stem) {
69            return Err(anyhow!(
70                "user_id cannot be a Windows reserved device name: {}",
71                user_id
72            ));
73        }
74    }
75
76    Ok(())
77}
78
79/// Validate memory_id (UUID format)
80pub fn validate_memory_id(memory_id: &str) -> Result<uuid::Uuid> {
81    uuid::Uuid::parse_str(memory_id).map_err(|e| anyhow!("Invalid memory_id UUID format: {e}"))
82}
83
84/// Validate memory_id as either a full UUID or a hex prefix (8+ chars).
85///
86/// Returns `Ok(Some(uuid))` for valid full UUIDs, `Ok(None)` for valid hex prefixes
87/// that require resolution against stored memories.
88pub fn validate_memory_id_or_prefix(memory_id: &str) -> Result<Option<uuid::Uuid>> {
89    // Fast path: try full UUID first
90    if let Ok(uuid) = uuid::Uuid::parse_str(memory_id) {
91        return Ok(Some(uuid));
92    }
93
94    // Validate as hex prefix: minimum 8 chars, all hex digits
95    let trimmed = memory_id.trim();
96    if trimmed.len() < 8 {
97        return Err(anyhow!(
98            "Memory ID must be a full UUID or at least 8 hex characters, got {} chars",
99            trimmed.len()
100        ));
101    }
102
103    if !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
104        return Err(anyhow!(
105            "Memory ID prefix contains invalid characters (only hex digits 0-9, a-f allowed)"
106        ));
107    }
108
109    Ok(None)
110}
111
112/// Validate content
113pub fn validate_content(content: &str, allow_empty: bool) -> Result<()> {
114    if !allow_empty && content.trim().is_empty() {
115        return Err(anyhow!("content cannot be empty"));
116    }
117
118    if content.len() > MAX_CONTENT_LENGTH {
119        return Err(anyhow!(
120            "content too long: {} bytes (max: {})",
121            content.len(),
122            MAX_CONTENT_LENGTH
123        ));
124    }
125
126    Ok(())
127}
128
129/// Validate embeddings vector
130pub fn validate_embeddings(embeddings: &[f32]) -> Result<()> {
131    if embeddings.is_empty() {
132        return Err(anyhow!("embeddings cannot be empty"));
133    }
134
135    // Common embedding dimensions: 384, 512, 768, 1024, 1536
136    let valid_dims = [128, 256, 384, 512, 768, 1024, 1536, 2048];
137    if !valid_dims.contains(&embeddings.len()) {
138        return Err(anyhow!(
139            "Unusual embedding dimension: {}. Common dimensions: {:?}",
140            embeddings.len(),
141            valid_dims
142        ));
143    }
144
145    // Check for NaN or Inf
146    if embeddings.iter().any(|&v| !v.is_finite()) {
147        return Err(anyhow!("embeddings contain NaN or Inf values"));
148    }
149
150    Ok(())
151}
152
153/// Validate importance threshold
154pub fn validate_importance_threshold(threshold: f32) -> Result<()> {
155    if !(0.0..=1.0).contains(&threshold) {
156        return Err(anyhow!(
157            "importance_threshold must be between 0.0 and 1.0, got: {threshold}"
158        ));
159    }
160    Ok(())
161}
162
163/// Validate max_results
164pub fn validate_max_results(max_results: usize) -> Result<()> {
165    if max_results == 0 {
166        return Err(anyhow!("max_results must be greater than 0"));
167    }
168
169    if max_results > 10_000 {
170        return Err(anyhow!(
171            "max_results too large: {max_results} (max: 10,000)"
172        ));
173    }
174
175    Ok(())
176}
177
178/// Validate and compile a regex pattern with ReDoS protection
179///
180/// Validates and compiles a regex pattern safely.
181///
182/// The `regex` crate guarantees linear-time matching by construction
183/// (no backtracking engine), so ReDoS is not a concern. We only enforce
184/// length limits and delegate to the crate's built-in size/complexity limits.
185pub fn validate_and_compile_pattern(pattern: &str) -> Result<Regex> {
186    if pattern.is_empty() {
187        return Err(anyhow!("Pattern cannot be empty"));
188    }
189
190    if pattern.len() > MAX_PATTERN_LENGTH {
191        return Err(anyhow!(
192            "Pattern too long: {} chars (max: {})",
193            pattern.len(),
194            MAX_PATTERN_LENGTH
195        ));
196    }
197
198    // regex crate has built-in size/complexity limits and guarantees linear-time matching
199    Regex::new(pattern).map_err(|e| anyhow!("Invalid regex pattern: {e}"))
200}
201
202/// Validate entity name
203pub fn validate_entity(entity: &str) -> Result<()> {
204    if entity.is_empty() {
205        return Err(anyhow!("Entity name cannot be empty"));
206    }
207
208    if entity.len() > MAX_ENTITY_LENGTH {
209        return Err(anyhow!(
210            "Entity name too long: {} chars (max: {})",
211            entity.len(),
212            MAX_ENTITY_LENGTH
213        ));
214    }
215
216    // Only allow printable characters, no control characters
217    if entity.chars().any(|c| c.is_control()) {
218        return Err(anyhow!("Entity name contains invalid control characters"));
219    }
220
221    // No path traversal patterns
222    if entity.contains("..") || entity.contains('/') || entity.contains('\\') {
223        return Err(anyhow!("Entity name contains invalid path characters"));
224    }
225
226    Ok(())
227}
228
229/// Validate entities list
230#[allow(unused)] // Public API - available for validation
231pub fn validate_entities(entities: &[String]) -> Result<()> {
232    if entities.len() > MAX_ENTITIES_PER_MEMORY {
233        return Err(anyhow!(
234            "Too many entities: {} (max: {})",
235            entities.len(),
236            MAX_ENTITIES_PER_MEMORY
237        ));
238    }
239
240    for entity in entities {
241        validate_entity(entity)?;
242    }
243
244    Ok(())
245}
246
247/// Validate metadata JSON size
248#[allow(unused)] // Public API - available for validation
249pub fn validate_metadata(metadata: &serde_json::Value) -> Result<()> {
250    let size = metadata.to_string().len();
251    if size > MAX_METADATA_SIZE {
252        return Err(anyhow!(
253            "Metadata too large: {size} bytes (max: {MAX_METADATA_SIZE})"
254        ));
255    }
256    Ok(())
257}
258
259/// Validate relationship strength
260pub fn validate_relationship_strength(strength: f32) -> Result<()> {
261    if !(0.0..=1.0).contains(&strength) {
262        return Err(anyhow!(
263            "Relationship strength must be between 0.0 and 1.0, got: {strength}"
264        ));
265    }
266    Ok(())
267}
268
269/// Validate a scoring weight (0.0 to 1.0 inclusive, must be finite)
270pub fn validate_weight(name: &str, value: f32) -> Result<()> {
271    if !value.is_finite() || !(0.0..=1.0).contains(&value) {
272        return Err(anyhow!("{name} must be between 0.0 and 1.0, got: {value}"));
273    }
274    Ok(())
275}
276
277/// Validate a reminder timestamp is not unreasonably far in the past or future
278pub fn validate_reminder_timestamp(at: &chrono::DateTime<chrono::Utc>) -> Result<()> {
279    let now = chrono::Utc::now();
280    let max_future = now + chrono::Duration::days(365 * 5); // 5 years
281    let max_past = now - chrono::Duration::hours(1); // Allow up to 1 hour in the past (clock skew)
282
283    if *at < max_past {
284        return Err(anyhow!("Reminder timestamp is in the past: {at}"));
285    }
286
287    if *at > max_future {
288        return Err(anyhow!(
289            "Reminder timestamp is too far in the future (max 5 years): {at}"
290        ));
291    }
292
293    Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_valid_user_id() {
302        assert!(validate_user_id("alice").is_ok());
303        assert!(validate_user_id("user-123").is_ok());
304        assert!(validate_user_id("test_user").is_ok());
305        assert!(validate_user_id("user@example.com").is_ok());
306    }
307
308    #[test]
309    fn test_invalid_user_id() {
310        assert!(validate_user_id("").is_err()); // empty
311        assert!(validate_user_id("user/123").is_err()); // invalid char
312        assert!(validate_user_id(&"a".repeat(200)).is_err()); // too long
313    }
314
315    #[test]
316    fn test_path_traversal_prevention() {
317        assert!(validate_user_id("user..admin").is_err()); // path traversal
318        assert!(validate_user_id("..").is_err()); // pure traversal
319        assert!(validate_user_id("a..b..c").is_err()); // multiple traversal
320        assert!(validate_user_id(".hidden").is_err()); // leading dot
321        assert!(validate_user_id("user.").is_err()); // trailing dot
322                                                     // Valid uses of single dots in email-style user_ids
323        assert!(validate_user_id("user.name@example.com").is_ok());
324        assert!(validate_user_id("first.last").is_ok());
325    }
326
327    #[test]
328    fn test_valid_content() {
329        assert!(validate_content("Hello world", false).is_ok());
330        assert!(validate_content("", true).is_ok()); // allowed when allow_empty=true
331    }
332
333    #[test]
334    fn test_invalid_content() {
335        assert!(validate_content("", false).is_err()); // empty not allowed
336        assert!(validate_content(&"x".repeat(100_000), false).is_err()); // too long
337    }
338
339    #[test]
340    fn test_valid_embeddings() {
341        let emb_384 = vec![0.5_f32; 384];
342        assert!(validate_embeddings(&emb_384).is_ok());
343
344        let emb_768 = vec![0.5_f32; 768];
345        assert!(validate_embeddings(&emb_768).is_ok());
346    }
347
348    #[test]
349    fn test_invalid_embeddings() {
350        assert!(validate_embeddings(&[]).is_err()); // empty
351        assert!(validate_embeddings(&[f32::NAN, 0.5]).is_err()); // NaN
352        assert!(validate_embeddings(&vec![0.5; 999]).is_err()); // unusual dimension
353    }
354
355    #[test]
356    fn test_importance_threshold() {
357        assert!(validate_importance_threshold(0.0).is_ok());
358        assert!(validate_importance_threshold(0.5).is_ok());
359        assert!(validate_importance_threshold(1.0).is_ok());
360        assert!(validate_importance_threshold(-0.1).is_err());
361        assert!(validate_importance_threshold(1.5).is_err());
362    }
363
364    #[test]
365    fn test_max_results() {
366        assert!(validate_max_results(1).is_ok());
367        assert!(validate_max_results(100).is_ok());
368        assert!(validate_max_results(10_000).is_ok());
369        assert!(validate_max_results(0).is_err());
370        assert!(validate_max_results(20_000).is_err());
371    }
372
373    #[test]
374    fn test_valid_patterns() {
375        // Simple patterns should work
376        assert!(validate_and_compile_pattern("hello").is_ok());
377        assert!(validate_and_compile_pattern("user.*").is_ok());
378        assert!(validate_and_compile_pattern("[a-z]+").is_ok());
379        assert!(validate_and_compile_pattern("^start").is_ok());
380        assert!(validate_and_compile_pattern("end$").is_ok());
381    }
382
383    #[test]
384    fn test_regex_edge_cases() {
385        // regex crate handles these safely (linear-time, no backtracking)
386        assert!(validate_and_compile_pattern("(a+)+").is_ok());
387        assert!(validate_and_compile_pattern("(.*)*").is_ok());
388        assert!(validate_and_compile_pattern("(.+)+").is_ok());
389        // Pattern too long
390        assert!(validate_and_compile_pattern(&"a".repeat(300)).is_err());
391        // Empty pattern
392        assert!(validate_and_compile_pattern("").is_err());
393    }
394
395    #[test]
396    fn test_valid_entity() {
397        assert!(validate_entity("user").is_ok());
398        assert!(validate_entity("John Doe").is_ok());
399        assert!(validate_entity("entity-123").is_ok());
400    }
401
402    #[test]
403    fn test_invalid_entity() {
404        assert!(validate_entity("").is_err()); // empty
405        assert!(validate_entity(&"a".repeat(300)).is_err()); // too long
406        assert!(validate_entity("../etc/passwd").is_err()); // path traversal
407        assert!(validate_entity("entity\x00null").is_err()); // control char
408    }
409
410    #[test]
411    fn test_entities_list() {
412        let valid: Vec<String> = vec!["a".to_string(), "b".to_string()];
413        assert!(validate_entities(&valid).is_ok());
414
415        // Too many entities
416        let too_many: Vec<String> = (0..100).map(|i| format!("entity{i}")).collect();
417        assert!(validate_entities(&too_many).is_err());
418    }
419
420    #[test]
421    fn test_memory_id_or_prefix_full_uuid() {
422        let result = validate_memory_id_or_prefix("c77bb954-1234-5678-abcd-ef0123456789");
423        assert!(result.is_ok());
424        assert!(result.unwrap().is_some());
425    }
426
427    #[test]
428    fn test_memory_id_or_prefix_valid_prefix() {
429        let result = validate_memory_id_or_prefix("c77bb954");
430        assert!(result.is_ok());
431        assert!(result.unwrap().is_none());
432    }
433
434    #[test]
435    fn test_memory_id_or_prefix_long_prefix() {
436        let result = validate_memory_id_or_prefix("c77bb9541234abcd");
437        assert!(result.is_ok());
438        assert!(result.unwrap().is_none());
439    }
440
441    #[test]
442    fn test_memory_id_or_prefix_too_short() {
443        assert!(validate_memory_id_or_prefix("c77bb").is_err());
444    }
445
446    #[test]
447    fn test_memory_id_or_prefix_invalid_chars() {
448        assert!(validate_memory_id_or_prefix("c77bb95z").is_err());
449    }
450
451    #[test]
452    fn test_memory_id_or_prefix_empty() {
453        assert!(validate_memory_id_or_prefix("").is_err());
454    }
455
456    #[test]
457    fn test_relationship_strength() {
458        assert!(validate_relationship_strength(0.0).is_ok());
459        assert!(validate_relationship_strength(0.5).is_ok());
460        assert!(validate_relationship_strength(1.0).is_ok());
461        assert!(validate_relationship_strength(-0.1).is_err());
462        assert!(validate_relationship_strength(1.1).is_err());
463    }
464
465    #[test]
466    fn test_validate_weight() {
467        assert!(validate_weight("test", 0.0).is_ok());
468        assert!(validate_weight("test", 0.5).is_ok());
469        assert!(validate_weight("test", 1.0).is_ok());
470        assert!(validate_weight("test", -0.1).is_err());
471        assert!(validate_weight("test", 1.1).is_err());
472        assert!(validate_weight("test", f32::NAN).is_err());
473        assert!(validate_weight("test", f32::INFINITY).is_err());
474    }
475
476    #[test]
477    fn test_validate_reminder_timestamp() {
478        let now = chrono::Utc::now();
479
480        // Valid: 1 hour from now
481        assert!(validate_reminder_timestamp(&(now + chrono::Duration::hours(1))).is_ok());
482
483        // Valid: 30 minutes ago (within 1 hour tolerance)
484        assert!(validate_reminder_timestamp(&(now - chrono::Duration::minutes(30))).is_ok());
485
486        // Invalid: 2 hours ago
487        assert!(validate_reminder_timestamp(&(now - chrono::Duration::hours(2))).is_err());
488
489        // Invalid: 10 years from now
490        assert!(validate_reminder_timestamp(&(now + chrono::Duration::days(365 * 10))).is_err());
491    }
492}