Skip to main content

torsh_graph/
generative.rs

1//! Graph Generation Models
2//!
3//! Advanced implementation of generative models for graphs including
4//! Variational Autoencoders (VAE) and Generative Adversarial Networks (GAN)
5//! specifically designed for graph-structured data.
6//!
7//! # Features:
8//! - Graph Variational Autoencoder (GraphVAE)
9//! - Graph Generative Adversarial Network (GraphGAN)
10//! - Conditional graph generation
11//! - Graph reconstruction and completion
12//! - Latent space graph interpolation
13//! - Property-guided graph generation
14
15// Framework infrastructure - components designed for future use
16#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use scirs2_core::random::thread_rng;
20use torsh_tensor::{
21    creation::{from_vec, randn, zeros},
22    Tensor,
23};
24
25/// Graph Variational Autoencoder (GraphVAE)
26/// Learns a probabilistic latent representation of graphs
27#[derive(Debug)]
28pub struct GraphVAE {
29    // Encoder parameters
30    encoder_in_features: usize,
31    encoder_hidden_features: usize,
32    latent_dim: usize,
33
34    // Encoder layers
35    encoder_layer1: Parameter,
36    encoder_layer2: Parameter,
37
38    // Variational parameters (mean and log-variance)
39    mu_layer: Parameter,
40    logvar_layer: Parameter,
41
42    // Decoder parameters
43    decoder_layer1: Parameter,
44    decoder_layer2: Parameter,
45    node_decoder: Parameter,
46    edge_decoder: Parameter,
47
48    // KL divergence weight
49    beta: f32,
50
51    // Bias terms
52    encoder_bias1: Option<Parameter>,
53    encoder_bias2: Option<Parameter>,
54    decoder_bias1: Option<Parameter>,
55    decoder_bias2: Option<Parameter>,
56}
57
58impl GraphVAE {
59    /// Create a new Graph Variational Autoencoder
60    pub fn new(
61        in_features: usize,
62        hidden_features: usize,
63        latent_dim: usize,
64        beta: f32,
65        use_bias: bool,
66    ) -> Self {
67        // Encoder layers
68        let encoder_layer1 = Parameter::new(
69            randn(&[in_features, hidden_features]).expect("failed to create encoder_layer1 tensor"),
70        );
71        let encoder_layer2 = Parameter::new(
72            randn(&[hidden_features, hidden_features])
73                .expect("failed to create encoder_layer2 tensor"),
74        );
75
76        // Variational layers
77        let mu_layer = Parameter::new(
78            randn(&[hidden_features, latent_dim]).expect("failed to create mu_layer tensor"),
79        );
80        let logvar_layer = Parameter::new(
81            randn(&[hidden_features, latent_dim]).expect("failed to create logvar_layer tensor"),
82        );
83
84        // Decoder layers
85        let decoder_layer1 = Parameter::new(
86            randn(&[latent_dim, hidden_features]).expect("failed to create decoder_layer1 tensor"),
87        );
88        let decoder_layer2 = Parameter::new(
89            randn(&[hidden_features, hidden_features])
90                .expect("failed to create decoder_layer2 tensor"),
91        );
92        let node_decoder = Parameter::new(
93            randn(&[hidden_features, in_features]).expect("failed to create node_decoder tensor"),
94        );
95        let edge_decoder = Parameter::new(
96            randn(&[hidden_features, 1]).expect("failed to create edge_decoder tensor"),
97        );
98
99        let (encoder_bias1, encoder_bias2, decoder_bias1, decoder_bias2) = if use_bias {
100            (
101                Some(Parameter::new(
102                    zeros(&[hidden_features]).expect("failed to create encoder_bias1 tensor"),
103                )),
104                Some(Parameter::new(
105                    zeros(&[hidden_features]).expect("failed to create encoder_bias2 tensor"),
106                )),
107                Some(Parameter::new(
108                    zeros(&[hidden_features]).expect("failed to create decoder_bias1 tensor"),
109                )),
110                Some(Parameter::new(
111                    zeros(&[hidden_features]).expect("failed to create decoder_bias2 tensor"),
112                )),
113            )
114        } else {
115            (None, None, None, None)
116        };
117
118        Self {
119            encoder_in_features: in_features,
120            encoder_hidden_features: hidden_features,
121            latent_dim,
122            encoder_layer1,
123            encoder_layer2,
124            mu_layer,
125            logvar_layer,
126            decoder_layer1,
127            decoder_layer2,
128            node_decoder,
129            edge_decoder,
130            beta,
131            encoder_bias1,
132            encoder_bias2,
133            decoder_bias1,
134            decoder_bias2,
135        }
136    }
137
138    /// Encode graph to latent distribution parameters
139    pub fn encode(&self, graph: &GraphData) -> (Tensor, Tensor) {
140        // Forward through encoder
141        let mut h = graph
142            .x
143            .matmul(&self.encoder_layer1.clone_data())
144            .expect("operation should succeed");
145        if let Some(ref bias) = self.encoder_bias1 {
146            h = h.add(&bias.clone_data()).expect("operation should succeed");
147        }
148        h = self.relu(&h);
149
150        h = h
151            .matmul(&self.encoder_layer2.clone_data())
152            .expect("operation should succeed");
153        if let Some(ref bias) = self.encoder_bias2 {
154            h = h.add(&bias.clone_data()).expect("operation should succeed");
155        }
156        h = self.relu(&h);
157
158        // Global mean pooling
159        let graph_embedding = h
160            .mean(Some(&[0]), false)
161            .expect("mean pooling should succeed");
162        let graph_embedding_2d = graph_embedding
163            .unsqueeze(0)
164            .expect("unsqueeze should succeed"); // Make 2D for matmul
165
166        // Compute mu and logvar
167        let mu = graph_embedding_2d
168            .matmul(&self.mu_layer.clone_data())
169            .expect("mu layer matmul should succeed");
170        let logvar = graph_embedding_2d
171            .matmul(&self.logvar_layer.clone_data())
172            .expect("logvar layer matmul should succeed");
173
174        (mu, logvar)
175    }
176
177    /// Reparameterization trick for sampling from latent distribution
178    pub fn reparameterize(&self, mu: &Tensor, logvar: &Tensor) -> Tensor {
179        // std = exp(0.5 * logvar)
180        let std = logvar
181            .mul_scalar(0.5)
182            .expect("logvar scaling should succeed")
183            .exp()
184            .expect("exp should succeed");
185
186        // Sample epsilon from N(0, 1)
187        let epsilon = randn(mu.shape().dims()).expect("epsilon sampling should succeed");
188
189        // z = mu + std * epsilon
190        mu.add(&std.mul(&epsilon).expect("operation should succeed"))
191            .expect("operation should succeed")
192    }
193
194    /// Decode latent representation to graph
195    pub fn decode(&self, z: &Tensor, num_nodes: usize) -> GraphData {
196        // Forward through decoder
197        let mut h = z
198            .matmul(&self.decoder_layer1.clone_data())
199            .expect("operation should succeed");
200        if let Some(ref bias) = self.decoder_bias1 {
201            h = h.add(&bias.clone_data()).expect("operation should succeed");
202        }
203        h = self.relu(&h);
204
205        h = h
206            .matmul(&self.decoder_layer2.clone_data())
207            .expect("operation should succeed");
208        if let Some(ref bias) = self.decoder_bias2 {
209            h = h.add(&bias.clone_data()).expect("operation should succeed");
210        }
211        h = self.relu(&h);
212
213        // Expand to node-level representation
214        let h_expanded = self.expand_to_nodes(&h, num_nodes);
215
216        // Decode node features
217        let node_features = h_expanded
218            .matmul(&self.node_decoder.clone_data())
219            .expect("operation should succeed");
220
221        // Decode edge probabilities
222        let edge_logits = self.decode_edges(&h_expanded, num_nodes);
223        let edge_index = self.sample_edges(&edge_logits, num_nodes);
224
225        GraphData::new(node_features, edge_index)
226    }
227
228    /// Forward pass through GraphVAE
229    pub fn forward(&self, graph: &GraphData) -> (GraphData, Tensor, Tensor) {
230        // Encode
231        let (mu, logvar) = self.encode(graph);
232
233        // Sample latent variable
234        let z = self.reparameterize(&mu, &logvar);
235
236        // Decode
237        let reconstructed = self.decode(&z, graph.num_nodes);
238
239        (reconstructed, mu, logvar)
240    }
241
242    /// Compute VAE loss (reconstruction + KL divergence)
243    pub fn compute_loss(
244        &self,
245        graph: &GraphData,
246        reconstructed: &GraphData,
247        mu: &Tensor,
248        logvar: &Tensor,
249    ) -> f32 {
250        // Reconstruction loss (MSE for node features)
251        let recon_loss = self.reconstruction_loss(graph, reconstructed);
252
253        // KL divergence: -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
254        let kl_loss = self.kl_divergence(mu, logvar);
255
256        // Total loss
257        recon_loss + self.beta * kl_loss
258    }
259
260    /// Reconstruction loss (MSE)
261    fn reconstruction_loss(&self, original: &GraphData, reconstructed: &GraphData) -> f32 {
262        let orig_data = original.x.to_vec().expect("conversion should succeed");
263        let recon_data = reconstructed.x.to_vec().expect("conversion should succeed");
264
265        let mut mse = 0.0;
266        let len = orig_data.len().min(recon_data.len());
267
268        for i in 0..len {
269            mse += (orig_data[i] - recon_data[i]).powi(2);
270        }
271
272        mse / len as f32
273    }
274
275    /// KL divergence loss
276    fn kl_divergence(&self, mu: &Tensor, logvar: &Tensor) -> f32 {
277        let mu_data = mu.to_vec().expect("conversion should succeed");
278        let logvar_data = logvar.to_vec().expect("conversion should succeed");
279
280        let mut kl = 0.0;
281        for i in 0..mu_data.len() {
282            kl += -0.5 * (1.0 + logvar_data[i] - mu_data[i].powi(2) - logvar_data[i].exp());
283        }
284
285        kl / mu_data.len() as f32
286    }
287
288    /// Generate new graph from random latent vector
289    pub fn generate(&self, num_nodes: usize) -> GraphData {
290        // Sample from standard normal
291        let z = randn(&[1, self.latent_dim]).expect("latent vector sampling should succeed");
292
293        // Decode to graph
294        self.decode(&z, num_nodes)
295    }
296
297    /// Interpolate between two graphs in latent space
298    pub fn interpolate(
299        &self,
300        graph1: &GraphData,
301        graph2: &GraphData,
302        alpha: f32,
303        num_nodes: usize,
304    ) -> GraphData {
305        let (mu1, _) = self.encode(graph1);
306        let (mu2, _) = self.encode(graph2);
307
308        // Linear interpolation
309        let z_interp = mu1
310            .mul_scalar(1.0 - alpha)
311            .expect("mu1 scaling should succeed")
312            .add(&mu2.mul_scalar(alpha).expect("operation should succeed"))
313            .expect("interpolation addition should succeed");
314
315        // Decode interpolated latent
316        self.decode(&z_interp, num_nodes)
317    }
318
319    // Helper methods
320
321    fn relu(&self, x: &Tensor) -> Tensor {
322        let data = x.to_vec().expect("conversion should succeed");
323        let activated: Vec<f32> = data.iter().map(|&v| v.max(0.0)).collect();
324        from_vec(
325            activated,
326            x.shape().dims(),
327            torsh_core::device::DeviceType::Cpu,
328        )
329        .expect("relu tensor creation should succeed")
330    }
331
332    fn expand_to_nodes(&self, h: &Tensor, num_nodes: usize) -> Tensor {
333        // Repeat graph-level embedding for each node
334        let h_data = h.to_vec().expect("conversion should succeed");
335        let feat_dim = h_data.len();
336
337        let mut expanded_data = Vec::new();
338        for _ in 0..num_nodes {
339            expanded_data.extend(&h_data);
340        }
341
342        from_vec(
343            expanded_data,
344            &[num_nodes, feat_dim],
345            torsh_core::device::DeviceType::Cpu,
346        )
347        .expect("expanded nodes tensor creation should succeed")
348    }
349
350    fn decode_edges(&self, h: &Tensor, num_nodes: usize) -> Tensor {
351        // Compute pairwise edge probabilities
352        let mut edge_logits_data = Vec::new();
353
354        for i in 0..num_nodes {
355            for j in 0..num_nodes {
356                if i != j {
357                    // Simplified: use dot product of node embeddings as edge logit
358                    let h_i = h
359                        .slice_tensor(0, i, i + 1)
360                        .expect("node i slice should succeed");
361                    let h_j = h
362                        .slice_tensor(0, j, j + 1)
363                        .expect("node j slice should succeed");
364
365                    let logit = h_i
366                        .dot(&h_j.t().expect("transpose should succeed"))
367                        .expect("dot product should succeed")
368                        .item()
369                        .expect("tensor should have single item");
370                    edge_logits_data.push(logit);
371                } else {
372                    edge_logits_data.push(-1000.0); // No self-loops
373                }
374            }
375        }
376
377        from_vec(
378            edge_logits_data,
379            &[num_nodes, num_nodes],
380            torsh_core::device::DeviceType::Cpu,
381        )
382        .expect("edge logits tensor creation should succeed")
383    }
384
385    fn sample_edges(&self, edge_logits: &Tensor, num_nodes: usize) -> Tensor {
386        let logits_data = edge_logits.to_vec().expect("conversion should succeed");
387        let mut edges = Vec::new();
388
389        // Sample edges based on probabilities (threshold at 0.5)
390        for i in 0..num_nodes {
391            for j in 0..num_nodes {
392                if i != j {
393                    let idx = i * num_nodes + j;
394                    let prob = 1.0 / (1.0 + (-logits_data[idx]).exp()); // Sigmoid
395
396                    if prob > 0.5 {
397                        edges.push(i as f32);
398                        edges.push(j as f32);
399                    }
400                }
401            }
402        }
403
404        if edges.is_empty() {
405            // Return empty edge index
406            return zeros(&[2, 0]).expect("empty edge index creation should succeed");
407        }
408
409        let num_edges = edges.len() / 2;
410        from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
411            .expect("edge index tensor creation should succeed")
412    }
413}
414
415impl GraphLayer for GraphVAE {
416    fn forward(&self, graph: &GraphData) -> GraphData {
417        let (reconstructed, _, _) = self.forward(graph);
418        reconstructed
419    }
420
421    fn parameters(&self) -> Vec<Tensor> {
422        let mut params = vec![
423            self.encoder_layer1.clone_data(),
424            self.encoder_layer2.clone_data(),
425            self.mu_layer.clone_data(),
426            self.logvar_layer.clone_data(),
427            self.decoder_layer1.clone_data(),
428            self.decoder_layer2.clone_data(),
429            self.node_decoder.clone_data(),
430            self.edge_decoder.clone_data(),
431        ];
432
433        if let Some(ref b) = self.encoder_bias1 {
434            params.push(b.clone_data());
435        }
436        if let Some(ref b) = self.encoder_bias2 {
437            params.push(b.clone_data());
438        }
439        if let Some(ref b) = self.decoder_bias1 {
440            params.push(b.clone_data());
441        }
442        if let Some(ref b) = self.decoder_bias2 {
443            params.push(b.clone_data());
444        }
445
446        params
447    }
448}
449
450/// Graph Generative Adversarial Network (GraphGAN)
451/// Learns to generate realistic graphs through adversarial training
452#[derive(Debug)]
453pub struct GraphGAN {
454    latent_dim: usize,
455    hidden_dim: usize,
456    output_features: usize,
457
458    // Generator network
459    generator: GraphGANGenerator,
460
461    // Discriminator network
462    discriminator: GraphGANDiscriminator,
463}
464
465impl GraphGAN {
466    /// Create a new Graph GAN
467    pub fn new(
468        latent_dim: usize,
469        hidden_dim: usize,
470        output_features: usize,
471        use_bias: bool,
472    ) -> Self {
473        let generator = GraphGANGenerator::new(latent_dim, hidden_dim, output_features, use_bias);
474        let discriminator = GraphGANDiscriminator::new(output_features, hidden_dim, use_bias);
475
476        Self {
477            latent_dim,
478            hidden_dim,
479            output_features,
480            generator,
481            discriminator,
482        }
483    }
484
485    /// Generate fake graph from random noise
486    pub fn generate(&self, num_nodes: usize) -> GraphData {
487        let z = randn(&[1, self.latent_dim]).expect("latent vector sampling should succeed");
488        self.generator.generate(&z, num_nodes)
489    }
490
491    /// Discriminator forward pass (returns real/fake score)
492    pub fn discriminate(&self, graph: &GraphData) -> f32 {
493        self.discriminator.forward(graph)
494    }
495
496    /// Train generator (maximize discriminator error)
497    pub fn generator_loss(&self, num_nodes: usize) -> f32 {
498        let fake_graph = self.generate(num_nodes);
499        let fake_score = self.discriminate(&fake_graph);
500
501        // Generator loss: -log(D(G(z)))
502        -(fake_score.ln())
503    }
504
505    /// Train discriminator (distinguish real from fake)
506    pub fn discriminator_loss(&self, real_graph: &GraphData, num_nodes: usize) -> f32 {
507        // Real graph score
508        let real_score = self.discriminate(real_graph);
509
510        // Fake graph score
511        let fake_graph = self.generate(num_nodes);
512        let fake_score = self.discriminate(&fake_graph);
513
514        // Discriminator loss: -log(D(real)) - log(1 - D(fake))
515        -(real_score.ln()) - ((1.0 - fake_score).ln())
516    }
517
518    /// Get generator parameters
519    pub fn generator_parameters(&self) -> Vec<Tensor> {
520        self.generator.parameters()
521    }
522
523    /// Get discriminator parameters
524    pub fn discriminator_parameters(&self) -> Vec<Tensor> {
525        self.discriminator.parameters()
526    }
527}
528
529/// Generator network for GraphGAN
530#[derive(Debug)]
531struct GraphGANGenerator {
532    latent_dim: usize,
533    hidden_dim: usize,
534    output_features: usize,
535
536    layer1: Parameter,
537    layer2: Parameter,
538    node_layer: Parameter,
539    edge_layer: Parameter,
540
541    bias1: Option<Parameter>,
542    bias2: Option<Parameter>,
543}
544
545impl GraphGANGenerator {
546    fn new(latent_dim: usize, hidden_dim: usize, output_features: usize, use_bias: bool) -> Self {
547        let layer1 = Parameter::new(
548            randn(&[latent_dim, hidden_dim]).expect("failed to create generator layer1 tensor"),
549        );
550        let layer2 = Parameter::new(
551            randn(&[hidden_dim, hidden_dim]).expect("failed to create generator layer2 tensor"),
552        );
553        let node_layer = Parameter::new(
554            randn(&[hidden_dim, output_features])
555                .expect("failed to create generator node_layer tensor"),
556        );
557        let edge_layer = Parameter::new(
558            randn(&[hidden_dim, 1]).expect("failed to create generator edge_layer tensor"),
559        );
560
561        let (bias1, bias2) = if use_bias {
562            (
563                Some(Parameter::new(
564                    zeros(&[hidden_dim]).expect("failed to create generator bias1 tensor"),
565                )),
566                Some(Parameter::new(
567                    zeros(&[hidden_dim]).expect("failed to create generator bias2 tensor"),
568                )),
569            )
570        } else {
571            (None, None)
572        };
573
574        Self {
575            latent_dim,
576            hidden_dim,
577            output_features,
578            layer1,
579            layer2,
580            node_layer,
581            edge_layer,
582            bias1,
583            bias2,
584        }
585    }
586
587    fn generate(&self, z: &Tensor, num_nodes: usize) -> GraphData {
588        // Forward through generator
589        let mut h = z
590            .matmul(&self.layer1.clone_data())
591            .expect("operation should succeed");
592        if let Some(ref bias) = self.bias1 {
593            h = h.add(&bias.clone_data()).expect("operation should succeed");
594        }
595        h = self.leaky_relu(&h, 0.2);
596
597        h = h
598            .matmul(&self.layer2.clone_data())
599            .expect("operation should succeed");
600        if let Some(ref bias) = self.bias2 {
601            h = h.add(&bias.clone_data()).expect("operation should succeed");
602        }
603        h = self.leaky_relu(&h, 0.2);
604
605        // Expand to node-level
606        let h_expanded = self.expand_to_nodes(&h, num_nodes);
607
608        // Generate node features
609        let node_features = h_expanded
610            .matmul(&self.node_layer.clone_data())
611            .expect("operation should succeed");
612        let node_features = self.tanh(&node_features);
613
614        // Generate edges
615        let edge_index = self.generate_edges(&h_expanded, num_nodes);
616
617        GraphData::new(node_features, edge_index)
618    }
619
620    fn leaky_relu(&self, x: &Tensor, alpha: f32) -> Tensor {
621        let data = x.to_vec().expect("conversion should succeed");
622        let activated: Vec<f32> = data
623            .iter()
624            .map(|&v| if v > 0.0 { v } else { alpha * v })
625            .collect();
626        from_vec(
627            activated,
628            x.shape().dims(),
629            torsh_core::device::DeviceType::Cpu,
630        )
631        .expect("leaky_relu tensor creation should succeed")
632    }
633
634    fn tanh(&self, x: &Tensor) -> Tensor {
635        let data = x.to_vec().expect("conversion should succeed");
636        let activated: Vec<f32> = data.iter().map(|&v| v.tanh()).collect();
637        from_vec(
638            activated,
639            x.shape().dims(),
640            torsh_core::device::DeviceType::Cpu,
641        )
642        .expect("tanh tensor creation should succeed")
643    }
644
645    fn expand_to_nodes(&self, h: &Tensor, num_nodes: usize) -> Tensor {
646        let h_data = h.to_vec().expect("conversion should succeed");
647        let feat_dim = h_data.len();
648
649        let mut expanded_data = Vec::new();
650        for _ in 0..num_nodes {
651            expanded_data.extend(&h_data);
652        }
653
654        from_vec(
655            expanded_data,
656            &[num_nodes, feat_dim],
657            torsh_core::device::DeviceType::Cpu,
658        )
659        .expect("expanded nodes tensor creation should succeed")
660    }
661
662    fn generate_edges(&self, _h: &Tensor, num_nodes: usize) -> Tensor {
663        let mut edges = Vec::new();
664        let mut rng = thread_rng();
665
666        // Generate edges probabilistically
667        for i in 0..num_nodes {
668            for j in (i + 1)..num_nodes {
669                // Use node embeddings to determine edge probability
670                if rng.gen_range(0.0..1.0) > 0.7 {
671                    edges.push(i as f32);
672                    edges.push(j as f32);
673                    edges.push(j as f32);
674                    edges.push(i as f32);
675                }
676            }
677        }
678
679        if edges.is_empty() {
680            return zeros(&[2, 0]).expect("empty edge index creation should succeed");
681        }
682
683        let num_edges = edges.len() / 2;
684        from_vec(edges, &[2, num_edges], torsh_core::device::DeviceType::Cpu)
685            .expect("edge index tensor creation should succeed")
686    }
687
688    fn parameters(&self) -> Vec<Tensor> {
689        let mut params = vec![
690            self.layer1.clone_data(),
691            self.layer2.clone_data(),
692            self.node_layer.clone_data(),
693            self.edge_layer.clone_data(),
694        ];
695
696        if let Some(ref b) = self.bias1 {
697            params.push(b.clone_data());
698        }
699        if let Some(ref b) = self.bias2 {
700            params.push(b.clone_data());
701        }
702
703        params
704    }
705}
706
707/// Discriminator network for GraphGAN
708#[derive(Debug)]
709struct GraphGANDiscriminator {
710    input_features: usize,
711    hidden_dim: usize,
712
713    layer1: Parameter,
714    layer2: Parameter,
715    output_layer: Parameter,
716
717    bias1: Option<Parameter>,
718    bias2: Option<Parameter>,
719    bias_out: Option<Parameter>,
720}
721
722impl GraphGANDiscriminator {
723    fn new(input_features: usize, hidden_dim: usize, use_bias: bool) -> Self {
724        let layer1 = Parameter::new(
725            randn(&[input_features, hidden_dim])
726                .expect("failed to create discriminator layer1 tensor"),
727        );
728        let layer2 = Parameter::new(
729            randn(&[hidden_dim, hidden_dim]).expect("failed to create discriminator layer2 tensor"),
730        );
731        let output_layer = Parameter::new(
732            randn(&[hidden_dim, 1]).expect("failed to create discriminator output_layer tensor"),
733        );
734
735        let (bias1, bias2, bias_out) = if use_bias {
736            (
737                Some(Parameter::new(
738                    zeros(&[hidden_dim]).expect("failed to create discriminator bias1 tensor"),
739                )),
740                Some(Parameter::new(
741                    zeros(&[hidden_dim]).expect("failed to create discriminator bias2 tensor"),
742                )),
743                Some(Parameter::new(
744                    zeros(&[1]).expect("failed to create discriminator bias_out tensor"),
745                )),
746            )
747        } else {
748            (None, None, None)
749        };
750
751        Self {
752            input_features,
753            hidden_dim,
754            layer1,
755            layer2,
756            output_layer,
757            bias1,
758            bias2,
759            bias_out,
760        }
761    }
762
763    fn forward(&self, graph: &GraphData) -> f32 {
764        // Forward through discriminator
765        let mut h = graph
766            .x
767            .matmul(&self.layer1.clone_data())
768            .expect("operation should succeed");
769        if let Some(ref bias) = self.bias1 {
770            h = h.add(&bias.clone_data()).expect("operation should succeed");
771        }
772        h = self.leaky_relu(&h, 0.2);
773
774        h = h
775            .matmul(&self.layer2.clone_data())
776            .expect("operation should succeed");
777        if let Some(ref bias) = self.bias2 {
778            h = h.add(&bias.clone_data()).expect("operation should succeed");
779        }
780        h = self.leaky_relu(&h, 0.2);
781
782        // Global mean pooling
783        let graph_repr = h
784            .mean(Some(&[0]), false)
785            .expect("mean pooling should succeed");
786        let graph_repr_2d = graph_repr.unsqueeze(0).expect("unsqueeze should succeed"); // Make 2D for matmul
787
788        // Output layer
789        let mut logit = graph_repr_2d
790            .matmul(&self.output_layer.clone_data())
791            .expect("operation should succeed");
792        if let Some(ref bias) = self.bias_out {
793            logit = logit
794                .add(&bias.clone_data())
795                .expect("operation should succeed");
796        }
797
798        // Sigmoid activation
799        let logit_val = logit.item().expect("tensor should have single item");
800        1.0 / (1.0 + (-logit_val).exp())
801    }
802
803    fn leaky_relu(&self, x: &Tensor, alpha: f32) -> Tensor {
804        let data = x.to_vec().expect("conversion should succeed");
805        let activated: Vec<f32> = data
806            .iter()
807            .map(|&v| if v > 0.0 { v } else { alpha * v })
808            .collect();
809        from_vec(
810            activated,
811            x.shape().dims(),
812            torsh_core::device::DeviceType::Cpu,
813        )
814        .expect("discriminator leaky_relu tensor creation should succeed")
815    }
816
817    fn parameters(&self) -> Vec<Tensor> {
818        let mut params = vec![
819            self.layer1.clone_data(),
820            self.layer2.clone_data(),
821            self.output_layer.clone_data(),
822        ];
823
824        if let Some(ref b) = self.bias1 {
825            params.push(b.clone_data());
826        }
827        if let Some(ref b) = self.bias2 {
828            params.push(b.clone_data());
829        }
830        if let Some(ref b) = self.bias_out {
831            params.push(b.clone_data());
832        }
833
834        params
835    }
836}
837
838/// Conditional Graph Generation
839#[derive(Debug)]
840pub struct ConditionalGraphGenerator {
841    vae: GraphVAE,
842    condition_dim: usize,
843    condition_layer: Parameter,
844}
845
846impl ConditionalGraphGenerator {
847    /// Create a new conditional graph generator
848    pub fn new(
849        in_features: usize,
850        hidden_features: usize,
851        latent_dim: usize,
852        condition_dim: usize,
853        beta: f32,
854    ) -> Self {
855        let vae = GraphVAE::new(in_features, hidden_features, latent_dim, beta, true);
856        let condition_layer = Parameter::new(
857            randn(&[condition_dim, latent_dim]).expect("failed to create condition_layer tensor"),
858        );
859
860        Self {
861            vae,
862            condition_dim,
863            condition_layer,
864        }
865    }
866
867    /// Generate graph conditioned on a property vector
868    pub fn generate_conditional(&self, condition: &Tensor, num_nodes: usize) -> GraphData {
869        // Map condition to latent space bias
870        let condition_bias = condition
871            .matmul(&self.condition_layer.clone_data())
872            .expect("condition matmul should succeed");
873
874        // Sample base latent vector
875        let z_base =
876            randn(&[1, self.vae.latent_dim]).expect("latent vector sampling should succeed");
877
878        // Add conditional bias
879        let z = z_base
880            .add(&condition_bias)
881            .expect("operation should succeed");
882
883        // Decode to graph
884        self.vae.decode(&z, num_nodes)
885    }
886
887    fn parameters(&self) -> Vec<Tensor> {
888        let mut params = self.vae.parameters();
889        params.push(self.condition_layer.clone_data());
890        params
891    }
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897    use torsh_core::device::DeviceType;
898
899    #[test]
900    fn test_graphvae_creation() {
901        let vae = GraphVAE::new(8, 16, 10, 1.0, true);
902        assert_eq!(vae.encoder_in_features, 8);
903        assert_eq!(vae.encoder_hidden_features, 16);
904        assert_eq!(vae.latent_dim, 10);
905        assert_eq!(vae.beta, 1.0);
906    }
907
908    #[test]
909    fn test_graphvae_encode_decode() {
910        let features = randn(&[5, 8]).unwrap();
911        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
912        let edge_index = from_vec(edges, &[2, 4], DeviceType::Cpu).unwrap();
913        let graph = GraphData::new(features, edge_index);
914
915        let vae = GraphVAE::new(8, 16, 10, 1.0, true);
916
917        let (mu, logvar) = vae.encode(&graph);
918        assert_eq!(mu.shape().dims(), &[1, 10]);
919        assert_eq!(logvar.shape().dims(), &[1, 10]);
920
921        let z = vae.reparameterize(&mu, &logvar);
922        assert_eq!(z.shape().dims(), &[1, 10]);
923
924        let reconstructed = vae.decode(&z, 5);
925        assert_eq!(reconstructed.num_nodes, 5);
926    }
927
928    #[test]
929    fn test_graphvae_generation() {
930        let vae = GraphVAE::new(8, 16, 10, 1.0, true);
931        let generated = vae.generate(6);
932
933        assert_eq!(generated.num_nodes, 6);
934        assert_eq!(generated.x.shape().dims()[0], 6);
935        assert_eq!(generated.x.shape().dims()[1], 8);
936    }
937
938    #[test]
939    fn test_graphvae_interpolation() {
940        let features1 = randn(&[4, 6]).unwrap();
941        let features2 = randn(&[4, 6]).unwrap();
942        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
943        let edge_index = from_vec(edges.clone(), &[2, 3], DeviceType::Cpu).unwrap();
944
945        let graph1 = GraphData::new(features1, edge_index.clone());
946        let graph2 = GraphData::new(features2, edge_index);
947
948        let vae = GraphVAE::new(6, 12, 8, 1.0, true);
949
950        // Interpolate at alpha = 0.5 (midpoint)
951        let interpolated = vae.interpolate(&graph1, &graph2, 0.5, 4);
952        assert_eq!(interpolated.num_nodes, 4);
953    }
954
955    #[test]
956    fn test_graphgan_creation() {
957        let gan = GraphGAN::new(16, 32, 8, true);
958        assert_eq!(gan.latent_dim, 16);
959        assert_eq!(gan.hidden_dim, 32);
960        assert_eq!(gan.output_features, 8);
961    }
962
963    #[test]
964    fn test_graphgan_generation() {
965        let gan = GraphGAN::new(16, 32, 8, true);
966        let generated = gan.generate(5);
967
968        assert_eq!(generated.num_nodes, 5);
969        assert_eq!(generated.x.shape().dims()[1], 8);
970    }
971
972    #[test]
973    fn test_graphgan_discriminate() {
974        let features = randn(&[4, 8]).unwrap();
975        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
976        let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
977        let graph = GraphData::new(features, edge_index);
978
979        let gan = GraphGAN::new(16, 32, 8, true);
980        let score = gan.discriminate(&graph);
981
982        assert!(score >= 0.0 && score <= 1.0);
983    }
984
985    #[test]
986    fn test_conditional_generation() {
987        let cond_gen = ConditionalGraphGenerator::new(8, 16, 10, 4, 1.0);
988
989        let condition = randn(&[1, 4]).unwrap();
990        let generated = cond_gen.generate_conditional(&condition, 5);
991
992        assert_eq!(generated.num_nodes, 5);
993        assert_eq!(generated.x.shape().dims()[1], 8);
994    }
995
996    #[test]
997    fn test_graphvae_loss_computation() {
998        let features = randn(&[3, 6]).unwrap();
999        let edges = vec![0.0, 1.0, 1.0, 2.0];
1000        let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1001        let graph = GraphData::new(features, edge_index);
1002
1003        let vae = GraphVAE::new(6, 12, 8, 1.0, true);
1004        let (reconstructed, mu, logvar) = vae.forward(&graph);
1005
1006        let loss = vae.compute_loss(&graph, &reconstructed, &mu, &logvar);
1007        assert!(loss > 0.0);
1008    }
1009
1010    #[test]
1011    fn test_graphgan_losses() {
1012        let features = randn(&[4, 8]).unwrap();
1013        let edges = vec![0.0, 1.0, 1.0, 2.0, 2.0, 3.0];
1014        let edge_index = from_vec(edges, &[2, 3], DeviceType::Cpu).unwrap();
1015        let graph = GraphData::new(features, edge_index);
1016
1017        let gan = GraphGAN::new(16, 32, 8, true);
1018
1019        let gen_loss = gan.generator_loss(4);
1020        assert!(gen_loss > 0.0);
1021
1022        let disc_loss = gan.discriminator_loss(&graph, 4);
1023        // Discriminator loss can be negative
1024        assert!(disc_loss.is_finite());
1025    }
1026}