Skip to main content

tensorlogic_adapters/
learning.rs

1//! Schema learning from data.
2//!
3//! This module provides automatic schema inference from sample data,
4//! enabling rapid prototyping and schema bootstrapping from existing datasets.
5//!
6//! # Overview
7//!
8//! Instead of manually defining schemas, you can learn them from:
9//! - JSON objects and arrays
10//! - CSV files with headers
11//! - Relational data patterns
12//! - Example predicates and relationships
13//!
14//! The learner analyzes data to infer:
15//! - Domain types and cardinalities
16//! - Predicate signatures and properties
17//! - Type hierarchies
18//! - Value ranges and constraints
19//! - Functional dependencies
20//!
21//! # Architecture
22//!
23//! - **SchemaLearner**: Main inference engine
24//! - **DataSample**: Represents sample data for analysis
25//! - **InferenceConfig**: Configuration for learning behavior
26//! - **LearningStatistics**: Statistics about the learning process
27//! - **ConfidenceScore**: Confidence in inferred schema elements
28//!
29//! # Example
30//!
31//! ```rust
32//! use tensorlogic_adapters::{SchemaLearner, DataSample, InferenceConfig};
33//!
34//! let json_data = r#"[
35//!     {"id": 1, "name": "Alice", "age": 30, "city": "NYC"},
36//!     {"id": 2, "name": "Bob", "age": 25, "city": "LA"},
37//!     {"id": 3, "name": "Charlie", "age": 35, "city": "NYC"}
38//! ]"#;
39//!
40//! let sample = DataSample::from_json(json_data).unwrap();
41//! let config = InferenceConfig::default();
42//! let mut learner = SchemaLearner::new(config);
43//!
44//! let schema = learner.learn_from_sample(&sample).unwrap();
45//! let stats = learner.statistics();
46//!
47//! assert!(stats.domains_inferred > 0);
48//! assert!(stats.predicates_inferred > 0);
49//! ```
50
51use anyhow::{anyhow, Result};
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::collections::{HashMap, HashSet};
55
56use crate::{DomainInfo, PredicateInfo, SymbolTable, ValueRange};
57
58/// Configuration for schema inference.
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct InferenceConfig {
61    /// Minimum confidence threshold for inferred elements (0.0 to 1.0)
62    pub min_confidence: f64,
63    /// Whether to infer domain hierarchies
64    pub infer_hierarchies: bool,
65    /// Whether to infer constraints
66    pub infer_constraints: bool,
67    /// Whether to infer functional dependencies
68    pub infer_dependencies: bool,
69    /// Cardinality multiplier for domain size estimation
70    pub cardinality_multiplier: f64,
71    /// Maximum depth for nested object analysis
72    pub max_nesting_depth: usize,
73}
74
75impl Default for InferenceConfig {
76    fn default() -> Self {
77        Self {
78            min_confidence: 0.7,
79            infer_hierarchies: true,
80            infer_constraints: true,
81            infer_dependencies: true,
82            cardinality_multiplier: 2.0,
83            max_nesting_depth: 5,
84        }
85    }
86}
87
88/// Confidence score for inferred schema elements.
89#[derive(Clone, Debug, PartialEq)]
90pub struct ConfidenceScore {
91    pub score: f64,
92    pub evidence_count: usize,
93    pub reasoning: String,
94}
95
96impl ConfidenceScore {
97    pub fn new(score: f64, evidence_count: usize, reasoning: impl Into<String>) -> Self {
98        Self {
99            score: score.clamp(0.0, 1.0),
100            evidence_count,
101            reasoning: reasoning.into(),
102        }
103    }
104
105    pub fn is_confident(&self, threshold: f64) -> bool {
106        self.score >= threshold
107    }
108}
109
110/// Sample data for schema learning.
111#[derive(Clone, Debug)]
112pub struct DataSample {
113    records: Vec<HashMap<String, Value>>,
114}
115
116impl DataSample {
117    /// Create a data sample from JSON array.
118    pub fn from_json(json: &str) -> Result<Self> {
119        let value: Value = serde_json::from_str(json)?;
120
121        let records = match value {
122            Value::Array(arr) => arr
123                .into_iter()
124                .filter_map(|v| {
125                    if let Value::Object(map) = v {
126                        Some(map.into_iter().collect::<HashMap<_, _>>())
127                    } else {
128                        None
129                    }
130                })
131                .collect(),
132            Value::Object(map) => {
133                vec![map.into_iter().collect()]
134            }
135            _ => return Err(anyhow!("Expected JSON array or object")),
136        };
137
138        Ok(Self { records })
139    }
140
141    /// Create a data sample from CSV data.
142    pub fn from_csv(csv: &str) -> Result<Self> {
143        let mut lines = csv.lines();
144        let headers: Vec<String> = lines
145            .next()
146            .ok_or_else(|| anyhow!("Empty CSV"))?
147            .split(',')
148            .map(|s| s.trim().to_string())
149            .collect();
150
151        let records = lines
152            .map(|line| {
153                let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
154                headers
155                    .iter()
156                    .zip(values.iter())
157                    .map(|(k, v)| {
158                        let json_val = if let Ok(num) = v.parse::<f64>() {
159                            Value::Number(serde_json::Number::from_f64(num).unwrap())
160                        } else if *v == "true" || *v == "false" {
161                            Value::Bool(*v == "true")
162                        } else {
163                            Value::String(v.to_string())
164                        };
165                        (k.clone(), json_val)
166                    })
167                    .collect()
168            })
169            .collect();
170
171        Ok(Self { records })
172    }
173
174    /// Get all unique field names across records.
175    pub fn field_names(&self) -> HashSet<String> {
176        self.records
177            .iter()
178            .flat_map(|record| record.keys().cloned())
179            .collect()
180    }
181
182    /// Get values for a specific field.
183    pub fn field_values(&self, field: &str) -> Vec<&Value> {
184        self.records
185            .iter()
186            .filter_map(|record| record.get(field))
187            .collect()
188    }
189
190    /// Get number of records.
191    pub fn len(&self) -> usize {
192        self.records.len()
193    }
194
195    /// Check if sample is empty.
196    pub fn is_empty(&self) -> bool {
197        self.records.is_empty()
198    }
199}
200
201/// Statistics about the learning process.
202#[derive(Clone, Debug, Default)]
203pub struct LearningStatistics {
204    pub domains_inferred: usize,
205    pub predicates_inferred: usize,
206    pub constraints_inferred: usize,
207    pub hierarchies_inferred: usize,
208    pub dependencies_inferred: usize,
209    pub total_samples_analyzed: usize,
210    pub inference_time_ms: u128,
211}
212
213/// Schema learner for automatic inference from data.
214pub struct SchemaLearner {
215    config: InferenceConfig,
216    statistics: LearningStatistics,
217    confidence_scores: HashMap<String, ConfidenceScore>,
218}
219
220impl SchemaLearner {
221    /// Create a new schema learner with configuration.
222    pub fn new(config: InferenceConfig) -> Self {
223        Self {
224            config,
225            statistics: LearningStatistics::default(),
226            confidence_scores: HashMap::new(),
227        }
228    }
229
230    /// Learn a complete schema from a data sample.
231    pub fn learn_from_sample(&mut self, sample: &DataSample) -> Result<SymbolTable> {
232        let start = std::time::Instant::now();
233
234        let mut table = SymbolTable::new();
235
236        // Infer domains from data types
237        self.infer_domains(sample, &mut table)?;
238
239        // Infer predicates from fields
240        self.infer_predicates(sample, &mut table)?;
241
242        // Infer constraints if enabled
243        if self.config.infer_constraints {
244            self.infer_constraints(sample, &mut table)?;
245        }
246
247        // Infer hierarchies if enabled
248        if self.config.infer_hierarchies {
249            self.infer_hierarchies(sample, &mut table)?;
250        }
251
252        self.statistics.total_samples_analyzed = sample.len();
253        self.statistics.inference_time_ms = start.elapsed().as_millis();
254
255        Ok(table)
256    }
257
258    /// Infer domains from data types in the sample.
259    fn infer_domains(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
260        let mut domain_types: HashMap<String, HashSet<String>> = HashMap::new();
261
262        // Analyze each field's type distribution
263        for field in sample.field_names() {
264            let values = sample.field_values(&field);
265            let types: HashSet<String> = values.iter().map(|v| self.infer_type(v)).collect();
266            domain_types.insert(field.clone(), types);
267        }
268
269        // Create domains for inferred types
270        let mut inferred_types: HashSet<String> = HashSet::new();
271        for types in domain_types.values() {
272            inferred_types.extend(types.clone());
273        }
274
275        for type_name in inferred_types {
276            let cardinality = self.estimate_cardinality(sample, &type_name);
277            let domain = DomainInfo::new(&type_name, cardinality);
278
279            if table.add_domain(domain).is_ok() {
280                self.statistics.domains_inferred += 1;
281                self.confidence_scores.insert(
282                    format!("domain:{}", type_name),
283                    ConfidenceScore::new(
284                        0.9,
285                        sample.len(),
286                        format!("Inferred from {} samples", sample.len()),
287                    ),
288                );
289            }
290        }
291
292        Ok(())
293    }
294
295    /// Infer predicates from field relationships.
296    fn infer_predicates(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
297        let fields: Vec<String> = sample.field_names().into_iter().collect();
298
299        // Create unary predicates for each field
300        for field in &fields {
301            let values = sample.field_values(field);
302            if values.is_empty() {
303                continue;
304            }
305
306            let type_name = self.infer_type(values[0]);
307            let predicate = PredicateInfo::new(field, vec![type_name.clone()]);
308
309            if table.add_predicate(predicate).is_ok() {
310                self.statistics.predicates_inferred += 1;
311                self.confidence_scores.insert(
312                    format!("predicate:{}", field),
313                    ConfidenceScore::new(
314                        0.85,
315                        values.len(),
316                        format!("Inferred from {} values", values.len()),
317                    ),
318                );
319            }
320        }
321
322        // Infer binary predicates from field co-occurrence
323        for i in 0..fields.len() {
324            for j in (i + 1)..fields.len() {
325                let field1 = &fields[i];
326                let field2 = &fields[j];
327
328                if self.has_relationship(sample, field1, field2) {
329                    let type1 = self.infer_type(sample.field_values(field1)[0]);
330                    let type2 = self.infer_type(sample.field_values(field2)[0]);
331
332                    let rel_name = format!("{}_{}", field1, field2);
333                    let predicate = PredicateInfo::new(&rel_name, vec![type1, type2]);
334
335                    if table.add_predicate(predicate).is_ok() {
336                        self.statistics.predicates_inferred += 1;
337                    }
338                }
339            }
340        }
341
342        Ok(())
343    }
344
345    /// Infer constraints from data patterns.
346    fn infer_constraints(&mut self, sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
347        for field in sample.field_names() {
348            let values = sample.field_values(&field);
349
350            // Infer value ranges for numeric fields
351            if let Some(range) = self.infer_value_range(&values) {
352                self.statistics.constraints_inferred += 1;
353                self.confidence_scores.insert(
354                    format!("constraint:{}:range", field),
355                    ConfidenceScore::new(
356                        0.8,
357                        values.len(),
358                        "Inferred from numeric values".to_string(),
359                    ),
360                );
361                // Note: Constraints would be attached to predicates in a full implementation
362                let _ = range; // Suppress unused warning
363            }
364        }
365
366        Ok(())
367    }
368
369    /// Infer domain hierarchies from data.
370    fn infer_hierarchies(&mut self, _sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
371        // Placeholder for hierarchy inference
372        // Would analyze naming patterns, value containment, etc.
373        Ok(())
374    }
375
376    /// Infer the JSON value type.
377    fn infer_type(&self, value: &Value) -> String {
378        match value {
379            Value::Number(_) => "Number".to_string(),
380            Value::String(_) => "String".to_string(),
381            Value::Bool(_) => "Boolean".to_string(),
382            Value::Array(_) => "Array".to_string(),
383            Value::Object(_) => "Object".to_string(),
384            Value::Null => "Unknown".to_string(),
385        }
386    }
387
388    /// Estimate domain cardinality from sample.
389    fn estimate_cardinality(&self, sample: &DataSample, type_name: &str) -> usize {
390        let mut unique_values = HashSet::new();
391
392        for record in &sample.records {
393            for value in record.values() {
394                if self.infer_type(value) == type_name {
395                    unique_values.insert(format!("{:?}", value));
396                }
397            }
398        }
399
400        ((unique_values.len() as f64) * self.config.cardinality_multiplier).ceil() as usize
401    }
402
403    /// Check if two fields have a meaningful relationship.
404    fn has_relationship(&self, sample: &DataSample, field1: &str, field2: &str) -> bool {
405        let values1 = sample.field_values(field1);
406        let values2 = sample.field_values(field2);
407
408        // Simple heuristic: if both fields are present in most records
409        let co_occurrence = values1.len().min(values2.len());
410        let threshold = (sample.len() as f64 * 0.8).ceil() as usize;
411
412        co_occurrence >= threshold
413    }
414
415    /// Infer value range from numeric values.
416    fn infer_value_range(&self, values: &[&Value]) -> Option<ValueRange> {
417        let numbers: Vec<f64> = values.iter().filter_map(|v| v.as_f64()).collect();
418
419        if numbers.is_empty() {
420            return None;
421        }
422
423        let min = numbers.iter().copied().fold(f64::INFINITY, f64::min);
424        let max = numbers.iter().copied().fold(f64::NEG_INFINITY, f64::max);
425
426        Some(ValueRange::new().with_min(min, true).with_max(max, true))
427    }
428
429    /// Get learning statistics.
430    pub fn statistics(&self) -> &LearningStatistics {
431        &self.statistics
432    }
433
434    /// Get confidence score for a schema element.
435    pub fn confidence(&self, element: &str) -> Option<&ConfidenceScore> {
436        self.confidence_scores.get(element)
437    }
438
439    /// Get all confidence scores.
440    pub fn all_confidences(&self) -> &HashMap<String, ConfidenceScore> {
441        &self.confidence_scores
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_data_sample_from_json() {
451        let json = r#"[
452            {"id": 1, "name": "Alice"},
453            {"id": 2, "name": "Bob"}
454        ]"#;
455
456        let sample = DataSample::from_json(json).unwrap();
457        assert_eq!(sample.len(), 2);
458        assert_eq!(sample.field_names().len(), 2);
459    }
460
461    #[test]
462    fn test_data_sample_from_csv() {
463        let csv = "id,name,age\n1,Alice,30\n2,Bob,25";
464
465        let sample = DataSample::from_csv(csv).unwrap();
466        assert_eq!(sample.len(), 2);
467        assert_eq!(sample.field_names().len(), 3);
468    }
469
470    #[test]
471    fn test_schema_learner_basic() {
472        let json = r#"[
473            {"id": 1, "name": "Alice", "age": 30},
474            {"id": 2, "name": "Bob", "age": 25}
475        ]"#;
476
477        let sample = DataSample::from_json(json).unwrap();
478        let config = InferenceConfig::default();
479        let mut learner = SchemaLearner::new(config);
480
481        let _schema = learner.learn_from_sample(&sample).unwrap();
482        let stats = learner.statistics();
483
484        assert!(stats.domains_inferred > 0);
485        assert!(stats.predicates_inferred > 0);
486        assert_eq!(stats.total_samples_analyzed, 2);
487    }
488
489    #[test]
490    fn test_type_inference() {
491        let config = InferenceConfig::default();
492        let learner = SchemaLearner::new(config);
493
494        assert_eq!(learner.infer_type(&Value::Number(42.into())), "Number");
495        assert_eq!(learner.infer_type(&Value::String("test".into())), "String");
496        assert_eq!(learner.infer_type(&Value::Bool(true)), "Boolean");
497    }
498
499    #[test]
500    fn test_value_range_inference() {
501        let val1 = Value::Number(10.into());
502        let val2 = Value::Number(20.into());
503        let val3 = Value::Number(30.into());
504        let values = vec![&val1, &val2, &val3];
505
506        let config = InferenceConfig::default();
507        let learner = SchemaLearner::new(config);
508        let range = learner.infer_value_range(&values).unwrap();
509
510        assert_eq!(range.min, Some(10.0));
511        assert_eq!(range.max, Some(30.0));
512    }
513
514    #[test]
515    fn test_confidence_score() {
516        let score = ConfidenceScore::new(0.85, 100, "High confidence");
517        assert_eq!(score.score, 0.85);
518        assert_eq!(score.evidence_count, 100);
519        assert!(score.is_confident(0.7));
520        assert!(!score.is_confident(0.9));
521    }
522
523    #[test]
524    fn test_inference_config_default() {
525        let config = InferenceConfig::default();
526        assert_eq!(config.min_confidence, 0.7);
527        assert!(config.infer_hierarchies);
528        assert!(config.infer_constraints);
529    }
530
531    #[test]
532    fn test_cardinality_estimation() {
533        let json = r#"[
534            {"id": 1, "type": "A"},
535            {"id": 2, "type": "B"},
536            {"id": 3, "type": "A"}
537        ]"#;
538
539        let sample = DataSample::from_json(json).unwrap();
540        let config = InferenceConfig::default();
541        let learner = SchemaLearner::new(config);
542
543        let cardinality = learner.estimate_cardinality(&sample, "Number");
544        assert!(cardinality > 0);
545    }
546
547    #[test]
548    fn test_field_values_extraction() {
549        let json = r#"[
550            {"name": "Alice", "age": 30},
551            {"name": "Bob", "age": 25}
552        ]"#;
553
554        let sample = DataSample::from_json(json).unwrap();
555        let names = sample.field_values("name");
556
557        assert_eq!(names.len(), 2);
558    }
559
560    #[test]
561    fn test_relationship_detection() {
562        let json = r#"[
563            {"person": "Alice", "city": "NYC"},
564            {"person": "Bob", "city": "LA"}
565        ]"#;
566
567        let sample = DataSample::from_json(json).unwrap();
568        let config = InferenceConfig::default();
569        let learner = SchemaLearner::new(config);
570
571        assert!(learner.has_relationship(&sample, "person", "city"));
572    }
573
574    #[test]
575    fn test_empty_sample() {
576        let json = "[]";
577        let sample = DataSample::from_json(json).unwrap();
578        assert!(sample.is_empty());
579        assert_eq!(sample.len(), 0);
580    }
581
582    #[test]
583    fn test_single_object_json() {
584        let json = r#"{"id": 1, "name": "Alice"}"#;
585        let sample = DataSample::from_json(json).unwrap();
586        assert_eq!(sample.len(), 1);
587    }
588
589    #[test]
590    fn test_csv_type_detection() {
591        let csv = "id,name,active\n1,Alice,true\n2,Bob,false";
592        let sample = DataSample::from_csv(csv).unwrap();
593
594        let active_values = sample.field_values("active");
595        assert!(active_values.iter().all(|v| v.is_boolean()));
596    }
597
598    #[test]
599    fn test_confidence_scores_tracking() {
600        let json = r#"[{"id": 1, "name": "Alice"}]"#;
601        let sample = DataSample::from_json(json).unwrap();
602        let config = InferenceConfig::default();
603        let mut learner = SchemaLearner::new(config);
604
605        learner.learn_from_sample(&sample).unwrap();
606        assert!(!learner.all_confidences().is_empty());
607    }
608
609    #[test]
610    fn test_learning_statistics() {
611        let json = r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#;
612        let sample = DataSample::from_json(json).unwrap();
613        let config = InferenceConfig::default();
614        let mut learner = SchemaLearner::new(config);
615
616        learner.learn_from_sample(&sample).unwrap();
617        let stats = learner.statistics();
618
619        assert_eq!(stats.total_samples_analyzed, 3);
620        // Inference time is recorded (can be 0 for fast operations)
621        assert!(stats.domains_inferred > 0 || stats.predicates_inferred > 0);
622    }
623}