1use anyhow::{anyhow, Result};
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::collections::{HashMap, HashSet};
55
56use crate::{DomainInfo, PredicateInfo, SymbolTable, ValueRange};
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct InferenceConfig {
61 pub min_confidence: f64,
63 pub infer_hierarchies: bool,
65 pub infer_constraints: bool,
67 pub infer_dependencies: bool,
69 pub cardinality_multiplier: f64,
71 pub max_nesting_depth: usize,
73}
74
75impl Default for InferenceConfig {
76 fn default() -> Self {
77 Self {
78 min_confidence: 0.7,
79 infer_hierarchies: true,
80 infer_constraints: true,
81 infer_dependencies: true,
82 cardinality_multiplier: 2.0,
83 max_nesting_depth: 5,
84 }
85 }
86}
87
88#[derive(Clone, Debug, PartialEq)]
90pub struct ConfidenceScore {
91 pub score: f64,
92 pub evidence_count: usize,
93 pub reasoning: String,
94}
95
96impl ConfidenceScore {
97 pub fn new(score: f64, evidence_count: usize, reasoning: impl Into<String>) -> Self {
98 Self {
99 score: score.clamp(0.0, 1.0),
100 evidence_count,
101 reasoning: reasoning.into(),
102 }
103 }
104
105 pub fn is_confident(&self, threshold: f64) -> bool {
106 self.score >= threshold
107 }
108}
109
110#[derive(Clone, Debug)]
112pub struct DataSample {
113 records: Vec<HashMap<String, Value>>,
114}
115
116impl DataSample {
117 pub fn from_json(json: &str) -> Result<Self> {
119 let value: Value = serde_json::from_str(json)?;
120
121 let records = match value {
122 Value::Array(arr) => arr
123 .into_iter()
124 .filter_map(|v| {
125 if let Value::Object(map) = v {
126 Some(map.into_iter().collect::<HashMap<_, _>>())
127 } else {
128 None
129 }
130 })
131 .collect(),
132 Value::Object(map) => {
133 vec![map.into_iter().collect()]
134 }
135 _ => return Err(anyhow!("Expected JSON array or object")),
136 };
137
138 Ok(Self { records })
139 }
140
141 pub fn from_csv(csv: &str) -> Result<Self> {
143 let mut lines = csv.lines();
144 let headers: Vec<String> = lines
145 .next()
146 .ok_or_else(|| anyhow!("Empty CSV"))?
147 .split(',')
148 .map(|s| s.trim().to_string())
149 .collect();
150
151 let records = lines
152 .map(|line| {
153 let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
154 headers
155 .iter()
156 .zip(values.iter())
157 .map(|(k, v)| {
158 let json_val = if let Ok(num) = v.parse::<f64>() {
159 Value::Number(serde_json::Number::from_f64(num).unwrap())
160 } else if *v == "true" || *v == "false" {
161 Value::Bool(*v == "true")
162 } else {
163 Value::String(v.to_string())
164 };
165 (k.clone(), json_val)
166 })
167 .collect()
168 })
169 .collect();
170
171 Ok(Self { records })
172 }
173
174 pub fn field_names(&self) -> HashSet<String> {
176 self.records
177 .iter()
178 .flat_map(|record| record.keys().cloned())
179 .collect()
180 }
181
182 pub fn field_values(&self, field: &str) -> Vec<&Value> {
184 self.records
185 .iter()
186 .filter_map(|record| record.get(field))
187 .collect()
188 }
189
190 pub fn len(&self) -> usize {
192 self.records.len()
193 }
194
195 pub fn is_empty(&self) -> bool {
197 self.records.is_empty()
198 }
199}
200
201#[derive(Clone, Debug, Default)]
203pub struct LearningStatistics {
204 pub domains_inferred: usize,
205 pub predicates_inferred: usize,
206 pub constraints_inferred: usize,
207 pub hierarchies_inferred: usize,
208 pub dependencies_inferred: usize,
209 pub total_samples_analyzed: usize,
210 pub inference_time_ms: u128,
211}
212
213pub struct SchemaLearner {
215 config: InferenceConfig,
216 statistics: LearningStatistics,
217 confidence_scores: HashMap<String, ConfidenceScore>,
218}
219
220impl SchemaLearner {
221 pub fn new(config: InferenceConfig) -> Self {
223 Self {
224 config,
225 statistics: LearningStatistics::default(),
226 confidence_scores: HashMap::new(),
227 }
228 }
229
230 pub fn learn_from_sample(&mut self, sample: &DataSample) -> Result<SymbolTable> {
232 let start = std::time::Instant::now();
233
234 let mut table = SymbolTable::new();
235
236 self.infer_domains(sample, &mut table)?;
238
239 self.infer_predicates(sample, &mut table)?;
241
242 if self.config.infer_constraints {
244 self.infer_constraints(sample, &mut table)?;
245 }
246
247 if self.config.infer_hierarchies {
249 self.infer_hierarchies(sample, &mut table)?;
250 }
251
252 self.statistics.total_samples_analyzed = sample.len();
253 self.statistics.inference_time_ms = start.elapsed().as_millis();
254
255 Ok(table)
256 }
257
258 fn infer_domains(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
260 let mut domain_types: HashMap<String, HashSet<String>> = HashMap::new();
261
262 for field in sample.field_names() {
264 let values = sample.field_values(&field);
265 let types: HashSet<String> = values.iter().map(|v| self.infer_type(v)).collect();
266 domain_types.insert(field.clone(), types);
267 }
268
269 let mut inferred_types: HashSet<String> = HashSet::new();
271 for types in domain_types.values() {
272 inferred_types.extend(types.clone());
273 }
274
275 for type_name in inferred_types {
276 let cardinality = self.estimate_cardinality(sample, &type_name);
277 let domain = DomainInfo::new(&type_name, cardinality);
278
279 if table.add_domain(domain).is_ok() {
280 self.statistics.domains_inferred += 1;
281 self.confidence_scores.insert(
282 format!("domain:{}", type_name),
283 ConfidenceScore::new(
284 0.9,
285 sample.len(),
286 format!("Inferred from {} samples", sample.len()),
287 ),
288 );
289 }
290 }
291
292 Ok(())
293 }
294
295 fn infer_predicates(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
297 let fields: Vec<String> = sample.field_names().into_iter().collect();
298
299 for field in &fields {
301 let values = sample.field_values(field);
302 if values.is_empty() {
303 continue;
304 }
305
306 let type_name = self.infer_type(values[0]);
307 let predicate = PredicateInfo::new(field, vec![type_name.clone()]);
308
309 if table.add_predicate(predicate).is_ok() {
310 self.statistics.predicates_inferred += 1;
311 self.confidence_scores.insert(
312 format!("predicate:{}", field),
313 ConfidenceScore::new(
314 0.85,
315 values.len(),
316 format!("Inferred from {} values", values.len()),
317 ),
318 );
319 }
320 }
321
322 for i in 0..fields.len() {
324 for j in (i + 1)..fields.len() {
325 let field1 = &fields[i];
326 let field2 = &fields[j];
327
328 if self.has_relationship(sample, field1, field2) {
329 let type1 = self.infer_type(sample.field_values(field1)[0]);
330 let type2 = self.infer_type(sample.field_values(field2)[0]);
331
332 let rel_name = format!("{}_{}", field1, field2);
333 let predicate = PredicateInfo::new(&rel_name, vec![type1, type2]);
334
335 if table.add_predicate(predicate).is_ok() {
336 self.statistics.predicates_inferred += 1;
337 }
338 }
339 }
340 }
341
342 Ok(())
343 }
344
345 fn infer_constraints(&mut self, sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
347 for field in sample.field_names() {
348 let values = sample.field_values(&field);
349
350 if let Some(range) = self.infer_value_range(&values) {
352 self.statistics.constraints_inferred += 1;
353 self.confidence_scores.insert(
354 format!("constraint:{}:range", field),
355 ConfidenceScore::new(
356 0.8,
357 values.len(),
358 "Inferred from numeric values".to_string(),
359 ),
360 );
361 let _ = range; }
364 }
365
366 Ok(())
367 }
368
369 fn infer_hierarchies(&mut self, _sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
371 Ok(())
374 }
375
376 fn infer_type(&self, value: &Value) -> String {
378 match value {
379 Value::Number(_) => "Number".to_string(),
380 Value::String(_) => "String".to_string(),
381 Value::Bool(_) => "Boolean".to_string(),
382 Value::Array(_) => "Array".to_string(),
383 Value::Object(_) => "Object".to_string(),
384 Value::Null => "Unknown".to_string(),
385 }
386 }
387
388 fn estimate_cardinality(&self, sample: &DataSample, type_name: &str) -> usize {
390 let mut unique_values = HashSet::new();
391
392 for record in &sample.records {
393 for value in record.values() {
394 if self.infer_type(value) == type_name {
395 unique_values.insert(format!("{:?}", value));
396 }
397 }
398 }
399
400 ((unique_values.len() as f64) * self.config.cardinality_multiplier).ceil() as usize
401 }
402
403 fn has_relationship(&self, sample: &DataSample, field1: &str, field2: &str) -> bool {
405 let values1 = sample.field_values(field1);
406 let values2 = sample.field_values(field2);
407
408 let co_occurrence = values1.len().min(values2.len());
410 let threshold = (sample.len() as f64 * 0.8).ceil() as usize;
411
412 co_occurrence >= threshold
413 }
414
415 fn infer_value_range(&self, values: &[&Value]) -> Option<ValueRange> {
417 let numbers: Vec<f64> = values.iter().filter_map(|v| v.as_f64()).collect();
418
419 if numbers.is_empty() {
420 return None;
421 }
422
423 let min = numbers.iter().copied().fold(f64::INFINITY, f64::min);
424 let max = numbers.iter().copied().fold(f64::NEG_INFINITY, f64::max);
425
426 Some(ValueRange::new().with_min(min, true).with_max(max, true))
427 }
428
429 pub fn statistics(&self) -> &LearningStatistics {
431 &self.statistics
432 }
433
434 pub fn confidence(&self, element: &str) -> Option<&ConfidenceScore> {
436 self.confidence_scores.get(element)
437 }
438
439 pub fn all_confidences(&self) -> &HashMap<String, ConfidenceScore> {
441 &self.confidence_scores
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_data_sample_from_json() {
451 let json = r#"[
452 {"id": 1, "name": "Alice"},
453 {"id": 2, "name": "Bob"}
454 ]"#;
455
456 let sample = DataSample::from_json(json).unwrap();
457 assert_eq!(sample.len(), 2);
458 assert_eq!(sample.field_names().len(), 2);
459 }
460
461 #[test]
462 fn test_data_sample_from_csv() {
463 let csv = "id,name,age\n1,Alice,30\n2,Bob,25";
464
465 let sample = DataSample::from_csv(csv).unwrap();
466 assert_eq!(sample.len(), 2);
467 assert_eq!(sample.field_names().len(), 3);
468 }
469
470 #[test]
471 fn test_schema_learner_basic() {
472 let json = r#"[
473 {"id": 1, "name": "Alice", "age": 30},
474 {"id": 2, "name": "Bob", "age": 25}
475 ]"#;
476
477 let sample = DataSample::from_json(json).unwrap();
478 let config = InferenceConfig::default();
479 let mut learner = SchemaLearner::new(config);
480
481 let _schema = learner.learn_from_sample(&sample).unwrap();
482 let stats = learner.statistics();
483
484 assert!(stats.domains_inferred > 0);
485 assert!(stats.predicates_inferred > 0);
486 assert_eq!(stats.total_samples_analyzed, 2);
487 }
488
489 #[test]
490 fn test_type_inference() {
491 let config = InferenceConfig::default();
492 let learner = SchemaLearner::new(config);
493
494 assert_eq!(learner.infer_type(&Value::Number(42.into())), "Number");
495 assert_eq!(learner.infer_type(&Value::String("test".into())), "String");
496 assert_eq!(learner.infer_type(&Value::Bool(true)), "Boolean");
497 }
498
499 #[test]
500 fn test_value_range_inference() {
501 let val1 = Value::Number(10.into());
502 let val2 = Value::Number(20.into());
503 let val3 = Value::Number(30.into());
504 let values = vec![&val1, &val2, &val3];
505
506 let config = InferenceConfig::default();
507 let learner = SchemaLearner::new(config);
508 let range = learner.infer_value_range(&values).unwrap();
509
510 assert_eq!(range.min, Some(10.0));
511 assert_eq!(range.max, Some(30.0));
512 }
513
514 #[test]
515 fn test_confidence_score() {
516 let score = ConfidenceScore::new(0.85, 100, "High confidence");
517 assert_eq!(score.score, 0.85);
518 assert_eq!(score.evidence_count, 100);
519 assert!(score.is_confident(0.7));
520 assert!(!score.is_confident(0.9));
521 }
522
523 #[test]
524 fn test_inference_config_default() {
525 let config = InferenceConfig::default();
526 assert_eq!(config.min_confidence, 0.7);
527 assert!(config.infer_hierarchies);
528 assert!(config.infer_constraints);
529 }
530
531 #[test]
532 fn test_cardinality_estimation() {
533 let json = r#"[
534 {"id": 1, "type": "A"},
535 {"id": 2, "type": "B"},
536 {"id": 3, "type": "A"}
537 ]"#;
538
539 let sample = DataSample::from_json(json).unwrap();
540 let config = InferenceConfig::default();
541 let learner = SchemaLearner::new(config);
542
543 let cardinality = learner.estimate_cardinality(&sample, "Number");
544 assert!(cardinality > 0);
545 }
546
547 #[test]
548 fn test_field_values_extraction() {
549 let json = r#"[
550 {"name": "Alice", "age": 30},
551 {"name": "Bob", "age": 25}
552 ]"#;
553
554 let sample = DataSample::from_json(json).unwrap();
555 let names = sample.field_values("name");
556
557 assert_eq!(names.len(), 2);
558 }
559
560 #[test]
561 fn test_relationship_detection() {
562 let json = r#"[
563 {"person": "Alice", "city": "NYC"},
564 {"person": "Bob", "city": "LA"}
565 ]"#;
566
567 let sample = DataSample::from_json(json).unwrap();
568 let config = InferenceConfig::default();
569 let learner = SchemaLearner::new(config);
570
571 assert!(learner.has_relationship(&sample, "person", "city"));
572 }
573
574 #[test]
575 fn test_empty_sample() {
576 let json = "[]";
577 let sample = DataSample::from_json(json).unwrap();
578 assert!(sample.is_empty());
579 assert_eq!(sample.len(), 0);
580 }
581
582 #[test]
583 fn test_single_object_json() {
584 let json = r#"{"id": 1, "name": "Alice"}"#;
585 let sample = DataSample::from_json(json).unwrap();
586 assert_eq!(sample.len(), 1);
587 }
588
589 #[test]
590 fn test_csv_type_detection() {
591 let csv = "id,name,active\n1,Alice,true\n2,Bob,false";
592 let sample = DataSample::from_csv(csv).unwrap();
593
594 let active_values = sample.field_values("active");
595 assert!(active_values.iter().all(|v| v.is_boolean()));
596 }
597
598 #[test]
599 fn test_confidence_scores_tracking() {
600 let json = r#"[{"id": 1, "name": "Alice"}]"#;
601 let sample = DataSample::from_json(json).unwrap();
602 let config = InferenceConfig::default();
603 let mut learner = SchemaLearner::new(config);
604
605 learner.learn_from_sample(&sample).unwrap();
606 assert!(!learner.all_confidences().is_empty());
607 }
608
609 #[test]
610 fn test_learning_statistics() {
611 let json = r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#;
612 let sample = DataSample::from_json(json).unwrap();
613 let config = InferenceConfig::default();
614 let mut learner = SchemaLearner::new(config);
615
616 learner.learn_from_sample(&sample).unwrap();
617 let stats = learner.statistics();
618
619 assert_eq!(stats.total_samples_analyzed, 3);
620 assert!(stats.domains_inferred > 0 || stats.predicates_inferred > 0);
622 }
623}