Skip to main content

zuna_rs/
config.rs

1/// Model and runtime configuration for ZUNA inference.
2///
3/// `ModelConfig` is deserialised from the HuggingFace `config.json`
4/// (the `"model"` sub-object).  Field names must match exactly.
5
6// ── ModelConfig ───────────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10    // Core transformer
11    pub dim:      usize,    // 1024
12    pub n_layers: usize,    // 16
13    pub head_dim: usize,    // 64
14
15    // Token I/O
16    pub input_dim:          usize,  // 32
17    pub encoder_output_dim: usize,  // 32
18
19    // Encoder register/downsampling
20    pub encoder_latent_downsample_factor: usize,  // 1
21
22    // Decoder timestep conditioner output dim
23    #[serde(default = "default_t_dim")]
24    pub t_dim: usize,  // 64
25
26    // Rotary embeddings
27    pub max_seqlen:  usize,  // 50
28    pub rope_dim:    usize,  // 4
29    pub rope_theta:  f64,    // 10_000.0
30
31    // Normalisation
32    #[serde(default = "default_norm_eps")]
33    pub norm_eps: f64,  // 1e-5
34
35    // Feed-forward rounding
36    #[serde(default)]
37    pub ffn_dim_multiplier: Option<f64>,
38    #[serde(default = "default_multiple_of")]
39    pub multiple_of: usize,  // 256
40
41    // Diffusion noise std
42    pub stft_global_sigma: f64,  // 0.1
43}
44
45fn default_t_dim()       -> usize { 64 }
46fn default_norm_eps()    -> f64   { 1e-5 }
47fn default_multiple_of() -> usize { 256 }
48
49impl ModelConfig {
50    /// n_heads is NOT dim/head_dim for this checkpoint.
51    /// It must be inferred from the wq weight shape at load time.
52    /// Use `WeightMap::infer_n_heads()` instead of calling this.
53    pub fn n_heads_fallback(&self) -> usize { self.dim / self.head_dim }
54
55    /// Feed-forward hidden dim (matches Python FeedForward.__init__):
56    ///   hidden = int(2 * 4 * dim / 3)  →  2730
57    ///   hidden = 256 * ceil(2730 / 256) →  2816
58    pub fn ffn_hidden_dim(&self) -> usize {
59        let mut h = (2 * 4 * self.dim) / 3;
60        if let Some(m) = self.ffn_dim_multiplier {
61            h = (m * h as f64) as usize;
62        }
63        self.multiple_of * ((h + self.multiple_of - 1) / self.multiple_of)
64    }
65}
66
67// ── InferConfig ───────────────────────────────────────────────────────────────
68
69#[derive(Debug, Clone)]
70pub struct InferConfig {
71    pub sample_steps: usize,  // 50
72    pub cfg:          f32,    // 1.0 (no guidance)
73    pub data_norm:    f32,    // 10.0
74}
75
76impl Default for InferConfig {
77    fn default() -> Self {
78        Self { sample_steps: 50, cfg: 1.0, data_norm: 10.0 }
79    }
80}
81
82// ── DataConfig ────────────────────────────────────────────────────────────────
83
84#[derive(Debug, Clone)]
85pub struct DataConfig {
86    /// Fine time points per EEG token (= input_dim of the model).
87    pub num_fine_time_pts: usize,  // 32
88    /// Number of bins for x/y/z channel-position discretisation.
89    pub num_bins: usize,           // 50
90    /// Bounding box for scalp positions (metres), used in discretisation.
91    pub xyz_min: [f32; 3],         // [-0.12, -0.12, -0.12]
92    pub xyz_max: [f32; 3],         // [ 0.12,  0.12,  0.12]
93}
94
95impl Default for DataConfig {
96    fn default() -> Self {
97        Self {
98            num_fine_time_pts: 32,
99            num_bins: 50,
100            xyz_min: [-0.12, -0.12, -0.12],
101            xyz_max: [ 0.12,  0.12,  0.12],
102        }
103    }
104}