Skip to main content

tensorlogic_adapters/
learning.rs

1//! Schema learning from data.
2//!
3//! This module provides automatic schema inference from sample data,
4//! enabling rapid prototyping and schema bootstrapping from existing datasets.
5//!
6//! # Overview
7//!
8//! Instead of manually defining schemas, you can learn them from:
9//! - JSON objects and arrays
10//! - CSV files with headers
11//! - Relational data patterns
12//! - Example predicates and relationships
13//!
14//! The learner analyzes data to infer:
15//! - Domain types and cardinalities
16//! - Predicate signatures and properties
17//! - Type hierarchies
18//! - Value ranges and constraints
19//! - Functional dependencies
20//!
21//! # Architecture
22//!
23//! - **SchemaLearner**: Main inference engine
24//! - **DataSample**: Represents sample data for analysis
25//! - **InferenceConfig**: Configuration for learning behavior
26//! - **LearningStatistics**: Statistics about the learning process
27//! - **ConfidenceScore**: Confidence in inferred schema elements
28//!
29//! # Example
30//!
31//! ```rust
32//! use tensorlogic_adapters::{SchemaLearner, DataSample, InferenceConfig};
33//!
34//! let json_data = r#"[
35//!     {"id": 1, "name": "Alice", "age": 30, "city": "NYC"},
36//!     {"id": 2, "name": "Bob", "age": 25, "city": "LA"},
37//!     {"id": 3, "name": "Charlie", "age": 35, "city": "NYC"}
38//! ]"#;
39//!
40//! let sample = DataSample::from_json(json_data).expect("unwrap");
41//! let config = InferenceConfig::default();
42//! let mut learner = SchemaLearner::new(config);
43//!
44//! let schema = learner.learn_from_sample(&sample).expect("unwrap");
45//! let stats = learner.statistics();
46//!
47//! assert!(stats.domains_inferred > 0);
48//! assert!(stats.predicates_inferred > 0);
49//! ```
50
51use anyhow::{anyhow, Result};
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::collections::{HashMap, HashSet};
55
56use crate::{DomainInfo, PredicateInfo, SymbolTable, ValueRange};
57
58/// Configuration for schema inference.
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct InferenceConfig {
61    /// Minimum confidence threshold for inferred elements (0.0 to 1.0)
62    pub min_confidence: f64,
63    /// Whether to infer domain hierarchies
64    pub infer_hierarchies: bool,
65    /// Whether to infer constraints
66    pub infer_constraints: bool,
67    /// Whether to infer functional dependencies
68    pub infer_dependencies: bool,
69    /// Cardinality multiplier for domain size estimation
70    pub cardinality_multiplier: f64,
71    /// Maximum depth for nested object analysis
72    pub max_nesting_depth: usize,
73}
74
75impl Default for InferenceConfig {
76    fn default() -> Self {
77        Self {
78            min_confidence: 0.7,
79            infer_hierarchies: true,
80            infer_constraints: true,
81            infer_dependencies: true,
82            cardinality_multiplier: 2.0,
83            max_nesting_depth: 5,
84        }
85    }
86}
87
88/// Confidence score for inferred schema elements.
89#[derive(Clone, Debug, PartialEq)]
90pub struct ConfidenceScore {
91    pub score: f64,
92    pub evidence_count: usize,
93    pub reasoning: String,
94}
95
96impl ConfidenceScore {
97    pub fn new(score: f64, evidence_count: usize, reasoning: impl Into<String>) -> Self {
98        Self {
99            score: score.clamp(0.0, 1.0),
100            evidence_count,
101            reasoning: reasoning.into(),
102        }
103    }
104
105    pub fn is_confident(&self, threshold: f64) -> bool {
106        self.score >= threshold
107    }
108}
109
110/// Sample data for schema learning.
111#[derive(Clone, Debug)]
112pub struct DataSample {
113    records: Vec<HashMap<String, Value>>,
114}
115
116impl DataSample {
117    /// Create a data sample from JSON array.
118    pub fn from_json(json: &str) -> Result<Self> {
119        let value: Value = serde_json::from_str(json)?;
120
121        let records = match value {
122            Value::Array(arr) => arr
123                .into_iter()
124                .filter_map(|v| {
125                    if let Value::Object(map) = v {
126                        Some(map.into_iter().collect::<HashMap<_, _>>())
127                    } else {
128                        None
129                    }
130                })
131                .collect(),
132            Value::Object(map) => {
133                vec![map.into_iter().collect()]
134            }
135            _ => return Err(anyhow!("Expected JSON array or object")),
136        };
137
138        Ok(Self { records })
139    }
140
141    /// Create a data sample from CSV data.
142    pub fn from_csv(csv: &str) -> Result<Self> {
143        let mut lines = csv.lines();
144        let headers: Vec<String> = lines
145            .next()
146            .ok_or_else(|| anyhow!("Empty CSV"))?
147            .split(',')
148            .map(|s| s.trim().to_string())
149            .collect();
150
151        let records = lines
152            .map(|line| {
153                let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
154                headers
155                    .iter()
156                    .zip(values.iter())
157                    .map(|(k, v)| {
158                        let json_val = if let Ok(num) = v.parse::<f64>() {
159                            Value::Number(
160                                serde_json::Number::from_f64(num)
161                                    .unwrap_or_else(|| serde_json::Number::from(0i64)),
162                            )
163                        } else if *v == "true" || *v == "false" {
164                            Value::Bool(*v == "true")
165                        } else {
166                            Value::String(v.to_string())
167                        };
168                        (k.clone(), json_val)
169                    })
170                    .collect()
171            })
172            .collect();
173
174        Ok(Self { records })
175    }
176
177    /// Get all unique field names across records.
178    pub fn field_names(&self) -> HashSet<String> {
179        self.records
180            .iter()
181            .flat_map(|record| record.keys().cloned())
182            .collect()
183    }
184
185    /// Get values for a specific field.
186    pub fn field_values(&self, field: &str) -> Vec<&Value> {
187        self.records
188            .iter()
189            .filter_map(|record| record.get(field))
190            .collect()
191    }
192
193    /// Get number of records.
194    pub fn len(&self) -> usize {
195        self.records.len()
196    }
197
198    /// Check if sample is empty.
199    pub fn is_empty(&self) -> bool {
200        self.records.is_empty()
201    }
202}
203
204/// Statistics about the learning process.
205#[derive(Clone, Debug, Default)]
206pub struct LearningStatistics {
207    pub domains_inferred: usize,
208    pub predicates_inferred: usize,
209    pub constraints_inferred: usize,
210    pub hierarchies_inferred: usize,
211    pub dependencies_inferred: usize,
212    pub total_samples_analyzed: usize,
213    pub inference_time_ms: u128,
214}
215
216/// Schema learner for automatic inference from data.
217pub struct SchemaLearner {
218    config: InferenceConfig,
219    statistics: LearningStatistics,
220    confidence_scores: HashMap<String, ConfidenceScore>,
221}
222
223impl SchemaLearner {
224    /// Create a new schema learner with configuration.
225    pub fn new(config: InferenceConfig) -> Self {
226        Self {
227            config,
228            statistics: LearningStatistics::default(),
229            confidence_scores: HashMap::new(),
230        }
231    }
232
233    /// Learn a complete schema from a data sample.
234    pub fn learn_from_sample(&mut self, sample: &DataSample) -> Result<SymbolTable> {
235        let start = std::time::Instant::now();
236
237        let mut table = SymbolTable::new();
238
239        // Infer domains from data types
240        self.infer_domains(sample, &mut table)?;
241
242        // Infer predicates from fields
243        self.infer_predicates(sample, &mut table)?;
244
245        // Infer constraints if enabled
246        if self.config.infer_constraints {
247            self.infer_constraints(sample, &mut table)?;
248        }
249
250        // Infer hierarchies if enabled
251        if self.config.infer_hierarchies {
252            self.infer_hierarchies(sample, &mut table)?;
253        }
254
255        self.statistics.total_samples_analyzed = sample.len();
256        self.statistics.inference_time_ms = start.elapsed().as_millis();
257
258        Ok(table)
259    }
260
261    /// Infer domains from data types in the sample.
262    fn infer_domains(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
263        let mut domain_types: HashMap<String, HashSet<String>> = HashMap::new();
264
265        // Analyze each field's type distribution
266        for field in sample.field_names() {
267            let values = sample.field_values(&field);
268            let types: HashSet<String> = values.iter().map(|v| self.infer_type(v)).collect();
269            domain_types.insert(field.clone(), types);
270        }
271
272        // Create domains for inferred types
273        let mut inferred_types: HashSet<String> = HashSet::new();
274        for types in domain_types.values() {
275            inferred_types.extend(types.clone());
276        }
277
278        for type_name in inferred_types {
279            let cardinality = self.estimate_cardinality(sample, &type_name);
280            let domain = DomainInfo::new(&type_name, cardinality);
281
282            if table.add_domain(domain).is_ok() {
283                self.statistics.domains_inferred += 1;
284                self.confidence_scores.insert(
285                    format!("domain:{}", type_name),
286                    ConfidenceScore::new(
287                        0.9,
288                        sample.len(),
289                        format!("Inferred from {} samples", sample.len()),
290                    ),
291                );
292            }
293        }
294
295        Ok(())
296    }
297
298    /// Infer predicates from field relationships.
299    fn infer_predicates(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
300        let fields: Vec<String> = sample.field_names().into_iter().collect();
301
302        // Create unary predicates for each field
303        for field in &fields {
304            let values = sample.field_values(field);
305            if values.is_empty() {
306                continue;
307            }
308
309            let type_name = self.infer_type(values[0]);
310            let predicate = PredicateInfo::new(field, vec![type_name.clone()]);
311
312            if table.add_predicate(predicate).is_ok() {
313                self.statistics.predicates_inferred += 1;
314                self.confidence_scores.insert(
315                    format!("predicate:{}", field),
316                    ConfidenceScore::new(
317                        0.85,
318                        values.len(),
319                        format!("Inferred from {} values", values.len()),
320                    ),
321                );
322            }
323        }
324
325        // Infer binary predicates from field co-occurrence
326        for i in 0..fields.len() {
327            for j in (i + 1)..fields.len() {
328                let field1 = &fields[i];
329                let field2 = &fields[j];
330
331                if self.has_relationship(sample, field1, field2) {
332                    let type1 = self.infer_type(sample.field_values(field1)[0]);
333                    let type2 = self.infer_type(sample.field_values(field2)[0]);
334
335                    let rel_name = format!("{}_{}", field1, field2);
336                    let predicate = PredicateInfo::new(&rel_name, vec![type1, type2]);
337
338                    if table.add_predicate(predicate).is_ok() {
339                        self.statistics.predicates_inferred += 1;
340                    }
341                }
342            }
343        }
344
345        Ok(())
346    }
347
348    /// Infer constraints from data patterns.
349    fn infer_constraints(&mut self, sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
350        for field in sample.field_names() {
351            let values = sample.field_values(&field);
352
353            // Infer value ranges for numeric fields
354            if let Some(range) = self.infer_value_range(&values) {
355                self.statistics.constraints_inferred += 1;
356                self.confidence_scores.insert(
357                    format!("constraint:{}:range", field),
358                    ConfidenceScore::new(
359                        0.8,
360                        values.len(),
361                        "Inferred from numeric values".to_string(),
362                    ),
363                );
364                // Note: Constraints would be attached to predicates in a full implementation
365                let _ = range; // Suppress unused warning
366            }
367        }
368
369        Ok(())
370    }
371
372    /// Infer domain hierarchies from data.
373    fn infer_hierarchies(&mut self, _sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
374        // Placeholder for hierarchy inference
375        // Would analyze naming patterns, value containment, etc.
376        Ok(())
377    }
378
379    /// Infer the JSON value type.
380    fn infer_type(&self, value: &Value) -> String {
381        match value {
382            Value::Number(_) => "Number".to_string(),
383            Value::String(_) => "String".to_string(),
384            Value::Bool(_) => "Boolean".to_string(),
385            Value::Array(_) => "Array".to_string(),
386            Value::Object(_) => "Object".to_string(),
387            Value::Null => "Unknown".to_string(),
388        }
389    }
390
391    /// Estimate domain cardinality from sample.
392    fn estimate_cardinality(&self, sample: &DataSample, type_name: &str) -> usize {
393        let mut unique_values = HashSet::new();
394
395        for record in &sample.records {
396            for value in record.values() {
397                if self.infer_type(value) == type_name {
398                    unique_values.insert(format!("{:?}", value));
399                }
400            }
401        }
402
403        ((unique_values.len() as f64) * self.config.cardinality_multiplier).ceil() as usize
404    }
405
406    /// Check if two fields have a meaningful relationship.
407    fn has_relationship(&self, sample: &DataSample, field1: &str, field2: &str) -> bool {
408        let values1 = sample.field_values(field1);
409        let values2 = sample.field_values(field2);
410
411        // Simple heuristic: if both fields are present in most records
412        let co_occurrence = values1.len().min(values2.len());
413        let threshold = (sample.len() as f64 * 0.8).ceil() as usize;
414
415        co_occurrence >= threshold
416    }
417
418    /// Infer value range from numeric values.
419    fn infer_value_range(&self, values: &[&Value]) -> Option<ValueRange> {
420        let numbers: Vec<f64> = values.iter().filter_map(|v| v.as_f64()).collect();
421
422        if numbers.is_empty() {
423            return None;
424        }
425
426        let min = numbers.iter().copied().fold(f64::INFINITY, f64::min);
427        let max = numbers.iter().copied().fold(f64::NEG_INFINITY, f64::max);
428
429        Some(ValueRange::new().with_min(min, true).with_max(max, true))
430    }
431
432    /// Get learning statistics.
433    pub fn statistics(&self) -> &LearningStatistics {
434        &self.statistics
435    }
436
437    /// Get confidence score for a schema element.
438    pub fn confidence(&self, element: &str) -> Option<&ConfidenceScore> {
439        self.confidence_scores.get(element)
440    }
441
442    /// Get all confidence scores.
443    pub fn all_confidences(&self) -> &HashMap<String, ConfidenceScore> {
444        &self.confidence_scores
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_data_sample_from_json() {
454        let json = r#"[
455            {"id": 1, "name": "Alice"},
456            {"id": 2, "name": "Bob"}
457        ]"#;
458
459        let sample = DataSample::from_json(json).expect("unwrap");
460        assert_eq!(sample.len(), 2);
461        assert_eq!(sample.field_names().len(), 2);
462    }
463
464    #[test]
465    fn test_data_sample_from_csv() {
466        let csv = "id,name,age\n1,Alice,30\n2,Bob,25";
467
468        let sample = DataSample::from_csv(csv).expect("unwrap");
469        assert_eq!(sample.len(), 2);
470        assert_eq!(sample.field_names().len(), 3);
471    }
472
473    #[test]
474    fn test_schema_learner_basic() {
475        let json = r#"[
476            {"id": 1, "name": "Alice", "age": 30},
477            {"id": 2, "name": "Bob", "age": 25}
478        ]"#;
479
480        let sample = DataSample::from_json(json).expect("unwrap");
481        let config = InferenceConfig::default();
482        let mut learner = SchemaLearner::new(config);
483
484        let _schema = learner.learn_from_sample(&sample).expect("unwrap");
485        let stats = learner.statistics();
486
487        assert!(stats.domains_inferred > 0);
488        assert!(stats.predicates_inferred > 0);
489        assert_eq!(stats.total_samples_analyzed, 2);
490    }
491
492    #[test]
493    fn test_type_inference() {
494        let config = InferenceConfig::default();
495        let learner = SchemaLearner::new(config);
496
497        assert_eq!(learner.infer_type(&Value::Number(42.into())), "Number");
498        assert_eq!(learner.infer_type(&Value::String("test".into())), "String");
499        assert_eq!(learner.infer_type(&Value::Bool(true)), "Boolean");
500    }
501
502    #[test]
503    fn test_value_range_inference() {
504        let val1 = Value::Number(10.into());
505        let val2 = Value::Number(20.into());
506        let val3 = Value::Number(30.into());
507        let values = vec![&val1, &val2, &val3];
508
509        let config = InferenceConfig::default();
510        let learner = SchemaLearner::new(config);
511        let range = learner.infer_value_range(&values).expect("unwrap");
512
513        assert_eq!(range.min, Some(10.0));
514        assert_eq!(range.max, Some(30.0));
515    }
516
517    #[test]
518    fn test_confidence_score() {
519        let score = ConfidenceScore::new(0.85, 100, "High confidence");
520        assert_eq!(score.score, 0.85);
521        assert_eq!(score.evidence_count, 100);
522        assert!(score.is_confident(0.7));
523        assert!(!score.is_confident(0.9));
524    }
525
526    #[test]
527    fn test_inference_config_default() {
528        let config = InferenceConfig::default();
529        assert_eq!(config.min_confidence, 0.7);
530        assert!(config.infer_hierarchies);
531        assert!(config.infer_constraints);
532    }
533
534    #[test]
535    fn test_cardinality_estimation() {
536        let json = r#"[
537            {"id": 1, "type": "A"},
538            {"id": 2, "type": "B"},
539            {"id": 3, "type": "A"}
540        ]"#;
541
542        let sample = DataSample::from_json(json).expect("unwrap");
543        let config = InferenceConfig::default();
544        let learner = SchemaLearner::new(config);
545
546        let cardinality = learner.estimate_cardinality(&sample, "Number");
547        assert!(cardinality > 0);
548    }
549
550    #[test]
551    fn test_field_values_extraction() {
552        let json = r#"[
553            {"name": "Alice", "age": 30},
554            {"name": "Bob", "age": 25}
555        ]"#;
556
557        let sample = DataSample::from_json(json).expect("unwrap");
558        let names = sample.field_values("name");
559
560        assert_eq!(names.len(), 2);
561    }
562
563    #[test]
564    fn test_relationship_detection() {
565        let json = r#"[
566            {"person": "Alice", "city": "NYC"},
567            {"person": "Bob", "city": "LA"}
568        ]"#;
569
570        let sample = DataSample::from_json(json).expect("unwrap");
571        let config = InferenceConfig::default();
572        let learner = SchemaLearner::new(config);
573
574        assert!(learner.has_relationship(&sample, "person", "city"));
575    }
576
577    #[test]
578    fn test_empty_sample() {
579        let json = "[]";
580        let sample = DataSample::from_json(json).expect("unwrap");
581        assert!(sample.is_empty());
582        assert_eq!(sample.len(), 0);
583    }
584
585    #[test]
586    fn test_single_object_json() {
587        let json = r#"{"id": 1, "name": "Alice"}"#;
588        let sample = DataSample::from_json(json).expect("unwrap");
589        assert_eq!(sample.len(), 1);
590    }
591
592    #[test]
593    fn test_csv_type_detection() {
594        let csv = "id,name,active\n1,Alice,true\n2,Bob,false";
595        let sample = DataSample::from_csv(csv).expect("unwrap");
596
597        let active_values = sample.field_values("active");
598        assert!(active_values.iter().all(|v| v.is_boolean()));
599    }
600
601    #[test]
602    fn test_confidence_scores_tracking() {
603        let json = r#"[{"id": 1, "name": "Alice"}]"#;
604        let sample = DataSample::from_json(json).expect("unwrap");
605        let config = InferenceConfig::default();
606        let mut learner = SchemaLearner::new(config);
607
608        learner.learn_from_sample(&sample).expect("unwrap");
609        assert!(!learner.all_confidences().is_empty());
610    }
611
612    #[test]
613    fn test_learning_statistics() {
614        let json = r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#;
615        let sample = DataSample::from_json(json).expect("unwrap");
616        let config = InferenceConfig::default();
617        let mut learner = SchemaLearner::new(config);
618
619        learner.learn_from_sample(&sample).expect("unwrap");
620        let stats = learner.statistics();
621
622        assert_eq!(stats.total_samples_analyzed, 3);
623        // Inference time is recorded (can be 0 for fast operations)
624        assert!(stats.domains_inferred > 0 || stats.predicates_inferred > 0);
625    }
626}