Skip to main content

trident/neural/model/
composite.rs

1//! Composite neural compiler model: GNN encoder + Transformer decoder.
2//!
3//! Wraps the encoder and decoder into a single `Module` that can be
4//! saved/loaded as a unit.
5
6use burn::config::Config;
7use burn::module::Module;
8use burn::prelude::*;
9
10use super::decoder::{DecoderConfig, StackAwareDecoder};
11use super::encoder::{GnnEncoder, GnnEncoderConfig};
12use super::vocab::VOCAB_SIZE;
13
14/// Configuration for the composite neural compiler model.
15#[derive(Config, Debug)]
16pub struct NeuralCompilerConfig {
17    /// Model dimension (shared between encoder and decoder).
18    #[config(default = 256)]
19    pub d_model: usize,
20    /// Edge embedding dimension for GNN.
21    #[config(default = 32)]
22    pub d_edge: usize,
23    /// Number of GNN layers.
24    #[config(default = 4)]
25    pub gnn_layers: usize,
26    /// Number of decoder layers.
27    #[config(default = 6)]
28    pub decoder_layers: usize,
29    /// Number of attention heads.
30    #[config(default = 8)]
31    pub n_heads: usize,
32    /// FFN inner dimension.
33    #[config(default = 1024)]
34    pub d_ff: usize,
35    /// Maximum output sequence length.
36    #[config(default = 256)]
37    pub max_seq: usize,
38    /// Dropout rate (0 for inference).
39    #[config(default = 0.1)]
40    pub dropout: f64,
41}
42
43/// Composite model: GNN encoder + Transformer decoder.
44#[derive(Module, Debug)]
45pub struct NeuralCompilerV2<B: Backend> {
46    pub encoder: GnnEncoder<B>,
47    pub decoder: StackAwareDecoder<B>,
48}
49
50impl NeuralCompilerConfig {
51    /// Initialize the composite model.
52    pub fn init<B: Backend>(&self, device: &B::Device) -> NeuralCompilerV2<B> {
53        let encoder = GnnEncoderConfig::new()
54            .with_d_model(self.d_model)
55            .with_d_edge(self.d_edge)
56            .with_num_layers(self.gnn_layers)
57            .init(device);
58
59        let decoder = DecoderConfig {
60            d_model: self.d_model,
61            num_layers: self.decoder_layers,
62            n_heads: self.n_heads,
63            d_ff: self.d_ff,
64            max_seq: self.max_seq,
65            max_stack_depth: 65,
66            type_window: 8,
67            dropout: self.dropout,
68        }
69        .init(device);
70
71        NeuralCompilerV2 { encoder, decoder }
72    }
73
74    /// Parameter count estimate.
75    pub fn param_estimate(&self) -> usize {
76        // GNN: node_proj(59*d) + edge_embed(3*d_e) + layers*(3*d*d + d_e*d + d + d*d + d*d)
77        // Decoder: token_embed(V*d) + pos_embed(S*d) + depth_embed(65*32) + type(24*32)
78        //        + proj(d+64)*d + layers*(3*d*d*3 + d*4d + 4d*d) + out(d*V)
79        let d = self.d_model;
80        let gnn_per_layer = 4 * d * d + self.d_edge * d;
81        let gnn = 59 * d + 3 * self.d_edge + self.gnn_layers * gnn_per_layer + 2 * d * d;
82        let dec_per_layer = 3 * (d * d * 4) + d * self.d_ff + self.d_ff * d;
83        let dec = VOCAB_SIZE * d
84            + self.max_seq * d
85            + 65 * 32
86            + 24 * 32
87            + (d + 64) * d
88            + self.decoder_layers * dec_per_layer
89            + d * VOCAB_SIZE;
90        gnn + dec
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use burn::backend::NdArray;
98
99    type B = NdArray;
100
101    #[test]
102    fn composite_model_initializes() {
103        let device = Default::default();
104        let config = NeuralCompilerConfig {
105            d_model: 32,
106            d_edge: 8,
107            gnn_layers: 1,
108            decoder_layers: 1,
109            n_heads: 4,
110            d_ff: 64,
111            max_seq: 32,
112            dropout: 0.0,
113        };
114        let _model = config.init::<B>(&device);
115    }
116
117    #[test]
118    fn param_estimate_reasonable() {
119        let config = NeuralCompilerConfig::new();
120        let params = config.param_estimate();
121        // Should be in the ~10-15M range for default config
122        assert!(params > 5_000_000, "too few params: {}", params);
123        assert!(params < 50_000_000, "too many params: {}", params);
124    }
125}