shodh_memory/
validation.rs1use anyhow::{anyhow, Result};
5use regex::Regex;
6
7pub const MAX_USER_ID_LENGTH: usize = 128;
9pub const MAX_CONTENT_LENGTH: usize = 50_000; pub const MAX_PATTERN_LENGTH: usize = 256; pub const MAX_ENTITY_LENGTH: usize = 256; #[allow(unused)] pub const MAX_METADATA_SIZE: usize = 10_000; #[allow(unused)] pub const MAX_ENTITIES_PER_MEMORY: usize = 50; pub fn validate_user_id(user_id: &str) -> Result<()> {
19 if user_id.is_empty() {
20 return Err(anyhow!("user_id cannot be empty"));
21 }
22
23 if user_id.len() > MAX_USER_ID_LENGTH {
24 return Err(anyhow!(
25 "user_id too long: {} chars (max: {})",
26 user_id.len(),
27 MAX_USER_ID_LENGTH
28 ));
29 }
30
31 if !user_id
33 .chars()
34 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '@' || c == '.')
35 {
36 return Err(anyhow!(
37 "user_id contains invalid characters (allowed: alphanumeric, -, _, @, .)"
38 ));
39 }
40
41 if user_id.contains("..") {
43 return Err(anyhow!(
44 "user_id contains invalid path traversal sequence (..)"
45 ));
46 }
47
48 if user_id.starts_with('.') || user_id.ends_with('.') {
50 return Err(anyhow!("user_id cannot start or end with a dot"));
51 }
52
53 if std::path::Path::new(user_id).is_absolute() {
55 return Err(anyhow!("user_id cannot be an absolute path"));
56 }
57
58 {
61 let upper = user_id.to_uppercase();
62 let stem = upper.split('.').next().unwrap_or(&upper);
64 const DEVICE_NAMES: &[&str] = &[
65 "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7",
66 "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
67 ];
68 if DEVICE_NAMES.contains(&stem) {
69 return Err(anyhow!(
70 "user_id cannot be a Windows reserved device name: {}",
71 user_id
72 ));
73 }
74 }
75
76 Ok(())
77}
78
79pub fn validate_memory_id(memory_id: &str) -> Result<uuid::Uuid> {
81 uuid::Uuid::parse_str(memory_id).map_err(|e| anyhow!("Invalid memory_id UUID format: {e}"))
82}
83
84pub fn validate_memory_id_or_prefix(memory_id: &str) -> Result<Option<uuid::Uuid>> {
89 if let Ok(uuid) = uuid::Uuid::parse_str(memory_id) {
91 return Ok(Some(uuid));
92 }
93
94 let trimmed = memory_id.trim();
96 if trimmed.len() < 8 {
97 return Err(anyhow!(
98 "Memory ID must be a full UUID or at least 8 hex characters, got {} chars",
99 trimmed.len()
100 ));
101 }
102
103 if !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
104 return Err(anyhow!(
105 "Memory ID prefix contains invalid characters (only hex digits 0-9, a-f allowed)"
106 ));
107 }
108
109 Ok(None)
110}
111
112pub fn validate_content(content: &str, allow_empty: bool) -> Result<()> {
114 if !allow_empty && content.trim().is_empty() {
115 return Err(anyhow!("content cannot be empty"));
116 }
117
118 if content.len() > MAX_CONTENT_LENGTH {
119 return Err(anyhow!(
120 "content too long: {} bytes (max: {})",
121 content.len(),
122 MAX_CONTENT_LENGTH
123 ));
124 }
125
126 Ok(())
127}
128
129pub fn validate_embeddings(embeddings: &[f32]) -> Result<()> {
131 if embeddings.is_empty() {
132 return Err(anyhow!("embeddings cannot be empty"));
133 }
134
135 let valid_dims = [128, 256, 384, 512, 768, 1024, 1536, 2048];
137 if !valid_dims.contains(&embeddings.len()) {
138 return Err(anyhow!(
139 "Unusual embedding dimension: {}. Common dimensions: {:?}",
140 embeddings.len(),
141 valid_dims
142 ));
143 }
144
145 if embeddings.iter().any(|&v| !v.is_finite()) {
147 return Err(anyhow!("embeddings contain NaN or Inf values"));
148 }
149
150 Ok(())
151}
152
153pub fn validate_importance_threshold(threshold: f32) -> Result<()> {
155 if !(0.0..=1.0).contains(&threshold) {
156 return Err(anyhow!(
157 "importance_threshold must be between 0.0 and 1.0, got: {threshold}"
158 ));
159 }
160 Ok(())
161}
162
163pub fn validate_max_results(max_results: usize) -> Result<()> {
165 if max_results == 0 {
166 return Err(anyhow!("max_results must be greater than 0"));
167 }
168
169 if max_results > 10_000 {
170 return Err(anyhow!(
171 "max_results too large: {max_results} (max: 10,000)"
172 ));
173 }
174
175 Ok(())
176}
177
178pub fn validate_and_compile_pattern(pattern: &str) -> Result<Regex> {
186 if pattern.is_empty() {
187 return Err(anyhow!("Pattern cannot be empty"));
188 }
189
190 if pattern.len() > MAX_PATTERN_LENGTH {
191 return Err(anyhow!(
192 "Pattern too long: {} chars (max: {})",
193 pattern.len(),
194 MAX_PATTERN_LENGTH
195 ));
196 }
197
198 Regex::new(pattern).map_err(|e| anyhow!("Invalid regex pattern: {e}"))
200}
201
202pub fn validate_entity(entity: &str) -> Result<()> {
204 if entity.is_empty() {
205 return Err(anyhow!("Entity name cannot be empty"));
206 }
207
208 if entity.len() > MAX_ENTITY_LENGTH {
209 return Err(anyhow!(
210 "Entity name too long: {} chars (max: {})",
211 entity.len(),
212 MAX_ENTITY_LENGTH
213 ));
214 }
215
216 if entity.chars().any(|c| c.is_control()) {
218 return Err(anyhow!("Entity name contains invalid control characters"));
219 }
220
221 if entity.contains("..") || entity.contains('/') || entity.contains('\\') {
223 return Err(anyhow!("Entity name contains invalid path characters"));
224 }
225
226 Ok(())
227}
228
229#[allow(unused)] pub fn validate_entities(entities: &[String]) -> Result<()> {
232 if entities.len() > MAX_ENTITIES_PER_MEMORY {
233 return Err(anyhow!(
234 "Too many entities: {} (max: {})",
235 entities.len(),
236 MAX_ENTITIES_PER_MEMORY
237 ));
238 }
239
240 for entity in entities {
241 validate_entity(entity)?;
242 }
243
244 Ok(())
245}
246
247#[allow(unused)] pub fn validate_metadata(metadata: &serde_json::Value) -> Result<()> {
250 let size = metadata.to_string().len();
251 if size > MAX_METADATA_SIZE {
252 return Err(anyhow!(
253 "Metadata too large: {size} bytes (max: {MAX_METADATA_SIZE})"
254 ));
255 }
256 Ok(())
257}
258
259pub fn validate_relationship_strength(strength: f32) -> Result<()> {
261 if !(0.0..=1.0).contains(&strength) {
262 return Err(anyhow!(
263 "Relationship strength must be between 0.0 and 1.0, got: {strength}"
264 ));
265 }
266 Ok(())
267}
268
269pub fn validate_weight(name: &str, value: f32) -> Result<()> {
271 if !value.is_finite() || !(0.0..=1.0).contains(&value) {
272 return Err(anyhow!("{name} must be between 0.0 and 1.0, got: {value}"));
273 }
274 Ok(())
275}
276
277pub fn validate_reminder_timestamp(at: &chrono::DateTime<chrono::Utc>) -> Result<()> {
279 let now = chrono::Utc::now();
280 let max_future = now + chrono::Duration::days(365 * 5); let max_past = now - chrono::Duration::hours(1); if *at < max_past {
284 return Err(anyhow!("Reminder timestamp is in the past: {at}"));
285 }
286
287 if *at > max_future {
288 return Err(anyhow!(
289 "Reminder timestamp is too far in the future (max 5 years): {at}"
290 ));
291 }
292
293 Ok(())
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_valid_user_id() {
302 assert!(validate_user_id("alice").is_ok());
303 assert!(validate_user_id("user-123").is_ok());
304 assert!(validate_user_id("test_user").is_ok());
305 assert!(validate_user_id("user@example.com").is_ok());
306 }
307
308 #[test]
309 fn test_invalid_user_id() {
310 assert!(validate_user_id("").is_err()); assert!(validate_user_id("user/123").is_err()); assert!(validate_user_id(&"a".repeat(200)).is_err()); }
314
315 #[test]
316 fn test_path_traversal_prevention() {
317 assert!(validate_user_id("user..admin").is_err()); assert!(validate_user_id("..").is_err()); assert!(validate_user_id("a..b..c").is_err()); assert!(validate_user_id(".hidden").is_err()); assert!(validate_user_id("user.").is_err()); assert!(validate_user_id("user.name@example.com").is_ok());
324 assert!(validate_user_id("first.last").is_ok());
325 }
326
327 #[test]
328 fn test_valid_content() {
329 assert!(validate_content("Hello world", false).is_ok());
330 assert!(validate_content("", true).is_ok()); }
332
333 #[test]
334 fn test_invalid_content() {
335 assert!(validate_content("", false).is_err()); assert!(validate_content(&"x".repeat(100_000), false).is_err()); }
338
339 #[test]
340 fn test_valid_embeddings() {
341 let emb_384 = vec![0.5_f32; 384];
342 assert!(validate_embeddings(&emb_384).is_ok());
343
344 let emb_768 = vec![0.5_f32; 768];
345 assert!(validate_embeddings(&emb_768).is_ok());
346 }
347
348 #[test]
349 fn test_invalid_embeddings() {
350 assert!(validate_embeddings(&[]).is_err()); assert!(validate_embeddings(&[f32::NAN, 0.5]).is_err()); assert!(validate_embeddings(&vec![0.5; 999]).is_err()); }
354
355 #[test]
356 fn test_importance_threshold() {
357 assert!(validate_importance_threshold(0.0).is_ok());
358 assert!(validate_importance_threshold(0.5).is_ok());
359 assert!(validate_importance_threshold(1.0).is_ok());
360 assert!(validate_importance_threshold(-0.1).is_err());
361 assert!(validate_importance_threshold(1.5).is_err());
362 }
363
364 #[test]
365 fn test_max_results() {
366 assert!(validate_max_results(1).is_ok());
367 assert!(validate_max_results(100).is_ok());
368 assert!(validate_max_results(10_000).is_ok());
369 assert!(validate_max_results(0).is_err());
370 assert!(validate_max_results(20_000).is_err());
371 }
372
373 #[test]
374 fn test_valid_patterns() {
375 assert!(validate_and_compile_pattern("hello").is_ok());
377 assert!(validate_and_compile_pattern("user.*").is_ok());
378 assert!(validate_and_compile_pattern("[a-z]+").is_ok());
379 assert!(validate_and_compile_pattern("^start").is_ok());
380 assert!(validate_and_compile_pattern("end$").is_ok());
381 }
382
383 #[test]
384 fn test_regex_edge_cases() {
385 assert!(validate_and_compile_pattern("(a+)+").is_ok());
387 assert!(validate_and_compile_pattern("(.*)*").is_ok());
388 assert!(validate_and_compile_pattern("(.+)+").is_ok());
389 assert!(validate_and_compile_pattern(&"a".repeat(300)).is_err());
391 assert!(validate_and_compile_pattern("").is_err());
393 }
394
395 #[test]
396 fn test_valid_entity() {
397 assert!(validate_entity("user").is_ok());
398 assert!(validate_entity("John Doe").is_ok());
399 assert!(validate_entity("entity-123").is_ok());
400 }
401
402 #[test]
403 fn test_invalid_entity() {
404 assert!(validate_entity("").is_err()); assert!(validate_entity(&"a".repeat(300)).is_err()); assert!(validate_entity("../etc/passwd").is_err()); assert!(validate_entity("entity\x00null").is_err()); }
409
410 #[test]
411 fn test_entities_list() {
412 let valid: Vec<String> = vec!["a".to_string(), "b".to_string()];
413 assert!(validate_entities(&valid).is_ok());
414
415 let too_many: Vec<String> = (0..100).map(|i| format!("entity{i}")).collect();
417 assert!(validate_entities(&too_many).is_err());
418 }
419
420 #[test]
421 fn test_memory_id_or_prefix_full_uuid() {
422 let result = validate_memory_id_or_prefix("c77bb954-1234-5678-abcd-ef0123456789");
423 assert!(result.is_ok());
424 assert!(result.unwrap().is_some());
425 }
426
427 #[test]
428 fn test_memory_id_or_prefix_valid_prefix() {
429 let result = validate_memory_id_or_prefix("c77bb954");
430 assert!(result.is_ok());
431 assert!(result.unwrap().is_none());
432 }
433
434 #[test]
435 fn test_memory_id_or_prefix_long_prefix() {
436 let result = validate_memory_id_or_prefix("c77bb9541234abcd");
437 assert!(result.is_ok());
438 assert!(result.unwrap().is_none());
439 }
440
441 #[test]
442 fn test_memory_id_or_prefix_too_short() {
443 assert!(validate_memory_id_or_prefix("c77bb").is_err());
444 }
445
446 #[test]
447 fn test_memory_id_or_prefix_invalid_chars() {
448 assert!(validate_memory_id_or_prefix("c77bb95z").is_err());
449 }
450
451 #[test]
452 fn test_memory_id_or_prefix_empty() {
453 assert!(validate_memory_id_or_prefix("").is_err());
454 }
455
456 #[test]
457 fn test_relationship_strength() {
458 assert!(validate_relationship_strength(0.0).is_ok());
459 assert!(validate_relationship_strength(0.5).is_ok());
460 assert!(validate_relationship_strength(1.0).is_ok());
461 assert!(validate_relationship_strength(-0.1).is_err());
462 assert!(validate_relationship_strength(1.1).is_err());
463 }
464
465 #[test]
466 fn test_validate_weight() {
467 assert!(validate_weight("test", 0.0).is_ok());
468 assert!(validate_weight("test", 0.5).is_ok());
469 assert!(validate_weight("test", 1.0).is_ok());
470 assert!(validate_weight("test", -0.1).is_err());
471 assert!(validate_weight("test", 1.1).is_err());
472 assert!(validate_weight("test", f32::NAN).is_err());
473 assert!(validate_weight("test", f32::INFINITY).is_err());
474 }
475
476 #[test]
477 fn test_validate_reminder_timestamp() {
478 let now = chrono::Utc::now();
479
480 assert!(validate_reminder_timestamp(&(now + chrono::Duration::hours(1))).is_ok());
482
483 assert!(validate_reminder_timestamp(&(now - chrono::Duration::minutes(30))).is_ok());
485
486 assert!(validate_reminder_timestamp(&(now - chrono::Duration::hours(2))).is_err());
488
489 assert!(validate_reminder_timestamp(&(now + chrono::Duration::days(365 * 10))).is_err());
491 }
492}