1use std::sync::LazyLock;
12
13use regex::Regex;
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use zeph_memory::graph::extractor::ExtractionResult;
17
18static ENTITY_EMAIL_RE: LazyLock<Regex> = LazyLock::new(|| {
24 Regex::new(r"[a-zA-Z0-9._%+\-]{2,}@(?:[a-zA-Z]+\.)+[a-zA-Z]{2,6}")
25 .expect("valid ENTITY_EMAIL_RE")
26});
27
28static ENTITY_SSN_RE: LazyLock<Regex> =
30 LazyLock::new(|| Regex::new(r"\b\d{3}-\d{2}-\d{4}\b").expect("valid ENTITY_SSN_RE"));
31
32#[derive(Debug, Error)]
38pub enum MemoryValidationError {
39 #[error("content too large: {size} bytes exceeds max {max}")]
40 ContentTooLarge { size: usize, max: usize },
41
42 #[error("entity name too long: '{name}' exceeds max {max} bytes")]
43 EntityNameTooLong { name: String, max: usize },
44
45 #[error("fact text too long: exceeds max {max} bytes")]
46 FactTooLong { max: usize },
47
48 #[error("too many entities: {count} exceeds max {max}")]
49 TooManyEntities { count: usize, max: usize },
50
51 #[error("too many edges: {count} exceeds max {max}")]
52 TooManyEdges { count: usize, max: usize },
53
54 #[error("forbidden pattern detected: {pattern}")]
55 ForbiddenPattern { pattern: String },
56
57 #[error("PII detected in entity name: '{entity}'")]
58 SuspiciousPiiInEntityName { entity: String },
59}
60
61fn default_true() -> bool {
66 true
67}
68
69fn default_max_content_bytes() -> usize {
70 4096
71}
72
73fn default_max_entity_name_bytes() -> usize {
74 256
75}
76
77fn default_max_fact_bytes() -> usize {
78 1024
79}
80
81fn default_max_entities() -> usize {
82 50
83}
84
85fn default_max_edges() -> usize {
86 100
87}
88
89#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
95pub struct MemoryWriteValidationConfig {
96 #[serde(default = "default_true")]
98 pub enabled: bool,
99 #[serde(default = "default_max_content_bytes")]
101 pub max_content_bytes: usize,
102 #[serde(default = "default_max_entity_name_bytes")]
104 pub max_entity_name_bytes: usize,
105 #[serde(default = "default_max_fact_bytes")]
107 pub max_fact_bytes: usize,
108 #[serde(default = "default_max_entities")]
110 pub max_entities_per_extraction: usize,
111 #[serde(default = "default_max_edges")]
113 pub max_edges_per_extraction: usize,
114 #[serde(default)]
117 pub forbidden_content_patterns: Vec<String>,
118}
119
120impl Default for MemoryWriteValidationConfig {
121 fn default() -> Self {
122 Self {
123 enabled: true,
124 max_content_bytes: default_max_content_bytes(),
125 max_entity_name_bytes: default_max_entity_name_bytes(),
126 max_fact_bytes: default_max_fact_bytes(),
127 max_entities_per_extraction: default_max_entities(),
128 max_edges_per_extraction: default_max_edges(),
129 forbidden_content_patterns: Vec::new(),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
143pub struct MemoryWriteValidator {
144 config: MemoryWriteValidationConfig,
145}
146
147impl MemoryWriteValidator {
148 #[must_use]
150 pub fn new(config: MemoryWriteValidationConfig) -> Self {
151 Self { config }
152 }
153
154 pub fn validate_memory_save(&self, content: &str) -> Result<(), MemoryValidationError> {
160 if !self.config.enabled {
161 return Ok(());
162 }
163
164 let size = content.len();
165 if size > self.config.max_content_bytes {
166 return Err(MemoryValidationError::ContentTooLarge {
167 size,
168 max: self.config.max_content_bytes,
169 });
170 }
171
172 for pattern in &self.config.forbidden_content_patterns {
173 if content.contains(pattern.as_str()) {
174 return Err(MemoryValidationError::ForbiddenPattern {
175 pattern: pattern.clone(),
176 });
177 }
178 }
179
180 Ok(())
181 }
182
183 pub fn validate_graph_extraction(
191 &self,
192 result: &ExtractionResult,
193 ) -> Result<(), MemoryValidationError> {
194 if !self.config.enabled {
195 return Ok(());
196 }
197
198 let entity_count = result.entities.len();
199 if entity_count > self.config.max_entities_per_extraction {
200 return Err(MemoryValidationError::TooManyEntities {
201 count: entity_count,
202 max: self.config.max_entities_per_extraction,
203 });
204 }
205
206 let edge_count = result.edges.len();
207 if edge_count > self.config.max_edges_per_extraction {
208 return Err(MemoryValidationError::TooManyEdges {
209 count: edge_count,
210 max: self.config.max_edges_per_extraction,
211 });
212 }
213
214 for entity in &result.entities {
215 let name_len = entity.name.len();
216 if name_len > self.config.max_entity_name_bytes {
217 return Err(MemoryValidationError::EntityNameTooLong {
218 name: entity.name.clone(),
219 max: self.config.max_entity_name_bytes,
220 });
221 }
222 if ENTITY_EMAIL_RE.is_match(&entity.name) || ENTITY_SSN_RE.is_match(&entity.name) {
224 return Err(MemoryValidationError::SuspiciousPiiInEntityName {
225 entity: entity.name.clone(),
226 });
227 }
228 }
229
230 for edge in &result.edges {
231 let fact_len = edge.fact.len();
232 if fact_len > self.config.max_fact_bytes {
233 return Err(MemoryValidationError::FactTooLong {
234 max: self.config.max_fact_bytes,
235 });
236 }
237 }
238
239 Ok(())
240 }
241
242 #[must_use]
244 pub fn is_enabled(&self) -> bool {
245 self.config.enabled
246 }
247}
248
249#[cfg(test)]
254mod tests {
255 use zeph_memory::graph::extractor::{ExtractedEdge, ExtractedEntity};
256
257 use super::*;
258
259 fn validator() -> MemoryWriteValidator {
260 MemoryWriteValidator::new(MemoryWriteValidationConfig::default())
261 }
262
263 fn validator_disabled() -> MemoryWriteValidator {
264 MemoryWriteValidator::new(MemoryWriteValidationConfig {
265 enabled: false,
266 ..MemoryWriteValidationConfig::default()
267 })
268 }
269
270 fn entity(name: &str) -> ExtractedEntity {
271 ExtractedEntity {
272 name: name.to_owned(),
273 entity_type: "person".to_owned(),
274 summary: None,
275 }
276 }
277
278 fn edge(fact: &str) -> ExtractedEdge {
279 ExtractedEdge {
280 source: "A".to_owned(),
281 target: "B".to_owned(),
282 relation: "knows".to_owned(),
283 fact: fact.to_owned(),
284 temporal_hint: None,
285 }
286 }
287
288 fn result_with(entities: Vec<ExtractedEntity>, edges: Vec<ExtractedEdge>) -> ExtractionResult {
289 ExtractionResult { entities, edges }
290 }
291
292 #[test]
295 fn valid_content_passes() {
296 assert!(validator().validate_memory_save("hello world").is_ok());
297 }
298
299 #[test]
300 fn oversized_content_rejected() {
301 let big = "x".repeat(5000);
302 let err = validator().validate_memory_save(&big).unwrap_err();
303 assert!(matches!(err, MemoryValidationError::ContentTooLarge { .. }));
304 }
305
306 #[test]
307 fn forbidden_pattern_rejected() {
308 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
309 forbidden_content_patterns: vec!["<script".to_owned()],
310 ..MemoryWriteValidationConfig::default()
311 });
312 let err = v
313 .validate_memory_save("text <script>alert(1)</script>")
314 .unwrap_err();
315 assert!(matches!(
316 err,
317 MemoryValidationError::ForbiddenPattern { .. }
318 ));
319 }
320
321 #[test]
322 fn disabled_skips_validation() {
323 let big = "x".repeat(9999);
324 assert!(validator_disabled().validate_memory_save(&big).is_ok());
325 }
326
327 #[test]
330 fn valid_extraction_passes() {
331 let r = result_with(vec![entity("Rust"), entity("Alice")], vec![edge("fact")]);
332 assert!(validator().validate_graph_extraction(&r).is_ok());
333 }
334
335 #[test]
336 fn too_many_entities_rejected() {
337 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
338 max_entities_per_extraction: 2,
339 ..MemoryWriteValidationConfig::default()
340 });
341 let r = result_with(vec![entity("A"), entity("B"), entity("C")], vec![]);
342 let err = v.validate_graph_extraction(&r).unwrap_err();
343 assert!(matches!(err, MemoryValidationError::TooManyEntities { .. }));
344 }
345
346 #[test]
347 fn too_many_edges_rejected() {
348 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
349 max_edges_per_extraction: 1,
350 ..MemoryWriteValidationConfig::default()
351 });
352 let r = result_with(vec![], vec![edge("a"), edge("b")]);
353 let err = v.validate_graph_extraction(&r).unwrap_err();
354 assert!(matches!(err, MemoryValidationError::TooManyEdges { .. }));
355 }
356
357 #[test]
358 fn entity_name_too_long_rejected() {
359 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
360 max_entity_name_bytes: 5,
361 ..MemoryWriteValidationConfig::default()
362 });
363 let r = result_with(vec![entity("TooLongName")], vec![]);
364 let err = v.validate_graph_extraction(&r).unwrap_err();
365 assert!(matches!(
366 err,
367 MemoryValidationError::EntityNameTooLong { .. }
368 ));
369 }
370
371 #[test]
372 fn fact_too_long_rejected() {
373 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
374 max_fact_bytes: 10,
375 ..MemoryWriteValidationConfig::default()
376 });
377 let r = result_with(vec![], vec![edge("this fact is longer than ten chars")]);
378 let err = v.validate_graph_extraction(&r).unwrap_err();
379 assert!(matches!(err, MemoryValidationError::FactTooLong { .. }));
380 }
381
382 #[test]
383 fn email_in_entity_name_rejected() {
384 let r = result_with(vec![entity("user@example.com")], vec![]);
385 let err = validator().validate_graph_extraction(&r).unwrap_err();
386 assert!(matches!(
387 err,
388 MemoryValidationError::SuspiciousPiiInEntityName { .. }
389 ));
390 }
391
392 #[test]
393 fn ssn_in_entity_name_rejected() {
394 let r = result_with(vec![entity("123-45-6789")], vec![]);
395 let err = validator().validate_graph_extraction(&r).unwrap_err();
396 assert!(matches!(
397 err,
398 MemoryValidationError::SuspiciousPiiInEntityName { .. }
399 ));
400 }
401
402 #[test]
403 fn disabled_skips_graph_validation() {
404 let v = validator_disabled();
405 let big_entities: Vec<_> = (0..200).map(|i| entity(&format!("E{i}"))).collect();
406 let r = result_with(big_entities, vec![]);
407 assert!(v.validate_graph_extraction(&r).is_ok());
408 }
409
410 #[test]
413 fn content_exactly_at_limit_passes() {
414 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
415 max_content_bytes: 10,
416 ..MemoryWriteValidationConfig::default()
417 });
418 assert!(v.validate_memory_save("1234567890").is_ok());
420 }
421
422 #[test]
423 fn content_one_byte_over_limit_rejected() {
424 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
425 max_content_bytes: 10,
426 ..MemoryWriteValidationConfig::default()
427 });
428 let err = v.validate_memory_save("12345678901").unwrap_err();
430 assert!(matches!(err, MemoryValidationError::ContentTooLarge { .. }));
431 }
432
433 #[test]
436 fn multiple_forbidden_patterns_first_match_blocks() {
437 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
438 forbidden_content_patterns: vec!["<script".to_owned(), "javascript:".to_owned()],
439 ..MemoryWriteValidationConfig::default()
440 });
441 let err = v.validate_memory_save("javascript:alert(1)").unwrap_err();
442 assert!(matches!(
443 err,
444 MemoryValidationError::ForbiddenPattern { .. }
445 ));
446 }
447
448 #[test]
449 fn content_without_forbidden_pattern_passes() {
450 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
451 forbidden_content_patterns: vec!["<script".to_owned()],
452 ..MemoryWriteValidationConfig::default()
453 });
454 assert!(v.validate_memory_save("safe content here").is_ok());
455 }
456
457 #[test]
460 fn is_enabled_true_by_default() {
461 assert!(validator().is_enabled());
462 }
463
464 #[test]
465 fn is_enabled_false_when_disabled() {
466 assert!(!validator_disabled().is_enabled());
467 }
468
469 #[test]
472 fn empty_extraction_passes() {
473 let r = result_with(vec![], vec![]);
474 assert!(validator().validate_graph_extraction(&r).is_ok());
475 }
476
477 #[test]
480 fn entity_name_exactly_at_limit_passes() {
481 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
482 max_entity_name_bytes: 5,
483 ..MemoryWriteValidationConfig::default()
484 });
485 let r = result_with(vec![entity("Alice")], vec![]); assert!(v.validate_graph_extraction(&r).is_ok());
487 }
488
489 #[test]
490 fn entity_name_one_byte_over_limit_rejected() {
491 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
492 max_entity_name_bytes: 5,
493 ..MemoryWriteValidationConfig::default()
494 });
495 let r = result_with(vec![entity("AliceX")], vec![]); let err = v.validate_graph_extraction(&r).unwrap_err();
497 assert!(matches!(
498 err,
499 MemoryValidationError::EntityNameTooLong { .. }
500 ));
501 }
502
503 #[test]
506 fn entities_exactly_at_limit_passes() {
507 let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
508 max_entities_per_extraction: 3,
509 ..MemoryWriteValidationConfig::default()
510 });
511 let r = result_with(vec![entity("A"), entity("B"), entity("C")], vec![]);
512 assert!(v.validate_graph_extraction(&r).is_ok());
513 }
514
515 #[test]
518 fn content_too_large_error_message() {
519 let big = "x".repeat(5000);
520 let err = validator().validate_memory_save(&big).unwrap_err();
521 let msg = err.to_string();
522 assert!(msg.contains("5000"), "error must include actual size");
523 assert!(msg.contains("4096"), "error must include max size");
524 }
525}