Skip to main content

synapse_pingora/profiler/
schema_learner.rs

1//! Schema learning engine for API endpoints.
2//!
3//! Automatically learns JSON schema structure from request/response bodies:
4//! - Extracts field types and constraints
5//! - Builds schema maps per endpoint
6//! - Validates requests against learned schemas
7//!
8//! ## Limitations
9//!
10//! **Array-root bodies are not supported**: Only JSON object bodies are processed.
11//! Array-root bodies (e.g., `[{...}, {...}]`) are silently skipped. This is a known
12//! limitation for APIs that use arrays as the root element in request/response bodies.
13//! Such APIs will not benefit from schema learning or validation.
14//!
15//! ## Performance
16//! - Learn from request: ~5us typical
17//! - Validate request: ~3us typical
18//! - Thread-safe via DashMap
19//! - O(1) amortized LRU eviction via generation-tracked queue
20
21use std::collections::{HashMap, VecDeque};
22use std::time::{SystemTime, UNIX_EPOCH};
23
24use dashmap::DashMap;
25use parking_lot::Mutex;
26use serde::{Deserialize, Serialize};
27
28use crate::profiler::patterns::detect_pattern;
29use crate::profiler::schema_types::{
30    EndpointSchema, FieldSchema, FieldType, SchemaViolation, ValidationResult,
31};
32
33// ============================================================================
34// Configuration
35// ============================================================================
36
37/// Configuration for the schema learner.
38///
39/// # Security Considerations
40///
41/// The tolerance values (`string_length_tolerance` and `number_value_tolerance`) directly
42/// impact the security posture of schema validation. These multipliers determine how much
43/// deviation from learned baselines is permitted before a request is flagged as anomalous.
44///
45/// ## Tolerance Trade-offs
46///
47/// - **Lower values (1.0-1.5)**: Stricter validation, higher security, but may cause
48///   false positives if legitimate traffic has natural variance.
49/// - **Higher values (2.0+)**: More permissive, fewer false positives, but allows
50///   attackers more room to inject oversized payloads or extreme values.
51///
52/// ## Recommendations
53///
54/// - Start with default tolerance (1.5) and monitor for false positives
55/// - For high-security APIs: consider 1.2-1.3
56/// - For APIs with high variance: consider 1.5-2.0
57/// - Never set below 1.0 (would reject valid baseline data)
58///
59/// # Example
60///
61/// ```
62/// use synapse_pingora::profiler::SchemaLearnerConfig;
63///
64/// // Stricter configuration for sensitive APIs
65/// let config = SchemaLearnerConfig {
66///     string_length_tolerance: 1.3,  // 30% buffer above learned max
67///     number_value_tolerance: 1.25,  // 25% buffer above learned max
68///     ..Default::default()
69/// };
70///
71/// // Validate config before use
72/// config.validate().expect("Invalid configuration");
73/// ```
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct SchemaLearnerConfig {
76    /// Maximum number of endpoint schemas to track.
77    ///
78    /// When this limit is reached, the least recently used (LRU) schema is evicted.
79    pub max_schemas: usize,
80
81    /// Minimum samples required before validation is active.
82    ///
83    /// Until an endpoint has been observed this many times, validation will not flag
84    /// anomalies. This prevents false positives during the initial learning phase.
85    pub min_samples_for_validation: u32,
86
87    /// Maximum depth for nested object learning.
88    ///
89    /// Prevents excessive memory usage from deeply nested JSON structures.
90    pub max_nesting_depth: usize,
91
92    /// Maximum fields per schema (memory protection).
93    ///
94    /// Limits the number of fields tracked per endpoint to prevent memory exhaustion
95    /// from APIs with dynamic or unbounded field sets.
96    pub max_fields_per_schema: usize,
97
98    /// String length tolerance multiplier for validation.
99    ///
100    /// When validating string fields, the maximum allowed length is:
101    /// `learned_max_length * string_length_tolerance`
102    ///
103    /// # Security Impact
104    ///
105    /// - **Lower values (1.0-1.3)**: Catches buffer overflow attempts more aggressively
106    ///   but may flag legitimate variance as anomalous.
107    /// - **Higher values (1.5-2.0)**: More permissive, reducing false positives but
108    ///   allowing larger payloads that could exploit vulnerabilities.
109    ///
110    /// Default: 1.5 (50% buffer above learned maximum)
111    ///
112    /// # Constraints
113    ///
114    /// Must be >= 1.0. Values below 1.0 would reject strings that were seen in the
115    /// baseline training data, causing immediate false positives.
116    pub string_length_tolerance: f64,
117
118    /// Number value tolerance multiplier for validation.
119    ///
120    /// When validating numeric fields:
121    /// - Maximum allowed: `learned_max * number_value_tolerance`
122    /// - Minimum allowed: `learned_min / number_value_tolerance`
123    ///
124    /// # Security Impact
125    ///
126    /// - **Lower values (1.0-1.3)**: Catches integer overflow attempts and extreme
127    ///   value injection more aggressively.
128    /// - **Higher values (1.5-2.0)**: More permissive for APIs with high numeric variance.
129    ///
130    /// Default: 1.5 (50% buffer on max values, 33% reduction on min values)
131    ///
132    /// # Constraints
133    ///
134    /// Must be >= 1.0. Values below 1.0 would reject values that were seen in the
135    /// baseline training data, causing immediate false positives.
136    pub number_value_tolerance: f64,
137
138    /// Required field threshold (fields seen in > threshold% of requests).
139    ///
140    /// Fields that appear in more than this percentage of observed requests are
141    /// considered "required" and their absence will trigger a MissingField violation.
142    ///
143    /// Default: 0.9 (90% - fields must appear in 90% of samples to be required)
144    pub required_field_threshold: f64,
145}
146
147/// Validation error for SchemaLearnerConfig.
148#[derive(Debug, Clone, PartialEq)]
149pub struct ConfigValidationError {
150    /// The field that failed validation
151    pub field: &'static str,
152    /// Description of the validation failure
153    pub message: String,
154}
155
156impl std::fmt::Display for ConfigValidationError {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(f, "Invalid {}: {}", self.field, self.message)
159    }
160}
161
162impl std::error::Error for ConfigValidationError {}
163
164impl SchemaLearnerConfig {
165    /// Validates the configuration, ensuring all values are within acceptable ranges.
166    ///
167    /// # Errors
168    ///
169    /// Returns `ConfigValidationError` if:
170    /// - `string_length_tolerance` < 1.0
171    /// - `number_value_tolerance` < 1.0
172    /// - `required_field_threshold` is not in range [0.0, 1.0]
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// use synapse_pingora::profiler::SchemaLearnerConfig;
178    ///
179    /// let config = SchemaLearnerConfig {
180    ///     string_length_tolerance: 0.5, // Invalid!
181    ///     ..Default::default()
182    /// };
183    ///
184    /// assert!(config.validate().is_err());
185    /// ```
186    pub fn validate(&self) -> Result<(), ConfigValidationError> {
187        if self.string_length_tolerance < 1.0 {
188            return Err(ConfigValidationError {
189                field: "string_length_tolerance",
190                message: format!(
191                    "must be >= 1.0 to avoid rejecting baseline data (got {})",
192                    self.string_length_tolerance
193                ),
194            });
195        }
196
197        if self.number_value_tolerance < 1.0 {
198            return Err(ConfigValidationError {
199                field: "number_value_tolerance",
200                message: format!(
201                    "must be >= 1.0 to avoid rejecting baseline data (got {})",
202                    self.number_value_tolerance
203                ),
204            });
205        }
206
207        if !(0.0..=1.0).contains(&self.required_field_threshold) {
208            return Err(ConfigValidationError {
209                field: "required_field_threshold",
210                message: format!(
211                    "must be between 0.0 and 1.0 (got {})",
212                    self.required_field_threshold
213                ),
214            });
215        }
216
217        Ok(())
218    }
219}
220
221impl Default for SchemaLearnerConfig {
222    fn default() -> Self {
223        Self {
224            max_schemas: 5000,
225            min_samples_for_validation: 10,
226            max_nesting_depth: 10,
227            max_fields_per_schema: 100,
228            string_length_tolerance: 1.5,
229            number_value_tolerance: 1.5,
230            required_field_threshold: 0.9,
231        }
232    }
233}
234
235// ============================================================================
236// LRU Tracker (O(1) amortized eviction)
237// ============================================================================
238
239/// Entry in the LRU queue with generation tracking.
240#[derive(Debug, Clone)]
241struct LruEntry {
242    /// The template key
243    key: String,
244    /// Generation number for this entry
245    generation: u64,
246}
247
248/// Thread-safe LRU tracker using a generation-based queue.
249///
250/// Uses a `VecDeque` as a FIFO queue with generation tracking to achieve
251/// amortized O(1) eviction. When a key is accessed, a new entry with an
252/// incremented generation is pushed to the back. During eviction, stale
253/// entries (where generation doesn't match) are skipped.
254///
255/// ## Complexity
256/// - `touch()`: O(1) - push to back of queue
257/// - `evict_oldest()`: O(1) amortized - pop from front, skip stale entries
258///
259/// ## Memory
260/// The queue may temporarily hold stale entries, but these are cleaned up
261/// during eviction. The `generations` HashMap always has at most `max_schemas`
262/// entries.
263struct LruTracker {
264    /// FIFO queue of (key, generation) pairs
265    queue: VecDeque<LruEntry>,
266    /// Current generation for each key
267    generations: HashMap<String, u64>,
268    /// Counter for assigning unique generations
269    next_generation: u64,
270}
271
272impl LruTracker {
273    /// Create a new LRU tracker with initial capacity.
274    fn new(capacity: usize) -> Self {
275        Self {
276            queue: VecDeque::with_capacity(capacity),
277            generations: HashMap::with_capacity(capacity),
278            next_generation: 0,
279        }
280    }
281
282    /// Touch a key, marking it as recently used.
283    /// Returns true if this is a new key, false if it already existed.
284    fn touch(&mut self, key: &str) -> bool {
285        let generation = self.next_generation;
286        self.next_generation = self.next_generation.wrapping_add(1);
287
288        let is_new = !self.generations.contains_key(key);
289        self.generations.insert(key.to_string(), generation);
290        self.queue.push_back(LruEntry {
291            key: key.to_string(),
292            generation,
293        });
294
295        is_new
296    }
297
298    /// Remove a key from tracking (used when schema is evicted).
299    #[allow(dead_code)]
300    fn remove(&mut self, key: &str) {
301        self.generations.remove(key);
302        // Note: stale entries in queue will be skipped during eviction
303    }
304
305    /// Evict the oldest key that is still valid.
306    /// Returns the key to evict, or None if empty.
307    fn evict_oldest(&mut self) -> Option<String> {
308        while let Some(entry) = self.queue.pop_front() {
309            // Check if this entry is still valid (generation matches)
310            if let Some(&current_gen) = self.generations.get(&entry.key) {
311                if current_gen == entry.generation {
312                    // This is the current entry for this key - evict it
313                    self.generations.remove(&entry.key);
314                    return Some(entry.key);
315                }
316            }
317            // Entry is stale (key was updated or removed), continue to next
318        }
319        None
320    }
321
322    /// Get number of tracked keys.
323    #[allow(dead_code)]
324    fn len(&self) -> usize {
325        self.generations.len()
326    }
327
328    /// Clear all entries (used during import).
329    fn clear(&mut self) {
330        self.queue.clear();
331        self.generations.clear();
332        self.next_generation = 0;
333    }
334}
335
336// ============================================================================
337// SchemaLearner
338// ============================================================================
339
340/// Thread-safe schema learner for API endpoints.
341///
342/// Uses DashMap for lock-free concurrent access to endpoint schemas.
343/// Implements O(1) amortized LRU eviction when max_schemas is exceeded.
344pub struct SchemaLearner {
345    /// Endpoint schemas indexed by template path
346    schemas: DashMap<String, EndpointSchema>,
347
348    /// LRU tracker for O(1) eviction (protected by Mutex for thread safety)
349    lru: Mutex<LruTracker>,
350
351    /// Configuration
352    config: SchemaLearnerConfig,
353}
354
355impl Default for SchemaLearner {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361impl SchemaLearner {
362    /// Create a new schema learner with default configuration.
363    pub fn new() -> Self {
364        Self::with_config(SchemaLearnerConfig::default())
365    }
366
367    /// Create a new schema learner with custom configuration.
368    pub fn with_config(config: SchemaLearnerConfig) -> Self {
369        Self {
370            schemas: DashMap::with_capacity(config.max_schemas),
371            lru: Mutex::new(LruTracker::new(config.max_schemas)),
372            config,
373        }
374    }
375
376    /// Get current configuration.
377    pub fn config(&self) -> &SchemaLearnerConfig {
378        &self.config
379    }
380
381    /// Get number of tracked schemas.
382    pub fn len(&self) -> usize {
383        self.schemas.len()
384    }
385
386    /// Check if empty.
387    pub fn is_empty(&self) -> bool {
388        self.schemas.is_empty()
389    }
390
391    /// Get current timestamp in milliseconds.
392    fn now_ms() -> u64 {
393        SystemTime::now()
394            .duration_since(UNIX_EPOCH)
395            .map(|d| d.as_millis() as u64)
396            .unwrap_or(0)
397    }
398
399    // ========================================================================
400    // Learning
401    // ========================================================================
402
403    /// Learn schema from a request body.
404    ///
405    /// Updates the endpoint schema with field types and constraints learned
406    /// from the JSON body.
407    ///
408    /// # Note
409    ///
410    /// Only JSON object bodies are processed. Array-root bodies (e.g., `[{...}]`)
411    /// are silently skipped. This is a known limitation for APIs that use arrays
412    /// as the root element in request/response bodies.
413    pub fn learn_from_request(&self, template: &str, request_body: &serde_json::Value) {
414        self.learn_internal(template, request_body, SchemaTarget::Request);
415    }
416
417    /// Learn schema from a response body.
418    ///
419    /// # Note
420    ///
421    /// Only JSON object bodies are processed. Array-root bodies (e.g., `[{...}]`)
422    /// are silently skipped. This is a known limitation for APIs that use arrays
423    /// as the root element in request/response bodies.
424    pub fn learn_from_response(&self, template: &str, response_body: &serde_json::Value) {
425        self.learn_internal(template, response_body, SchemaTarget::Response);
426    }
427
428    /// Learn from both request and response.
429    ///
430    /// # Note
431    ///
432    /// Only JSON object bodies are processed. Array-root bodies (e.g., `[{...}]`)
433    /// are silently skipped. This is a known limitation for APIs that use arrays
434    /// as the root element in request/response bodies.
435    pub fn learn_from_pair(
436        &self,
437        template: &str,
438        request_body: Option<&serde_json::Value>,
439        response_body: Option<&serde_json::Value>,
440    ) {
441        let now = Self::now_ms();
442
443        // Ensure schema exists and update sample count
444        self.ensure_schema(template, now);
445
446        if let Some(req) = request_body {
447            if req.is_object() {
448                self.update_schema_fields(template, req, SchemaTarget::Request, "", 0);
449            }
450        }
451
452        if let Some(resp) = response_body {
453            if resp.is_object() {
454                self.update_schema_fields(template, resp, SchemaTarget::Response, "", 0);
455            }
456        }
457
458        // Increment sample count
459        if let Some(mut schema) = self.schemas.get_mut(template) {
460            schema.sample_count += 1;
461            schema.last_updated_ms = now;
462        }
463    }
464
465    /// Internal learning implementation.
466    fn learn_internal(&self, template: &str, body: &serde_json::Value, target: SchemaTarget) {
467        if !body.is_object() {
468            return;
469        }
470
471        let now = Self::now_ms();
472        self.ensure_schema(template, now);
473        self.update_schema_fields(template, body, target, "", 0);
474
475        // Update sample count only for request bodies (avoid double counting)
476        if matches!(target, SchemaTarget::Request) {
477            if let Some(mut schema) = self.schemas.get_mut(template) {
478                schema.sample_count += 1;
479                schema.last_updated_ms = now;
480            }
481        }
482    }
483
484    /// Ensure a schema exists for the template.
485    fn ensure_schema(&self, template: &str, now: u64) {
486        // Fast path: schema already exists, just touch in LRU
487        if self.schemas.contains_key(template) {
488            // Touch the key in LRU to mark as recently used
489            let mut lru = self.lru.lock();
490            lru.touch(template);
491            return;
492        }
493
494        // Slow path: need to insert new schema
495        let mut lru = self.lru.lock();
496
497        // Double-check after acquiring lock (another thread may have inserted)
498        if self.schemas.contains_key(template) {
499            lru.touch(template);
500            return;
501        }
502
503        // LRU eviction if at capacity (O(1) amortized)
504        if self.schemas.len() >= self.config.max_schemas {
505            if let Some(evict_key) = lru.evict_oldest() {
506                self.schemas.remove(&evict_key);
507            }
508        }
509
510        // Insert new schema and track in LRU
511        lru.touch(template);
512        self.schemas.insert(
513            template.to_string(),
514            EndpointSchema::new(template.to_string(), now),
515        );
516    }
517
518    /// Update schema fields from JSON value.
519    /// Optimized to collect nested objects in a single pass, avoiding double iteration.
520    fn update_schema_fields(
521        &self,
522        template: &str,
523        value: &serde_json::Value,
524        target: SchemaTarget,
525        prefix: &str,
526        depth: usize,
527    ) {
528        // Guard against deep nesting (depth is 0-indexed, so >= ensures max_nesting_depth levels)
529        if depth >= self.config.max_nesting_depth {
530            return;
531        }
532
533        let obj = match value.as_object() {
534            Some(o) => o,
535            None => return,
536        };
537
538        // Collect nested objects in the same pass (avoiding double iteration)
539        let mut nested_objects: Vec<(String, &serde_json::Value)> = Vec::new();
540
541        {
542            let mut schema_guard = match self.schemas.get_mut(template) {
543                Some(s) => s,
544                None => return,
545            };
546
547            let schema_map = match target {
548                SchemaTarget::Request => &mut schema_guard.request_schema,
549                SchemaTarget::Response => &mut schema_guard.response_schema,
550            };
551
552            for (key, val) in obj {
553                // Memory protection: check before adding each field
554                if schema_map.len() >= self.config.max_fields_per_schema {
555                    break;
556                }
557
558                let field_name = if prefix.is_empty() {
559                    key.clone()
560                } else {
561                    format!("{}.{}", prefix, key)
562                };
563
564                let field_type = FieldType::from_json_value(val);
565
566                // Get or create field schema
567                let field_schema = schema_map
568                    .entry(field_name.clone())
569                    .or_insert_with(|| FieldSchema::new(field_name.clone()));
570
571                // Record type
572                field_schema.record_type(field_type);
573
574                // Update constraints based on type
575                match val {
576                    serde_json::Value::String(s) => {
577                        let pattern = detect_pattern(s);
578                        field_schema.update_string_constraints(s.len() as u32, pattern);
579                    }
580                    serde_json::Value::Number(n) => {
581                        if let Some(f) = n.as_f64() {
582                            field_schema.update_number_constraints(f);
583                        }
584                    }
585                    serde_json::Value::Array(arr) => {
586                        for item in arr {
587                            let item_type = FieldType::from_json_value(item);
588                            field_schema.add_array_item_type(item_type);
589                        }
590                    }
591                    serde_json::Value::Object(_) => {
592                        // Initialize nested object schema if needed
593                        if field_schema.object_schema.is_none() {
594                            field_schema.object_schema = Some(HashMap::new());
595                        }
596                        // Collect for recursion (single pass optimization)
597                        nested_objects.push((field_name, val));
598                    }
599                    _ => {}
600                }
601            }
602            // schema_guard dropped here at end of block
603        }
604
605        // Recurse into nested objects (guard already dropped)
606        for (field_name, val) in nested_objects {
607            self.update_schema_fields(template, val, target, &field_name, depth + 1);
608        }
609    }
610
611    // ========================================================================
612    // Validation
613    // ========================================================================
614
615    /// Validate a request body against the learned schema.
616    ///
617    /// Returns a list of violations. Empty list means validation passed.
618    /// Returns empty if schema doesn't exist or has insufficient samples.
619    pub fn validate_request(
620        &self,
621        template: &str,
622        request_body: &serde_json::Value,
623    ) -> ValidationResult {
624        self.validate_internal(template, request_body, SchemaTarget::Request)
625    }
626
627    /// Validate a response body against the learned schema.
628    pub fn validate_response(
629        &self,
630        template: &str,
631        response_body: &serde_json::Value,
632    ) -> ValidationResult {
633        self.validate_internal(template, response_body, SchemaTarget::Response)
634    }
635
636    /// Internal validation implementation.
637    fn validate_internal(
638        &self,
639        template: &str,
640        body: &serde_json::Value,
641        target: SchemaTarget,
642    ) -> ValidationResult {
643        let mut result = ValidationResult::new();
644
645        let schema = match self.schemas.get(template) {
646            Some(s) => s,
647            None => return result, // No schema = no validation
648        };
649
650        // Skip validation if insufficient samples
651        if schema.sample_count < self.config.min_samples_for_validation {
652            return result;
653        }
654
655        let schema_map = match target {
656            SchemaTarget::Request => &schema.request_schema,
657            SchemaTarget::Response => &schema.response_schema,
658        };
659
660        self.validate_against_schema(
661            schema_map,
662            body,
663            "",
664            &mut result,
665            schema.sample_count,
666            0, // Start at depth 0
667        );
668
669        result
670    }
671
672    /// Validate data against a schema map.
673    fn validate_against_schema(
674        &self,
675        root_schema_map: &HashMap<String, FieldSchema>,
676        data: &serde_json::Value,
677        prefix: &str,
678        result: &mut ValidationResult,
679        sample_count: u32,
680        depth: usize,
681    ) {
682        // Protect against stack overflow from malicious deeply nested JSON
683        if depth >= self.config.max_nesting_depth {
684            return;
685        }
686
687        let obj = match data.as_object() {
688            Some(o) => o,
689            None => return,
690        };
691
692        // Check for unexpected fields
693        for (key, val) in obj {
694            let field_name = if prefix.is_empty() {
695                key.clone()
696            } else {
697                format!("{}.{}", prefix, key)
698            };
699
700            let field_schema = match root_schema_map.get(&field_name) {
701                Some(s) => s,
702                None => {
703                    result.add(SchemaViolation::unexpected_field(&field_name));
704                    continue;
705                }
706            };
707
708            let actual_type = FieldType::from_json_value(val);
709
710            // Type mismatch check
711            let dominant_type = field_schema.dominant_type();
712            if actual_type != dominant_type && !(val.is_null() && field_schema.nullable) {
713                result.add(SchemaViolation::type_mismatch(
714                    &field_name,
715                    dominant_type,
716                    actual_type,
717                ));
718            }
719
720            // String constraint checks
721            if let serde_json::Value::String(s) = val {
722                self.validate_string_field(&field_name, s, field_schema, result);
723            }
724
725            // Number constraint checks
726            if let serde_json::Value::Number(n) = val {
727                if let Some(f) = n.as_f64() {
728                    self.validate_number_field(&field_name, f, field_schema, result);
729                }
730            }
731
732            // Recurse into nested objects using the root map and dotted prefix
733            if val.is_object() {
734                self.validate_against_schema(
735                    root_schema_map,
736                    val,
737                    &field_name,
738                    result,
739                    sample_count,
740                    depth + 1,
741                );
742            }
743        }
744
745        // Check for missing required fields (seen in >90% of samples)
746        let threshold = (sample_count as f64 * self.config.required_field_threshold) as u32;
747        for (field_name, field_schema) in root_schema_map {
748            // Only check fields that are immediate children of this prefix
749            let is_direct_child = if prefix.is_empty() {
750                !field_name.contains('.')
751            } else if field_name.starts_with(prefix) && field_name.len() > prefix.len() + 1 {
752                let suffix = &field_name[prefix.len() + 1..];
753                !suffix.contains('.')
754            } else {
755                false
756            };
757
758            if is_direct_child && field_schema.seen_count >= threshold {
759                let key = field_name.rsplit('.').next().unwrap_or(field_name);
760                if !obj.contains_key(key) {
761                    result.add(SchemaViolation::missing_field(field_name));
762                }
763            }
764        }
765    }
766
767    /// Validate string field constraints.
768    fn validate_string_field(
769        &self,
770        field_name: &str,
771        value: &str,
772        schema: &FieldSchema,
773        result: &mut ValidationResult,
774    ) {
775        let len = value.len() as u32;
776
777        // Length too short
778        if let Some(min) = schema.min_length {
779            if len < min {
780                result.add(SchemaViolation::string_too_short(field_name, min, len));
781            }
782        }
783
784        // Length too long (with tolerance)
785        if let Some(max) = schema.max_length {
786            let allowed_max = (max as f64 * self.config.string_length_tolerance) as u32;
787            if len > allowed_max {
788                result.add(SchemaViolation::string_too_long(
789                    field_name,
790                    allowed_max,
791                    len,
792                ));
793            }
794        }
795
796        // Pattern mismatch
797        if let Some(expected_pattern) = schema.pattern {
798            let actual_pattern = detect_pattern(value);
799            if actual_pattern != Some(expected_pattern) {
800                result.add(SchemaViolation::pattern_mismatch(
801                    field_name,
802                    expected_pattern,
803                    actual_pattern,
804                ));
805            }
806        }
807    }
808
809    /// Validate number field constraints.
810    fn validate_number_field(
811        &self,
812        field_name: &str,
813        value: f64,
814        schema: &FieldSchema,
815        result: &mut ValidationResult,
816    ) {
817        // Value too small (with tolerance)
818        if let Some(min) = schema.min_value {
819            let allowed_min = min * (1.0 / self.config.number_value_tolerance);
820            if value < allowed_min {
821                result.add(SchemaViolation::number_too_small(
822                    field_name,
823                    allowed_min,
824                    value,
825                ));
826            }
827        }
828
829        // Value too large (with tolerance)
830        if let Some(max) = schema.max_value {
831            let allowed_max = max * self.config.number_value_tolerance;
832            if value > allowed_max {
833                result.add(SchemaViolation::number_too_large(
834                    field_name,
835                    allowed_max,
836                    value,
837                ));
838            }
839        }
840    }
841
842    // ========================================================================
843    // Schema Access
844    // ========================================================================
845
846    /// Get schema for an endpoint.
847    pub fn get_schema(&self, template: &str) -> Option<EndpointSchema> {
848        self.schemas.get(template).map(|s| s.value().clone())
849    }
850
851    /// Get all schemas.
852    pub fn get_all_schemas(&self) -> Vec<EndpointSchema> {
853        self.schemas
854            .iter()
855            .map(|entry| entry.value().clone())
856            .collect()
857    }
858
859    /// Get statistics.
860    pub fn get_stats(&self) -> SchemaLearnerStats {
861        let schemas: Vec<_> = self.schemas.iter().collect();
862        let total_samples: u32 = schemas.iter().map(|s| s.sample_count).sum();
863        let total_fields: usize = schemas
864            .iter()
865            .map(|s| s.request_schema.len() + s.response_schema.len())
866            .sum();
867
868        SchemaLearnerStats {
869            total_schemas: schemas.len(),
870            total_samples,
871            avg_fields_per_endpoint: if schemas.is_empty() {
872                0.0
873            } else {
874                total_fields as f64 / schemas.len() as f64
875            },
876        }
877    }
878
879    // ========================================================================
880    // Persistence
881    // ========================================================================
882
883    /// Export all schemas for persistence.
884    pub fn export(&self) -> Vec<EndpointSchema> {
885        self.get_all_schemas()
886    }
887
888    /// Import schemas from persistence.
889    pub fn import(&self, schemas: Vec<EndpointSchema>) {
890        // Clear both schemas and LRU tracker
891        self.schemas.clear();
892        let mut lru = self.lru.lock();
893        lru.clear();
894
895        // Re-insert all schemas, sorted by last_updated_ms to preserve LRU order
896        let mut sorted_schemas = schemas;
897        sorted_schemas.sort_by_key(|s| s.last_updated_ms);
898
899        for schema in sorted_schemas {
900            lru.touch(&schema.template);
901            self.schemas.insert(schema.template.clone(), schema);
902        }
903    }
904
905    /// Clear all schemas.
906    pub fn clear(&self) {
907        self.schemas.clear();
908        self.lru.lock().clear();
909    }
910}
911
912// ============================================================================
913// Helper Types
914// ============================================================================
915
916/// Target schema (request or response).
917#[derive(Debug, Clone, Copy)]
918enum SchemaTarget {
919    Request,
920    Response,
921}
922
923/// Statistics about the schema learner.
924#[derive(Debug, Clone, Serialize)]
925pub struct SchemaLearnerStats {
926    /// Total number of endpoint schemas
927    pub total_schemas: usize,
928    /// Total samples across all schemas
929    pub total_samples: u32,
930    /// Average fields per endpoint
931    pub avg_fields_per_endpoint: f64,
932}
933
934// ============================================================================
935// Tests
936// ============================================================================
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941    use crate::profiler::schema_types::{PatternType, ViolationType};
942    use serde_json::json;
943
944    #[test]
945    fn test_learn_from_request() {
946        let learner = SchemaLearner::new();
947
948        let body = json!({
949            "username": "john_doe",
950            "email": "john@example.com",
951            "age": 30
952        });
953
954        learner.learn_from_request("/api/users", &body);
955
956        let schema = learner.get_schema("/api/users").unwrap();
957        assert_eq!(schema.sample_count, 1);
958        assert!(schema.request_schema.contains_key("username"));
959        assert!(schema.request_schema.contains_key("email"));
960        assert!(schema.request_schema.contains_key("age"));
961    }
962
963    #[test]
964    fn test_learn_type_tracking() {
965        let learner = SchemaLearner::new();
966
967        // Learn multiple requests with same field types
968        for i in 0..10 {
969            let body = json!({
970                "id": i,
971                "name": format!("user_{}", i)
972            });
973            learner.learn_from_request("/api/users", &body);
974        }
975
976        let schema = learner.get_schema("/api/users").unwrap();
977        let id_schema = schema.request_schema.get("id").unwrap();
978        let name_schema = schema.request_schema.get("name").unwrap();
979
980        assert_eq!(id_schema.dominant_type(), FieldType::Number);
981        assert_eq!(name_schema.dominant_type(), FieldType::String);
982        assert_eq!(id_schema.seen_count, 10);
983    }
984
985    #[test]
986    fn test_learn_string_constraints() {
987        let learner = SchemaLearner::new();
988
989        let bodies = vec![
990            json!({"name": "ab"}),     // 2 chars
991            json!({"name": "abcdef"}), // 6 chars
992            json!({"name": "abcd"}),   // 4 chars
993        ];
994
995        for body in bodies {
996            learner.learn_from_request("/api/test", &body);
997        }
998
999        let schema = learner.get_schema("/api/test").unwrap();
1000        let name_schema = schema.request_schema.get("name").unwrap();
1001
1002        assert_eq!(name_schema.min_length, Some(2));
1003        assert_eq!(name_schema.max_length, Some(6));
1004    }
1005
1006    #[test]
1007    fn test_learn_pattern_detection() {
1008        let learner = SchemaLearner::new();
1009
1010        let body = json!({
1011            "id": "550e8400-e29b-41d4-a716-446655440000",
1012            "email": "user@example.com"
1013        });
1014
1015        learner.learn_from_request("/api/users", &body);
1016
1017        let schema = learner.get_schema("/api/users").unwrap();
1018        let id_schema = schema.request_schema.get("id").unwrap();
1019        let email_schema = schema.request_schema.get("email").unwrap();
1020
1021        assert_eq!(id_schema.pattern, Some(PatternType::Uuid));
1022        assert_eq!(email_schema.pattern, Some(PatternType::Email));
1023    }
1024
1025    #[test]
1026    fn test_learn_nested_objects() {
1027        let learner = SchemaLearner::new();
1028
1029        let body = json!({
1030            "user": {
1031                "name": "John",
1032                "address": {
1033                    "city": "NYC"
1034                }
1035            }
1036        });
1037
1038        learner.learn_from_request("/api/data", &body);
1039
1040        let schema = learner.get_schema("/api/data").unwrap();
1041        assert!(schema.request_schema.contains_key("user"));
1042        assert!(schema.request_schema.contains_key("user.name"));
1043        assert!(schema.request_schema.contains_key("user.address"));
1044        assert!(schema.request_schema.contains_key("user.address.city"));
1045    }
1046
1047    #[test]
1048    fn test_validate_unexpected_field() {
1049        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1050            min_samples_for_validation: 5,
1051            ..Default::default()
1052        });
1053
1054        // Train with known fields
1055        for _ in 0..10 {
1056            learner.learn_from_request("/api/users", &json!({"name": "test"}));
1057        }
1058
1059        // Validate with unexpected field
1060        let result =
1061            learner.validate_request("/api/users", &json!({"name": "test", "malicious": "value"}));
1062
1063        assert!(!result.is_valid());
1064        assert!(result
1065            .violations
1066            .iter()
1067            .any(|v| v.violation_type == ViolationType::UnexpectedField));
1068    }
1069
1070    #[test]
1071    fn test_validate_type_mismatch() {
1072        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1073            min_samples_for_validation: 5,
1074            ..Default::default()
1075        });
1076
1077        // Train with number type
1078        for i in 0..10 {
1079            learner.learn_from_request("/api/users", &json!({"id": i}));
1080        }
1081
1082        // Validate with string type
1083        let result = learner.validate_request("/api/users", &json!({"id": "not_a_number"}));
1084
1085        assert!(!result.is_valid());
1086        assert!(result
1087            .violations
1088            .iter()
1089            .any(|v| v.violation_type == ViolationType::TypeMismatch));
1090    }
1091
1092    #[test]
1093    fn test_validate_string_too_long() {
1094        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1095            min_samples_for_validation: 5,
1096            string_length_tolerance: 2.0,
1097            ..Default::default()
1098        });
1099
1100        // Train with short strings
1101        for _ in 0..10 {
1102            learner.learn_from_request("/api/users", &json!({"name": "john"})); // 4 chars
1103        }
1104
1105        // Validate with very long string (> 4 * 2 = 8 chars)
1106        let long_name = "a".repeat(20);
1107        let result = learner.validate_request("/api/users", &json!({"name": long_name}));
1108
1109        assert!(!result.is_valid());
1110        assert!(result
1111            .violations
1112            .iter()
1113            .any(|v| v.violation_type == ViolationType::StringTooLong));
1114    }
1115
1116    #[test]
1117    fn test_validate_pattern_mismatch() {
1118        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1119            min_samples_for_validation: 5,
1120            ..Default::default()
1121        });
1122
1123        // Train with UUID pattern
1124        for _ in 0..10 {
1125            learner.learn_from_request(
1126                "/api/users",
1127                &json!({"id": "550e8400-e29b-41d4-a716-446655440000"}),
1128            );
1129        }
1130
1131        // Validate with non-UUID
1132        let result = learner.validate_request("/api/users", &json!({"id": "not-a-uuid-value"}));
1133
1134        assert!(!result.is_valid());
1135        assert!(result
1136            .violations
1137            .iter()
1138            .any(|v| v.violation_type == ViolationType::PatternMismatch));
1139    }
1140
1141    #[test]
1142    fn test_validate_insufficient_samples() {
1143        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1144            min_samples_for_validation: 10,
1145            ..Default::default()
1146        });
1147
1148        // Train with only 5 samples
1149        for _ in 0..5 {
1150            learner.learn_from_request("/api/users", &json!({"name": "test"}));
1151        }
1152
1153        // Validation should pass (no enforcement) because insufficient samples
1154        let result = learner.validate_request("/api/users", &json!({"malicious": "field"}));
1155        assert!(result.is_valid());
1156    }
1157
1158    #[test]
1159    fn test_lru_eviction() {
1160        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1161            max_schemas: 3,
1162            ..Default::default()
1163        });
1164
1165        // Add 4 schemas (exceeds max of 3)
1166        learner.learn_from_request("/api/users", &json!({"a": 1}));
1167        std::thread::sleep(std::time::Duration::from_millis(10));
1168        learner.learn_from_request("/api/orders", &json!({"b": 2}));
1169        std::thread::sleep(std::time::Duration::from_millis(10));
1170        learner.learn_from_request("/api/products", &json!({"c": 3}));
1171        std::thread::sleep(std::time::Duration::from_millis(10));
1172        learner.learn_from_request("/api/inventory", &json!({"d": 4}));
1173
1174        // Should have evicted oldest (users)
1175        assert_eq!(learner.len(), 3);
1176        assert!(learner.get_schema("/api/users").is_none());
1177        assert!(learner.get_schema("/api/orders").is_some());
1178    }
1179
1180    #[test]
1181    fn test_stats() {
1182        let learner = SchemaLearner::new();
1183
1184        for i in 0..10 {
1185            learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
1186        }
1187        for i in 0..5 {
1188            learner.learn_from_request("/api/orders", &json!({"order_id": i}));
1189        }
1190
1191        let stats = learner.get_stats();
1192        assert_eq!(stats.total_schemas, 2);
1193        assert_eq!(stats.total_samples, 15);
1194        assert!(stats.avg_fields_per_endpoint > 0.0);
1195    }
1196
1197    #[test]
1198    fn test_export_import() {
1199        let learner = SchemaLearner::new();
1200
1201        learner.learn_from_request("/api/users", &json!({"id": 1, "name": "test"}));
1202        learner.learn_from_request("/api/orders", &json!({"order_id": 100}));
1203
1204        let exported = learner.export();
1205        assert_eq!(exported.len(), 2);
1206
1207        // Import into new learner
1208        let learner2 = SchemaLearner::new();
1209        learner2.import(exported);
1210
1211        assert_eq!(learner2.len(), 2);
1212        assert!(learner2.get_schema("/api/users").is_some());
1213        assert!(learner2.get_schema("/api/orders").is_some());
1214    }
1215
1216    #[test]
1217    fn test_nullable_fields() {
1218        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1219            min_samples_for_validation: 5,
1220            ..Default::default()
1221        });
1222
1223        // Train with mix of null and non-null
1224        for i in 0..10 {
1225            let body = if i % 2 == 0 {
1226                json!({"name": "test"})
1227            } else {
1228                json!({"name": null})
1229            };
1230            learner.learn_from_request("/api/users", &body);
1231        }
1232
1233        let schema = learner.get_schema("/api/users").unwrap();
1234        let name_schema = schema.request_schema.get("name").unwrap();
1235        assert!(name_schema.nullable);
1236
1237        // Validate null value (should pass because field is nullable)
1238        let result = learner.validate_request("/api/users", &json!({"name": null}));
1239        // Type mismatch should not fire for nullable fields with null value
1240        assert!(!result
1241            .violations
1242            .iter()
1243            .any(|v| v.violation_type == ViolationType::TypeMismatch && v.field == "name"));
1244    }
1245
1246    #[test]
1247    fn test_array_item_types() {
1248        let learner = SchemaLearner::new();
1249
1250        let body = json!({
1251            "tags": ["tag1", "tag2"],
1252            "numbers": [1, 2, 3]
1253        });
1254
1255        learner.learn_from_request("/api/items", &body);
1256
1257        let schema = learner.get_schema("/api/items").unwrap();
1258        let tags_schema = schema.request_schema.get("tags").unwrap();
1259        let numbers_schema = schema.request_schema.get("numbers").unwrap();
1260
1261        assert!(tags_schema
1262            .array_item_types
1263            .as_ref()
1264            .unwrap()
1265            .contains(&FieldType::String));
1266        assert!(numbers_schema
1267            .array_item_types
1268            .as_ref()
1269            .unwrap()
1270            .contains(&FieldType::Number));
1271    }
1272
1273    #[test]
1274    fn test_validate_missing_required_field() {
1275        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1276            min_samples_for_validation: 5,
1277            required_field_threshold: 0.9,
1278            ..Default::default()
1279        });
1280
1281        // Train with consistent fields - name and id present in all samples
1282        for i in 0..10 {
1283            learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
1284        }
1285
1286        // Validate with missing required field "name"
1287        let result = learner.validate_request("/api/users", &json!({"id": 1}));
1288
1289        assert!(!result.is_valid());
1290        assert!(result
1291            .violations
1292            .iter()
1293            .any(|v| v.violation_type == ViolationType::MissingField && v.field == "name"));
1294    }
1295
1296    #[test]
1297    fn test_validate_number_constraints() {
1298        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1299            min_samples_for_validation: 5,
1300            number_value_tolerance: 2.0,
1301            ..Default::default()
1302        });
1303
1304        // Train with numbers in range 10-100
1305        for i in 0..10 {
1306            learner.learn_from_request("/api/items", &json!({"price": 10 + i * 10}));
1307        }
1308
1309        // Value too large (> 100 * 2 = 200)
1310        let result = learner.validate_request("/api/items", &json!({"price": 500}));
1311        assert!(!result.is_valid());
1312        assert!(result
1313            .violations
1314            .iter()
1315            .any(|v| v.violation_type == ViolationType::NumberTooLarge));
1316
1317        // Value too small (< 10 * 0.5 = 5)
1318        let result = learner.validate_request("/api/items", &json!({"price": 1}));
1319        assert!(!result.is_valid());
1320        assert!(result
1321            .violations
1322            .iter()
1323            .any(|v| v.violation_type == ViolationType::NumberTooSmall));
1324    }
1325
1326    #[test]
1327    fn test_validate_deeply_nested_json_does_not_stack_overflow() {
1328        let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1329            max_nesting_depth: 10,
1330            min_samples_for_validation: 0, // Always validate
1331            ..Default::default()
1332        });
1333
1334        // Build deeply nested JSON (depth 100)
1335        let mut body = json!({"leaf": true});
1336        for i in 0..100 {
1337            body = json!({ format!("nest_{}", i): body });
1338        }
1339
1340        // Training - should not crash
1341        learner.learn_from_request("/api/nested", &body);
1342
1343        // Validation - should not crash
1344        let result = learner.validate_request("/api/nested", &body);
1345
1346        // Deep parts should be ignored due to max_nesting_depth
1347        assert!(result.is_valid());
1348    }
1349
1350    #[test]
1351    fn test_learn_array_root_body_is_silently_skipped() {
1352        let learner = SchemaLearner::new();
1353        let body = json!([{"id": 1}, {"id": 2}]);
1354
1355        learner.learn_from_request("/api/arrays", &body);
1356
1357        // Should not have created a schema
1358        assert_eq!(learner.len(), 0);
1359    }
1360
1361    #[test]
1362    fn test_learn_from_response_does_not_increment_sample_count() {
1363        let learner = SchemaLearner::new();
1364
1365        // Response learning should not increment sample_count
1366        learner.learn_from_response("/api/test", &json!({"ok": true}));
1367
1368        let schema = learner.get_schema("/api/test").unwrap();
1369        assert_eq!(schema.sample_count, 0);
1370        assert!(schema.response_schema.contains_key("ok"));
1371
1372        // Request learning SHOULD increment it
1373        learner.learn_from_request("/api/test", &json!({"id": 1}));
1374        let schema = learner.get_schema("/api/test").unwrap();
1375        assert_eq!(schema.sample_count, 1);
1376    }
1377
1378    #[test]
1379    fn test_learn_from_pair_both_none() {
1380        let learner = SchemaLearner::new();
1381
1382        // Should increment sample count even if bodies are None
1383        learner.learn_from_pair("/api/empty", None, None);
1384
1385        let schema = learner.get_schema("/api/empty").unwrap();
1386        assert_eq!(schema.sample_count, 1);
1387        assert!(schema.request_schema.is_empty());
1388        assert!(schema.response_schema.is_empty());
1389    }
1390}