trident/neural/model/
composite.rs1use burn::config::Config;
7use burn::module::Module;
8use burn::prelude::*;
9
10use super::decoder::{DecoderConfig, StackAwareDecoder};
11use super::encoder::{GnnEncoder, GnnEncoderConfig};
12use super::vocab::VOCAB_SIZE;
13
14#[derive(Config, Debug)]
16pub struct NeuralCompilerConfig {
17 #[config(default = 256)]
19 pub d_model: usize,
20 #[config(default = 32)]
22 pub d_edge: usize,
23 #[config(default = 4)]
25 pub gnn_layers: usize,
26 #[config(default = 6)]
28 pub decoder_layers: usize,
29 #[config(default = 8)]
31 pub n_heads: usize,
32 #[config(default = 1024)]
34 pub d_ff: usize,
35 #[config(default = 256)]
37 pub max_seq: usize,
38 #[config(default = 0.1)]
40 pub dropout: f64,
41}
42
43#[derive(Module, Debug)]
45pub struct NeuralCompilerV2<B: Backend> {
46 pub encoder: GnnEncoder<B>,
47 pub decoder: StackAwareDecoder<B>,
48}
49
50impl NeuralCompilerConfig {
51 pub fn init<B: Backend>(&self, device: &B::Device) -> NeuralCompilerV2<B> {
53 let encoder = GnnEncoderConfig::new()
54 .with_d_model(self.d_model)
55 .with_d_edge(self.d_edge)
56 .with_num_layers(self.gnn_layers)
57 .init(device);
58
59 let decoder = DecoderConfig {
60 d_model: self.d_model,
61 num_layers: self.decoder_layers,
62 n_heads: self.n_heads,
63 d_ff: self.d_ff,
64 max_seq: self.max_seq,
65 max_stack_depth: 65,
66 type_window: 8,
67 dropout: self.dropout,
68 }
69 .init(device);
70
71 NeuralCompilerV2 { encoder, decoder }
72 }
73
74 pub fn param_estimate(&self) -> usize {
76 let d = self.d_model;
80 let gnn_per_layer = 4 * d * d + self.d_edge * d;
81 let gnn = 59 * d + 3 * self.d_edge + self.gnn_layers * gnn_per_layer + 2 * d * d;
82 let dec_per_layer = 3 * (d * d * 4) + d * self.d_ff + self.d_ff * d;
83 let dec = VOCAB_SIZE * d
84 + self.max_seq * d
85 + 65 * 32
86 + 24 * 32
87 + (d + 64) * d
88 + self.decoder_layers * dec_per_layer
89 + d * VOCAB_SIZE;
90 gnn + dec
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use burn::backend::NdArray;
98
99 type B = NdArray;
100
101 #[test]
102 fn composite_model_initializes() {
103 let device = Default::default();
104 let config = NeuralCompilerConfig {
105 d_model: 32,
106 d_edge: 8,
107 gnn_layers: 1,
108 decoder_layers: 1,
109 n_heads: 4,
110 d_ff: 64,
111 max_seq: 32,
112 dropout: 0.0,
113 };
114 let _model = config.init::<B>(&device);
115 }
116
117 #[test]
118 fn param_estimate_reasonable() {
119 let config = NeuralCompilerConfig::new();
120 let params = config.param_estimate();
121 assert!(params > 5_000_000, "too few params: {}", params);
123 assert!(params < 50_000_000, "too many params: {}", params);
124 }
125}