1use std::collections::HashMap;
29
30use scirs2_core::random::{Rng, RngExt};
31
32use crate::error::{GraphError, Result};
33
34#[derive(Debug, Clone)]
44pub struct KGDataset {
45 pub triples: Vec<(usize, usize, usize)>,
47 pub n_entities: usize,
49 pub n_relations: usize,
51 pub entity_labels: Vec<String>,
53 pub relation_labels: Vec<String>,
55}
56
57impl KGDataset {
58 pub fn new(
65 triples: Vec<(usize, usize, usize)>,
66 n_entities: usize,
67 n_relations: usize,
68 ) -> Result<Self> {
69 for &(h, r, t) in &triples {
70 if h >= n_entities || t >= n_entities {
71 return Err(GraphError::InvalidParameter {
72 param: "triples".to_string(),
73 value: format!("entity index ({h},{t}) out of range"),
74 expected: format!("< n_entities={n_entities}"),
75 context: "KGDataset::new".to_string(),
76 });
77 }
78 if r >= n_relations {
79 return Err(GraphError::InvalidParameter {
80 param: "triples".to_string(),
81 value: format!("relation index {r} out of range"),
82 expected: format!("< n_relations={n_relations}"),
83 context: "KGDataset::new".to_string(),
84 });
85 }
86 }
87 let entity_labels = (0..n_entities).map(|i| format!("e{i}")).collect();
88 let relation_labels = (0..n_relations).map(|i| format!("r{i}")).collect();
89 Ok(KGDataset {
90 triples,
91 n_entities,
92 n_relations,
93 entity_labels,
94 relation_labels,
95 })
96 }
97
98 pub fn from_str_triples(triples: &[(&str, &str, &str)]) -> Self {
103 let mut entity_map: HashMap<String, usize> = HashMap::new();
104 let mut relation_map: HashMap<String, usize> = HashMap::new();
105 let mut entity_labels: Vec<String> = Vec::new();
106 let mut relation_labels: Vec<String> = Vec::new();
107
108 let mut get_or_insert_entity = |s: &str| -> usize {
109 if let Some(&idx) = entity_map.get(s) {
110 idx
111 } else {
112 let idx = entity_labels.len();
113 entity_map.insert(s.to_string(), idx);
114 entity_labels.push(s.to_string());
115 idx
116 }
117 };
118
119 let mut indexed_triples: Vec<(usize, usize, usize)> = Vec::with_capacity(triples.len());
120 for &(h, r, t) in triples {
121 let hi = get_or_insert_entity(h);
122 let ti = get_or_insert_entity(t);
123 let ri = if let Some(&idx) = relation_map.get(r) {
124 idx
125 } else {
126 let idx = relation_labels.len();
127 relation_map.insert(r.to_string(), idx);
128 relation_labels.push(r.to_string());
129 idx
130 };
131 indexed_triples.push((hi, ri, ti));
132 }
133
134 let n_entities = entity_labels.len();
135 let n_relations = relation_labels.len();
136
137 KGDataset {
138 triples: indexed_triples,
139 n_entities,
140 n_relations,
141 entity_labels,
142 relation_labels,
143 }
144 }
145
146 pub fn len(&self) -> usize {
148 self.triples.len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.triples.is_empty()
154 }
155
156 pub fn corrupt_triple(&self, triple: (usize, usize, usize)) -> (usize, usize, usize) {
159 let (h, r, t) = triple;
160 let mut rng = scirs2_core::random::rng();
161 let replace_head = rng.random::<f64>() < 0.5;
162 if replace_head {
163 let mut new_h = (rng.random::<f64>() * self.n_entities as f64) as usize;
164 new_h = new_h.min(self.n_entities - 1);
165 if new_h == h && self.n_entities > 1 {
167 new_h = (new_h + 1) % self.n_entities;
168 }
169 (new_h, r, t)
170 } else {
171 let mut new_t = (rng.random::<f64>() * self.n_entities as f64) as usize;
172 new_t = new_t.min(self.n_entities - 1);
173 if new_t == t && self.n_entities > 1 {
174 new_t = (new_t + 1) % self.n_entities;
175 }
176 (h, r, new_t)
177 }
178 }
179}
180
181fn init_embeddings(n_items: usize, dim: usize, scale: f64) -> Vec<Vec<f64>> {
188 let mut rng = scirs2_core::random::rng();
189 (0..n_items)
190 .map(|_| {
191 let mut row: Vec<f64> = (0..dim)
192 .map(|_| rng.random::<f64>() * 2.0 * scale - scale)
193 .collect();
194 let norm = row.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
195 row.iter_mut().for_each(|x| *x /= norm);
196 row
197 })
198 .collect()
199}
200
201#[inline]
203fn l2_norm(v: &[f64]) -> f64 {
204 v.iter().map(|x| x * x).sum::<f64>().sqrt()
205}
206
207fn l2_normalize(v: &mut [f64]) {
209 let norm = l2_norm(v).max(1e-12);
210 v.iter_mut().for_each(|x| *x /= norm);
211}
212
213#[derive(Debug, Clone)]
223pub struct TransE {
224 pub entity_embeddings: Vec<Vec<f64>>,
226 pub relation_embeddings: Vec<Vec<f64>>,
228 pub dim: usize,
230 pub norm_order: u32,
232}
233
234impl TransE {
235 pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
242 if dim == 0 {
243 return Err(GraphError::InvalidParameter {
244 param: "dim".to_string(),
245 value: "0".to_string(),
246 expected: "> 0".to_string(),
247 context: "TransE::new".to_string(),
248 });
249 }
250 let entity_embeddings = init_embeddings(n_entities, dim, 1.0 / (dim as f64).sqrt());
251 let relation_embeddings = init_embeddings(n_relations, dim, 1.0 / (dim as f64).sqrt());
252 Ok(TransE {
253 entity_embeddings,
254 relation_embeddings,
255 dim,
256 norm_order: 2,
257 })
258 }
259
260 pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
262 self.validate_indices(h, r, t)?;
263 let he = &self.entity_embeddings[h];
264 let re = &self.relation_embeddings[r];
265 let te = &self.entity_embeddings[t];
266 let dist = translation_distance(he, re, te, self.norm_order);
267 Ok(-dist)
268 }
269
270 pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
272 let n = self.entity_embeddings.len();
273 if h >= n {
274 return Err(GraphError::InvalidParameter {
275 param: "h".to_string(),
276 value: format!("{h}"),
277 expected: format!("< {n}"),
278 context: "TransE::predict_tails".to_string(),
279 });
280 }
281 let he = &self.entity_embeddings[h];
282 let re = &self.relation_embeddings[r];
283 let mut scores: Vec<(usize, f64)> = (0..n)
284 .map(|t| {
285 let te = &self.entity_embeddings[t];
286 let dist = translation_distance(he, re, te, self.norm_order);
287 (t, -dist) })
289 .collect();
290 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
291 Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
292 }
293
294 pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
296 self.score(h, r, t)
297 }
298
299 pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
307 let mut total_loss = 0.0;
308
309 for &(h, r, t) in &dataset.triples {
310 let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
311
312 let pos_score = {
313 let he = &self.entity_embeddings[h];
314 let re = &self.relation_embeddings[r];
315 let te = &self.entity_embeddings[t];
316 translation_distance(he, re, te, self.norm_order)
317 };
318 let neg_score = {
319 let he = &self.entity_embeddings[nh];
320 let re = &self.relation_embeddings[nr];
321 let te = &self.entity_embeddings[nt];
322 translation_distance(he, re, te, self.norm_order)
323 };
324
325 let loss = (margin + pos_score - neg_score).max(0.0);
326 total_loss += loss;
327
328 if loss > 0.0 {
329 let dim = self.dim;
331
332 let g_pos: Vec<f64> = (0..dim)
334 .map(|k| {
335 let diff = self.entity_embeddings[h][k] + self.relation_embeddings[r][k]
336 - self.entity_embeddings[t][k];
337 if diff >= 0.0 {
338 1.0
339 } else {
340 -1.0
341 }
342 })
343 .collect();
344
345 let g_neg: Vec<f64> = (0..dim)
347 .map(|k| {
348 let diff = self.entity_embeddings[nh][k] + self.relation_embeddings[nr][k]
349 - self.entity_embeddings[nt][k];
350 if diff >= 0.0 {
351 1.0
352 } else {
353 -1.0
354 }
355 })
356 .collect();
357
358 for k in 0..dim {
360 self.entity_embeddings[h][k] -= lr * g_pos[k];
361 self.entity_embeddings[t][k] += lr * g_pos[k];
362 self.relation_embeddings[r][k] -= lr * g_pos[k];
363 }
364 for k in 0..dim {
366 self.entity_embeddings[nh][k] += lr * g_neg[k];
367 self.entity_embeddings[nt][k] -= lr * g_neg[k];
368 }
369
370 l2_normalize(&mut self.entity_embeddings[h]);
372 l2_normalize(&mut self.entity_embeddings[t]);
373 l2_normalize(&mut self.entity_embeddings[nh]);
374 l2_normalize(&mut self.entity_embeddings[nt]);
375 }
376 }
377
378 total_loss
379 }
380
381 fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
382 let ne = self.entity_embeddings.len();
383 let nr = self.relation_embeddings.len();
384 if h >= ne || t >= ne {
385 return Err(GraphError::InvalidParameter {
386 param: "entity_index".to_string(),
387 value: format!("({h},{t})"),
388 expected: format!("< {ne}"),
389 context: "TransE score".to_string(),
390 });
391 }
392 if r >= nr {
393 return Err(GraphError::InvalidParameter {
394 param: "relation_index".to_string(),
395 value: format!("{r}"),
396 expected: format!("< {nr}"),
397 context: "TransE score".to_string(),
398 });
399 }
400 Ok(())
401 }
402}
403
404fn translation_distance(h: &[f64], r: &[f64], t: &[f64], norm_order: u32) -> f64 {
406 let diff_sum: f64 = h
407 .iter()
408 .zip(r.iter())
409 .zip(t.iter())
410 .map(|((&hi, &ri), &ti)| {
411 let d = hi + ri - ti;
412 match norm_order {
413 1 => d.abs(),
414 _ => d * d,
415 }
416 })
417 .sum();
418 match norm_order {
419 1 => diff_sum,
420 _ => diff_sum.sqrt(),
421 }
422}
423
424#[derive(Debug, Clone)]
432pub struct DistMult {
433 pub entity_embeddings: Vec<Vec<f64>>,
435 pub relation_embeddings: Vec<Vec<f64>>,
437 pub dim: usize,
439}
440
441impl DistMult {
442 pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
444 if dim == 0 {
445 return Err(GraphError::InvalidParameter {
446 param: "dim".to_string(),
447 value: "0".to_string(),
448 expected: "> 0".to_string(),
449 context: "DistMult::new".to_string(),
450 });
451 }
452 let mut rng = scirs2_core::random::rng();
453 let scale = 1.0 / (dim as f64).sqrt();
454 let mut mk_table = |n: usize| -> Vec<Vec<f64>> {
455 (0..n)
456 .map(|_| {
457 (0..dim)
458 .map(|_| rng.random::<f64>() * 2.0 * scale - scale)
459 .collect()
460 })
461 .collect()
462 };
463 Ok(DistMult {
464 entity_embeddings: mk_table(n_entities),
465 relation_embeddings: mk_table(n_relations),
466 dim,
467 })
468 }
469
470 pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
472 self.validate_indices(h, r, t)?;
473 let score = distmult_score(
474 &self.entity_embeddings[h],
475 &self.relation_embeddings[r],
476 &self.entity_embeddings[t],
477 );
478 Ok(score)
479 }
480
481 pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
483 self.score(h, r, t)
484 }
485
486 pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
488 let n = self.entity_embeddings.len();
489 if h >= n {
490 return Err(GraphError::InvalidParameter {
491 param: "h".to_string(),
492 value: format!("{h}"),
493 expected: format!("< {n}"),
494 context: "DistMult::predict_tails".to_string(),
495 });
496 }
497 let he = &self.entity_embeddings[h];
498 let re = &self.relation_embeddings[r];
499 let mut scores: Vec<(usize, f64)> = (0..n)
500 .map(|ti| {
501 let te = &self.entity_embeddings[ti];
502 (ti, distmult_score(he, re, te))
503 })
504 .collect();
505 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
506 Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
507 }
508
509 pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
511 let mut total_loss = 0.0;
512 for &(h, r, t) in &dataset.triples {
513 let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
514
515 let pos = distmult_score(
516 &self.entity_embeddings[h],
517 &self.relation_embeddings[r],
518 &self.entity_embeddings[t],
519 );
520 let neg = distmult_score(
521 &self.entity_embeddings[nh],
522 &self.relation_embeddings[nr],
523 &self.entity_embeddings[nt],
524 );
525
526 let loss = (margin - pos + neg).max(0.0);
527 total_loss += loss;
528
529 if loss > 0.0 {
530 let dim = self.dim;
531 for k in 0..dim {
533 let re = self.relation_embeddings[r][k];
534 let te = self.entity_embeddings[t][k];
535 self.entity_embeddings[h][k] += lr * re * te;
536 }
537 for k in 0..dim {
538 let re = self.relation_embeddings[nr][k];
539 let te = self.entity_embeddings[nt][k];
540 self.entity_embeddings[nh][k] -= lr * re * te;
541 }
542 }
543 }
544 total_loss
545 }
546
547 fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
548 let ne = self.entity_embeddings.len();
549 let nr = self.relation_embeddings.len();
550 if h >= ne || t >= ne {
551 return Err(GraphError::InvalidParameter {
552 param: "entity_index".to_string(),
553 value: format!("({h},{t})"),
554 expected: format!("< {ne}"),
555 context: "DistMult score".to_string(),
556 });
557 }
558 if r >= nr {
559 return Err(GraphError::InvalidParameter {
560 param: "relation_index".to_string(),
561 value: format!("{r}"),
562 expected: format!("< {nr}"),
563 context: "DistMult score".to_string(),
564 });
565 }
566 Ok(())
567 }
568}
569
570fn distmult_score(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
571 h.iter()
572 .zip(r.iter())
573 .zip(t.iter())
574 .map(|((&hi, &ri), &ti)| hi * ri * ti)
575 .sum()
576}
577
578#[derive(Debug, Clone)]
591pub struct ComplEx {
592 pub entity_re: Vec<Vec<f64>>,
594 pub entity_im: Vec<Vec<f64>>,
596 pub relation_re: Vec<Vec<f64>>,
598 pub relation_im: Vec<Vec<f64>>,
600 pub dim: usize,
602}
603
604impl ComplEx {
605 pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
607 if dim == 0 {
608 return Err(GraphError::InvalidParameter {
609 param: "dim".to_string(),
610 value: "0".to_string(),
611 expected: "> 0".to_string(),
612 context: "ComplEx::new".to_string(),
613 });
614 }
615 let scale = 1.0 / (dim as f64).sqrt();
616 Ok(ComplEx {
617 entity_re: init_embeddings(n_entities, dim, scale),
618 entity_im: init_embeddings(n_entities, dim, scale),
619 relation_re: init_embeddings(n_relations, dim, scale),
620 relation_im: init_embeddings(n_relations, dim, scale),
621 dim,
622 })
623 }
624
625 pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
627 self.validate_indices(h, r, t)?;
628 let s = complex_score(
629 &self.entity_re[h],
630 &self.entity_im[h],
631 &self.relation_re[r],
632 &self.relation_im[r],
633 &self.entity_re[t],
634 &self.entity_im[t],
635 );
636 Ok(s)
637 }
638
639 pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
641 self.score(h, r, t)
642 }
643
644 pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
646 let n = self.entity_re.len();
647 if h >= n {
648 return Err(GraphError::InvalidParameter {
649 param: "h".to_string(),
650 value: format!("{h}"),
651 expected: format!("< {n}"),
652 context: "ComplEx::predict_tails".to_string(),
653 });
654 }
655 let mut scores: Vec<(usize, f64)> = (0..n)
656 .map(|ti| {
657 let s = complex_score(
658 &self.entity_re[h],
659 &self.entity_im[h],
660 &self.relation_re[r],
661 &self.relation_im[r],
662 &self.entity_re[ti],
663 &self.entity_im[ti],
664 );
665 (ti, s)
666 })
667 .collect();
668 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
669 Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
670 }
671
672 pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
674 let mut total_loss = 0.0;
675 for &(h, r, t) in &dataset.triples {
676 let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
677
678 let pos = complex_score(
679 &self.entity_re[h],
680 &self.entity_im[h],
681 &self.relation_re[r],
682 &self.relation_im[r],
683 &self.entity_re[t],
684 &self.entity_im[t],
685 );
686 let neg = complex_score(
687 &self.entity_re[nh],
688 &self.entity_im[nh],
689 &self.relation_re[nr],
690 &self.relation_im[nr],
691 &self.entity_re[nt],
692 &self.entity_im[nt],
693 );
694
695 let loss = (margin - pos + neg).max(0.0);
696 total_loss += loss;
697
698 if loss > 0.0 {
699 let dim = self.dim;
700 for k in 0..dim {
702 let re_r = self.relation_re[r][k];
703 let im_r = self.relation_im[r][k];
704 let re_t = self.entity_re[t][k];
705 let im_t = self.entity_im[t][k];
706 let g_re_h = re_r * re_t + im_r * im_t;
707 let g_im_h = re_r * im_t - im_r * re_t;
708 self.entity_re[h][k] += lr * g_re_h;
709 self.entity_im[h][k] += lr * g_im_h;
710
711 let re_rn = self.relation_re[nr][k];
713 let im_rn = self.relation_im[nr][k];
714 let re_tn = self.entity_re[nt][k];
715 let im_tn = self.entity_im[nt][k];
716 let g_re_hn = re_rn * re_tn + im_rn * im_tn;
717 let g_im_hn = re_rn * im_tn - im_rn * re_tn;
718 self.entity_re[nh][k] -= lr * g_re_hn;
719 self.entity_im[nh][k] -= lr * g_im_hn;
720 }
721 }
722 }
723 total_loss
724 }
725
726 fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
727 let ne = self.entity_re.len();
728 let nr = self.relation_re.len();
729 if h >= ne || t >= ne {
730 return Err(GraphError::InvalidParameter {
731 param: "entity_index".to_string(),
732 value: format!("({h},{t})"),
733 expected: format!("< {ne}"),
734 context: "ComplEx score".to_string(),
735 });
736 }
737 if r >= nr {
738 return Err(GraphError::InvalidParameter {
739 param: "relation_index".to_string(),
740 value: format!("{r}"),
741 expected: format!("< {nr}"),
742 context: "ComplEx score".to_string(),
743 });
744 }
745 Ok(())
746 }
747}
748
749fn complex_score(
759 h_re: &[f64],
760 h_im: &[f64],
761 r_re: &[f64],
762 r_im: &[f64],
763 t_re: &[f64],
764 t_im: &[f64],
765) -> f64 {
766 h_re.iter()
767 .zip(h_im.iter())
768 .zip(r_re.iter())
769 .zip(r_im.iter())
770 .zip(t_re.iter())
771 .zip(t_im.iter())
772 .map(|(((((hre, him), rre), rim), tre), tim)| {
773 hre * rre * tre + him * rre * tim + hre * rim * tim - him * rim * tre
774 })
775 .sum()
776}
777
778pub enum KgModel {
784 TransE(TransE),
786 DistMult(DistMult),
788 ComplEx(ComplEx),
790}
791
792impl KgModel {
793 pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
795 match self {
796 KgModel::TransE(m) => m.link_prediction_score(h, r, t),
797 KgModel::DistMult(m) => m.link_prediction_score(h, r, t),
798 KgModel::ComplEx(m) => m.link_prediction_score(h, r, t),
799 }
800 }
801}
802
803#[cfg(test)]
808mod tests {
809 use super::*;
810
811 fn simple_dataset() -> KGDataset {
812 let triples = vec![(0, 0, 1), (1, 0, 2), (2, 1, 3), (0, 1, 3)];
814 KGDataset::new(triples, 4, 2).expect("dataset")
815 }
816
817 #[test]
820 fn test_dataset_creation() {
821 let ds = simple_dataset();
822 assert_eq!(ds.n_entities, 4);
823 assert_eq!(ds.n_relations, 2);
824 assert_eq!(ds.len(), 4);
825 assert!(!ds.is_empty());
826 }
827
828 #[test]
829 fn test_dataset_from_str_triples() {
830 let raw = vec![
831 ("Alice", "knows", "Bob"),
832 ("Bob", "likes", "Carol"),
833 ("Alice", "likes", "Carol"),
834 ];
835 let ds = KGDataset::from_str_triples(&raw);
836 assert_eq!(ds.n_entities, 3); assert_eq!(ds.n_relations, 2); assert_eq!(ds.len(), 3);
839 }
840
841 #[test]
842 fn test_dataset_out_of_bounds() {
843 let triples = vec![(10, 0, 1)]; let result = KGDataset::new(triples, 4, 2);
845 assert!(result.is_err());
846 }
847
848 #[test]
849 fn test_corrupt_triple_changes_entity() {
850 let ds = simple_dataset();
851 let original = (0, 0, 1);
852 let corrupted = ds.corrupt_triple(original);
853 assert_eq!(corrupted.1, 0);
855 assert!(corrupted.0 != 0 || corrupted.2 != 1);
856 }
857
858 #[test]
861 fn test_transe_score_finite() {
862 let model = TransE::new(4, 2, 8).expect("TransE::new");
863 let score = model.score(0, 0, 1).expect("score");
864 assert!(score.is_finite());
865 }
866
867 #[test]
868 fn test_transe_score_range() {
869 let model = TransE::new(4, 2, 8).expect("TransE::new");
870 let score = model.score(0, 0, 1).expect("score");
872 assert!(score <= 0.0);
873 }
874
875 #[test]
876 fn test_transe_predict_tails_length() {
877 let model = TransE::new(10, 3, 16).expect("TransE");
878 let preds = model.predict_tails(0, 0, 5).expect("predict_tails");
879 assert_eq!(preds.len(), 5);
880 for &idx in &preds {
882 assert!(idx < 10);
883 }
884 }
885
886 #[test]
887 fn test_transe_train_epoch_reduces_loss() {
888 let ds = simple_dataset();
889 let mut model = TransE::new(4, 2, 8).expect("TransE");
890 let loss0 = model.train_epoch(&ds, 0.01, 1.0);
891 let loss1 = model.train_epoch(&ds, 0.01, 1.0);
892 assert!(loss0.is_finite());
894 assert!(loss1.is_finite());
895 }
896
897 #[test]
898 fn test_transe_invalid_index() {
899 let model = TransE::new(4, 2, 8).expect("TransE");
900 assert!(model.score(10, 0, 1).is_err());
901 }
902
903 #[test]
906 fn test_distmult_score_finite() {
907 let model = DistMult::new(4, 2, 8).expect("DistMult");
908 let score = model.score(0, 0, 1).expect("score");
909 assert!(score.is_finite());
910 }
911
912 #[test]
913 fn test_distmult_predict_tails() {
914 let model = DistMult::new(10, 3, 16).expect("DistMult");
915 let preds = model.predict_tails(0, 1, 3).expect("predict");
916 assert_eq!(preds.len(), 3);
917 }
918
919 #[test]
920 fn test_distmult_train_epoch() {
921 let ds = simple_dataset();
922 let mut model = DistMult::new(4, 2, 8).expect("DistMult");
923 let loss = model.train_epoch(&ds, 0.01, 1.0);
924 assert!(loss.is_finite());
925 }
926
927 #[test]
930 fn test_complex_score_finite() {
931 let model = ComplEx::new(4, 2, 8).expect("ComplEx");
932 let score = model.score(0, 0, 1).expect("score");
933 assert!(score.is_finite());
934 }
935
936 #[test]
937 fn test_complex_predict_tails() {
938 let model = ComplEx::new(10, 3, 16).expect("ComplEx");
939 let preds = model.predict_tails(0, 0, 4).expect("predict");
940 assert_eq!(preds.len(), 4);
941 }
942
943 #[test]
944 fn test_complex_train_epoch() {
945 let ds = simple_dataset();
946 let mut model = ComplEx::new(4, 2, 8).expect("ComplEx");
947 let loss = model.train_epoch(&ds, 0.01, 1.0);
948 assert!(loss.is_finite());
949 }
950
951 #[test]
952 fn test_complex_antisymmetry() {
953 let model = ComplEx::new(4, 2, 16).expect("ComplEx");
955 let s1 = model.score(0, 0, 1).expect("s1");
956 let s2 = model.score(1, 0, 0).expect("s2");
957 assert!(s1.is_finite());
960 assert!(s2.is_finite());
961 }
962
963 #[test]
966 fn test_kgmodel_dispatch() {
967 let transe = TransE::new(4, 2, 8).expect("TransE");
968 let model = KgModel::TransE(transe);
969 let score = model.link_prediction_score(0, 0, 1).expect("score");
970 assert!(score.is_finite());
971 }
972
973 #[test]
974 fn test_multi_epoch_training_transe() {
975 let ds = simple_dataset();
976 let mut model = TransE::new(4, 2, 16).expect("TransE");
977 let mut losses = Vec::new();
978 for _ in 0..5 {
979 losses.push(model.train_epoch(&ds, 0.01, 1.0));
980 }
981 for loss in &losses {
983 assert!(loss.is_finite());
984 }
985 }
986
987 #[test]
988 fn test_complex_score_symmetry_check() {
989 let mut model = ComplEx::new(2, 1, 2).expect("ComplEx");
991 model.entity_re[0] = vec![1.0, 0.0];
993 model.entity_im[0] = vec![0.0, 1.0];
994 model.relation_re[0] = vec![1.0, 1.0];
995 model.relation_im[0] = vec![0.0, 0.0];
996 model.entity_re[1] = vec![1.0, 0.0];
997 model.entity_im[1] = vec![0.0, 1.0];
998 let score = model.score(0, 0, 1).expect("manual score");
1003 assert!((score - 2.0).abs() < 1e-10, "expected 2.0, got {score}");
1004 }
1005}