1use anyhow::{anyhow, Result};
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::collections::{HashMap, HashSet};
55
56use crate::{DomainInfo, PredicateInfo, SymbolTable, ValueRange};
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct InferenceConfig {
61 pub min_confidence: f64,
63 pub infer_hierarchies: bool,
65 pub infer_constraints: bool,
67 pub infer_dependencies: bool,
69 pub cardinality_multiplier: f64,
71 pub max_nesting_depth: usize,
73}
74
75impl Default for InferenceConfig {
76 fn default() -> Self {
77 Self {
78 min_confidence: 0.7,
79 infer_hierarchies: true,
80 infer_constraints: true,
81 infer_dependencies: true,
82 cardinality_multiplier: 2.0,
83 max_nesting_depth: 5,
84 }
85 }
86}
87
88#[derive(Clone, Debug, PartialEq)]
90pub struct ConfidenceScore {
91 pub score: f64,
92 pub evidence_count: usize,
93 pub reasoning: String,
94}
95
96impl ConfidenceScore {
97 pub fn new(score: f64, evidence_count: usize, reasoning: impl Into<String>) -> Self {
98 Self {
99 score: score.clamp(0.0, 1.0),
100 evidence_count,
101 reasoning: reasoning.into(),
102 }
103 }
104
105 pub fn is_confident(&self, threshold: f64) -> bool {
106 self.score >= threshold
107 }
108}
109
110#[derive(Clone, Debug)]
112pub struct DataSample {
113 records: Vec<HashMap<String, Value>>,
114}
115
116impl DataSample {
117 pub fn from_json(json: &str) -> Result<Self> {
119 let value: Value = serde_json::from_str(json)?;
120
121 let records = match value {
122 Value::Array(arr) => arr
123 .into_iter()
124 .filter_map(|v| {
125 if let Value::Object(map) = v {
126 Some(map.into_iter().collect::<HashMap<_, _>>())
127 } else {
128 None
129 }
130 })
131 .collect(),
132 Value::Object(map) => {
133 vec![map.into_iter().collect()]
134 }
135 _ => return Err(anyhow!("Expected JSON array or object")),
136 };
137
138 Ok(Self { records })
139 }
140
141 pub fn from_csv(csv: &str) -> Result<Self> {
143 let mut lines = csv.lines();
144 let headers: Vec<String> = lines
145 .next()
146 .ok_or_else(|| anyhow!("Empty CSV"))?
147 .split(',')
148 .map(|s| s.trim().to_string())
149 .collect();
150
151 let records = lines
152 .map(|line| {
153 let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
154 headers
155 .iter()
156 .zip(values.iter())
157 .map(|(k, v)| {
158 let json_val = if let Ok(num) = v.parse::<f64>() {
159 Value::Number(
160 serde_json::Number::from_f64(num)
161 .unwrap_or_else(|| serde_json::Number::from(0i64)),
162 )
163 } else if *v == "true" || *v == "false" {
164 Value::Bool(*v == "true")
165 } else {
166 Value::String(v.to_string())
167 };
168 (k.clone(), json_val)
169 })
170 .collect()
171 })
172 .collect();
173
174 Ok(Self { records })
175 }
176
177 pub fn field_names(&self) -> HashSet<String> {
179 self.records
180 .iter()
181 .flat_map(|record| record.keys().cloned())
182 .collect()
183 }
184
185 pub fn field_values(&self, field: &str) -> Vec<&Value> {
187 self.records
188 .iter()
189 .filter_map(|record| record.get(field))
190 .collect()
191 }
192
193 pub fn len(&self) -> usize {
195 self.records.len()
196 }
197
198 pub fn is_empty(&self) -> bool {
200 self.records.is_empty()
201 }
202}
203
204#[derive(Clone, Debug, Default)]
206pub struct LearningStatistics {
207 pub domains_inferred: usize,
208 pub predicates_inferred: usize,
209 pub constraints_inferred: usize,
210 pub hierarchies_inferred: usize,
211 pub dependencies_inferred: usize,
212 pub total_samples_analyzed: usize,
213 pub inference_time_ms: u128,
214}
215
216pub struct SchemaLearner {
218 config: InferenceConfig,
219 statistics: LearningStatistics,
220 confidence_scores: HashMap<String, ConfidenceScore>,
221}
222
223impl SchemaLearner {
224 pub fn new(config: InferenceConfig) -> Self {
226 Self {
227 config,
228 statistics: LearningStatistics::default(),
229 confidence_scores: HashMap::new(),
230 }
231 }
232
233 pub fn learn_from_sample(&mut self, sample: &DataSample) -> Result<SymbolTable> {
235 let start = std::time::Instant::now();
236
237 let mut table = SymbolTable::new();
238
239 self.infer_domains(sample, &mut table)?;
241
242 self.infer_predicates(sample, &mut table)?;
244
245 if self.config.infer_constraints {
247 self.infer_constraints(sample, &mut table)?;
248 }
249
250 if self.config.infer_hierarchies {
252 self.infer_hierarchies(sample, &mut table)?;
253 }
254
255 self.statistics.total_samples_analyzed = sample.len();
256 self.statistics.inference_time_ms = start.elapsed().as_millis();
257
258 Ok(table)
259 }
260
261 fn infer_domains(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
263 let mut domain_types: HashMap<String, HashSet<String>> = HashMap::new();
264
265 for field in sample.field_names() {
267 let values = sample.field_values(&field);
268 let types: HashSet<String> = values.iter().map(|v| self.infer_type(v)).collect();
269 domain_types.insert(field.clone(), types);
270 }
271
272 let mut inferred_types: HashSet<String> = HashSet::new();
274 for types in domain_types.values() {
275 inferred_types.extend(types.clone());
276 }
277
278 for type_name in inferred_types {
279 let cardinality = self.estimate_cardinality(sample, &type_name);
280 let domain = DomainInfo::new(&type_name, cardinality);
281
282 if table.add_domain(domain).is_ok() {
283 self.statistics.domains_inferred += 1;
284 self.confidence_scores.insert(
285 format!("domain:{}", type_name),
286 ConfidenceScore::new(
287 0.9,
288 sample.len(),
289 format!("Inferred from {} samples", sample.len()),
290 ),
291 );
292 }
293 }
294
295 Ok(())
296 }
297
298 fn infer_predicates(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
300 let fields: Vec<String> = sample.field_names().into_iter().collect();
301
302 for field in &fields {
304 let values = sample.field_values(field);
305 if values.is_empty() {
306 continue;
307 }
308
309 let type_name = self.infer_type(values[0]);
310 let predicate = PredicateInfo::new(field, vec![type_name.clone()]);
311
312 if table.add_predicate(predicate).is_ok() {
313 self.statistics.predicates_inferred += 1;
314 self.confidence_scores.insert(
315 format!("predicate:{}", field),
316 ConfidenceScore::new(
317 0.85,
318 values.len(),
319 format!("Inferred from {} values", values.len()),
320 ),
321 );
322 }
323 }
324
325 for i in 0..fields.len() {
327 for j in (i + 1)..fields.len() {
328 let field1 = &fields[i];
329 let field2 = &fields[j];
330
331 if self.has_relationship(sample, field1, field2) {
332 let type1 = self.infer_type(sample.field_values(field1)[0]);
333 let type2 = self.infer_type(sample.field_values(field2)[0]);
334
335 let rel_name = format!("{}_{}", field1, field2);
336 let predicate = PredicateInfo::new(&rel_name, vec![type1, type2]);
337
338 if table.add_predicate(predicate).is_ok() {
339 self.statistics.predicates_inferred += 1;
340 }
341 }
342 }
343 }
344
345 Ok(())
346 }
347
348 fn infer_constraints(&mut self, sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
350 for field in sample.field_names() {
351 let values = sample.field_values(&field);
352
353 if let Some(range) = self.infer_value_range(&values) {
355 self.statistics.constraints_inferred += 1;
356 self.confidence_scores.insert(
357 format!("constraint:{}:range", field),
358 ConfidenceScore::new(
359 0.8,
360 values.len(),
361 "Inferred from numeric values".to_string(),
362 ),
363 );
364 let _ = range; }
367 }
368
369 Ok(())
370 }
371
372 fn infer_hierarchies(&mut self, _sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
374 Ok(())
377 }
378
379 fn infer_type(&self, value: &Value) -> String {
381 match value {
382 Value::Number(_) => "Number".to_string(),
383 Value::String(_) => "String".to_string(),
384 Value::Bool(_) => "Boolean".to_string(),
385 Value::Array(_) => "Array".to_string(),
386 Value::Object(_) => "Object".to_string(),
387 Value::Null => "Unknown".to_string(),
388 }
389 }
390
391 fn estimate_cardinality(&self, sample: &DataSample, type_name: &str) -> usize {
393 let mut unique_values = HashSet::new();
394
395 for record in &sample.records {
396 for value in record.values() {
397 if self.infer_type(value) == type_name {
398 unique_values.insert(format!("{:?}", value));
399 }
400 }
401 }
402
403 ((unique_values.len() as f64) * self.config.cardinality_multiplier).ceil() as usize
404 }
405
406 fn has_relationship(&self, sample: &DataSample, field1: &str, field2: &str) -> bool {
408 let values1 = sample.field_values(field1);
409 let values2 = sample.field_values(field2);
410
411 let co_occurrence = values1.len().min(values2.len());
413 let threshold = (sample.len() as f64 * 0.8).ceil() as usize;
414
415 co_occurrence >= threshold
416 }
417
418 fn infer_value_range(&self, values: &[&Value]) -> Option<ValueRange> {
420 let numbers: Vec<f64> = values.iter().filter_map(|v| v.as_f64()).collect();
421
422 if numbers.is_empty() {
423 return None;
424 }
425
426 let min = numbers.iter().copied().fold(f64::INFINITY, f64::min);
427 let max = numbers.iter().copied().fold(f64::NEG_INFINITY, f64::max);
428
429 Some(ValueRange::new().with_min(min, true).with_max(max, true))
430 }
431
432 pub fn statistics(&self) -> &LearningStatistics {
434 &self.statistics
435 }
436
437 pub fn confidence(&self, element: &str) -> Option<&ConfidenceScore> {
439 self.confidence_scores.get(element)
440 }
441
442 pub fn all_confidences(&self) -> &HashMap<String, ConfidenceScore> {
444 &self.confidence_scores
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_data_sample_from_json() {
454 let json = r#"[
455 {"id": 1, "name": "Alice"},
456 {"id": 2, "name": "Bob"}
457 ]"#;
458
459 let sample = DataSample::from_json(json).expect("unwrap");
460 assert_eq!(sample.len(), 2);
461 assert_eq!(sample.field_names().len(), 2);
462 }
463
464 #[test]
465 fn test_data_sample_from_csv() {
466 let csv = "id,name,age\n1,Alice,30\n2,Bob,25";
467
468 let sample = DataSample::from_csv(csv).expect("unwrap");
469 assert_eq!(sample.len(), 2);
470 assert_eq!(sample.field_names().len(), 3);
471 }
472
473 #[test]
474 fn test_schema_learner_basic() {
475 let json = r#"[
476 {"id": 1, "name": "Alice", "age": 30},
477 {"id": 2, "name": "Bob", "age": 25}
478 ]"#;
479
480 let sample = DataSample::from_json(json).expect("unwrap");
481 let config = InferenceConfig::default();
482 let mut learner = SchemaLearner::new(config);
483
484 let _schema = learner.learn_from_sample(&sample).expect("unwrap");
485 let stats = learner.statistics();
486
487 assert!(stats.domains_inferred > 0);
488 assert!(stats.predicates_inferred > 0);
489 assert_eq!(stats.total_samples_analyzed, 2);
490 }
491
492 #[test]
493 fn test_type_inference() {
494 let config = InferenceConfig::default();
495 let learner = SchemaLearner::new(config);
496
497 assert_eq!(learner.infer_type(&Value::Number(42.into())), "Number");
498 assert_eq!(learner.infer_type(&Value::String("test".into())), "String");
499 assert_eq!(learner.infer_type(&Value::Bool(true)), "Boolean");
500 }
501
502 #[test]
503 fn test_value_range_inference() {
504 let val1 = Value::Number(10.into());
505 let val2 = Value::Number(20.into());
506 let val3 = Value::Number(30.into());
507 let values = vec![&val1, &val2, &val3];
508
509 let config = InferenceConfig::default();
510 let learner = SchemaLearner::new(config);
511 let range = learner.infer_value_range(&values).expect("unwrap");
512
513 assert_eq!(range.min, Some(10.0));
514 assert_eq!(range.max, Some(30.0));
515 }
516
517 #[test]
518 fn test_confidence_score() {
519 let score = ConfidenceScore::new(0.85, 100, "High confidence");
520 assert_eq!(score.score, 0.85);
521 assert_eq!(score.evidence_count, 100);
522 assert!(score.is_confident(0.7));
523 assert!(!score.is_confident(0.9));
524 }
525
526 #[test]
527 fn test_inference_config_default() {
528 let config = InferenceConfig::default();
529 assert_eq!(config.min_confidence, 0.7);
530 assert!(config.infer_hierarchies);
531 assert!(config.infer_constraints);
532 }
533
534 #[test]
535 fn test_cardinality_estimation() {
536 let json = r#"[
537 {"id": 1, "type": "A"},
538 {"id": 2, "type": "B"},
539 {"id": 3, "type": "A"}
540 ]"#;
541
542 let sample = DataSample::from_json(json).expect("unwrap");
543 let config = InferenceConfig::default();
544 let learner = SchemaLearner::new(config);
545
546 let cardinality = learner.estimate_cardinality(&sample, "Number");
547 assert!(cardinality > 0);
548 }
549
550 #[test]
551 fn test_field_values_extraction() {
552 let json = r#"[
553 {"name": "Alice", "age": 30},
554 {"name": "Bob", "age": 25}
555 ]"#;
556
557 let sample = DataSample::from_json(json).expect("unwrap");
558 let names = sample.field_values("name");
559
560 assert_eq!(names.len(), 2);
561 }
562
563 #[test]
564 fn test_relationship_detection() {
565 let json = r#"[
566 {"person": "Alice", "city": "NYC"},
567 {"person": "Bob", "city": "LA"}
568 ]"#;
569
570 let sample = DataSample::from_json(json).expect("unwrap");
571 let config = InferenceConfig::default();
572 let learner = SchemaLearner::new(config);
573
574 assert!(learner.has_relationship(&sample, "person", "city"));
575 }
576
577 #[test]
578 fn test_empty_sample() {
579 let json = "[]";
580 let sample = DataSample::from_json(json).expect("unwrap");
581 assert!(sample.is_empty());
582 assert_eq!(sample.len(), 0);
583 }
584
585 #[test]
586 fn test_single_object_json() {
587 let json = r#"{"id": 1, "name": "Alice"}"#;
588 let sample = DataSample::from_json(json).expect("unwrap");
589 assert_eq!(sample.len(), 1);
590 }
591
592 #[test]
593 fn test_csv_type_detection() {
594 let csv = "id,name,active\n1,Alice,true\n2,Bob,false";
595 let sample = DataSample::from_csv(csv).expect("unwrap");
596
597 let active_values = sample.field_values("active");
598 assert!(active_values.iter().all(|v| v.is_boolean()));
599 }
600
601 #[test]
602 fn test_confidence_scores_tracking() {
603 let json = r#"[{"id": 1, "name": "Alice"}]"#;
604 let sample = DataSample::from_json(json).expect("unwrap");
605 let config = InferenceConfig::default();
606 let mut learner = SchemaLearner::new(config);
607
608 learner.learn_from_sample(&sample).expect("unwrap");
609 assert!(!learner.all_confidences().is_empty());
610 }
611
612 #[test]
613 fn test_learning_statistics() {
614 let json = r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#;
615 let sample = DataSample::from_json(json).expect("unwrap");
616 let config = InferenceConfig::default();
617 let mut learner = SchemaLearner::new(config);
618
619 learner.learn_from_sample(&sample).expect("unwrap");
620 let stats = learner.statistics();
621
622 assert_eq!(stats.total_samples_analyzed, 3);
623 assert!(stats.domains_inferred > 0 || stats.predicates_inferred > 0);
625 }
626}