Skip to main content

tensorlogic_oxirs_bridge/
knowledge_embeddings.rs

1//! Knowledge graph embeddings for TensorLogic integration.
2//!
3//! This module provides embedding generation from RDF knowledge graphs,
4//! enabling machine learning and neural-symbolic integration.
5//!
6//! # Overview
7//!
8//! Knowledge graph embeddings map entities and relations to dense vector spaces,
9//! useful for:
10//! - Link prediction (predicting missing triples)
11//! - Entity classification
12//! - Similarity computation
13//! - Integration with neural networks
14//!
15//! # Supported Embedding Models
16//!
17//! - **TransE**: Translation-based model (h + r ≈ t)
18//! - **DistMult**: Bilinear model (h ⊙ r ⊙ t)
19//! - **ComplEx**: Complex-valued embeddings
20//! - **Random**: Baseline random embeddings
21//!
22//! # Example
23//!
24//! ```no_run
25//! use tensorlogic_oxirs_bridge::knowledge_embeddings::{
26//!     KnowledgeEmbeddings, EmbeddingConfig, EmbeddingModel,
27//! };
28//!
29//! let mut embeddings = KnowledgeEmbeddings::new(EmbeddingConfig::default()).unwrap();
30//!
31//! // Load knowledge graph
32//! embeddings.load_turtle(r#"
33//!     @prefix ex: <http://example.org/> .
34//!     ex:Alice ex:knows ex:Bob .
35//!     ex:Bob ex:knows ex:Carol .
36//! "#).unwrap();
37//!
38//! // Train embeddings
39//! embeddings.train(100).unwrap();
40//!
41//! // Get entity embeddings
42//! let alice_emb = embeddings.entity_embedding("http://example.org/Alice");
43//! ```
44
45use crate::oxirs_executor::OxirsSparqlExecutor;
46use anyhow::{anyhow, Result};
47use scirs2_core::ndarray::{Array1, ArrayD};
48use scirs2_core::random::{thread_rng, Rng, SeedableRng, StdRng};
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use tensorlogic_ir::{TLExpr, Term};
52
53/// Embedding model type.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
55pub enum EmbeddingModel {
56    /// TransE: Translation-based model (h + r ≈ t)
57    #[default]
58    TransE,
59    /// DistMult: Bilinear diagonal model
60    DistMult,
61    /// ComplEx: Complex-valued embeddings
62    ComplEx,
63    /// Random baseline
64    Random,
65}
66
67/// Configuration for knowledge embeddings.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct EmbeddingConfig {
70    /// Embedding dimension
71    pub embedding_dim: usize,
72    /// Learning rate
73    pub learning_rate: f64,
74    /// Regularization coefficient
75    pub regularization: f64,
76    /// Margin for margin-based loss
77    pub margin: f64,
78    /// Batch size for training
79    pub batch_size: usize,
80    /// Embedding model type
81    pub model: EmbeddingModel,
82    /// Random seed for reproducibility
83    pub seed: Option<u64>,
84}
85
86impl Default for EmbeddingConfig {
87    fn default() -> Self {
88        Self {
89            embedding_dim: 50,
90            learning_rate: 0.01,
91            regularization: 0.001,
92            margin: 1.0,
93            batch_size: 100,
94            model: EmbeddingModel::TransE,
95            seed: None,
96        }
97    }
98}
99
100impl EmbeddingConfig {
101    /// Create a new config with specified dimension.
102    pub fn new(embedding_dim: usize) -> Self {
103        Self {
104            embedding_dim,
105            ..Default::default()
106        }
107    }
108
109    /// Set the model type.
110    pub fn with_model(mut self, model: EmbeddingModel) -> Self {
111        self.model = model;
112        self
113    }
114
115    /// Set the learning rate.
116    pub fn with_learning_rate(mut self, lr: f64) -> Self {
117        self.learning_rate = lr;
118        self
119    }
120
121    /// Set the batch size.
122    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
123        self.batch_size = batch_size;
124        self
125    }
126}
127
128/// A triple in the knowledge graph.
129#[derive(Debug, Clone, PartialEq, Eq, Hash)]
130pub struct KGTriple {
131    pub head: String,
132    pub relation: String,
133    pub tail: String,
134}
135
136/// Knowledge graph embeddings.
137///
138/// This struct manages entity and relation embeddings learned from
139/// a knowledge graph.
140pub struct KnowledgeEmbeddings {
141    /// Configuration
142    config: EmbeddingConfig,
143    /// Entity embeddings: entity IRI -> embedding vector
144    entity_embeddings: HashMap<String, Array1<f64>>,
145    /// Relation embeddings: relation IRI -> embedding vector
146    relation_embeddings: HashMap<String, Array1<f64>>,
147    /// Entity index: entity IRI -> index
148    entity_index: HashMap<String, usize>,
149    /// Relation index: relation IRI -> index
150    relation_index: HashMap<String, usize>,
151    /// Training triples
152    triples: Vec<KGTriple>,
153    /// SPARQL executor for data access
154    executor: OxirsSparqlExecutor,
155}
156
157impl KnowledgeEmbeddings {
158    /// Create new knowledge embeddings.
159    pub fn new(config: EmbeddingConfig) -> Result<Self> {
160        Ok(Self {
161            config,
162            entity_embeddings: HashMap::new(),
163            relation_embeddings: HashMap::new(),
164            entity_index: HashMap::new(),
165            relation_index: HashMap::new(),
166            triples: Vec::new(),
167            executor: OxirsSparqlExecutor::new()?,
168        })
169    }
170
171    /// Load knowledge graph from Turtle format.
172    pub fn load_turtle(&mut self, turtle: &str) -> Result<usize> {
173        let count = self.executor.load_turtle(turtle)?;
174        self.extract_triples()?;
175        self.initialize_embeddings();
176        Ok(count)
177    }
178
179    /// Extract triples from the executor.
180    fn extract_triples(&mut self) -> Result<()> {
181        // Query all triples
182        let query = "SELECT ?s ?p ?o WHERE { ?s ?p ?o }";
183        let results = self.executor.execute(query)?;
184
185        if let crate::oxirs_executor::QueryResults::Select { bindings, .. } = results {
186            for binding in bindings {
187                let head = binding
188                    .get("s")
189                    .map(|v| v.as_str().to_string())
190                    .unwrap_or_default();
191                let relation = binding
192                    .get("p")
193                    .map(|v| v.as_str().to_string())
194                    .unwrap_or_default();
195                let tail = binding
196                    .get("o")
197                    .map(|v| v.as_str().to_string())
198                    .unwrap_or_default();
199
200                if !head.is_empty() && !relation.is_empty() && !tail.is_empty() {
201                    self.triples.push(KGTriple {
202                        head,
203                        relation,
204                        tail,
205                    });
206                }
207            }
208        }
209
210        // Build indices
211        let mut entity_idx = 0;
212        let mut relation_idx = 0;
213
214        for triple in &self.triples {
215            if !self.entity_index.contains_key(&triple.head) {
216                self.entity_index.insert(triple.head.clone(), entity_idx);
217                entity_idx += 1;
218            }
219            if !self.entity_index.contains_key(&triple.tail) {
220                self.entity_index.insert(triple.tail.clone(), entity_idx);
221                entity_idx += 1;
222            }
223            if !self.relation_index.contains_key(&triple.relation) {
224                self.relation_index
225                    .insert(triple.relation.clone(), relation_idx);
226                relation_idx += 1;
227            }
228        }
229
230        Ok(())
231    }
232
233    /// Initialize embeddings with random values.
234    fn initialize_embeddings(&mut self) {
235        let mut rng_box: Box<dyn scirs2_core::random::RngCore> =
236            if let Some(seed) = self.config.seed {
237                Box::new(StdRng::seed_from_u64(seed))
238            } else {
239                Box::new(thread_rng())
240            };
241
242        let dim = self.config.embedding_dim;
243        let scale = 1.0 / (dim as f64).sqrt();
244
245        // Initialize entity embeddings
246        for entity in self.entity_index.keys() {
247            let embedding: Vec<f64> = (0..dim).map(|_| rng_box.random::<f64>() * scale).collect();
248            self.entity_embeddings
249                .insert(entity.clone(), Array1::from(embedding));
250        }
251
252        // Initialize relation embeddings
253        for relation in self.relation_index.keys() {
254            let embedding: Vec<f64> = (0..dim).map(|_| rng_box.random::<f64>() * scale).collect();
255            self.relation_embeddings
256                .insert(relation.clone(), Array1::from(embedding));
257        }
258    }
259
260    /// Train the embeddings.
261    pub fn train(&mut self, num_epochs: usize) -> Result<f64> {
262        if self.triples.is_empty() {
263            return Err(anyhow!("No triples to train on"));
264        }
265
266        let mut total_loss = 0.0;
267        let mut rng = thread_rng();
268
269        for _epoch in 0..num_epochs {
270            let mut epoch_loss = 0.0;
271
272            // Shuffle triples
273            let mut indices: Vec<usize> = (0..self.triples.len()).collect();
274            for i in (1..indices.len()).rev() {
275                let j = rng.random_range(0..=i);
276                indices.swap(i, j);
277            }
278
279            // Mini-batch training
280            for batch_start in (0..indices.len()).step_by(self.config.batch_size) {
281                let batch_end = (batch_start + self.config.batch_size).min(indices.len());
282
283                for &idx in &indices[batch_start..batch_end] {
284                    // Clone the triple to avoid borrow conflict
285                    let triple = self.triples[idx].clone();
286
287                    // Generate negative sample
288                    let neg_triple = self.generate_negative_sample(&triple, &mut rng);
289
290                    // Compute loss and update
291                    let loss = self.train_step(&triple, &neg_triple)?;
292                    epoch_loss += loss;
293                }
294            }
295
296            total_loss = epoch_loss / self.triples.len() as f64;
297        }
298
299        Ok(total_loss)
300    }
301
302    /// Generate a negative sample by corrupting head or tail.
303    fn generate_negative_sample(&self, triple: &KGTriple, rng: &mut impl Rng) -> KGTriple {
304        let entities: Vec<_> = self.entity_index.keys().collect();
305        if entities.is_empty() {
306            return triple.clone();
307        }
308
309        let corrupt_head = rng.random();
310        let random_entity = entities[rng.random_range(0..entities.len())].clone();
311
312        if corrupt_head {
313            KGTriple {
314                head: random_entity,
315                relation: triple.relation.clone(),
316                tail: triple.tail.clone(),
317            }
318        } else {
319            KGTriple {
320                head: triple.head.clone(),
321                relation: triple.relation.clone(),
322                tail: random_entity,
323            }
324        }
325    }
326
327    /// Perform one training step.
328    fn train_step(&mut self, pos_triple: &KGTriple, neg_triple: &KGTriple) -> Result<f64> {
329        match self.config.model {
330            EmbeddingModel::TransE => self.train_step_transe(pos_triple, neg_triple),
331            EmbeddingModel::DistMult => self.train_step_distmult(pos_triple, neg_triple),
332            EmbeddingModel::ComplEx => self.train_step_complex(pos_triple, neg_triple),
333            EmbeddingModel::Random => Ok(0.0), // No training for random
334        }
335    }
336
337    /// TransE training step.
338    fn train_step_transe(&mut self, pos_triple: &KGTriple, neg_triple: &KGTriple) -> Result<f64> {
339        let h_pos = self
340            .entity_embeddings
341            .get(&pos_triple.head)
342            .ok_or_else(|| anyhow!("Missing head embedding"))?
343            .clone();
344        let r = self
345            .relation_embeddings
346            .get(&pos_triple.relation)
347            .ok_or_else(|| anyhow!("Missing relation embedding"))?
348            .clone();
349        let t_pos = self
350            .entity_embeddings
351            .get(&pos_triple.tail)
352            .ok_or_else(|| anyhow!("Missing tail embedding"))?
353            .clone();
354
355        let h_neg = self
356            .entity_embeddings
357            .get(&neg_triple.head)
358            .ok_or_else(|| anyhow!("Missing negative head embedding"))?
359            .clone();
360        let t_neg = self
361            .entity_embeddings
362            .get(&neg_triple.tail)
363            .ok_or_else(|| anyhow!("Missing negative tail embedding"))?
364            .clone();
365
366        // TransE score: ||h + r - t||
367        let pos_diff = &h_pos + &r - &t_pos;
368        let neg_diff = &h_neg + &r - &t_neg;
369
370        let pos_score = pos_diff.iter().map(|x| x * x).sum::<f64>().sqrt();
371        let neg_score = neg_diff.iter().map(|x| x * x).sum::<f64>().sqrt();
372
373        // Margin-based ranking loss
374        let loss = (self.config.margin + pos_score - neg_score).max(0.0);
375
376        if loss > 0.0 {
377            let lr = self.config.learning_rate;
378            let reg = self.config.regularization;
379
380            // Gradient update for positive triple
381            let grad_pos: Array1<f64> = pos_diff.mapv(|x| x / pos_score.max(1e-10));
382
383            // Update head (positive)
384            if let Some(h) = self.entity_embeddings.get_mut(&pos_triple.head) {
385                *h = &*h - &(&grad_pos * lr);
386                // L2 regularization
387                *h = &*h - &(&*h * (lr * reg));
388            }
389
390            // Update tail (positive)
391            if let Some(t) = self.entity_embeddings.get_mut(&pos_triple.tail) {
392                *t = &*t + &(&grad_pos * lr);
393                *t = &*t - &(&*t * (lr * reg));
394            }
395
396            // Update relation
397            if let Some(r) = self.relation_embeddings.get_mut(&pos_triple.relation) {
398                *r = &*r - &(&grad_pos * lr);
399                *r = &*r - &(&*r * (lr * reg));
400            }
401        }
402
403        Ok(loss)
404    }
405
406    /// DistMult training step.
407    fn train_step_distmult(&mut self, pos_triple: &KGTriple, neg_triple: &KGTriple) -> Result<f64> {
408        // DistMult score: h ⊙ r ⊙ t
409        // Clone values to avoid borrow conflicts during update
410        let h_pos: Array1<f64> = self
411            .entity_embeddings
412            .get(&pos_triple.head)
413            .ok_or_else(|| anyhow!("Missing head embedding"))?
414            .clone();
415        let r: Array1<f64> = self
416            .relation_embeddings
417            .get(&pos_triple.relation)
418            .ok_or_else(|| anyhow!("Missing relation embedding"))?
419            .clone();
420        let t_pos: Array1<f64> = self
421            .entity_embeddings
422            .get(&pos_triple.tail)
423            .ok_or_else(|| anyhow!("Missing tail embedding"))?
424            .clone();
425
426        let h_neg: Array1<f64> = self
427            .entity_embeddings
428            .get(&neg_triple.head)
429            .ok_or_else(|| anyhow!("Missing negative head embedding"))?
430            .clone();
431        let t_neg: Array1<f64> = self
432            .entity_embeddings
433            .get(&neg_triple.tail)
434            .ok_or_else(|| anyhow!("Missing negative tail embedding"))?
435            .clone();
436
437        let pos_score: f64 = h_pos
438            .iter()
439            .zip(r.iter())
440            .zip(t_pos.iter())
441            .map(|((h, r), t)| h * r * t)
442            .sum();
443        let neg_score: f64 = h_neg
444            .iter()
445            .zip(r.iter())
446            .zip(t_neg.iter())
447            .map(|((h, r), t)| h * r * t)
448            .sum();
449
450        // Margin-based loss
451        let loss = (self.config.margin - pos_score + neg_score).max(0.0);
452
453        // Simplified gradient update (similar structure to TransE)
454        if loss > 0.0 {
455            let lr = self.config.learning_rate;
456
457            // Update embeddings (simplified)
458            if let Some(h) = self.entity_embeddings.get_mut(&pos_triple.head) {
459                let grad: Array1<f64> = r
460                    .iter()
461                    .zip(t_pos.iter())
462                    .map(|(ri, ti)| ri * ti * lr)
463                    .collect();
464                *h = &*h + &grad;
465            }
466        }
467
468        Ok(loss)
469    }
470
471    /// ComplEx training step (simplified).
472    fn train_step_complex(&mut self, pos_triple: &KGTriple, neg_triple: &KGTriple) -> Result<f64> {
473        // For simplicity, treat as real-valued DistMult
474        // Full ComplEx would use complex arithmetic
475        self.train_step_distmult(pos_triple, neg_triple)
476    }
477
478    /// Get entity embedding.
479    pub fn entity_embedding(&self, entity: &str) -> Option<&Array1<f64>> {
480        self.entity_embeddings.get(entity)
481    }
482
483    /// Get relation embedding.
484    pub fn relation_embedding(&self, relation: &str) -> Option<&Array1<f64>> {
485        self.relation_embeddings.get(relation)
486    }
487
488    /// Generate embeddings for all entities.
489    pub fn generate_entity_embeddings(&self) -> Result<HashMap<String, ArrayD<f64>>> {
490        let mut result: HashMap<String, ArrayD<f64>> = HashMap::new();
491        for (entity, embedding) in &self.entity_embeddings {
492            let entity_str: String = entity.to_string();
493            let emb: &Array1<f64> = embedding;
494            let shape = vec![emb.len()];
495            let data = emb.to_vec();
496            let array = ArrayD::from_shape_vec(shape, data)
497                .map_err(|e| anyhow!("Failed to reshape: {}", e))?;
498            result.insert(entity_str, array);
499        }
500        Ok(result)
501    }
502
503    /// Convert embeddings to weighted TensorLogic predicates.
504    ///
505    /// Creates predicates with weights based on embedding similarities.
506    pub fn to_weighted_predicates(&self) -> Result<Vec<TLExpr>> {
507        let mut predicates = Vec::new();
508
509        for triple in &self.triples {
510            // Compute triple score as weight
511            let score = self.score_triple(triple)?;
512            // Weight can be used for probabilistic reasoning (reserved for future use)
513            let _weight = (-score).exp().min(1.0); // Convert distance to probability
514
515            // Create weighted predicate
516            let relation_name = Self::iri_to_name(&triple.relation);
517            let pred = TLExpr::pred(
518                &relation_name,
519                vec![Term::constant(&triple.head), Term::constant(&triple.tail)],
520            );
521
522            // Wrap with weight (using a pseudo-weight representation)
523            // In a full implementation, this would integrate with TensorLogic's weight system
524            predicates.push(pred);
525        }
526
527        Ok(predicates)
528    }
529
530    /// Predict missing links.
531    ///
532    /// Given a subject and relation, predict likely objects.
533    pub fn predict_links(&self, subject: &str, relation: &str) -> Result<Vec<(String, f64)>> {
534        let h = self
535            .entity_embeddings
536            .get(subject)
537            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
538        let r = self
539            .relation_embeddings
540            .get(relation)
541            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
542
543        let mut predictions: Vec<(String, f64)> = Vec::new();
544
545        for (entity, t) in &self.entity_embeddings {
546            let entity_str: &String = entity;
547            let t_emb: &Array1<f64> = t;
548
549            if entity_str == subject {
550                continue; // Skip self-links
551            }
552
553            let score: f64 = match self.config.model {
554                EmbeddingModel::TransE => {
555                    // TransE: score = -||h + r - t||
556                    let diff = h + r - t_emb;
557                    -diff.iter().map(|x| x * x).sum::<f64>().sqrt()
558                }
559                EmbeddingModel::DistMult | EmbeddingModel::ComplEx => {
560                    // DistMult: score = h ⊙ r ⊙ t
561                    let h_arr: &Array1<f64> = h;
562                    let r_arr: &Array1<f64> = r;
563                    let t_arr: &Array1<f64> = t_emb;
564                    h_arr
565                        .iter()
566                        .zip(r_arr.iter())
567                        .zip(t_arr.iter())
568                        .map(|((hi, ri), ti): ((&f64, &f64), &f64)| hi * ri * ti)
569                        .sum()
570                }
571                EmbeddingModel::Random => thread_rng().random(),
572            };
573
574            predictions.push((entity_str.clone(), score));
575        }
576
577        // Sort by score (descending)
578        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
579
580        Ok(predictions)
581    }
582
583    /// Score a triple.
584    pub fn score_triple(&self, triple: &KGTriple) -> Result<f64> {
585        let h = self
586            .entity_embeddings
587            .get(&triple.head)
588            .ok_or_else(|| anyhow!("Unknown head"))?;
589        let r = self
590            .relation_embeddings
591            .get(&triple.relation)
592            .ok_or_else(|| anyhow!("Unknown relation"))?;
593        let t = self
594            .entity_embeddings
595            .get(&triple.tail)
596            .ok_or_else(|| anyhow!("Unknown tail"))?;
597
598        let score = match self.config.model {
599            EmbeddingModel::TransE => {
600                let diff = h + r - t;
601                diff.iter().map(|x| x * x).sum::<f64>().sqrt()
602            }
603            EmbeddingModel::DistMult | EmbeddingModel::ComplEx => -h
604                .iter()
605                .zip(r.iter())
606                .zip(t.iter())
607                .map(|((hi, ri), ti)| hi * ri * ti)
608                .sum::<f64>(),
609            EmbeddingModel::Random => 0.5,
610        };
611
612        Ok(score)
613    }
614
615    /// Get the number of entities.
616    pub fn num_entities(&self) -> usize {
617        self.entity_index.len()
618    }
619
620    /// Get the number of relations.
621    pub fn num_relations(&self) -> usize {
622        self.relation_index.len()
623    }
624
625    /// Get the number of triples.
626    pub fn num_triples(&self) -> usize {
627        self.triples.len()
628    }
629
630    /// Extract local name from IRI.
631    fn iri_to_name(iri: &str) -> String {
632        iri.split(['/', '#']).next_back().unwrap_or(iri).to_string()
633    }
634}
635
636/// Compute cosine similarity between two vectors.
637pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
638    let dot: f64 = a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum();
639    let norm_a = a.iter().map(|x| x * x).sum::<f64>().sqrt();
640    let norm_b = b.iter().map(|x| x * x).sum::<f64>().sqrt();
641
642    if norm_a > 0.0 && norm_b > 0.0 {
643        dot / (norm_a * norm_b)
644    } else {
645        0.0
646    }
647}
648
649/// Compute Euclidean distance between two vectors.
650pub fn euclidean_distance(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
651    a.iter()
652        .zip(b.iter())
653        .map(|(ai, bi): (&f64, &f64)| (ai - bi).powi(2))
654        .sum::<f64>()
655        .sqrt()
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use approx::assert_abs_diff_eq;
662
663    #[test]
664    fn test_embedding_config() {
665        let config = EmbeddingConfig::new(100)
666            .with_model(EmbeddingModel::DistMult)
667            .with_learning_rate(0.001)
668            .with_batch_size(256);
669
670        assert_eq!(config.embedding_dim, 100);
671        assert_eq!(config.model, EmbeddingModel::DistMult);
672        assert_abs_diff_eq!(config.learning_rate, 0.001);
673        assert_eq!(config.batch_size, 256);
674    }
675
676    #[test]
677    fn test_embeddings_creation() {
678        let config = EmbeddingConfig::default();
679        let embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
680
681        assert_eq!(embeddings.num_entities(), 0);
682        assert_eq!(embeddings.num_relations(), 0);
683        assert_eq!(embeddings.num_triples(), 0);
684    }
685
686    #[test]
687    fn test_load_turtle() {
688        let config = EmbeddingConfig::default();
689        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
690
691        let result = embeddings.load_turtle(
692            r#"
693            @prefix ex: <http://example.org/> .
694            ex:Alice ex:knows ex:Bob .
695            ex:Bob ex:knows ex:Carol .
696        "#,
697        );
698
699        assert!(result.is_ok());
700        assert_eq!(embeddings.num_triples(), 2);
701        assert!(embeddings.num_entities() >= 2);
702    }
703
704    #[test]
705    fn test_entity_embedding() {
706        let config = EmbeddingConfig::new(10);
707        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
708
709        embeddings
710            .load_turtle(
711                r#"
712            @prefix ex: <http://example.org/> .
713            ex:Alice ex:knows ex:Bob .
714        "#,
715            )
716            .expect("Load failed");
717
718        let alice_emb = embeddings.entity_embedding("http://example.org/Alice");
719        assert!(alice_emb.is_some());
720        assert_eq!(alice_emb.map(|e| e.len()), Some(10));
721    }
722
723    #[test]
724    fn test_train() {
725        let config = EmbeddingConfig::new(10).with_batch_size(2);
726        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
727
728        embeddings
729            .load_turtle(
730                r#"
731            @prefix ex: <http://example.org/> .
732            ex:Alice ex:knows ex:Bob .
733            ex:Bob ex:knows ex:Carol .
734        "#,
735            )
736            .expect("Load failed");
737
738        let loss = embeddings.train(5);
739        assert!(loss.is_ok());
740    }
741
742    #[test]
743    fn test_predict_links() {
744        let config = EmbeddingConfig::new(10);
745        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
746
747        embeddings
748            .load_turtle(
749                r#"
750            @prefix ex: <http://example.org/> .
751            ex:Alice ex:knows ex:Bob .
752            ex:Bob ex:knows ex:Carol .
753            ex:Carol ex:knows ex:Dave .
754        "#,
755            )
756            .expect("Load failed");
757
758        let predictions =
759            embeddings.predict_links("http://example.org/Alice", "http://example.org/knows");
760        assert!(predictions.is_ok());
761
762        let predictions = predictions.expect("Prediction failed");
763        assert!(!predictions.is_empty());
764    }
765
766    #[test]
767    fn test_score_triple() {
768        let config = EmbeddingConfig::new(10);
769        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
770
771        embeddings
772            .load_turtle(
773                r#"
774            @prefix ex: <http://example.org/> .
775            ex:Alice ex:knows ex:Bob .
776        "#,
777            )
778            .expect("Load failed");
779
780        let triple = KGTriple {
781            head: "http://example.org/Alice".to_string(),
782            relation: "http://example.org/knows".to_string(),
783            tail: "http://example.org/Bob".to_string(),
784        };
785
786        let score = embeddings.score_triple(&triple);
787        assert!(score.is_ok());
788    }
789
790    #[test]
791    fn test_cosine_similarity() {
792        let a = Array1::from(vec![1.0, 0.0, 0.0]);
793        let b = Array1::from(vec![1.0, 0.0, 0.0]);
794
795        let sim = cosine_similarity(&a, &b);
796        assert_abs_diff_eq!(sim, 1.0, epsilon = 1e-6);
797
798        let c = Array1::from(vec![0.0, 1.0, 0.0]);
799        let sim_orthogonal = cosine_similarity(&a, &c);
800        assert_abs_diff_eq!(sim_orthogonal, 0.0, epsilon = 1e-6);
801    }
802
803    #[test]
804    fn test_euclidean_distance() {
805        let a = Array1::from(vec![0.0, 0.0, 0.0]);
806        let b = Array1::from(vec![3.0, 4.0, 0.0]);
807
808        let dist = euclidean_distance(&a, &b);
809        assert_abs_diff_eq!(dist, 5.0, epsilon = 1e-6);
810    }
811
812    #[test]
813    fn test_generate_entity_embeddings() {
814        let config = EmbeddingConfig::new(5);
815        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
816
817        embeddings
818            .load_turtle(
819                r#"
820            @prefix ex: <http://example.org/> .
821            ex:Alice ex:knows ex:Bob .
822        "#,
823            )
824            .expect("Load failed");
825
826        let entity_embs = embeddings.generate_entity_embeddings();
827        assert!(entity_embs.is_ok());
828
829        let entity_embs = entity_embs.expect("Generation failed");
830        assert!(entity_embs.contains_key("http://example.org/Alice"));
831    }
832
833    #[test]
834    fn test_to_weighted_predicates() {
835        let config = EmbeddingConfig::new(5);
836        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
837
838        embeddings
839            .load_turtle(
840                r#"
841            @prefix ex: <http://example.org/> .
842            ex:Alice ex:knows ex:Bob .
843        "#,
844            )
845            .expect("Load failed");
846
847        let predicates = embeddings.to_weighted_predicates();
848        assert!(predicates.is_ok());
849
850        let predicates = predicates.expect("Predicate generation failed");
851        assert!(!predicates.is_empty());
852    }
853
854    #[test]
855    fn test_distmult_model() {
856        let config = EmbeddingConfig::new(10).with_model(EmbeddingModel::DistMult);
857        let mut embeddings = KnowledgeEmbeddings::new(config).expect("Failed to create embeddings");
858
859        embeddings
860            .load_turtle(
861                r#"
862            @prefix ex: <http://example.org/> .
863            ex:Alice ex:knows ex:Bob .
864        "#,
865            )
866            .expect("Load failed");
867
868        let loss = embeddings.train(3);
869        assert!(loss.is_ok());
870    }
871}