Skip to main content

tensorlogic_adapters/
embeddings.rs

1//! Schema embeddings for similarity search and ML applications.
2//!
3//! This module provides functionality to generate vector embeddings for domains,
4//! predicates, and entire schemas. These embeddings can be used for:
5//! - Similarity search (find similar domains/predicates)
6//! - Schema recommendation
7//! - Clustering and analysis
8//! - ML-based schema completion
9//!
10//! The embeddings are based on structural and semantic features of the schema elements.
11
12use std::collections::HashMap;
13
14use crate::{DomainInfo, PredicateInfo, SymbolTable};
15
16/// Dimensionality of the embedding vectors.
17///
18/// Using 64 dimensions provides a good balance between expressiveness
19/// and computational efficiency for typical schema sizes.
20pub const EMBEDDING_DIM: usize = 64;
21
22/// Vector embedding representation.
23pub type Embedding = Vec<f64>;
24
25/// Schema element embedding generator.
26///
27/// Generates vector embeddings for domains, predicates, and schemas
28/// based on their structural and semantic properties.
29pub struct SchemaEmbedder {
30    /// Whether to normalize embeddings to unit length
31    normalize: bool,
32    /// Feature weights for embedding computation
33    weights: EmbeddingWeights,
34}
35
36/// Weights for different embedding features.
37#[derive(Clone, Debug)]
38pub struct EmbeddingWeights {
39    /// Weight for cardinality-based features
40    pub cardinality_weight: f64,
41    /// Weight for arity-based features
42    pub arity_weight: f64,
43    /// Weight for name-based features
44    pub name_weight: f64,
45    /// Weight for structural features
46    pub structural_weight: f64,
47}
48
49impl Default for EmbeddingWeights {
50    fn default() -> Self {
51        Self {
52            cardinality_weight: 1.0,
53            arity_weight: 1.0,
54            name_weight: 0.5,
55            structural_weight: 0.8,
56        }
57    }
58}
59
60impl SchemaEmbedder {
61    /// Create a new schema embedder with default settings.
62    pub fn new() -> Self {
63        Self {
64            normalize: true,
65            weights: EmbeddingWeights::default(),
66        }
67    }
68
69    /// Set whether to normalize embeddings.
70    pub fn with_normalization(mut self, normalize: bool) -> Self {
71        self.normalize = normalize;
72        self
73    }
74
75    /// Set custom feature weights.
76    pub fn with_weights(mut self, weights: EmbeddingWeights) -> Self {
77        self.weights = weights;
78        self
79    }
80
81    /// Generate embedding for a domain.
82    pub fn embed_domain(&self, domain: &DomainInfo) -> Embedding {
83        let mut embedding = vec![0.0; EMBEDDING_DIM];
84
85        // Cardinality-based features (dimensions 0-15)
86        let log_card = (domain.cardinality as f64).ln();
87        embedding[0] = log_card * self.weights.cardinality_weight;
88        embedding[1] = (domain.cardinality as f64).sqrt() * self.weights.cardinality_weight;
89        embedding[2] = (domain.cardinality as f64).cbrt() * self.weights.cardinality_weight;
90
91        // Cardinality ranges (binary features)
92        embedding[3] = if domain.cardinality < 10 { 1.0 } else { 0.0 };
93        embedding[4] = if domain.cardinality < 100 { 1.0 } else { 0.0 };
94        embedding[5] = if domain.cardinality < 1000 { 1.0 } else { 0.0 };
95        embedding[6] = if domain.cardinality < 10000 { 1.0 } else { 0.0 };
96
97        // Name-based features (dimensions 16-31)
98        self.add_name_features(&mut embedding, &domain.name, 16);
99
100        // Description features (dimensions 32-39)
101        if let Some(ref desc) = domain.description {
102            embedding[32] = (desc.len() as f64).ln() * self.weights.structural_weight;
103            embedding[33] =
104                (desc.split_whitespace().count() as f64).ln() * self.weights.structural_weight;
105            embedding[34] = if desc.contains("person") || desc.contains("user") {
106                1.0
107            } else {
108                0.0
109            };
110            embedding[35] = if desc.contains("time") || desc.contains("temporal") {
111                1.0
112            } else {
113                0.0
114            };
115        }
116
117        // Metadata features (dimensions 40-47)
118        if let Some(ref metadata) = domain.metadata {
119            embedding[40] = if metadata.provenance.is_some() {
120                1.0
121            } else {
122                0.0
123            };
124            embedding[41] = metadata.version_history.len() as f64;
125            embedding[42] = metadata.tags.len() as f64;
126        }
127
128        if self.normalize {
129            self.normalize_embedding(&mut embedding);
130        }
131
132        embedding
133    }
134
135    /// Generate embedding for a predicate.
136    pub fn embed_predicate(&self, predicate: &PredicateInfo) -> Embedding {
137        let mut embedding = vec![0.0; EMBEDDING_DIM];
138
139        // Arity-based features (dimensions 0-15)
140        let arity = predicate.arg_domains.len();
141        embedding[0] = arity as f64 * self.weights.arity_weight;
142        embedding[1] = (arity as f64).sqrt() * self.weights.arity_weight;
143
144        // Arity ranges (binary features)
145        embedding[2] = if arity == 0 { 1.0 } else { 0.0 }; // Nullary
146        embedding[3] = if arity == 1 { 1.0 } else { 0.0 }; // Unary
147        embedding[4] = if arity == 2 { 1.0 } else { 0.0 }; // Binary
148        embedding[5] = if arity == 3 { 1.0 } else { 0.0 }; // Ternary
149        embedding[6] = if arity > 3 { 1.0 } else { 0.0 }; // N-ary
150
151        // Name-based features (dimensions 16-31)
152        self.add_name_features(&mut embedding, &predicate.name, 16);
153
154        // Constraint features (dimensions 32-47)
155        if let Some(ref constraints) = predicate.constraints {
156            embedding[32] = constraints.properties.len() as f64 * self.weights.structural_weight;
157            embedding[33] = if constraints.properties.iter().any(|p| {
158                matches!(
159                    p,
160                    crate::PredicateProperty::Symmetric | crate::PredicateProperty::Transitive
161                )
162            }) {
163                1.0
164            } else {
165                0.0
166            };
167            embedding[34] =
168                constraints.functional_dependencies.len() as f64 * self.weights.structural_weight;
169
170            // Count non-None value ranges
171            let num_ranges = constraints
172                .value_ranges
173                .iter()
174                .filter(|r| r.is_some())
175                .count();
176            embedding[35] = num_ranges as f64;
177        }
178
179        // Description features (dimensions 48-55)
180        if let Some(ref desc) = predicate.description {
181            embedding[48] = (desc.len() as f64).ln() * self.weights.structural_weight;
182            embedding[49] =
183                (desc.split_whitespace().count() as f64).ln() * self.weights.structural_weight;
184        }
185
186        if self.normalize {
187            self.normalize_embedding(&mut embedding);
188        }
189
190        embedding
191    }
192
193    /// Generate embedding for an entire schema.
194    pub fn embed_schema(&self, table: &SymbolTable) -> Embedding {
195        let mut embedding = vec![0.0; EMBEDDING_DIM];
196
197        // Schema size features (dimensions 0-15)
198        // Use max(1, len) to avoid ln(0) = -inf
199        embedding[0] = ((table.domains.len().max(1)) as f64).ln() * self.weights.structural_weight;
200        embedding[1] =
201            ((table.predicates.len().max(1)) as f64).ln() * self.weights.structural_weight;
202        embedding[2] =
203            ((table.variables.len().max(1)) as f64).ln() * self.weights.structural_weight;
204
205        // Total cardinality
206        let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
207        embedding[3] = ((total_card.max(1)) as f64).ln() * self.weights.cardinality_weight;
208
209        // Average arity
210        let avg_arity: f64 = if table.predicates.is_empty() {
211            0.0
212        } else {
213            table
214                .predicates
215                .values()
216                .map(|p| p.arg_domains.len())
217                .sum::<usize>() as f64
218                / table.predicates.len() as f64
219        };
220        embedding[4] = avg_arity * self.weights.arity_weight;
221
222        // Domain histogram (dimensions 16-23)
223        for domain in table.domains.values() {
224            let log_card = (domain.cardinality as f64).ln();
225            let idx = ((log_card / 10.0).min(7.0) as usize).min(7);
226            embedding[16 + idx] += 1.0;
227        }
228
229        // Arity histogram (dimensions 24-31)
230        for predicate in table.predicates.values() {
231            let arity = predicate.arg_domains.len().min(7);
232            embedding[24 + arity] += 1.0;
233        }
234
235        // Graph density (dimension 32)
236        let max_edges = table.domains.len() * table.domains.len();
237        let actual_edges = table
238            .predicates
239            .values()
240            .filter(|p| p.arg_domains.len() == 2)
241            .count();
242        embedding[32] = if max_edges > 0 {
243            actual_edges as f64 / max_edges as f64
244        } else {
245            0.0
246        };
247
248        if self.normalize {
249            self.normalize_embedding(&mut embedding);
250        }
251
252        embedding
253    }
254
255    /// Compute cosine similarity between two embeddings.
256    pub fn cosine_similarity(a: &Embedding, b: &Embedding) -> f64 {
257        assert_eq!(a.len(), b.len(), "Embeddings must have same dimension");
258
259        let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
260        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
261        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
262
263        if norm_a == 0.0 || norm_b == 0.0 {
264            0.0
265        } else {
266            dot_product / (norm_a * norm_b)
267        }
268    }
269
270    /// Compute Euclidean distance between two embeddings.
271    pub fn euclidean_distance(a: &Embedding, b: &Embedding) -> f64 {
272        assert_eq!(a.len(), b.len(), "Embeddings must have same dimension");
273
274        a.iter()
275            .zip(b.iter())
276            .map(|(x, y)| (x - y).powi(2))
277            .sum::<f64>()
278            .sqrt()
279    }
280
281    /// Add name-based features to embedding.
282    fn add_name_features(&self, embedding: &mut [f64], name: &str, start_idx: usize) {
283        let name_lower = name.to_lowercase();
284
285        // Length features
286        embedding[start_idx] = (name.len() as f64).ln() * self.weights.name_weight;
287        embedding[start_idx + 1] =
288            name.chars().filter(|c| c.is_uppercase()).count() as f64 * self.weights.name_weight;
289
290        // Character distribution
291        let vowels = name_lower.chars().filter(|c| "aeiou".contains(*c)).count();
292        embedding[start_idx + 2] = vowels as f64 / name.len().max(1) as f64;
293
294        // Common patterns
295        embedding[start_idx + 3] = if name_lower.contains('_') { 1.0 } else { 0.0 };
296        embedding[start_idx + 4] = if name_lower.starts_with("is") || name_lower.starts_with("has")
297        {
298            1.0
299        } else {
300            0.0
301        };
302
303        // Domain-specific keywords
304        embedding[start_idx + 5] = if name_lower.contains("person")
305            || name_lower.contains("user")
306            || name_lower.contains("agent")
307        {
308            1.0
309        } else {
310            0.0
311        };
312        embedding[start_idx + 6] = if name_lower.contains("time")
313            || name_lower.contains("date")
314            || name_lower.contains("temporal")
315        {
316            1.0
317        } else {
318            0.0
319        };
320        embedding[start_idx + 7] = if name_lower.contains("value")
321            || name_lower.contains("number")
322            || name_lower.contains("count")
323        {
324            1.0
325        } else {
326            0.0
327        };
328    }
329
330    /// Normalize embedding to unit length.
331    fn normalize_embedding(&self, embedding: &mut [f64]) {
332        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
333        if norm > 0.0 {
334            for x in embedding.iter_mut() {
335                *x /= norm;
336            }
337        }
338    }
339}
340
341impl Default for SchemaEmbedder {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347/// Schema similarity search engine.
348///
349/// Provides functionality to find similar domains, predicates, or schemas
350/// based on their embeddings.
351pub struct SimilaritySearch {
352    embedder: SchemaEmbedder,
353    domain_embeddings: HashMap<String, Embedding>,
354    predicate_embeddings: HashMap<String, Embedding>,
355}
356
357impl SimilaritySearch {
358    /// Create a new similarity search engine.
359    pub fn new() -> Self {
360        Self {
361            embedder: SchemaEmbedder::new(),
362            domain_embeddings: HashMap::new(),
363            predicate_embeddings: HashMap::new(),
364        }
365    }
366
367    /// Create with custom embedder.
368    pub fn with_embedder(embedder: SchemaEmbedder) -> Self {
369        Self {
370            embedder,
371            domain_embeddings: HashMap::new(),
372            predicate_embeddings: HashMap::new(),
373        }
374    }
375
376    /// Index a symbol table for similarity search.
377    pub fn index_table(&mut self, table: &SymbolTable) {
378        // Index domains
379        for (name, domain) in &table.domains {
380            let embedding = self.embedder.embed_domain(domain);
381            self.domain_embeddings.insert(name.clone(), embedding);
382        }
383
384        // Index predicates
385        for (name, predicate) in &table.predicates {
386            let embedding = self.embedder.embed_predicate(predicate);
387            self.predicate_embeddings.insert(name.clone(), embedding);
388        }
389    }
390
391    /// Find most similar domains to a query domain.
392    pub fn find_similar_domains(&self, query: &DomainInfo, top_k: usize) -> Vec<(String, f64)> {
393        let query_emb = self.embedder.embed_domain(query);
394        self.find_top_k(&self.domain_embeddings, &query_emb, top_k)
395    }
396
397    /// Find most similar predicates to a query predicate.
398    pub fn find_similar_predicates(
399        &self,
400        query: &PredicateInfo,
401        top_k: usize,
402    ) -> Vec<(String, f64)> {
403        let query_emb = self.embedder.embed_predicate(query);
404        self.find_top_k(&self.predicate_embeddings, &query_emb, top_k)
405    }
406
407    /// Find most similar domains by name.
408    pub fn find_similar_domains_by_name(&self, name: &str, top_k: usize) -> Vec<(String, f64)> {
409        if let Some(query_emb) = self.domain_embeddings.get(name) {
410            self.find_top_k(&self.domain_embeddings, query_emb, top_k + 1)
411                .into_iter()
412                .filter(|(n, _)| n != name)
413                .take(top_k)
414                .collect()
415        } else {
416            Vec::new()
417        }
418    }
419
420    /// Find most similar predicates by name.
421    pub fn find_similar_predicates_by_name(&self, name: &str, top_k: usize) -> Vec<(String, f64)> {
422        if let Some(query_emb) = self.predicate_embeddings.get(name) {
423            self.find_top_k(&self.predicate_embeddings, query_emb, top_k + 1)
424                .into_iter()
425                .filter(|(n, _)| n != name)
426                .take(top_k)
427                .collect()
428        } else {
429            Vec::new()
430        }
431    }
432
433    /// Get statistics about indexed elements.
434    pub fn stats(&self) -> SimilarityStats {
435        SimilarityStats {
436            num_domains: self.domain_embeddings.len(),
437            num_predicates: self.predicate_embeddings.len(),
438            embedding_dim: EMBEDDING_DIM,
439        }
440    }
441
442    /// Internal: Find top-k similar items from a set of embeddings.
443    fn find_top_k(
444        &self,
445        embeddings: &HashMap<String, Embedding>,
446        query: &Embedding,
447        k: usize,
448    ) -> Vec<(String, f64)> {
449        let mut similarities: Vec<(String, f64)> = embeddings
450            .iter()
451            .map(|(name, emb)| {
452                let sim = SchemaEmbedder::cosine_similarity(query, emb);
453                (name.clone(), sim)
454            })
455            .collect();
456
457        // Sort by similarity (descending)
458        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
459
460        // Return top k
461        similarities.into_iter().take(k).collect()
462    }
463}
464
465impl Default for SimilaritySearch {
466    fn default() -> Self {
467        Self::new()
468    }
469}
470
471/// Statistics about indexed elements in similarity search.
472#[derive(Clone, Debug)]
473pub struct SimilarityStats {
474    /// Number of indexed domains
475    pub num_domains: usize,
476    /// Number of indexed predicates
477    pub num_predicates: usize,
478    /// Embedding dimensionality
479    pub embedding_dim: usize,
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_domain_embedding_generation() {
488        let domain = DomainInfo::new("Person", 100);
489        let embedder = SchemaEmbedder::new();
490        let embedding = embedder.embed_domain(&domain);
491
492        assert_eq!(embedding.len(), EMBEDDING_DIM);
493        // Normalized embeddings should have unit length
494        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
495        assert!((norm - 1.0).abs() < 1e-6);
496    }
497
498    #[test]
499    fn test_predicate_embedding_generation() {
500        let predicate =
501            PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
502        let embedder = SchemaEmbedder::new();
503        let embedding = embedder.embed_predicate(&predicate);
504
505        assert_eq!(embedding.len(), EMBEDDING_DIM);
506        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
507        assert!((norm - 1.0).abs() < 1e-6);
508    }
509
510    #[test]
511    fn test_schema_embedding_generation() {
512        let mut table = SymbolTable::new();
513        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
514        table.add_domain(DomainInfo::new("Course", 50)).unwrap();
515
516        let embedder = SchemaEmbedder::new();
517        let embedding = embedder.embed_schema(&table);
518
519        assert_eq!(embedding.len(), EMBEDDING_DIM);
520        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
521        assert!((norm - 1.0).abs() < 1e-6);
522    }
523
524    #[test]
525    fn test_cosine_similarity() {
526        let a = vec![1.0, 0.0, 0.0];
527        let b = vec![1.0, 0.0, 0.0];
528        let c = vec![0.0, 1.0, 0.0];
529
530        assert!((SchemaEmbedder::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
531        assert!((SchemaEmbedder::cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
532    }
533
534    #[test]
535    fn test_euclidean_distance() {
536        let a = vec![0.0, 0.0, 0.0];
537        let b = vec![1.0, 1.0, 1.0];
538
539        let dist = SchemaEmbedder::euclidean_distance(&a, &b);
540        assert!((dist - 3.0_f64.sqrt()).abs() < 1e-6);
541    }
542
543    #[test]
544    fn test_similarity_search_indexing() {
545        let mut table = SymbolTable::new();
546        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
547        table.add_domain(DomainInfo::new("Student", 50)).unwrap();
548        table.add_domain(DomainInfo::new("Course", 30)).unwrap();
549
550        let mut search = SimilaritySearch::new();
551        search.index_table(&table);
552
553        let stats = search.stats();
554        assert_eq!(stats.num_domains, 3);
555        assert_eq!(stats.embedding_dim, EMBEDDING_DIM);
556    }
557
558    #[test]
559    fn test_find_similar_domains() {
560        let mut table = SymbolTable::new();
561        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
562        table.add_domain(DomainInfo::new("Student", 80)).unwrap();
563        table.add_domain(DomainInfo::new("Course", 50)).unwrap();
564
565        let mut search = SimilaritySearch::new();
566        search.index_table(&table);
567
568        let query = DomainInfo::new("Teacher", 90);
569        let similar = search.find_similar_domains(&query, 2);
570
571        assert_eq!(similar.len(), 2);
572        // Teacher (90) should be most similar to Person (100) and Student (80)
573        assert!(similar[0].1 > 0.5); // High similarity
574    }
575
576    #[test]
577    fn test_find_similar_predicates() {
578        let mut table = SymbolTable::new();
579        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
580
581        let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
582        let likes = PredicateInfo::new("likes", vec!["Person".to_string(), "Person".to_string()]);
583        let teaches =
584            PredicateInfo::new("teaches", vec!["Person".to_string(), "Person".to_string()]);
585
586        table.add_predicate(knows).unwrap();
587        table.add_predicate(likes).unwrap();
588        table.add_predicate(teaches).unwrap();
589
590        let mut search = SimilaritySearch::new();
591        search.index_table(&table);
592
593        let query = PredicateInfo::new("loves", vec!["Person".to_string(), "Person".to_string()]);
594        let similar = search.find_similar_predicates(&query, 3);
595
596        assert_eq!(similar.len(), 3);
597        // All binary predicates should have high similarity
598        for (_, sim) in &similar {
599            assert!(*sim > 0.8);
600        }
601    }
602
603    #[test]
604    fn test_similar_domains_by_name() {
605        let mut table = SymbolTable::new();
606        table.add_domain(DomainInfo::new("Person", 100)).unwrap();
607        table.add_domain(DomainInfo::new("Student", 80)).unwrap();
608        table.add_domain(DomainInfo::new("Course", 50)).unwrap();
609
610        let mut search = SimilaritySearch::new();
611        search.index_table(&table);
612
613        let similar = search.find_similar_domains_by_name("Person", 2);
614
615        assert_eq!(similar.len(), 2);
616        // Should not include "Person" itself
617        assert!(!similar.iter().any(|(n, _)| n == "Person"));
618    }
619
620    #[test]
621    fn test_unnormalized_embeddings() {
622        let embedder = SchemaEmbedder::new().with_normalization(false);
623        let domain = DomainInfo::new("Person", 100);
624        let embedding = embedder.embed_domain(&domain);
625
626        assert_eq!(embedding.len(), EMBEDDING_DIM);
627        // Unnormalized embeddings may not have unit length
628        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
629        // But should have non-zero length
630        assert!(norm > 0.0);
631    }
632
633    #[test]
634    fn test_custom_weights() {
635        let weights = EmbeddingWeights {
636            cardinality_weight: 2.0,
637            arity_weight: 1.0,
638            name_weight: 0.5,
639            structural_weight: 0.8,
640        };
641
642        let embedder = SchemaEmbedder::new().with_weights(weights);
643        let domain = DomainInfo::new("Person", 100);
644        let embedding = embedder.embed_domain(&domain);
645
646        assert_eq!(embedding.len(), EMBEDDING_DIM);
647    }
648
649    #[test]
650    fn test_empty_schema_embedding() {
651        let table = SymbolTable::new();
652        let embedder = SchemaEmbedder::new();
653        let embedding = embedder.embed_schema(&table);
654
655        assert_eq!(embedding.len(), EMBEDDING_DIM);
656        // Empty schema should still produce valid embedding
657        let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
658        assert!(norm >= 0.0);
659    }
660
661    #[test]
662    fn test_similarity_transitivity() {
663        let embedder = SchemaEmbedder::new();
664
665        let d1 = DomainInfo::new("Person", 100);
666        let d2 = DomainInfo::new("Student", 90);
667        let d3 = DomainInfo::new("Teacher", 95);
668
669        let e1 = embedder.embed_domain(&d1);
670        let e2 = embedder.embed_domain(&d2);
671        let e3 = embedder.embed_domain(&d3);
672
673        let sim_12 = SchemaEmbedder::cosine_similarity(&e1, &e2);
674        let sim_13 = SchemaEmbedder::cosine_similarity(&e1, &e3);
675        let sim_23 = SchemaEmbedder::cosine_similarity(&e2, &e3);
676
677        // All should be highly similar (same cardinality range)
678        assert!(sim_12 > 0.8);
679        assert!(sim_13 > 0.8);
680        assert!(sim_23 > 0.8);
681    }
682}