Skip to main content

zeph_core/sanitizer/
memory_validation.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Memory write validation: structural checks before content reaches the memory store
5//! or the graph extractor.
6//!
7//! Configured under `[security.memory_validation]` in the agent config file.
8//! Enabled by default — guards against oversized writes, injection markers, and PII
9//! leaking into entity names.
10
11use std::sync::LazyLock;
12
13use regex::Regex;
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use zeph_memory::graph::extractor::ExtractionResult;
17
18// ---------------------------------------------------------------------------
19// PII patterns for entity name scanning (subset — email and SSN only)
20// ---------------------------------------------------------------------------
21
22/// Email pattern kept in sync with `pii.rs`: domain labels must be purely alphabetic.
23static ENTITY_EMAIL_RE: LazyLock<Regex> = LazyLock::new(|| {
24    Regex::new(r"[a-zA-Z0-9._%+\-]{2,}@(?:[a-zA-Z]+\.)+[a-zA-Z]{2,6}")
25        .expect("valid ENTITY_EMAIL_RE")
26});
27
28/// SSN pattern for entity name scanning.
29static ENTITY_SSN_RE: LazyLock<Regex> =
30    LazyLock::new(|| Regex::new(r"\b\d{3}-\d{2}-\d{4}\b").expect("valid ENTITY_SSN_RE"));
31
32// ---------------------------------------------------------------------------
33// Error
34// ---------------------------------------------------------------------------
35
36/// Validation failure reported by [`MemoryWriteValidator`].
37#[derive(Debug, Error)]
38pub enum MemoryValidationError {
39    #[error("content too large: {size} bytes exceeds max {max}")]
40    ContentTooLarge { size: usize, max: usize },
41
42    #[error("entity name too long: '{name}' exceeds max {max} bytes")]
43    EntityNameTooLong { name: String, max: usize },
44
45    #[error("fact text too long: exceeds max {max} bytes")]
46    FactTooLong { max: usize },
47
48    #[error("too many entities: {count} exceeds max {max}")]
49    TooManyEntities { count: usize, max: usize },
50
51    #[error("too many edges: {count} exceeds max {max}")]
52    TooManyEdges { count: usize, max: usize },
53
54    #[error("forbidden pattern detected: {pattern}")]
55    ForbiddenPattern { pattern: String },
56
57    #[error("PII detected in entity name: '{entity}'")]
58    SuspiciousPiiInEntityName { entity: String },
59}
60
61// ---------------------------------------------------------------------------
62// Config
63// ---------------------------------------------------------------------------
64
65fn default_true() -> bool {
66    true
67}
68
69fn default_max_content_bytes() -> usize {
70    4096
71}
72
73fn default_max_entity_name_bytes() -> usize {
74    256
75}
76
77fn default_max_fact_bytes() -> usize {
78    1024
79}
80
81fn default_max_entities() -> usize {
82    50
83}
84
85fn default_max_edges() -> usize {
86    100
87}
88
89/// Configuration for memory write validation, nested under `[security.memory_validation]`.
90///
91/// Enabled by default with conservative limits. All values correspond to existing
92/// capacity constraints already enforced elsewhere; the validator makes them explicit
93/// and configurable.
94#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
95pub struct MemoryWriteValidationConfig {
96    /// Master switch. When `false`, validation is a no-op.
97    #[serde(default = "default_true")]
98    pub enabled: bool,
99    /// Maximum byte length of content passed to `memory_save`.
100    #[serde(default = "default_max_content_bytes")]
101    pub max_content_bytes: usize,
102    /// Maximum byte length of a single entity name in graph extraction.
103    #[serde(default = "default_max_entity_name_bytes")]
104    pub max_entity_name_bytes: usize,
105    /// Maximum byte length of an edge fact string in graph extraction.
106    #[serde(default = "default_max_fact_bytes")]
107    pub max_fact_bytes: usize,
108    /// Maximum number of entities allowed per graph extraction result.
109    #[serde(default = "default_max_entities")]
110    pub max_entities_per_extraction: usize,
111    /// Maximum number of edges allowed per graph extraction result.
112    #[serde(default = "default_max_edges")]
113    pub max_edges_per_extraction: usize,
114    /// Forbidden substring patterns. Content containing any of these is rejected.
115    /// Default empty — users can add custom patterns (e.g., `"<script"`, `"javascript:"`).
116    #[serde(default)]
117    pub forbidden_content_patterns: Vec<String>,
118}
119
120impl Default for MemoryWriteValidationConfig {
121    fn default() -> Self {
122        Self {
123            enabled: true,
124            max_content_bytes: default_max_content_bytes(),
125            max_entity_name_bytes: default_max_entity_name_bytes(),
126            max_fact_bytes: default_max_fact_bytes(),
127            max_entities_per_extraction: default_max_entities(),
128            max_edges_per_extraction: default_max_edges(),
129            forbidden_content_patterns: Vec::new(),
130        }
131    }
132}
133
134// ---------------------------------------------------------------------------
135// Validator
136// ---------------------------------------------------------------------------
137
138/// Validates content before it is written to the memory store or graph extractor.
139///
140/// Construct once from [`MemoryWriteValidationConfig`] and store on the agent.
141/// Cheap to clone.
142#[derive(Debug, Clone)]
143pub struct MemoryWriteValidator {
144    config: MemoryWriteValidationConfig,
145}
146
147impl MemoryWriteValidator {
148    /// Create a validator from the given configuration.
149    #[must_use]
150    pub fn new(config: MemoryWriteValidationConfig) -> Self {
151        Self { config }
152    }
153
154    /// Validate content before it is written via the `memory_save` tool.
155    ///
156    /// # Errors
157    ///
158    /// Returns [`MemoryValidationError`] if any validation check fails.
159    pub fn validate_memory_save(&self, content: &str) -> Result<(), MemoryValidationError> {
160        if !self.config.enabled {
161            return Ok(());
162        }
163
164        let size = content.len();
165        if size > self.config.max_content_bytes {
166            return Err(MemoryValidationError::ContentTooLarge {
167                size,
168                max: self.config.max_content_bytes,
169            });
170        }
171
172        for pattern in &self.config.forbidden_content_patterns {
173            if content.contains(pattern.as_str()) {
174                return Err(MemoryValidationError::ForbiddenPattern {
175                    pattern: pattern.clone(),
176                });
177            }
178        }
179
180        Ok(())
181    }
182
183    /// Validate a graph extraction result before entities and edges are upserted.
184    ///
185    /// Called inside the spawned extraction task, after `GraphExtractor::extract()` returns.
186    ///
187    /// # Errors
188    ///
189    /// Returns [`MemoryValidationError`] if any validation check fails.
190    pub fn validate_graph_extraction(
191        &self,
192        result: &ExtractionResult,
193    ) -> Result<(), MemoryValidationError> {
194        if !self.config.enabled {
195            return Ok(());
196        }
197
198        let entity_count = result.entities.len();
199        if entity_count > self.config.max_entities_per_extraction {
200            return Err(MemoryValidationError::TooManyEntities {
201                count: entity_count,
202                max: self.config.max_entities_per_extraction,
203            });
204        }
205
206        let edge_count = result.edges.len();
207        if edge_count > self.config.max_edges_per_extraction {
208            return Err(MemoryValidationError::TooManyEdges {
209                count: edge_count,
210                max: self.config.max_edges_per_extraction,
211            });
212        }
213
214        for entity in &result.entities {
215            let name_len = entity.name.len();
216            if name_len > self.config.max_entity_name_bytes {
217                return Err(MemoryValidationError::EntityNameTooLong {
218                    name: entity.name.clone(),
219                    max: self.config.max_entity_name_bytes,
220                });
221            }
222            // Guard against PII leaking into entity names (email and SSN).
223            if ENTITY_EMAIL_RE.is_match(&entity.name) || ENTITY_SSN_RE.is_match(&entity.name) {
224                return Err(MemoryValidationError::SuspiciousPiiInEntityName {
225                    entity: entity.name.clone(),
226                });
227            }
228        }
229
230        for edge in &result.edges {
231            let fact_len = edge.fact.len();
232            if fact_len > self.config.max_fact_bytes {
233                return Err(MemoryValidationError::FactTooLong {
234                    max: self.config.max_fact_bytes,
235                });
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Returns `true` when validation is enabled.
243    #[must_use]
244    pub fn is_enabled(&self) -> bool {
245        self.config.enabled
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Tests
251// ---------------------------------------------------------------------------
252
253#[cfg(test)]
254mod tests {
255    use zeph_memory::graph::extractor::{ExtractedEdge, ExtractedEntity};
256
257    use super::*;
258
259    fn validator() -> MemoryWriteValidator {
260        MemoryWriteValidator::new(MemoryWriteValidationConfig::default())
261    }
262
263    fn validator_disabled() -> MemoryWriteValidator {
264        MemoryWriteValidator::new(MemoryWriteValidationConfig {
265            enabled: false,
266            ..MemoryWriteValidationConfig::default()
267        })
268    }
269
270    fn entity(name: &str) -> ExtractedEntity {
271        ExtractedEntity {
272            name: name.to_owned(),
273            entity_type: "person".to_owned(),
274            summary: None,
275        }
276    }
277
278    fn edge(fact: &str) -> ExtractedEdge {
279        ExtractedEdge {
280            source: "A".to_owned(),
281            target: "B".to_owned(),
282            relation: "knows".to_owned(),
283            fact: fact.to_owned(),
284            temporal_hint: None,
285        }
286    }
287
288    fn result_with(entities: Vec<ExtractedEntity>, edges: Vec<ExtractedEdge>) -> ExtractionResult {
289        ExtractionResult { entities, edges }
290    }
291
292    // --- memory_save validation ---
293
294    #[test]
295    fn valid_content_passes() {
296        assert!(validator().validate_memory_save("hello world").is_ok());
297    }
298
299    #[test]
300    fn oversized_content_rejected() {
301        let big = "x".repeat(5000);
302        let err = validator().validate_memory_save(&big).unwrap_err();
303        assert!(matches!(err, MemoryValidationError::ContentTooLarge { .. }));
304    }
305
306    #[test]
307    fn forbidden_pattern_rejected() {
308        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
309            forbidden_content_patterns: vec!["<script".to_owned()],
310            ..MemoryWriteValidationConfig::default()
311        });
312        let err = v
313            .validate_memory_save("text <script>alert(1)</script>")
314            .unwrap_err();
315        assert!(matches!(
316            err,
317            MemoryValidationError::ForbiddenPattern { .. }
318        ));
319    }
320
321    #[test]
322    fn disabled_skips_validation() {
323        let big = "x".repeat(9999);
324        assert!(validator_disabled().validate_memory_save(&big).is_ok());
325    }
326
327    // --- graph extraction validation ---
328
329    #[test]
330    fn valid_extraction_passes() {
331        let r = result_with(vec![entity("Rust"), entity("Alice")], vec![edge("fact")]);
332        assert!(validator().validate_graph_extraction(&r).is_ok());
333    }
334
335    #[test]
336    fn too_many_entities_rejected() {
337        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
338            max_entities_per_extraction: 2,
339            ..MemoryWriteValidationConfig::default()
340        });
341        let r = result_with(vec![entity("A"), entity("B"), entity("C")], vec![]);
342        let err = v.validate_graph_extraction(&r).unwrap_err();
343        assert!(matches!(err, MemoryValidationError::TooManyEntities { .. }));
344    }
345
346    #[test]
347    fn too_many_edges_rejected() {
348        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
349            max_edges_per_extraction: 1,
350            ..MemoryWriteValidationConfig::default()
351        });
352        let r = result_with(vec![], vec![edge("a"), edge("b")]);
353        let err = v.validate_graph_extraction(&r).unwrap_err();
354        assert!(matches!(err, MemoryValidationError::TooManyEdges { .. }));
355    }
356
357    #[test]
358    fn entity_name_too_long_rejected() {
359        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
360            max_entity_name_bytes: 5,
361            ..MemoryWriteValidationConfig::default()
362        });
363        let r = result_with(vec![entity("TooLongName")], vec![]);
364        let err = v.validate_graph_extraction(&r).unwrap_err();
365        assert!(matches!(
366            err,
367            MemoryValidationError::EntityNameTooLong { .. }
368        ));
369    }
370
371    #[test]
372    fn fact_too_long_rejected() {
373        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
374            max_fact_bytes: 10,
375            ..MemoryWriteValidationConfig::default()
376        });
377        let r = result_with(vec![], vec![edge("this fact is longer than ten chars")]);
378        let err = v.validate_graph_extraction(&r).unwrap_err();
379        assert!(matches!(err, MemoryValidationError::FactTooLong { .. }));
380    }
381
382    #[test]
383    fn email_in_entity_name_rejected() {
384        let r = result_with(vec![entity("user@example.com")], vec![]);
385        let err = validator().validate_graph_extraction(&r).unwrap_err();
386        assert!(matches!(
387            err,
388            MemoryValidationError::SuspiciousPiiInEntityName { .. }
389        ));
390    }
391
392    #[test]
393    fn ssn_in_entity_name_rejected() {
394        let r = result_with(vec![entity("123-45-6789")], vec![]);
395        let err = validator().validate_graph_extraction(&r).unwrap_err();
396        assert!(matches!(
397            err,
398            MemoryValidationError::SuspiciousPiiInEntityName { .. }
399        ));
400    }
401
402    #[test]
403    fn disabled_skips_graph_validation() {
404        let v = validator_disabled();
405        let big_entities: Vec<_> = (0..200).map(|i| entity(&format!("E{i}"))).collect();
406        let r = result_with(big_entities, vec![]);
407        assert!(v.validate_graph_extraction(&r).is_ok());
408    }
409
410    // --- exact boundary: max_content_bytes ---
411
412    #[test]
413    fn content_exactly_at_limit_passes() {
414        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
415            max_content_bytes: 10,
416            ..MemoryWriteValidationConfig::default()
417        });
418        // Exactly 10 bytes — must pass.
419        assert!(v.validate_memory_save("1234567890").is_ok());
420    }
421
422    #[test]
423    fn content_one_byte_over_limit_rejected() {
424        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
425            max_content_bytes: 10,
426            ..MemoryWriteValidationConfig::default()
427        });
428        // 11 bytes — must fail.
429        let err = v.validate_memory_save("12345678901").unwrap_err();
430        assert!(matches!(err, MemoryValidationError::ContentTooLarge { .. }));
431    }
432
433    // --- multiple forbidden patterns: first match blocks ---
434
435    #[test]
436    fn multiple_forbidden_patterns_first_match_blocks() {
437        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
438            forbidden_content_patterns: vec!["<script".to_owned(), "javascript:".to_owned()],
439            ..MemoryWriteValidationConfig::default()
440        });
441        let err = v.validate_memory_save("javascript:alert(1)").unwrap_err();
442        assert!(matches!(
443            err,
444            MemoryValidationError::ForbiddenPattern { .. }
445        ));
446    }
447
448    #[test]
449    fn content_without_forbidden_pattern_passes() {
450        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
451            forbidden_content_patterns: vec!["<script".to_owned()],
452            ..MemoryWriteValidationConfig::default()
453        });
454        assert!(v.validate_memory_save("safe content here").is_ok());
455    }
456
457    // --- is_enabled ---
458
459    #[test]
460    fn is_enabled_true_by_default() {
461        assert!(validator().is_enabled());
462    }
463
464    #[test]
465    fn is_enabled_false_when_disabled() {
466        assert!(!validator_disabled().is_enabled());
467    }
468
469    // --- empty ExtractionResult passes ---
470
471    #[test]
472    fn empty_extraction_passes() {
473        let r = result_with(vec![], vec![]);
474        assert!(validator().validate_graph_extraction(&r).is_ok());
475    }
476
477    // --- exact boundary: entity name ---
478
479    #[test]
480    fn entity_name_exactly_at_limit_passes() {
481        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
482            max_entity_name_bytes: 5,
483            ..MemoryWriteValidationConfig::default()
484        });
485        let r = result_with(vec![entity("Alice")], vec![]); // 5 bytes exactly
486        assert!(v.validate_graph_extraction(&r).is_ok());
487    }
488
489    #[test]
490    fn entity_name_one_byte_over_limit_rejected() {
491        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
492            max_entity_name_bytes: 5,
493            ..MemoryWriteValidationConfig::default()
494        });
495        let r = result_with(vec![entity("AliceX")], vec![]); // 6 bytes
496        let err = v.validate_graph_extraction(&r).unwrap_err();
497        assert!(matches!(
498            err,
499            MemoryValidationError::EntityNameTooLong { .. }
500        ));
501    }
502
503    // --- exact boundary: entities count ---
504
505    #[test]
506    fn entities_exactly_at_limit_passes() {
507        let v = MemoryWriteValidator::new(MemoryWriteValidationConfig {
508            max_entities_per_extraction: 3,
509            ..MemoryWriteValidationConfig::default()
510        });
511        let r = result_with(vec![entity("A"), entity("B"), entity("C")], vec![]);
512        assert!(v.validate_graph_extraction(&r).is_ok());
513    }
514
515    // --- error message content ---
516
517    #[test]
518    fn content_too_large_error_message() {
519        let big = "x".repeat(5000);
520        let err = validator().validate_memory_save(&big).unwrap_err();
521        let msg = err.to_string();
522        assert!(msg.contains("5000"), "error must include actual size");
523        assert!(msg.contains("4096"), "error must include max size");
524    }
525}