1use std::collections::HashMap;
13
14use crate::{DomainInfo, PredicateInfo, SymbolTable};
15
16pub const EMBEDDING_DIM: usize = 64;
21
22pub type Embedding = Vec<f64>;
24
25pub struct SchemaEmbedder {
30 normalize: bool,
32 weights: EmbeddingWeights,
34}
35
36#[derive(Clone, Debug)]
38pub struct EmbeddingWeights {
39 pub cardinality_weight: f64,
41 pub arity_weight: f64,
43 pub name_weight: f64,
45 pub structural_weight: f64,
47}
48
49impl Default for EmbeddingWeights {
50 fn default() -> Self {
51 Self {
52 cardinality_weight: 1.0,
53 arity_weight: 1.0,
54 name_weight: 0.5,
55 structural_weight: 0.8,
56 }
57 }
58}
59
60impl SchemaEmbedder {
61 pub fn new() -> Self {
63 Self {
64 normalize: true,
65 weights: EmbeddingWeights::default(),
66 }
67 }
68
69 pub fn with_normalization(mut self, normalize: bool) -> Self {
71 self.normalize = normalize;
72 self
73 }
74
75 pub fn with_weights(mut self, weights: EmbeddingWeights) -> Self {
77 self.weights = weights;
78 self
79 }
80
81 pub fn embed_domain(&self, domain: &DomainInfo) -> Embedding {
83 let mut embedding = vec![0.0; EMBEDDING_DIM];
84
85 let log_card = (domain.cardinality as f64).ln();
87 embedding[0] = log_card * self.weights.cardinality_weight;
88 embedding[1] = (domain.cardinality as f64).sqrt() * self.weights.cardinality_weight;
89 embedding[2] = (domain.cardinality as f64).cbrt() * self.weights.cardinality_weight;
90
91 embedding[3] = if domain.cardinality < 10 { 1.0 } else { 0.0 };
93 embedding[4] = if domain.cardinality < 100 { 1.0 } else { 0.0 };
94 embedding[5] = if domain.cardinality < 1000 { 1.0 } else { 0.0 };
95 embedding[6] = if domain.cardinality < 10000 { 1.0 } else { 0.0 };
96
97 self.add_name_features(&mut embedding, &domain.name, 16);
99
100 if let Some(ref desc) = domain.description {
102 embedding[32] = (desc.len() as f64).ln() * self.weights.structural_weight;
103 embedding[33] =
104 (desc.split_whitespace().count() as f64).ln() * self.weights.structural_weight;
105 embedding[34] = if desc.contains("person") || desc.contains("user") {
106 1.0
107 } else {
108 0.0
109 };
110 embedding[35] = if desc.contains("time") || desc.contains("temporal") {
111 1.0
112 } else {
113 0.0
114 };
115 }
116
117 if let Some(ref metadata) = domain.metadata {
119 embedding[40] = if metadata.provenance.is_some() {
120 1.0
121 } else {
122 0.0
123 };
124 embedding[41] = metadata.version_history.len() as f64;
125 embedding[42] = metadata.tags.len() as f64;
126 }
127
128 if self.normalize {
129 self.normalize_embedding(&mut embedding);
130 }
131
132 embedding
133 }
134
135 pub fn embed_predicate(&self, predicate: &PredicateInfo) -> Embedding {
137 let mut embedding = vec![0.0; EMBEDDING_DIM];
138
139 let arity = predicate.arg_domains.len();
141 embedding[0] = arity as f64 * self.weights.arity_weight;
142 embedding[1] = (arity as f64).sqrt() * self.weights.arity_weight;
143
144 embedding[2] = if arity == 0 { 1.0 } else { 0.0 }; embedding[3] = if arity == 1 { 1.0 } else { 0.0 }; embedding[4] = if arity == 2 { 1.0 } else { 0.0 }; embedding[5] = if arity == 3 { 1.0 } else { 0.0 }; embedding[6] = if arity > 3 { 1.0 } else { 0.0 }; self.add_name_features(&mut embedding, &predicate.name, 16);
153
154 if let Some(ref constraints) = predicate.constraints {
156 embedding[32] = constraints.properties.len() as f64 * self.weights.structural_weight;
157 embedding[33] = if constraints.properties.iter().any(|p| {
158 matches!(
159 p,
160 crate::PredicateProperty::Symmetric | crate::PredicateProperty::Transitive
161 )
162 }) {
163 1.0
164 } else {
165 0.0
166 };
167 embedding[34] =
168 constraints.functional_dependencies.len() as f64 * self.weights.structural_weight;
169
170 let num_ranges = constraints
172 .value_ranges
173 .iter()
174 .filter(|r| r.is_some())
175 .count();
176 embedding[35] = num_ranges as f64;
177 }
178
179 if let Some(ref desc) = predicate.description {
181 embedding[48] = (desc.len() as f64).ln() * self.weights.structural_weight;
182 embedding[49] =
183 (desc.split_whitespace().count() as f64).ln() * self.weights.structural_weight;
184 }
185
186 if self.normalize {
187 self.normalize_embedding(&mut embedding);
188 }
189
190 embedding
191 }
192
193 pub fn embed_schema(&self, table: &SymbolTable) -> Embedding {
195 let mut embedding = vec![0.0; EMBEDDING_DIM];
196
197 embedding[0] = ((table.domains.len().max(1)) as f64).ln() * self.weights.structural_weight;
200 embedding[1] =
201 ((table.predicates.len().max(1)) as f64).ln() * self.weights.structural_weight;
202 embedding[2] =
203 ((table.variables.len().max(1)) as f64).ln() * self.weights.structural_weight;
204
205 let total_card: usize = table.domains.values().map(|d| d.cardinality).sum();
207 embedding[3] = ((total_card.max(1)) as f64).ln() * self.weights.cardinality_weight;
208
209 let avg_arity: f64 = if table.predicates.is_empty() {
211 0.0
212 } else {
213 table
214 .predicates
215 .values()
216 .map(|p| p.arg_domains.len())
217 .sum::<usize>() as f64
218 / table.predicates.len() as f64
219 };
220 embedding[4] = avg_arity * self.weights.arity_weight;
221
222 for domain in table.domains.values() {
224 let log_card = (domain.cardinality as f64).ln();
225 let idx = ((log_card / 10.0).min(7.0) as usize).min(7);
226 embedding[16 + idx] += 1.0;
227 }
228
229 for predicate in table.predicates.values() {
231 let arity = predicate.arg_domains.len().min(7);
232 embedding[24 + arity] += 1.0;
233 }
234
235 let max_edges = table.domains.len() * table.domains.len();
237 let actual_edges = table
238 .predicates
239 .values()
240 .filter(|p| p.arg_domains.len() == 2)
241 .count();
242 embedding[32] = if max_edges > 0 {
243 actual_edges as f64 / max_edges as f64
244 } else {
245 0.0
246 };
247
248 if self.normalize {
249 self.normalize_embedding(&mut embedding);
250 }
251
252 embedding
253 }
254
255 pub fn cosine_similarity(a: &Embedding, b: &Embedding) -> f64 {
257 assert_eq!(a.len(), b.len(), "Embeddings must have same dimension");
258
259 let dot_product: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
260 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
261 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
262
263 if norm_a == 0.0 || norm_b == 0.0 {
264 0.0
265 } else {
266 dot_product / (norm_a * norm_b)
267 }
268 }
269
270 pub fn euclidean_distance(a: &Embedding, b: &Embedding) -> f64 {
272 assert_eq!(a.len(), b.len(), "Embeddings must have same dimension");
273
274 a.iter()
275 .zip(b.iter())
276 .map(|(x, y)| (x - y).powi(2))
277 .sum::<f64>()
278 .sqrt()
279 }
280
281 fn add_name_features(&self, embedding: &mut [f64], name: &str, start_idx: usize) {
283 let name_lower = name.to_lowercase();
284
285 embedding[start_idx] = (name.len() as f64).ln() * self.weights.name_weight;
287 embedding[start_idx + 1] =
288 name.chars().filter(|c| c.is_uppercase()).count() as f64 * self.weights.name_weight;
289
290 let vowels = name_lower.chars().filter(|c| "aeiou".contains(*c)).count();
292 embedding[start_idx + 2] = vowels as f64 / name.len().max(1) as f64;
293
294 embedding[start_idx + 3] = if name_lower.contains('_') { 1.0 } else { 0.0 };
296 embedding[start_idx + 4] = if name_lower.starts_with("is") || name_lower.starts_with("has")
297 {
298 1.0
299 } else {
300 0.0
301 };
302
303 embedding[start_idx + 5] = if name_lower.contains("person")
305 || name_lower.contains("user")
306 || name_lower.contains("agent")
307 {
308 1.0
309 } else {
310 0.0
311 };
312 embedding[start_idx + 6] = if name_lower.contains("time")
313 || name_lower.contains("date")
314 || name_lower.contains("temporal")
315 {
316 1.0
317 } else {
318 0.0
319 };
320 embedding[start_idx + 7] = if name_lower.contains("value")
321 || name_lower.contains("number")
322 || name_lower.contains("count")
323 {
324 1.0
325 } else {
326 0.0
327 };
328 }
329
330 fn normalize_embedding(&self, embedding: &mut [f64]) {
332 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
333 if norm > 0.0 {
334 for x in embedding.iter_mut() {
335 *x /= norm;
336 }
337 }
338 }
339}
340
341impl Default for SchemaEmbedder {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347pub struct SimilaritySearch {
352 embedder: SchemaEmbedder,
353 domain_embeddings: HashMap<String, Embedding>,
354 predicate_embeddings: HashMap<String, Embedding>,
355}
356
357impl SimilaritySearch {
358 pub fn new() -> Self {
360 Self {
361 embedder: SchemaEmbedder::new(),
362 domain_embeddings: HashMap::new(),
363 predicate_embeddings: HashMap::new(),
364 }
365 }
366
367 pub fn with_embedder(embedder: SchemaEmbedder) -> Self {
369 Self {
370 embedder,
371 domain_embeddings: HashMap::new(),
372 predicate_embeddings: HashMap::new(),
373 }
374 }
375
376 pub fn index_table(&mut self, table: &SymbolTable) {
378 for (name, domain) in &table.domains {
380 let embedding = self.embedder.embed_domain(domain);
381 self.domain_embeddings.insert(name.clone(), embedding);
382 }
383
384 for (name, predicate) in &table.predicates {
386 let embedding = self.embedder.embed_predicate(predicate);
387 self.predicate_embeddings.insert(name.clone(), embedding);
388 }
389 }
390
391 pub fn find_similar_domains(&self, query: &DomainInfo, top_k: usize) -> Vec<(String, f64)> {
393 let query_emb = self.embedder.embed_domain(query);
394 self.find_top_k(&self.domain_embeddings, &query_emb, top_k)
395 }
396
397 pub fn find_similar_predicates(
399 &self,
400 query: &PredicateInfo,
401 top_k: usize,
402 ) -> Vec<(String, f64)> {
403 let query_emb = self.embedder.embed_predicate(query);
404 self.find_top_k(&self.predicate_embeddings, &query_emb, top_k)
405 }
406
407 pub fn find_similar_domains_by_name(&self, name: &str, top_k: usize) -> Vec<(String, f64)> {
409 if let Some(query_emb) = self.domain_embeddings.get(name) {
410 self.find_top_k(&self.domain_embeddings, query_emb, top_k + 1)
411 .into_iter()
412 .filter(|(n, _)| n != name)
413 .take(top_k)
414 .collect()
415 } else {
416 Vec::new()
417 }
418 }
419
420 pub fn find_similar_predicates_by_name(&self, name: &str, top_k: usize) -> Vec<(String, f64)> {
422 if let Some(query_emb) = self.predicate_embeddings.get(name) {
423 self.find_top_k(&self.predicate_embeddings, query_emb, top_k + 1)
424 .into_iter()
425 .filter(|(n, _)| n != name)
426 .take(top_k)
427 .collect()
428 } else {
429 Vec::new()
430 }
431 }
432
433 pub fn stats(&self) -> SimilarityStats {
435 SimilarityStats {
436 num_domains: self.domain_embeddings.len(),
437 num_predicates: self.predicate_embeddings.len(),
438 embedding_dim: EMBEDDING_DIM,
439 }
440 }
441
442 fn find_top_k(
444 &self,
445 embeddings: &HashMap<String, Embedding>,
446 query: &Embedding,
447 k: usize,
448 ) -> Vec<(String, f64)> {
449 let mut similarities: Vec<(String, f64)> = embeddings
450 .iter()
451 .map(|(name, emb)| {
452 let sim = SchemaEmbedder::cosine_similarity(query, emb);
453 (name.clone(), sim)
454 })
455 .collect();
456
457 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
459
460 similarities.into_iter().take(k).collect()
462 }
463}
464
465impl Default for SimilaritySearch {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471#[derive(Clone, Debug)]
473pub struct SimilarityStats {
474 pub num_domains: usize,
476 pub num_predicates: usize,
478 pub embedding_dim: usize,
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_domain_embedding_generation() {
488 let domain = DomainInfo::new("Person", 100);
489 let embedder = SchemaEmbedder::new();
490 let embedding = embedder.embed_domain(&domain);
491
492 assert_eq!(embedding.len(), EMBEDDING_DIM);
493 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
495 assert!((norm - 1.0).abs() < 1e-6);
496 }
497
498 #[test]
499 fn test_predicate_embedding_generation() {
500 let predicate =
501 PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
502 let embedder = SchemaEmbedder::new();
503 let embedding = embedder.embed_predicate(&predicate);
504
505 assert_eq!(embedding.len(), EMBEDDING_DIM);
506 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
507 assert!((norm - 1.0).abs() < 1e-6);
508 }
509
510 #[test]
511 fn test_schema_embedding_generation() {
512 let mut table = SymbolTable::new();
513 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
514 table.add_domain(DomainInfo::new("Course", 50)).unwrap();
515
516 let embedder = SchemaEmbedder::new();
517 let embedding = embedder.embed_schema(&table);
518
519 assert_eq!(embedding.len(), EMBEDDING_DIM);
520 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
521 assert!((norm - 1.0).abs() < 1e-6);
522 }
523
524 #[test]
525 fn test_cosine_similarity() {
526 let a = vec![1.0, 0.0, 0.0];
527 let b = vec![1.0, 0.0, 0.0];
528 let c = vec![0.0, 1.0, 0.0];
529
530 assert!((SchemaEmbedder::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
531 assert!((SchemaEmbedder::cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
532 }
533
534 #[test]
535 fn test_euclidean_distance() {
536 let a = vec![0.0, 0.0, 0.0];
537 let b = vec![1.0, 1.0, 1.0];
538
539 let dist = SchemaEmbedder::euclidean_distance(&a, &b);
540 assert!((dist - 3.0_f64.sqrt()).abs() < 1e-6);
541 }
542
543 #[test]
544 fn test_similarity_search_indexing() {
545 let mut table = SymbolTable::new();
546 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
547 table.add_domain(DomainInfo::new("Student", 50)).unwrap();
548 table.add_domain(DomainInfo::new("Course", 30)).unwrap();
549
550 let mut search = SimilaritySearch::new();
551 search.index_table(&table);
552
553 let stats = search.stats();
554 assert_eq!(stats.num_domains, 3);
555 assert_eq!(stats.embedding_dim, EMBEDDING_DIM);
556 }
557
558 #[test]
559 fn test_find_similar_domains() {
560 let mut table = SymbolTable::new();
561 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
562 table.add_domain(DomainInfo::new("Student", 80)).unwrap();
563 table.add_domain(DomainInfo::new("Course", 50)).unwrap();
564
565 let mut search = SimilaritySearch::new();
566 search.index_table(&table);
567
568 let query = DomainInfo::new("Teacher", 90);
569 let similar = search.find_similar_domains(&query, 2);
570
571 assert_eq!(similar.len(), 2);
572 assert!(similar[0].1 > 0.5); }
575
576 #[test]
577 fn test_find_similar_predicates() {
578 let mut table = SymbolTable::new();
579 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
580
581 let knows = PredicateInfo::new("knows", vec!["Person".to_string(), "Person".to_string()]);
582 let likes = PredicateInfo::new("likes", vec!["Person".to_string(), "Person".to_string()]);
583 let teaches =
584 PredicateInfo::new("teaches", vec!["Person".to_string(), "Person".to_string()]);
585
586 table.add_predicate(knows).unwrap();
587 table.add_predicate(likes).unwrap();
588 table.add_predicate(teaches).unwrap();
589
590 let mut search = SimilaritySearch::new();
591 search.index_table(&table);
592
593 let query = PredicateInfo::new("loves", vec!["Person".to_string(), "Person".to_string()]);
594 let similar = search.find_similar_predicates(&query, 3);
595
596 assert_eq!(similar.len(), 3);
597 for (_, sim) in &similar {
599 assert!(*sim > 0.8);
600 }
601 }
602
603 #[test]
604 fn test_similar_domains_by_name() {
605 let mut table = SymbolTable::new();
606 table.add_domain(DomainInfo::new("Person", 100)).unwrap();
607 table.add_domain(DomainInfo::new("Student", 80)).unwrap();
608 table.add_domain(DomainInfo::new("Course", 50)).unwrap();
609
610 let mut search = SimilaritySearch::new();
611 search.index_table(&table);
612
613 let similar = search.find_similar_domains_by_name("Person", 2);
614
615 assert_eq!(similar.len(), 2);
616 assert!(!similar.iter().any(|(n, _)| n == "Person"));
618 }
619
620 #[test]
621 fn test_unnormalized_embeddings() {
622 let embedder = SchemaEmbedder::new().with_normalization(false);
623 let domain = DomainInfo::new("Person", 100);
624 let embedding = embedder.embed_domain(&domain);
625
626 assert_eq!(embedding.len(), EMBEDDING_DIM);
627 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
629 assert!(norm > 0.0);
631 }
632
633 #[test]
634 fn test_custom_weights() {
635 let weights = EmbeddingWeights {
636 cardinality_weight: 2.0,
637 arity_weight: 1.0,
638 name_weight: 0.5,
639 structural_weight: 0.8,
640 };
641
642 let embedder = SchemaEmbedder::new().with_weights(weights);
643 let domain = DomainInfo::new("Person", 100);
644 let embedding = embedder.embed_domain(&domain);
645
646 assert_eq!(embedding.len(), EMBEDDING_DIM);
647 }
648
649 #[test]
650 fn test_empty_schema_embedding() {
651 let table = SymbolTable::new();
652 let embedder = SchemaEmbedder::new();
653 let embedding = embedder.embed_schema(&table);
654
655 assert_eq!(embedding.len(), EMBEDDING_DIM);
656 let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
658 assert!(norm >= 0.0);
659 }
660
661 #[test]
662 fn test_similarity_transitivity() {
663 let embedder = SchemaEmbedder::new();
664
665 let d1 = DomainInfo::new("Person", 100);
666 let d2 = DomainInfo::new("Student", 90);
667 let d3 = DomainInfo::new("Teacher", 95);
668
669 let e1 = embedder.embed_domain(&d1);
670 let e2 = embedder.embed_domain(&d2);
671 let e3 = embedder.embed_domain(&d3);
672
673 let sim_12 = SchemaEmbedder::cosine_similarity(&e1, &e2);
674 let sim_13 = SchemaEmbedder::cosine_similarity(&e1, &e3);
675 let sim_23 = SchemaEmbedder::cosine_similarity(&e2, &e3);
676
677 assert!(sim_12 > 0.8);
679 assert!(sim_13 > 0.8);
680 assert!(sim_23 > 0.8);
681 }
682}