1use std::collections::{HashMap, VecDeque};
22use std::time::{SystemTime, UNIX_EPOCH};
23
24use dashmap::DashMap;
25use parking_lot::Mutex;
26use serde::{Deserialize, Serialize};
27
28use crate::profiler::patterns::detect_pattern;
29use crate::profiler::schema_types::{
30 EndpointSchema, FieldSchema, FieldType, SchemaViolation, ValidationResult,
31};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct SchemaLearnerConfig {
76 pub max_schemas: usize,
80
81 pub min_samples_for_validation: u32,
86
87 pub max_nesting_depth: usize,
91
92 pub max_fields_per_schema: usize,
97
98 pub string_length_tolerance: f64,
117
118 pub number_value_tolerance: f64,
137
138 pub required_field_threshold: f64,
145}
146
147#[derive(Debug, Clone, PartialEq)]
149pub struct ConfigValidationError {
150 pub field: &'static str,
152 pub message: String,
154}
155
156impl std::fmt::Display for ConfigValidationError {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "Invalid {}: {}", self.field, self.message)
159 }
160}
161
162impl std::error::Error for ConfigValidationError {}
163
164impl SchemaLearnerConfig {
165 pub fn validate(&self) -> Result<(), ConfigValidationError> {
187 if self.string_length_tolerance < 1.0 {
188 return Err(ConfigValidationError {
189 field: "string_length_tolerance",
190 message: format!(
191 "must be >= 1.0 to avoid rejecting baseline data (got {})",
192 self.string_length_tolerance
193 ),
194 });
195 }
196
197 if self.number_value_tolerance < 1.0 {
198 return Err(ConfigValidationError {
199 field: "number_value_tolerance",
200 message: format!(
201 "must be >= 1.0 to avoid rejecting baseline data (got {})",
202 self.number_value_tolerance
203 ),
204 });
205 }
206
207 if !(0.0..=1.0).contains(&self.required_field_threshold) {
208 return Err(ConfigValidationError {
209 field: "required_field_threshold",
210 message: format!(
211 "must be between 0.0 and 1.0 (got {})",
212 self.required_field_threshold
213 ),
214 });
215 }
216
217 Ok(())
218 }
219}
220
221impl Default for SchemaLearnerConfig {
222 fn default() -> Self {
223 Self {
224 max_schemas: 5000,
225 min_samples_for_validation: 10,
226 max_nesting_depth: 10,
227 max_fields_per_schema: 100,
228 string_length_tolerance: 1.5,
229 number_value_tolerance: 1.5,
230 required_field_threshold: 0.9,
231 }
232 }
233}
234
235#[derive(Debug, Clone)]
241struct LruEntry {
242 key: String,
244 generation: u64,
246}
247
248struct LruTracker {
264 queue: VecDeque<LruEntry>,
266 generations: HashMap<String, u64>,
268 next_generation: u64,
270}
271
272impl LruTracker {
273 fn new(capacity: usize) -> Self {
275 Self {
276 queue: VecDeque::with_capacity(capacity),
277 generations: HashMap::with_capacity(capacity),
278 next_generation: 0,
279 }
280 }
281
282 fn touch(&mut self, key: &str) -> bool {
285 let generation = self.next_generation;
286 self.next_generation = self.next_generation.wrapping_add(1);
287
288 let is_new = !self.generations.contains_key(key);
289 self.generations.insert(key.to_string(), generation);
290 self.queue.push_back(LruEntry {
291 key: key.to_string(),
292 generation,
293 });
294
295 is_new
296 }
297
298 #[allow(dead_code)]
300 fn remove(&mut self, key: &str) {
301 self.generations.remove(key);
302 }
304
305 fn evict_oldest(&mut self) -> Option<String> {
308 while let Some(entry) = self.queue.pop_front() {
309 if let Some(¤t_gen) = self.generations.get(&entry.key) {
311 if current_gen == entry.generation {
312 self.generations.remove(&entry.key);
314 return Some(entry.key);
315 }
316 }
317 }
319 None
320 }
321
322 #[allow(dead_code)]
324 fn len(&self) -> usize {
325 self.generations.len()
326 }
327
328 fn clear(&mut self) {
330 self.queue.clear();
331 self.generations.clear();
332 self.next_generation = 0;
333 }
334}
335
336pub struct SchemaLearner {
345 schemas: DashMap<String, EndpointSchema>,
347
348 lru: Mutex<LruTracker>,
350
351 config: SchemaLearnerConfig,
353}
354
355impl Default for SchemaLearner {
356 fn default() -> Self {
357 Self::new()
358 }
359}
360
361impl SchemaLearner {
362 pub fn new() -> Self {
364 Self::with_config(SchemaLearnerConfig::default())
365 }
366
367 pub fn with_config(config: SchemaLearnerConfig) -> Self {
369 Self {
370 schemas: DashMap::with_capacity(config.max_schemas),
371 lru: Mutex::new(LruTracker::new(config.max_schemas)),
372 config,
373 }
374 }
375
376 pub fn config(&self) -> &SchemaLearnerConfig {
378 &self.config
379 }
380
381 pub fn len(&self) -> usize {
383 self.schemas.len()
384 }
385
386 pub fn is_empty(&self) -> bool {
388 self.schemas.is_empty()
389 }
390
391 fn now_ms() -> u64 {
393 SystemTime::now()
394 .duration_since(UNIX_EPOCH)
395 .map(|d| d.as_millis() as u64)
396 .unwrap_or(0)
397 }
398
399 pub fn learn_from_request(&self, template: &str, request_body: &serde_json::Value) {
414 self.learn_internal(template, request_body, SchemaTarget::Request);
415 }
416
417 pub fn learn_from_response(&self, template: &str, response_body: &serde_json::Value) {
425 self.learn_internal(template, response_body, SchemaTarget::Response);
426 }
427
428 pub fn learn_from_pair(
436 &self,
437 template: &str,
438 request_body: Option<&serde_json::Value>,
439 response_body: Option<&serde_json::Value>,
440 ) {
441 let now = Self::now_ms();
442
443 self.ensure_schema(template, now);
445
446 if let Some(req) = request_body {
447 if req.is_object() {
448 self.update_schema_fields(template, req, SchemaTarget::Request, "", 0);
449 }
450 }
451
452 if let Some(resp) = response_body {
453 if resp.is_object() {
454 self.update_schema_fields(template, resp, SchemaTarget::Response, "", 0);
455 }
456 }
457
458 if let Some(mut schema) = self.schemas.get_mut(template) {
460 schema.sample_count += 1;
461 schema.last_updated_ms = now;
462 }
463 }
464
465 fn learn_internal(&self, template: &str, body: &serde_json::Value, target: SchemaTarget) {
467 if !body.is_object() {
468 return;
469 }
470
471 let now = Self::now_ms();
472 self.ensure_schema(template, now);
473 self.update_schema_fields(template, body, target, "", 0);
474
475 if matches!(target, SchemaTarget::Request) {
477 if let Some(mut schema) = self.schemas.get_mut(template) {
478 schema.sample_count += 1;
479 schema.last_updated_ms = now;
480 }
481 }
482 }
483
484 fn ensure_schema(&self, template: &str, now: u64) {
486 if self.schemas.contains_key(template) {
488 let mut lru = self.lru.lock();
490 lru.touch(template);
491 return;
492 }
493
494 let mut lru = self.lru.lock();
496
497 if self.schemas.contains_key(template) {
499 lru.touch(template);
500 return;
501 }
502
503 if self.schemas.len() >= self.config.max_schemas {
505 if let Some(evict_key) = lru.evict_oldest() {
506 self.schemas.remove(&evict_key);
507 }
508 }
509
510 lru.touch(template);
512 self.schemas.insert(
513 template.to_string(),
514 EndpointSchema::new(template.to_string(), now),
515 );
516 }
517
518 fn update_schema_fields(
521 &self,
522 template: &str,
523 value: &serde_json::Value,
524 target: SchemaTarget,
525 prefix: &str,
526 depth: usize,
527 ) {
528 if depth >= self.config.max_nesting_depth {
530 return;
531 }
532
533 let obj = match value.as_object() {
534 Some(o) => o,
535 None => return,
536 };
537
538 let mut nested_objects: Vec<(String, &serde_json::Value)> = Vec::new();
540
541 {
542 let mut schema_guard = match self.schemas.get_mut(template) {
543 Some(s) => s,
544 None => return,
545 };
546
547 let schema_map = match target {
548 SchemaTarget::Request => &mut schema_guard.request_schema,
549 SchemaTarget::Response => &mut schema_guard.response_schema,
550 };
551
552 for (key, val) in obj {
553 if schema_map.len() >= self.config.max_fields_per_schema {
555 break;
556 }
557
558 let field_name = if prefix.is_empty() {
559 key.clone()
560 } else {
561 format!("{}.{}", prefix, key)
562 };
563
564 let field_type = FieldType::from_json_value(val);
565
566 let field_schema = schema_map
568 .entry(field_name.clone())
569 .or_insert_with(|| FieldSchema::new(field_name.clone()));
570
571 field_schema.record_type(field_type);
573
574 match val {
576 serde_json::Value::String(s) => {
577 let pattern = detect_pattern(s);
578 field_schema.update_string_constraints(s.len() as u32, pattern);
579 }
580 serde_json::Value::Number(n) => {
581 if let Some(f) = n.as_f64() {
582 field_schema.update_number_constraints(f);
583 }
584 }
585 serde_json::Value::Array(arr) => {
586 for item in arr {
587 let item_type = FieldType::from_json_value(item);
588 field_schema.add_array_item_type(item_type);
589 }
590 }
591 serde_json::Value::Object(_) => {
592 if field_schema.object_schema.is_none() {
594 field_schema.object_schema = Some(HashMap::new());
595 }
596 nested_objects.push((field_name, val));
598 }
599 _ => {}
600 }
601 }
602 }
604
605 for (field_name, val) in nested_objects {
607 self.update_schema_fields(template, val, target, &field_name, depth + 1);
608 }
609 }
610
611 pub fn validate_request(
620 &self,
621 template: &str,
622 request_body: &serde_json::Value,
623 ) -> ValidationResult {
624 self.validate_internal(template, request_body, SchemaTarget::Request)
625 }
626
627 pub fn validate_response(
629 &self,
630 template: &str,
631 response_body: &serde_json::Value,
632 ) -> ValidationResult {
633 self.validate_internal(template, response_body, SchemaTarget::Response)
634 }
635
636 fn validate_internal(
638 &self,
639 template: &str,
640 body: &serde_json::Value,
641 target: SchemaTarget,
642 ) -> ValidationResult {
643 let mut result = ValidationResult::new();
644
645 let schema = match self.schemas.get(template) {
646 Some(s) => s,
647 None => return result, };
649
650 if schema.sample_count < self.config.min_samples_for_validation {
652 return result;
653 }
654
655 let schema_map = match target {
656 SchemaTarget::Request => &schema.request_schema,
657 SchemaTarget::Response => &schema.response_schema,
658 };
659
660 self.validate_against_schema(
661 schema_map,
662 body,
663 "",
664 &mut result,
665 schema.sample_count,
666 0, );
668
669 result
670 }
671
672 fn validate_against_schema(
674 &self,
675 root_schema_map: &HashMap<String, FieldSchema>,
676 data: &serde_json::Value,
677 prefix: &str,
678 result: &mut ValidationResult,
679 sample_count: u32,
680 depth: usize,
681 ) {
682 if depth >= self.config.max_nesting_depth {
684 return;
685 }
686
687 let obj = match data.as_object() {
688 Some(o) => o,
689 None => return,
690 };
691
692 for (key, val) in obj {
694 let field_name = if prefix.is_empty() {
695 key.clone()
696 } else {
697 format!("{}.{}", prefix, key)
698 };
699
700 let field_schema = match root_schema_map.get(&field_name) {
701 Some(s) => s,
702 None => {
703 result.add(SchemaViolation::unexpected_field(&field_name));
704 continue;
705 }
706 };
707
708 let actual_type = FieldType::from_json_value(val);
709
710 let dominant_type = field_schema.dominant_type();
712 if actual_type != dominant_type && !(val.is_null() && field_schema.nullable) {
713 result.add(SchemaViolation::type_mismatch(
714 &field_name,
715 dominant_type,
716 actual_type,
717 ));
718 }
719
720 if let serde_json::Value::String(s) = val {
722 self.validate_string_field(&field_name, s, field_schema, result);
723 }
724
725 if let serde_json::Value::Number(n) = val {
727 if let Some(f) = n.as_f64() {
728 self.validate_number_field(&field_name, f, field_schema, result);
729 }
730 }
731
732 if val.is_object() {
734 self.validate_against_schema(
735 root_schema_map,
736 val,
737 &field_name,
738 result,
739 sample_count,
740 depth + 1,
741 );
742 }
743 }
744
745 let threshold = (sample_count as f64 * self.config.required_field_threshold) as u32;
747 for (field_name, field_schema) in root_schema_map {
748 let is_direct_child = if prefix.is_empty() {
750 !field_name.contains('.')
751 } else if field_name.starts_with(prefix) && field_name.len() > prefix.len() + 1 {
752 let suffix = &field_name[prefix.len() + 1..];
753 !suffix.contains('.')
754 } else {
755 false
756 };
757
758 if is_direct_child && field_schema.seen_count >= threshold {
759 let key = field_name.rsplit('.').next().unwrap_or(field_name);
760 if !obj.contains_key(key) {
761 result.add(SchemaViolation::missing_field(field_name));
762 }
763 }
764 }
765 }
766
767 fn validate_string_field(
769 &self,
770 field_name: &str,
771 value: &str,
772 schema: &FieldSchema,
773 result: &mut ValidationResult,
774 ) {
775 let len = value.len() as u32;
776
777 if let Some(min) = schema.min_length {
779 if len < min {
780 result.add(SchemaViolation::string_too_short(field_name, min, len));
781 }
782 }
783
784 if let Some(max) = schema.max_length {
786 let allowed_max = (max as f64 * self.config.string_length_tolerance) as u32;
787 if len > allowed_max {
788 result.add(SchemaViolation::string_too_long(
789 field_name,
790 allowed_max,
791 len,
792 ));
793 }
794 }
795
796 if let Some(expected_pattern) = schema.pattern {
798 let actual_pattern = detect_pattern(value);
799 if actual_pattern != Some(expected_pattern) {
800 result.add(SchemaViolation::pattern_mismatch(
801 field_name,
802 expected_pattern,
803 actual_pattern,
804 ));
805 }
806 }
807 }
808
809 fn validate_number_field(
811 &self,
812 field_name: &str,
813 value: f64,
814 schema: &FieldSchema,
815 result: &mut ValidationResult,
816 ) {
817 if let Some(min) = schema.min_value {
819 let allowed_min = min * (1.0 / self.config.number_value_tolerance);
820 if value < allowed_min {
821 result.add(SchemaViolation::number_too_small(
822 field_name,
823 allowed_min,
824 value,
825 ));
826 }
827 }
828
829 if let Some(max) = schema.max_value {
831 let allowed_max = max * self.config.number_value_tolerance;
832 if value > allowed_max {
833 result.add(SchemaViolation::number_too_large(
834 field_name,
835 allowed_max,
836 value,
837 ));
838 }
839 }
840 }
841
842 pub fn get_schema(&self, template: &str) -> Option<EndpointSchema> {
848 self.schemas.get(template).map(|s| s.value().clone())
849 }
850
851 pub fn get_all_schemas(&self) -> Vec<EndpointSchema> {
853 self.schemas
854 .iter()
855 .map(|entry| entry.value().clone())
856 .collect()
857 }
858
859 pub fn get_stats(&self) -> SchemaLearnerStats {
861 let schemas: Vec<_> = self.schemas.iter().collect();
862 let total_samples: u32 = schemas.iter().map(|s| s.sample_count).sum();
863 let total_fields: usize = schemas
864 .iter()
865 .map(|s| s.request_schema.len() + s.response_schema.len())
866 .sum();
867
868 SchemaLearnerStats {
869 total_schemas: schemas.len(),
870 total_samples,
871 avg_fields_per_endpoint: if schemas.is_empty() {
872 0.0
873 } else {
874 total_fields as f64 / schemas.len() as f64
875 },
876 }
877 }
878
879 pub fn export(&self) -> Vec<EndpointSchema> {
885 self.get_all_schemas()
886 }
887
888 pub fn import(&self, schemas: Vec<EndpointSchema>) {
890 self.schemas.clear();
892 let mut lru = self.lru.lock();
893 lru.clear();
894
895 let mut sorted_schemas = schemas;
897 sorted_schemas.sort_by_key(|s| s.last_updated_ms);
898
899 for schema in sorted_schemas {
900 lru.touch(&schema.template);
901 self.schemas.insert(schema.template.clone(), schema);
902 }
903 }
904
905 pub fn clear(&self) {
907 self.schemas.clear();
908 self.lru.lock().clear();
909 }
910}
911
912#[derive(Debug, Clone, Copy)]
918enum SchemaTarget {
919 Request,
920 Response,
921}
922
923#[derive(Debug, Clone, Serialize)]
925pub struct SchemaLearnerStats {
926 pub total_schemas: usize,
928 pub total_samples: u32,
930 pub avg_fields_per_endpoint: f64,
932}
933
934#[cfg(test)]
939mod tests {
940 use super::*;
941 use crate::profiler::schema_types::{PatternType, ViolationType};
942 use serde_json::json;
943
944 #[test]
945 fn test_learn_from_request() {
946 let learner = SchemaLearner::new();
947
948 let body = json!({
949 "username": "john_doe",
950 "email": "john@example.com",
951 "age": 30
952 });
953
954 learner.learn_from_request("/api/users", &body);
955
956 let schema = learner.get_schema("/api/users").unwrap();
957 assert_eq!(schema.sample_count, 1);
958 assert!(schema.request_schema.contains_key("username"));
959 assert!(schema.request_schema.contains_key("email"));
960 assert!(schema.request_schema.contains_key("age"));
961 }
962
963 #[test]
964 fn test_learn_type_tracking() {
965 let learner = SchemaLearner::new();
966
967 for i in 0..10 {
969 let body = json!({
970 "id": i,
971 "name": format!("user_{}", i)
972 });
973 learner.learn_from_request("/api/users", &body);
974 }
975
976 let schema = learner.get_schema("/api/users").unwrap();
977 let id_schema = schema.request_schema.get("id").unwrap();
978 let name_schema = schema.request_schema.get("name").unwrap();
979
980 assert_eq!(id_schema.dominant_type(), FieldType::Number);
981 assert_eq!(name_schema.dominant_type(), FieldType::String);
982 assert_eq!(id_schema.seen_count, 10);
983 }
984
985 #[test]
986 fn test_learn_string_constraints() {
987 let learner = SchemaLearner::new();
988
989 let bodies = vec![
990 json!({"name": "ab"}), json!({"name": "abcdef"}), json!({"name": "abcd"}), ];
994
995 for body in bodies {
996 learner.learn_from_request("/api/test", &body);
997 }
998
999 let schema = learner.get_schema("/api/test").unwrap();
1000 let name_schema = schema.request_schema.get("name").unwrap();
1001
1002 assert_eq!(name_schema.min_length, Some(2));
1003 assert_eq!(name_schema.max_length, Some(6));
1004 }
1005
1006 #[test]
1007 fn test_learn_pattern_detection() {
1008 let learner = SchemaLearner::new();
1009
1010 let body = json!({
1011 "id": "550e8400-e29b-41d4-a716-446655440000",
1012 "email": "user@example.com"
1013 });
1014
1015 learner.learn_from_request("/api/users", &body);
1016
1017 let schema = learner.get_schema("/api/users").unwrap();
1018 let id_schema = schema.request_schema.get("id").unwrap();
1019 let email_schema = schema.request_schema.get("email").unwrap();
1020
1021 assert_eq!(id_schema.pattern, Some(PatternType::Uuid));
1022 assert_eq!(email_schema.pattern, Some(PatternType::Email));
1023 }
1024
1025 #[test]
1026 fn test_learn_nested_objects() {
1027 let learner = SchemaLearner::new();
1028
1029 let body = json!({
1030 "user": {
1031 "name": "John",
1032 "address": {
1033 "city": "NYC"
1034 }
1035 }
1036 });
1037
1038 learner.learn_from_request("/api/data", &body);
1039
1040 let schema = learner.get_schema("/api/data").unwrap();
1041 assert!(schema.request_schema.contains_key("user"));
1042 assert!(schema.request_schema.contains_key("user.name"));
1043 assert!(schema.request_schema.contains_key("user.address"));
1044 assert!(schema.request_schema.contains_key("user.address.city"));
1045 }
1046
1047 #[test]
1048 fn test_validate_unexpected_field() {
1049 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1050 min_samples_for_validation: 5,
1051 ..Default::default()
1052 });
1053
1054 for _ in 0..10 {
1056 learner.learn_from_request("/api/users", &json!({"name": "test"}));
1057 }
1058
1059 let result =
1061 learner.validate_request("/api/users", &json!({"name": "test", "malicious": "value"}));
1062
1063 assert!(!result.is_valid());
1064 assert!(result
1065 .violations
1066 .iter()
1067 .any(|v| v.violation_type == ViolationType::UnexpectedField));
1068 }
1069
1070 #[test]
1071 fn test_validate_type_mismatch() {
1072 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1073 min_samples_for_validation: 5,
1074 ..Default::default()
1075 });
1076
1077 for i in 0..10 {
1079 learner.learn_from_request("/api/users", &json!({"id": i}));
1080 }
1081
1082 let result = learner.validate_request("/api/users", &json!({"id": "not_a_number"}));
1084
1085 assert!(!result.is_valid());
1086 assert!(result
1087 .violations
1088 .iter()
1089 .any(|v| v.violation_type == ViolationType::TypeMismatch));
1090 }
1091
1092 #[test]
1093 fn test_validate_string_too_long() {
1094 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1095 min_samples_for_validation: 5,
1096 string_length_tolerance: 2.0,
1097 ..Default::default()
1098 });
1099
1100 for _ in 0..10 {
1102 learner.learn_from_request("/api/users", &json!({"name": "john"})); }
1104
1105 let long_name = "a".repeat(20);
1107 let result = learner.validate_request("/api/users", &json!({"name": long_name}));
1108
1109 assert!(!result.is_valid());
1110 assert!(result
1111 .violations
1112 .iter()
1113 .any(|v| v.violation_type == ViolationType::StringTooLong));
1114 }
1115
1116 #[test]
1117 fn test_validate_pattern_mismatch() {
1118 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1119 min_samples_for_validation: 5,
1120 ..Default::default()
1121 });
1122
1123 for _ in 0..10 {
1125 learner.learn_from_request(
1126 "/api/users",
1127 &json!({"id": "550e8400-e29b-41d4-a716-446655440000"}),
1128 );
1129 }
1130
1131 let result = learner.validate_request("/api/users", &json!({"id": "not-a-uuid-value"}));
1133
1134 assert!(!result.is_valid());
1135 assert!(result
1136 .violations
1137 .iter()
1138 .any(|v| v.violation_type == ViolationType::PatternMismatch));
1139 }
1140
1141 #[test]
1142 fn test_validate_insufficient_samples() {
1143 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1144 min_samples_for_validation: 10,
1145 ..Default::default()
1146 });
1147
1148 for _ in 0..5 {
1150 learner.learn_from_request("/api/users", &json!({"name": "test"}));
1151 }
1152
1153 let result = learner.validate_request("/api/users", &json!({"malicious": "field"}));
1155 assert!(result.is_valid());
1156 }
1157
1158 #[test]
1159 fn test_lru_eviction() {
1160 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1161 max_schemas: 3,
1162 ..Default::default()
1163 });
1164
1165 learner.learn_from_request("/api/users", &json!({"a": 1}));
1167 std::thread::sleep(std::time::Duration::from_millis(10));
1168 learner.learn_from_request("/api/orders", &json!({"b": 2}));
1169 std::thread::sleep(std::time::Duration::from_millis(10));
1170 learner.learn_from_request("/api/products", &json!({"c": 3}));
1171 std::thread::sleep(std::time::Duration::from_millis(10));
1172 learner.learn_from_request("/api/inventory", &json!({"d": 4}));
1173
1174 assert_eq!(learner.len(), 3);
1176 assert!(learner.get_schema("/api/users").is_none());
1177 assert!(learner.get_schema("/api/orders").is_some());
1178 }
1179
1180 #[test]
1181 fn test_stats() {
1182 let learner = SchemaLearner::new();
1183
1184 for i in 0..10 {
1185 learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
1186 }
1187 for i in 0..5 {
1188 learner.learn_from_request("/api/orders", &json!({"order_id": i}));
1189 }
1190
1191 let stats = learner.get_stats();
1192 assert_eq!(stats.total_schemas, 2);
1193 assert_eq!(stats.total_samples, 15);
1194 assert!(stats.avg_fields_per_endpoint > 0.0);
1195 }
1196
1197 #[test]
1198 fn test_export_import() {
1199 let learner = SchemaLearner::new();
1200
1201 learner.learn_from_request("/api/users", &json!({"id": 1, "name": "test"}));
1202 learner.learn_from_request("/api/orders", &json!({"order_id": 100}));
1203
1204 let exported = learner.export();
1205 assert_eq!(exported.len(), 2);
1206
1207 let learner2 = SchemaLearner::new();
1209 learner2.import(exported);
1210
1211 assert_eq!(learner2.len(), 2);
1212 assert!(learner2.get_schema("/api/users").is_some());
1213 assert!(learner2.get_schema("/api/orders").is_some());
1214 }
1215
1216 #[test]
1217 fn test_nullable_fields() {
1218 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1219 min_samples_for_validation: 5,
1220 ..Default::default()
1221 });
1222
1223 for i in 0..10 {
1225 let body = if i % 2 == 0 {
1226 json!({"name": "test"})
1227 } else {
1228 json!({"name": null})
1229 };
1230 learner.learn_from_request("/api/users", &body);
1231 }
1232
1233 let schema = learner.get_schema("/api/users").unwrap();
1234 let name_schema = schema.request_schema.get("name").unwrap();
1235 assert!(name_schema.nullable);
1236
1237 let result = learner.validate_request("/api/users", &json!({"name": null}));
1239 assert!(!result
1241 .violations
1242 .iter()
1243 .any(|v| v.violation_type == ViolationType::TypeMismatch && v.field == "name"));
1244 }
1245
1246 #[test]
1247 fn test_array_item_types() {
1248 let learner = SchemaLearner::new();
1249
1250 let body = json!({
1251 "tags": ["tag1", "tag2"],
1252 "numbers": [1, 2, 3]
1253 });
1254
1255 learner.learn_from_request("/api/items", &body);
1256
1257 let schema = learner.get_schema("/api/items").unwrap();
1258 let tags_schema = schema.request_schema.get("tags").unwrap();
1259 let numbers_schema = schema.request_schema.get("numbers").unwrap();
1260
1261 assert!(tags_schema
1262 .array_item_types
1263 .as_ref()
1264 .unwrap()
1265 .contains(&FieldType::String));
1266 assert!(numbers_schema
1267 .array_item_types
1268 .as_ref()
1269 .unwrap()
1270 .contains(&FieldType::Number));
1271 }
1272
1273 #[test]
1274 fn test_validate_missing_required_field() {
1275 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1276 min_samples_for_validation: 5,
1277 required_field_threshold: 0.9,
1278 ..Default::default()
1279 });
1280
1281 for i in 0..10 {
1283 learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
1284 }
1285
1286 let result = learner.validate_request("/api/users", &json!({"id": 1}));
1288
1289 assert!(!result.is_valid());
1290 assert!(result
1291 .violations
1292 .iter()
1293 .any(|v| v.violation_type == ViolationType::MissingField && v.field == "name"));
1294 }
1295
1296 #[test]
1297 fn test_validate_number_constraints() {
1298 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1299 min_samples_for_validation: 5,
1300 number_value_tolerance: 2.0,
1301 ..Default::default()
1302 });
1303
1304 for i in 0..10 {
1306 learner.learn_from_request("/api/items", &json!({"price": 10 + i * 10}));
1307 }
1308
1309 let result = learner.validate_request("/api/items", &json!({"price": 500}));
1311 assert!(!result.is_valid());
1312 assert!(result
1313 .violations
1314 .iter()
1315 .any(|v| v.violation_type == ViolationType::NumberTooLarge));
1316
1317 let result = learner.validate_request("/api/items", &json!({"price": 1}));
1319 assert!(!result.is_valid());
1320 assert!(result
1321 .violations
1322 .iter()
1323 .any(|v| v.violation_type == ViolationType::NumberTooSmall));
1324 }
1325
1326 #[test]
1327 fn test_validate_deeply_nested_json_does_not_stack_overflow() {
1328 let learner = SchemaLearner::with_config(SchemaLearnerConfig {
1329 max_nesting_depth: 10,
1330 min_samples_for_validation: 0, ..Default::default()
1332 });
1333
1334 let mut body = json!({"leaf": true});
1336 for i in 0..100 {
1337 body = json!({ format!("nest_{}", i): body });
1338 }
1339
1340 learner.learn_from_request("/api/nested", &body);
1342
1343 let result = learner.validate_request("/api/nested", &body);
1345
1346 assert!(result.is_valid());
1348 }
1349
1350 #[test]
1351 fn test_learn_array_root_body_is_silently_skipped() {
1352 let learner = SchemaLearner::new();
1353 let body = json!([{"id": 1}, {"id": 2}]);
1354
1355 learner.learn_from_request("/api/arrays", &body);
1356
1357 assert_eq!(learner.len(), 0);
1359 }
1360
1361 #[test]
1362 fn test_learn_from_response_does_not_increment_sample_count() {
1363 let learner = SchemaLearner::new();
1364
1365 learner.learn_from_response("/api/test", &json!({"ok": true}));
1367
1368 let schema = learner.get_schema("/api/test").unwrap();
1369 assert_eq!(schema.sample_count, 0);
1370 assert!(schema.response_schema.contains_key("ok"));
1371
1372 learner.learn_from_request("/api/test", &json!({"id": 1}));
1374 let schema = learner.get_schema("/api/test").unwrap();
1375 assert_eq!(schema.sample_count, 1);
1376 }
1377
1378 #[test]
1379 fn test_learn_from_pair_both_none() {
1380 let learner = SchemaLearner::new();
1381
1382 learner.learn_from_pair("/api/empty", None, None);
1384
1385 let schema = learner.get_schema("/api/empty").unwrap();
1386 assert_eq!(schema.sample_count, 1);
1387 assert!(schema.request_schema.is_empty());
1388 assert!(schema.response_schema.is_empty());
1389 }
1390}