zuna_rs/config.rs
1/// Model and runtime configuration for ZUNA inference.
2///
3/// `ModelConfig` is deserialised from the HuggingFace `config.json`
4/// (the `"model"` sub-object). Field names must match exactly.
5
6// ── ModelConfig ───────────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10 // Core transformer
11 pub dim: usize, // 1024
12 pub n_layers: usize, // 16
13 pub head_dim: usize, // 64
14
15 // Token I/O
16 pub input_dim: usize, // 32
17 pub encoder_output_dim: usize, // 32
18
19 // Encoder register/downsampling
20 pub encoder_latent_downsample_factor: usize, // 1
21
22 // Decoder timestep conditioner output dim
23 #[serde(default = "default_t_dim")]
24 pub t_dim: usize, // 64
25
26 // Rotary embeddings
27 pub max_seqlen: usize, // 50
28 pub rope_dim: usize, // 4
29 pub rope_theta: f64, // 10_000.0
30
31 // Normalisation
32 #[serde(default = "default_norm_eps")]
33 pub norm_eps: f64, // 1e-5
34
35 // Feed-forward rounding
36 #[serde(default)]
37 pub ffn_dim_multiplier: Option<f64>,
38 #[serde(default = "default_multiple_of")]
39 pub multiple_of: usize, // 256
40
41 // Diffusion noise std
42 pub stft_global_sigma: f64, // 0.1
43}
44
45fn default_t_dim() -> usize { 64 }
46fn default_norm_eps() -> f64 { 1e-5 }
47fn default_multiple_of() -> usize { 256 }
48
49impl ModelConfig {
50 /// n_heads is NOT dim/head_dim for this checkpoint.
51 /// It must be inferred from the wq weight shape at load time.
52 /// Use `WeightMap::infer_n_heads()` instead of calling this.
53 pub fn n_heads_fallback(&self) -> usize { self.dim / self.head_dim }
54
55 /// Feed-forward hidden dim (matches Python FeedForward.__init__):
56 /// hidden = int(2 * 4 * dim / 3) → 2730
57 /// hidden = 256 * ceil(2730 / 256) → 2816
58 pub fn ffn_hidden_dim(&self) -> usize {
59 let mut h = (2 * 4 * self.dim) / 3;
60 if let Some(m) = self.ffn_dim_multiplier {
61 h = (m * h as f64) as usize;
62 }
63 self.multiple_of * ((h + self.multiple_of - 1) / self.multiple_of)
64 }
65}
66
67// ── InferConfig ───────────────────────────────────────────────────────────────
68
69#[derive(Debug, Clone)]
70pub struct InferConfig {
71 pub sample_steps: usize, // 50
72 pub cfg: f32, // 1.0 (no guidance)
73 pub data_norm: f32, // 10.0
74}
75
76impl Default for InferConfig {
77 fn default() -> Self {
78 Self { sample_steps: 50, cfg: 1.0, data_norm: 10.0 }
79 }
80}
81
82// ── DataConfig ────────────────────────────────────────────────────────────────
83
84#[derive(Debug, Clone)]
85pub struct DataConfig {
86 /// Fine time points per EEG token (= input_dim of the model).
87 pub num_fine_time_pts: usize, // 32
88 /// Number of bins for x/y/z channel-position discretisation.
89 pub num_bins: usize, // 50
90 /// Bounding box for scalp positions (metres), used in discretisation.
91 pub xyz_min: [f32; 3], // [-0.12, -0.12, -0.12]
92 pub xyz_max: [f32; 3], // [ 0.12, 0.12, 0.12]
93}
94
95impl Default for DataConfig {
96 fn default() -> Self {
97 Self {
98 num_fine_time_pts: 32,
99 num_bins: 50,
100 xyz_min: [-0.12, -0.12, -0.12],
101 xyz_max: [ 0.12, 0.12, 0.12],
102 }
103 }
104}