1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use scirs2_core::random::thread_rng;
20use torsh_tensor::{
21 creation::{from_vec, randn, zeros},
22 Tensor,
23};
24
25#[derive(Debug)]
28pub struct GraphVAE {
29 encoder_in_features: usize,
31 encoder_hidden_features: usize,
32 latent_dim: usize,
33
34 encoder_layer1: Parameter,
36 encoder_layer2: Parameter,
37
38 mu_layer: Parameter,
40 logvar_layer: Parameter,
41
42 decoder_layer1: Parameter,
44 decoder_layer2: Parameter,
45 node_decoder: Parameter,
46 edge_decoder: Parameter,
47
48 beta: f32,
50
51 encoder_bias1: Option<Parameter>,
53 encoder_bias2: Option<Parameter>,
54 decoder_bias1: Option<Parameter>,
55 decoder_bias2: Option<Parameter>,
56}
57
58impl GraphVAE {
59 pub fn new(
61 in_features: usize,
62 hidden_features: usize,
63 latent_dim: usize,
64 beta: f32,
65 use_bias: bool,
66 ) -> Self {
67 let encoder_layer1 = Parameter::new(
69 randn(&[in_features, hidden_features]).expect("failed to create encoder_layer1 tensor"),
70 );
71 let encoder_layer2 = Parameter::new(
72 randn(&[hidden_features, hidden_features])
73 .expect("failed to create encoder_layer2 tensor"),
74 );
75
76 let mu_layer = Parameter::new(
78 randn(&[hidden_features, latent_dim]).expect("failed to create mu_layer tensor"),
79 );
80 let logvar_layer = Parameter::new(
81 randn(&[hidden_features, latent_dim]).expect("failed to create logvar_layer tensor"),
82 );
83
84 let decoder_layer1 = Parameter::new(
86 randn(&[latent_dim, hidden_features]).expect("failed to create decoder_layer1 tensor"),
87 );
88 let decoder_layer2 = Parameter::new(
89 randn(&[hidden_features, hidden_features])
90 .expect("failed to create decoder_layer2 tensor"),
91 );
92 let node_decoder = Parameter::new(
93 randn(&[hidden_features, in_features]).expect("failed to create node_decoder tensor"),
94 );
95 let edge_decoder = Parameter::new(
96 randn(&[hidden_features, 1]).expect("failed to create edge_decoder tensor"),
97 );
98
99 let (encoder_bias1, encoder_bias2, decoder_bias1, decoder_bias2) = if use_bias {
100 (
101 Some(Parameter::new(
102 zeros(&[hidden_features]).expect("failed to create encoder_bias1 tensor"),
103 )),
104 Some(Parameter::new(
105 zeros(&[hidden_features]).expect("failed to create encoder_bias2 tensor"),
106 )),
107 Some(Parameter::new(
108 zeros(&[hidden_features]).expect("failed to create decoder_bias1 tensor"),
109 )),
110 Some(Parameter::new(
111 zeros(&[hidden_features]).expect("failed to create decoder_bias2 tensor"),
112 )),
113 )
114 } else {
115 (None, None, None, None)
116 };
117
118 Self {
119 encoder_in_features: in_features,
120 encoder_hidden_features: hidden_features,
121 latent_dim,
122 encoder_layer1,
123 encoder_layer2,
124 mu_layer,
125 logvar_layer,
126 decoder_layer1,
127 decoder_layer2,
128 node_decoder,
129 edge_decoder,
130 beta,
131 encoder_bias1,
132 encoder_bias2,
133 decoder_bias1,
134 decoder_bias2,
135 }
136 }
137
138 pub fn encode(&self, graph: &GraphData) -> (Tensor, Tensor) {
140 let mut h = graph
142 .x
143 .matmul(&self.encoder_layer1.clone_data())
144 .expect("operation should succeed");
145 if let Some(ref bias) = self.encoder_bias1 {
146 h = h.add(&bias.clone_data()).expect("operation should succeed");
147 }
148 h = self.relu(&h);
149
150 h = h
151 .matmul(&self.encoder_layer2.clone_data())
152 .expect("operation should succeed");
153 if let Some(ref bias) = self.encoder_bias2 {
154 h = h.add(&bias.clone_data()).expect("operation should succeed");
155 }
156 h = self.relu(&h);
157
158 let graph_embedding = h
160 .mean(Some(&[0]), false)
161 .expect("mean pooling should succeed");
162 let graph_embedding_2d = graph_embedding
163 .unsqueeze(0)
164 .expect("unsqueeze should succeed"); let mu = graph_embedding_2d
168 .matmul(&self.mu_layer.clone_data())
169 .expect("mu layer matmul should succeed");
170 let logvar = graph_embedding_2d
171 .matmul(&self.logvar_layer.clone_data())
172 .expect("logvar layer matmul should succeed");
173
174 (mu, logvar)
175 }
176
177 pub fn reparameterize(&self, mu: &Tensor, logvar: &Tensor) -> Tensor {
179 let std = logvar
181 .mul_scalar(0.5)
182 .expect("logvar scaling should succeed")
183 .exp()
184 .expect("exp should succeed");
185
186 let epsilon = randn(mu.shape().dims()).expect("epsilon sampling should succeed");
188
189 mu.add(&std.mul(&epsilon).expect("operation should succeed"))
191 .expect("operation should succeed")
192 }
193
194 pub fn decode(&self, z: &Tensor, num_nodes: usize) -> GraphData {
196 let mut h = z
198 .matmul(&self.decoder_layer1.clone_data())
199 .expect("operation should succeed");
200 if let Some(ref bias) = self.decoder_bias1 {
201 h = h.add(&bias.clone_data()).expect("operation should succeed");
202 }
203 h = self.relu(&h);
204
205 h = h
206 .matmul(&self.decoder_layer2.clone_data())
207 .expect("operation should succeed");
208 if let Some(ref bias) = self.decoder_bias2 {
209 h = h.add(&bias.clone_data()).expect("operation should succeed");
210 }
211 h = self.relu(&h);
212
213 let h_expanded = self.expand_to_nodes(&h, num_nodes);
215
216 let node_features = h_expanded
218 .matmul(&self.node_decoder.clone_data())
219 .expect("operation should succeed");
220
221 let edge_logits = self.decode_edges(&h_expanded, num_nodes);
223 let edge_index = self.sample_edges(&edge_logits, num_nodes);
224
225 GraphData::new(node_features, edge_index)
226 }
227
228 pub fn forward(&self, graph: &GraphData) -> (GraphData, Tensor, Tensor) {
230 let (mu, logvar) = self.encode(graph);
232
233 let z = self.reparameterize(&mu, &logvar);
235
236 let reconstructed = self.decode(&z, graph.num_nodes);
238
239 (reconstructed, mu, logvar)
240 }
241
242 pub fn compute_loss(
244 &self,
245 graph: &GraphData,
246 reconstructed: &GraphData,
247 mu: &Tensor,
248 logvar: &Tensor,
249 ) -> f32 {
250 let recon_loss = self.reconstruction_loss(graph, reconstructed);
252
253 let kl_loss = self.kl_divergence(mu, logvar);
255
256 recon_loss + self.beta * kl_loss
258 }
259
260 fn reconstruction_loss(&self, original: &GraphData, reconstructed: &GraphData) -> f32 {
262 let orig_data = original.x.to_vec().expect("conversion should succeed");
263 let recon_data = reconstructed.x.to_vec().expect("conversion should succeed");
264
265 let mut mse = 0.0;
266 let len = orig_data.len().min(recon_data.len());
267
268 for i in 0..len {
269 mse += (orig_data[i] - recon_data[i]).powi(2);
270 }
271
272 mse / len as f32
273 }
274
275 fn kl_divergence(&self, mu: &Tensor, logvar: &Tensor) -> f32 {
277 let mu_data = mu.to_vec().expect("conversion should succeed");
278 let logvar_data = logvar.to_vec().expect("conversion should succeed");
279
280 let mut kl = 0.0;
281 for i in 0..mu_data.len() {
282 kl += -0.5 * (1.0 + logvar_data[i] - mu_data[i].powi(2) - logvar_data[i].exp());
283 }
284
285 kl / mu_data.len() as f32
286 }
287
288 pub fn generate(&self, num_nodes: usize) -> GraphData {
290 let z = randn(&[1, self.latent_dim]).expect("latent vector sampling should succeed");
292
293 self.decode(&z, num_nodes)
295 }
296
297 pub fn interpolate(
299 &self,
300 graph1: &GraphData,
301 graph2: &GraphData,
302 alpha: f32,
303 num_nodes: usize,
304 ) -> GraphData {
305 let (mu1, _) = self.encode(graph1);
306 let (mu2, _) = self.encode(graph2);
307
308 let z_interp = mu1
310 .mul_scalar(1.0 - alpha)
311 .expect("mu1 scaling should succeed")
312 .add(&mu2.mul_scalar(alpha).expect("operation should succeed"))
313 .expect("interpolation addition should succeed");
314
315 self.decode(&z_interp, num_nodes)
317 }
318
319 fn relu(&self, x: &Tensor) -> Tensor {
322 let data = x.to_vec().expect("conversion should succeed");
323 let activated: Vec<f32> = data.iter().map(|&v| v.max(0.0)).collect();
324 from_vec(
325 activated,
326 x.shape().dims(),
327 torsh_core::device::DeviceType::Cpu,
328 )
329 .expect("relu tensor creation should succeed")
330 }
331
332 fn expand_to_nodes(&self, h: &Tensor, num_nodes: usize) -> Tensor {
333 let h_data = h.to_vec().expect("conversion should succeed");
335 let feat_dim = h_data.len();
336
337 let mut expanded_data = Vec::new();
338 for _ in 0..num_nodes {
339 expanded_data.extend(&h_data);
340 }
341
342 from_vec(
343 expanded_data,
344 &[num_nodes, feat_dim],
345 torsh_core::device::DeviceType::Cpu,
346 )
347 .expect("expanded nodes tensor creation should succeed")
348 }
349
350 fn decode_edges(&self, h: &Tensor, num_nodes: usize) -> Tensor {
351 let mut edge_logits_data = Vec::new();
353
354 for i in 0..num_nodes {
355 for j in 0..num_nodes {
356 if i != j {
357 let h_i = h
359 .slice_tensor(0, i, i + 1)
360 .expect("node i slice should succeed");
361 let h_j = h
362 .slice_tensor(0, j, j + 1)
363 .expect("node j slice should succeed");
364
365 let logit = h_i
366 .dot(&h_j.t().expect("transpose should succeed"))
367 .expect("dot product should succeed")
368 .item()
369 .expect("tensor should have single item");
370 edge_logits_data.push(logit);
371 } else {
372 edge_logits_data.push(-1000.0); }
374 }
375 }
376
377 from_vec(
378 edge_logits_data,
379 &[num_nodes, num_nodes],
380 torsh_core::device::DeviceType::Cpu,
381 )
382 .expect("edge logits tensor creation should succeed")
383 }
384
385 fn sample_edges(&self, edge_logits: &Tensor, num_nodes: usize) -> Tensor {
386 let logits_data = edge_logits.to_vec().expect("conversion should succeed");
387 let mut edges = Vec::new();
388
389 for i in 0..num_nodes {
391 for j in 0..num_nodes {
392 if i != j {
393 let idx = i * num_nodes + j;
394 let prob = 1.0 / (1.0 + (-logits_data[idx]).exp()); if prob > 0.5 {
397 edges.push(i as f32);
398 edges.push(j as f32);
399 }
400 }
401 }
402 }
403
404 if edges.is_empty() {
405 return zeros(&[2, 0]).expect("empty edge index creation should succeed");
407 }
408
409 let num_edges = edges.len() / 2;
410 from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
411 .expect("edge index tensor creation should succeed")
412 }
413}
414
415impl GraphLayer for GraphVAE {
416 fn forward(&self, graph: &GraphData) -> GraphData {
417 let (reconstructed, _, _) = self.forward(graph);
418 reconstructed
419 }
420
421 fn parameters(&self) -> Vec<Tensor> {
422 let mut params = vec![
423 self.encoder_layer1.clone_data(),
424 self.encoder_layer2.clone_data(),
425 self.mu_layer.clone_data(),
426 self.logvar_layer.clone_data(),
427 self.decoder_layer1.clone_data(),
428 self.decoder_layer2.clone_data(),
429 self.node_decoder.clone_data(),
430 self.edge_decoder.clone_data(),
431 ];
432
433 if let Some(ref b) = self.encoder_bias1 {
434 params.push(b.clone_data());
435 }
436 if let Some(ref b) = self.encoder_bias2 {
437 params.push(b.clone_data());
438 }
439 if let Some(ref b) = self.decoder_bias1 {
440 params.push(b.clone_data());
441 }
442 if let Some(ref b) = self.decoder_bias2 {
443 params.push(b.clone_data());
444 }
445
446 params
447 }
448}
449
450#[derive(Debug)]
453pub struct GraphGAN {
454 latent_dim: usize,
455 hidden_dim: usize,
456 output_features: usize,
457
458 generator: GraphGANGenerator,
460
461 discriminator: GraphGANDiscriminator,
463}
464
465impl GraphGAN {
466 pub fn new(
468 latent_dim: usize,
469 hidden_dim: usize,
470 output_features: usize,
471 use_bias: bool,
472 ) -> Self {
473 let generator = GraphGANGenerator::new(latent_dim, hidden_dim, output_features, use_bias);
474 let discriminator = GraphGANDiscriminator::new(output_features, hidden_dim, use_bias);
475
476 Self {
477 latent_dim,
478 hidden_dim,
479 output_features,
480 generator,
481 discriminator,
482 }
483 }
484
485 pub fn generate(&self, num_nodes: usize) -> GraphData {
487 let z = randn(&[1, self.latent_dim]).expect("latent vector sampling should succeed");
488 self.generator.generate(&z, num_nodes)
489 }
490
491 pub fn discriminate(&self, graph: &GraphData) -> f32 {
493 self.discriminator.forward(graph)
494 }
495
496 pub fn generator_loss(&self, num_nodes: usize) -> f32 {
498 let fake_graph = self.generate(num_nodes);
499 let fake_score = self.discriminate(&fake_graph);
500
501 -(fake_score.ln())
503 }
504
505 pub fn discriminator_loss(&self, real_graph: &GraphData, num_nodes: usize) -> f32 {
507 let real_score = self.discriminate(real_graph);
509
510 let fake_graph = self.generate(num_nodes);
512 let fake_score = self.discriminate(&fake_graph);
513
514 -(real_score.ln()) - ((1.0 - fake_score).ln())
516 }
517
518 pub fn generator_parameters(&self) -> Vec<Tensor> {
520 self.generator.parameters()
521 }
522
523 pub fn discriminator_parameters(&self) -> Vec<Tensor> {
525 self.discriminator.parameters()
526 }
527}
528
529#[derive(Debug)]
531struct GraphGANGenerator {
532 latent_dim: usize,
533 hidden_dim: usize,
534 output_features: usize,
535
536 layer1: Parameter,
537 layer2: Parameter,
538 node_layer: Parameter,
539 edge_layer: Parameter,
540
541 bias1: Option<Parameter>,
542 bias2: Option<Parameter>,
543}
544
545impl GraphGANGenerator {
546 fn new(latent_dim: usize, hidden_dim: usize, output_features: usize, use_bias: bool) -> Self {
547 let layer1 = Parameter::new(
548 randn(&[latent_dim, hidden_dim]).expect("failed to create generator layer1 tensor"),
549 );
550 let layer2 = Parameter::new(
551 randn(&[hidden_dim, hidden_dim]).expect("failed to create generator layer2 tensor"),
552 );
553 let node_layer = Parameter::new(
554 randn(&[hidden_dim, output_features])
555 .expect("failed to create generator node_layer tensor"),
556 );
557 let edge_layer = Parameter::new(
558 randn(&[hidden_dim, 1]).expect("failed to create generator edge_layer tensor"),
559 );
560
561 let (bias1, bias2) = if use_bias {
562 (
563 Some(Parameter::new(
564 zeros(&[hidden_dim]).expect("failed to create generator bias1 tensor"),
565 )),
566 Some(Parameter::new(
567 zeros(&[hidden_dim]).expect("failed to create generator bias2 tensor"),
568 )),
569 )
570 } else {
571 (None, None)
572 };
573
574 Self {
575 latent_dim,
576 hidden_dim,
577 output_features,
578 layer1,
579 layer2,
580 node_layer,
581 edge_layer,
582 bias1,
583 bias2,
584 }
585 }
586
587 fn generate(&self, z: &Tensor, num_nodes: usize) -> GraphData {
588 let mut h = z
590 .matmul(&self.layer1.clone_data())
591 .expect("operation should succeed");
592 if let Some(ref bias) = self.bias1 {
593 h = h.add(&bias.clone_data()).expect("operation should succeed");
594 }
595 h = self.leaky_relu(&h, 0.2);
596
597 h = h
598 .matmul(&self.layer2.clone_data())
599 .expect("operation should succeed");
600 if let Some(ref bias) = self.bias2 {
601 h = h.add(&bias.clone_data()).expect("operation should succeed");
602 }
603 h = self.leaky_relu(&h, 0.2);
604
605 let h_expanded = self.expand_to_nodes(&h, num_nodes);
607
608 let node_features = h_expanded
610 .matmul(&self.node_layer.clone_data())
611 .expect("operation should succeed");
612 let node_features = self.tanh(&node_features);
613
614 let edge_index = self.generate_edges(&h_expanded, num_nodes);
616
617 GraphData::new(node_features, edge_index)
618 }
619
620 fn leaky_relu(&self, x: &Tensor, alpha: f32) -> Tensor {
621 let data = x.to_vec().expect("conversion should succeed");
622 let activated: Vec<f32> = data
623 .iter()
624 .map(|&v| if v > 0.0 { v } else { alpha * v })
625 .collect();
626 from_vec(
627 activated,
628 x.shape().dims(),
629 torsh_core::device::DeviceType::Cpu,
630 )
631 .expect("leaky_relu tensor creation should succeed")
632 }
633
634 fn tanh(&self, x: &Tensor) -> Tensor {
635 let data = x.to_vec().expect("conversion should succeed");
636 let activated: Vec<f32> = data.iter().map(|&v| v.tanh()).collect();
637 from_vec(
638 activated,
639 x.shape().dims(),
640 torsh_core::device::DeviceType::Cpu,
641 )
642 .expect("tanh tensor creation should succeed")
643 }
644
645 fn expand_to_nodes(&self, h: &Tensor, num_nodes: usize) -> Tensor {
646 let h_data = h.to_vec().expect("conversion should succeed");
647 let feat_dim = h_data.len();
648
649 let mut expanded_data = Vec::new();
650 for _ in 0..num_nodes {
651 expanded_data.extend(&h_data);
652 }
653
654 from_vec(
655 expanded_data,
656 &[num_nodes, feat_dim],
657 torsh_core::device::DeviceType::Cpu,
658 )
659 .expect("expanded nodes tensor creation should succeed")
660 }
661
662 fn generate_edges(&self, _h: &Tensor, num_nodes: usize) -> Tensor {
663 let mut edges = Vec::new();
664 let mut rng = thread_rng();
665
666 for i in 0..num_nodes {
668 for j in (i + 1)..num_nodes {
669 if rng.gen_range(0.0..1.0) > 0.7 {
671 edges.push(i as f32);
672 edges.push(j as f32);
673 edges.push(j as f32);
674 edges.push(i as f32);
675 }
676 }
677 }
678
679 if edges.is_empty() {
680 return zeros(&[2, 0]).expect("empty edge index creation should succeed");
681 }
682
683 let num_edges = edges.len() / 2;
684 from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
685 .expect("edge index tensor creation should succeed")
686 }
687
688 fn parameters(&self) -> Vec<Tensor> {
689 let mut params = vec![
690 self.layer1.clone_data(),
691 self.layer2.clone_data(),
692 self.node_layer.clone_data(),
693 self.edge_layer.clone_data(),
694 ];
695
696 if let Some(ref b) = self.bias1 {
697 params.push(b.clone_data());
698 }
699 if let Some(ref b) = self.bias2 {
700 params.push(b.clone_data());
701 }
702
703 params
704 }
705}
706
707#[derive(Debug)]
709struct GraphGANDiscriminator {
710 input_features: usize,
711 hidden_dim: usize,
712
713 layer1: Parameter,
714 layer2: Parameter,
715 output_layer: Parameter,
716
717 bias1: Option<Parameter>,
718 bias2: Option<Parameter>,
719 bias_out: Option<Parameter>,
720}
721
722impl GraphGANDiscriminator {
723 fn new(input_features: usize, hidden_dim: usize, use_bias: bool) -> Self {
724 let layer1 = Parameter::new(
725 randn(&[input_features, hidden_dim])
726 .expect("failed to create discriminator layer1 tensor"),
727 );
728 let layer2 = Parameter::new(
729 randn(&[hidden_dim, hidden_dim]).expect("failed to create discriminator layer2 tensor"),
730 );
731 let output_layer = Parameter::new(
732 randn(&[hidden_dim, 1]).expect("failed to create discriminator output_layer tensor"),
733 );
734
735 let (bias1, bias2, bias_out) = if use_bias {
736 (
737 Some(Parameter::new(
738 zeros(&[hidden_dim]).expect("failed to create discriminator bias1 tensor"),
739 )),
740 Some(Parameter::new(
741 zeros(&[hidden_dim]).expect("failed to create discriminator bias2 tensor"),
742 )),
743 Some(Parameter::new(
744 zeros(&[1]).expect("failed to create discriminator bias_out tensor"),
745 )),
746 )
747 } else {
748 (None, None, None)
749 };
750
751 Self {
752 input_features,
753 hidden_dim,
754 layer1,
755 layer2,
756 output_layer,
757 bias1,
758 bias2,
759 bias_out,
760 }
761 }
762
763 fn forward(&self, graph: &GraphData) -> f32 {
764 let mut h = graph
766 .x
767 .matmul(&self.layer1.clone_data())
768 .expect("operation should succeed");
769 if let Some(ref bias) = self.bias1 {
770 h = h.add(&bias.clone_data()).expect("operation should succeed");
771 }
772 h = self.leaky_relu(&h, 0.2);
773
774 h = h
775 .matmul(&self.layer2.clone_data())
776 .expect("operation should succeed");
777 if let Some(ref bias) = self.bias2 {
778 h = h.add(&bias.clone_data()).expect("operation should succeed");
779 }
780 h = self.leaky_relu(&h, 0.2);
781
782 let graph_repr = h
784 .mean(Some(&[0]), false)
785 .expect("mean pooling should succeed");
786 let graph_repr_2d = graph_repr.unsqueeze(0).expect("unsqueeze should succeed"); let mut logit = graph_repr_2d
790 .matmul(&self.output_layer.clone_data())
791 .expect("operation should succeed");
792 if let Some(ref bias) = self.bias_out {
793 logit = logit
794 .add(&bias.clone_data())
795 .expect("operation should succeed");
796 }
797
798 let logit_val = logit.item().expect("tensor should have single item");
800 1.0 / (1.0 + (-logit_val).exp())
801 }
802
803 fn leaky_relu(&self, x: &Tensor, alpha: f32) -> Tensor {
804 let data = x.to_vec().expect("conversion should succeed");
805 let activated: Vec<f32> = data
806 .iter()
807 .map(|&v| if v > 0.0 { v } else { alpha * v })
808 .collect();
809 from_vec(
810 activated,
811 x.shape().dims(),
812 torsh_core::device::DeviceType::Cpu,
813 )
814 .expect("discriminator leaky_relu tensor creation should succeed")
815 }
816
817 fn parameters(&self) -> Vec<Tensor> {
818 let mut params = vec![
819 self.layer1.clone_data(),
820 self.layer2.clone_data(),
821 self.output_layer.clone_data(),
822 ];
823
824 if let Some(ref b) = self.bias1 {
825 params.push(b.clone_data());
826 }
827 if let Some(ref b) = self.bias2 {
828 params.push(b.clone_data());
829 }
830 if let Some(ref b) = self.bias_out {
831 params.push(b.clone_data());
832 }
833
834 params
835 }
836}
837
838#[derive(Debug)]
840pub struct ConditionalGraphGenerator {
841 vae: GraphVAE,
842 condition_dim: usize,
843 condition_layer: Parameter,
844}
845
846impl ConditionalGraphGenerator {
847 pub fn new(
849 in_features: usize,
850 hidden_features: usize,
851 latent_dim: usize,
852 condition_dim: usize,
853 beta: f32,
854 ) -> Self {
855 let vae = GraphVAE::new(in_features, hidden_features, latent_dim, beta, true);
856 let condition_layer = Parameter::new(
857 randn(&[condition_dim, latent_dim]).expect("failed to create condition_layer tensor"),
858 );
859
860 Self {
861 vae,
862 condition_dim,
863 condition_layer,
864 }
865 }
866
867 pub fn generate_conditional(&self, condition: &Tensor, num_nodes: usize) -> GraphData {
869 let condition_bias = condition
871 .matmul(&self.condition_layer.clone_data())
872 .expect("condition matmul should succeed");
873
874 let z_base =
876 randn(&[1, self.vae.latent_dim]).expect("latent vector sampling should succeed");
877
878 let z = z_base
880 .add(&condition_bias)
881 .expect("operation should succeed");
882
883 self.vae.decode(&z, num_nodes)
885 }
886
887 fn parameters(&self) -> Vec<Tensor> {
888 let mut params = self.vae.parameters();
889 params.push(self.condition_layer.clone_data());
890 params
891 }
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897 use torsh_core::device::DeviceType;
898
899 #[test]
900 fn test_graphvae_creation() {
901 let vae = GraphVAE::new(8, 16, 10, 1.0, true);
902 assert_eq!(vae.encoder_in_features, 8);
903 assert_eq!(vae.encoder_hidden_features, 16);
904 assert_eq!(vae.latent_dim, 10);
905 assert_eq!(vae.beta, 1.0);
906 }
907
908 #[test]
909 fn test_graphvae_encode_decode() {
910 let features = randn(&[5, 8]).unwrap();
911 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
912 let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
913 let graph = GraphData::new(features, edge_index);
914
915 let vae = GraphVAE::new(8, 16, 10, 1.0, true);
916
917 let (mu, logvar) = vae.encode(&graph);
918 assert_eq!(mu.shape().dims(), &[1, 10]);
919 assert_eq!(logvar.shape().dims(), &[1, 10]);
920
921 let z = vae.reparameterize(&mu, &logvar);
922 assert_eq!(z.shape().dims(), &[1, 10]);
923
924 let reconstructed = vae.decode(&z, 5);
925 assert_eq!(reconstructed.num_nodes, 5);
926 }
927
928 #[test]
929 fn test_graphvae_generation() {
930 let vae = GraphVAE::new(8, 16, 10, 1.0, true);
931 let generated = vae.generate(6);
932
933 assert_eq!(generated.num_nodes, 6);
934 assert_eq!(generated.x.shape().dims()[0], 6);
935 assert_eq!(generated.x.shape().dims()[1], 8);
936 }
937
938 #[test]
939 fn test_graphvae_interpolation() {
940 let features1 = randn(&[4, 6]).unwrap();
941 let features2 = randn(&[4, 6]).unwrap();
942 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
943 let edge_index = from_vec(edges.clone(), &[2, 3], DeviceType::Cpu).unwrap();
944
945 let graph1 = GraphData::new(features1, edge_index.clone());
946 let graph2 = GraphData::new(features2, edge_index);
947
948 let vae = GraphVAE::new(6, 12, 8, 1.0, true);
949
950 let interpolated = vae.interpolate(&graph1, &graph2, 0.5, 4);
952 assert_eq!(interpolated.num_nodes, 4);
953 }
954
955 #[test]
956 fn test_graphgan_creation() {
957 let gan = GraphGAN::new(16, 32, 8, true);
958 assert_eq!(gan.latent_dim, 16);
959 assert_eq!(gan.hidden_dim, 32);
960 assert_eq!(gan.output_features, 8);
961 }
962
963 #[test]
964 fn test_graphgan_generation() {
965 let gan = GraphGAN::new(16, 32, 8, true);
966 let generated = gan.generate(5);
967
968 assert_eq!(generated.num_nodes, 5);
969 assert_eq!(generated.x.shape().dims()[1], 8);
970 }
971
972 #[test]
973 fn test_graphgan_discriminate() {
974 let features = randn(&[4, 8]).unwrap();
975 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
976 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
977 let graph = GraphData::new(features, edge_index);
978
979 let gan = GraphGAN::new(16, 32, 8, true);
980 let score = gan.discriminate(&graph);
981
982 assert!(score >= 0.0 && score <= 1.0);
983 }
984
985 #[test]
986 fn test_conditional_generation() {
987 let cond_gen = ConditionalGraphGenerator::new(8, 16, 10, 4, 1.0);
988
989 let condition = randn(&[1, 4]).unwrap();
990 let generated = cond_gen.generate_conditional(&condition, 5);
991
992 assert_eq!(generated.num_nodes, 5);
993 assert_eq!(generated.x.shape().dims()[1], 8);
994 }
995
996 #[test]
997 fn test_graphvae_loss_computation() {
998 let features = randn(&[3, 6]).unwrap();
999 let edges = vec![0.0, 1.0, 1.0, 2.0];
1000 let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1001 let graph = GraphData::new(features, edge_index);
1002
1003 let vae = GraphVAE::new(6, 12, 8, 1.0, true);
1004 let (reconstructed, mu, logvar) = vae.forward(&graph);
1005
1006 let loss = vae.compute_loss(&graph, &reconstructed, &mu, &logvar);
1007 assert!(loss > 0.0);
1008 }
1009
1010 #[test]
1011 fn test_graphgan_losses() {
1012 let features = randn(&[4, 8]).unwrap();
1013 let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1014 let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1015 let graph = GraphData::new(features, edge_index);
1016
1017 let gan = GraphGAN::new(16, 32, 8, true);
1018
1019 let gen_loss = gan.generator_loss(4);
1020 assert!(gen_loss > 0.0);
1021
1022 let disc_loss = gan.discriminator_loss(&graph, 4);
1023 assert!(disc_loss.is_finite());
1025 }
1026}