Skip to main content

zuna_rs/
encoder.rs

1//! Standalone ZUNA encoder — produce latent EEG embeddings.
2//!
3//! # Regularisation
4//!
5//! The ZUNA encoder uses an **MMD bottleneck** (Maximum Mean Discrepancy).
6//! During training, an MMD loss constrains the encoder's output distribution
7//! to be close to **N(0, I)**.  At inference the bottleneck is a pure
8//! passthrough — no reparameterisation or additional normalisation is applied.
9//! The weights therefore carry the regularisation implicitly: the encoder
10//! output is already approximately normally distributed with zero mean and
11//! unit variance per dimension.
12//!
13//! # Loading — encoder only
14//!
15//! ```rust,ignore
16//! use zuna_rs::encoder::ZunaEncoder;
17//!
18//! let (enc, ms) = ZunaEncoder::<B>::load(
19//!     Path::new("config.json"),
20//!     Path::new("model.safetensors"),
21//!     device,
22//! )?;
23//!
24//! // Encode a FIF recording
25//! let result = enc.encode_fif(Path::new("recording.fif"), 10.0)?;
26//! result.save_safetensors("data/embeddings.safetensors")?;
27//! ```
28
29use std::{path::Path, time::Instant};
30
31use anyhow::Context;
32use burn::prelude::*;
33
34use crate::{
35    config::{DataConfig, ModelConfig},
36    data::{load_batch, load_from_fif, FifInfo, InputBatch},
37    model::{encoder::EncoderTransformer, rope::RotaryEmbedding},
38    weights::load_encoder_weights,
39};
40
41// ── Output types ──────────────────────────────────────────────────────────────
42
43/// Per-epoch encoder embedding produced by [`ZunaEncoder`].
44///
45/// Represents one 5-second EEG epoch in the model's latent space.
46///
47/// ## Shape
48/// `embeddings` is a flat row-major `f32` buffer of shape
49/// `[n_tokens, output_dim]` where:
50/// - `n_tokens = n_channels × tc`  (S in model notation)
51/// - `output_dim = encoder_output_dim` from config (32 by default)
52/// - `tc = T / num_fine_time_pts`   (coarse time steps)
53///
54/// ## Regularisation
55/// Per the MMD training objective, each dimension of the embedding is
56/// approximately N(0, 1) at the population level.  Individual samples will
57/// deviate; no further normalisation is needed before downstream use.
58pub struct EpochEmbedding {
59    /// Latent tokens: row-major f32, shape `[n_tokens, output_dim]`.
60    pub embeddings: Vec<f32>,
61    /// Shape `[n_tokens, output_dim]`.
62    pub shape: Vec<usize>,
63    /// Discrete token indices needed to re-decode this embedding.
64    /// Row-major i64, shape `[n_tokens, 4]`  (x_bin, y_bin, z_bin, t_coarse).
65    pub tok_idx: Vec<i64>,
66    /// Channel positions in metres, row-major f32, shape `[n_channels, 3]`.
67    pub chan_pos: Vec<f32>,
68    /// Number of EEG channels (C).
69    pub n_channels: usize,
70    /// Coarse time steps per epoch (tc = T / num_fine_time_pts).
71    pub tc: usize,
72}
73
74impl EpochEmbedding {
75    /// Total number of tokens  S = n_channels × tc.
76    #[inline] pub fn n_tokens(&self) -> usize { self.n_channels * self.tc }
77    /// Output dimension of the encoder bottleneck (32 by default).
78    #[inline] pub fn output_dim(&self) -> usize { self.shape.get(1).copied().unwrap_or(0) }
79}
80
81/// Collection of per-epoch embeddings returned by [`ZunaEncoder`].
82pub struct EncodingResult {
83    /// One entry per 5-second EEG epoch.
84    pub epochs: Vec<EpochEmbedding>,
85    /// Metadata extracted from the FIF file; `None` for safetensors batch input.
86    pub fif_info: Option<FifInfo>,
87    /// Preprocessing time in milliseconds.
88    pub ms_preproc: f64,
89    /// Encoder forward-pass time in milliseconds (all epochs combined).
90    pub ms_encode: f64,
91}
92
93impl EncodingResult {
94    /// Persist embeddings to a safetensors file.
95    ///
96    /// Keys written per epoch `N`:
97    /// - `embeddings_N` — `[n_tokens, output_dim]` float32
98    /// - `tok_idx_N`    — `[n_tokens, 4]` int32  (needed for decoding)
99    /// - `chan_pos_N`   — `[n_channels, 3]` float32
100    ///
101    /// Plus:
102    /// - `n_samples`    — scalar float32 = number of epochs
103    pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
104        use safetensors::{Dtype, View};
105        use std::borrow::Cow;
106
107        struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
108        impl View for RawTensor {
109            fn dtype(&self)    -> Dtype         { self.dtype }
110            fn shape(&self)    -> &[usize]      { &self.shape }
111            fn data(&self)     -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
112            fn data_len(&self) -> usize          { self.data.len() }
113        }
114
115        let f32_bytes = |v: &[f32]| -> Vec<u8> { v.iter().flat_map(|f| f.to_le_bytes()).collect() };
116        let i64_bytes = |v: &[i64]| -> Vec<u8> { v.iter().flat_map(|i| i.to_le_bytes()).collect() };
117
118        let mut keys:    Vec<String>    = Vec::new();
119        let mut tensors: Vec<RawTensor> = Vec::new();
120
121        for (i, ep) in self.epochs.iter().enumerate() {
122            let n_tok = ep.n_tokens();
123
124            keys.push(format!("embeddings_{i}"));
125            tensors.push(RawTensor {
126                data: f32_bytes(&ep.embeddings),
127                shape: ep.shape.clone(),
128                dtype: Dtype::F32,
129            });
130
131            keys.push(format!("tok_idx_{i}"));
132            tensors.push(RawTensor {
133                data: i64_bytes(&ep.tok_idx),
134                shape: vec![n_tok, 4],
135                dtype: Dtype::I64,
136            });
137
138            keys.push(format!("chan_pos_{i}"));
139            tensors.push(RawTensor {
140                data: f32_bytes(&ep.chan_pos),
141                shape: vec![ep.n_channels, 3],
142                dtype: Dtype::F32,
143            });
144        }
145
146        let n = self.epochs.len() as f32;
147        keys.push("n_samples".into());
148        tensors.push(RawTensor { data: f32_bytes(&[n]), shape: vec![1], dtype: Dtype::F32 });
149
150        let pairs: Vec<(&str, RawTensor)> = keys.iter().map(|s| s.as_str()).zip(tensors).collect();
151        let bytes = safetensors::serialize(pairs, None)?;
152        std::fs::write(path, bytes)?;
153        Ok(())
154    }
155}
156
157// ── ZunaEncoder ───────────────────────────────────────────────────────────────
158
159/// Standalone ZUNA encoder.
160///
161/// Loads only the encoder half of the pretrained weights — useful when you only
162/// need latent embeddings and want to save memory and startup time compared to
163/// loading the full [`crate::ZunaInference`].
164///
165/// # Backend
166/// Compile-time choice (same as the full model):
167/// - CPU (default): `--features ndarray`
168/// - GPU: `--no-default-features --features wgpu`
169pub struct ZunaEncoder<B: Backend> {
170    encoder:       EncoderTransformer<B>,
171    rope:          RotaryEmbedding<B>,
172    /// Architecture hyperparameters (from config.json).
173    pub model_cfg: ModelConfig,
174    /// Preprocessing / tokenisation parameters.
175    pub data_cfg:  DataConfig,
176    device:        B::Device,
177}
178
179impl<B: Backend> ZunaEncoder<B> {
180    // ── Construction ──────────────────────────────────────────────────────────
181
182    /// Load encoder weights from a HuggingFace `config.json` and
183    /// `model.safetensors`.  Decoder tensors are read from disk but not kept
184    /// in memory (the full file is parsed once for key extraction).
185    ///
186    /// Returns `(encoder, weight_load_ms)`.
187    pub fn load(
188        config_path:  &Path,
189        weights_path: &Path,
190        device:       B::Device,
191    ) -> anyhow::Result<(Self, f64)> {
192        let cfg_str = std::fs::read_to_string(config_path)
193            .with_context(|| format!("config: {}", config_path.display()))?;
194        let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
195        let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
196            .context("parsing model config")?;
197
198        let rope = RotaryEmbedding::<B>::new(
199            model_cfg.head_dim, model_cfg.rope_dim,
200            model_cfg.max_seqlen, model_cfg.rope_theta, &device,
201        );
202
203        let t = Instant::now();
204        let (encoder, n_heads) = load_encoder_weights::<B>(
205            &model_cfg,
206            weights_path.to_str().context("weights path not valid UTF-8")?,
207            &device,
208        )?;
209        let ms = t.elapsed().as_secs_f64() * 1000.0;
210
211        println!("Detected n_heads = {n_heads}");
212
213        Ok((Self { encoder, rope, model_cfg, data_cfg: DataConfig::default(), device }, ms))
214    }
215
216    // ── Accessors ─────────────────────────────────────────────────────────────
217
218    /// One-line description of the loaded encoder.
219    pub fn describe(&self) -> String {
220        let c = &self.model_cfg;
221        format!(
222            "ZUNA encoder  dim={}  layers={}  head_dim={}  out_dim={}",
223            c.dim, c.n_layers, c.head_dim, c.encoder_output_dim,
224        )
225    }
226
227    // ── High-level encode API ─────────────────────────────────────────────────
228
229    /// Preprocess a `.fif` recording and encode it into latent embeddings.
230    ///
231    /// `data_norm` is the same divisor used to train ZUNA (default: 10.0).
232    /// It is applied during preprocessing; the encoder output is **not**
233    /// re-scaled — it reflects the MMD-regularised latent space directly.
234    pub fn encode_fif(
235        &self,
236        fif_path:  &Path,
237        data_norm: f32,
238    ) -> anyhow::Result<EncodingResult> {
239        let t_pp = Instant::now();
240        let (batches, fif_info) = load_from_fif::<B>(
241            fif_path, &self.data_cfg, data_norm, &self.device,
242        ).with_context(|| format!("exg on {}", fif_path.display()))?;
243        let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
244
245        let t_enc = Instant::now();
246        let epochs = self.encode_inputs(batches)?;
247        let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
248
249        Ok(EncodingResult { epochs, fif_info: Some(fif_info), ms_preproc, ms_encode })
250    }
251
252    /// Encode a pre-processed safetensors batch (Python / legacy input path).
253    ///
254    /// The batch is assumed to already be normalised (÷ data_norm); the
255    /// `data_norm` argument is **not** applied again here — it exists only to
256    /// document the convention used when the file was created.
257    pub fn encode_batch(
258        &self,
259        batch_path: &Path,
260    ) -> anyhow::Result<EncodingResult> {
261        let t_pp = Instant::now();
262        let batches = load_batch::<B>(
263            batch_path.to_str().context("batch path not valid UTF-8")?,
264            &self.data_cfg,
265            &self.device,
266        )?;
267        let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
268
269        let t_enc = Instant::now();
270        let epochs = self.encode_inputs(batches)?;
271        let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
272
273        Ok(EncodingResult { epochs, fif_info: None, ms_preproc, ms_encode })
274    }
275
276    /// Encode a single prepared [`InputBatch`], returning the raw encoder
277    /// output tensor `[1, S, output_dim]`.
278    ///
279    /// This is the **MMD-regularised embedding**: training constrains the
280    /// distribution to N(0, I); at inference the bottleneck is a passthrough.
281    /// No further normalisation is applied here.
282    pub fn encode_tensor(&self, batch: &InputBatch<B>) -> Tensor<B, 3> {
283        self.encoder.forward(
284            batch.encoder_input.clone(),
285            batch.tok_idx.clone(),
286            &self.rope,
287        )
288    }
289
290    // ── Lower-level API (benchmark / export) ─────────────────────────────────
291
292    /// Run the FIF preprocessing pipeline and return raw [`InputBatch`]es
293    /// **without** running the encoder.
294    ///
295    /// Use together with [`Self::encode_batches`] to time encode separately,
296    /// or to export the pre-tokenised tensors for external comparison.
297    pub fn preprocess_fif(
298        &self,
299        fif_path:  &Path,
300        data_norm: f32,
301    ) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
302        load_from_fif(fif_path, &self.data_cfg, data_norm, &self.device)
303    }
304
305    /// Encode a list of [`InputBatch`]es produced by [`Self::preprocess_fif`].
306    pub fn encode_batches(
307        &self,
308        batches: Vec<InputBatch<B>>,
309    ) -> anyhow::Result<Vec<EpochEmbedding>> {
310        self.encode_inputs(batches)
311    }
312
313    /// Reference to the Burn device this encoder was loaded on.
314    pub fn device(&self) -> &B::Device { &self.device }
315
316    // ── Internal ──────────────────────────────────────────────────────────────
317
318    pub(crate) fn encode_inputs(
319        &self,
320        batches: Vec<InputBatch<B>>,
321    ) -> anyhow::Result<Vec<EpochEmbedding>> {
322        batches.into_iter().map(|b| self.encode_one(b)).collect()
323    }
324
325    fn encode_one(&self, batch: InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
326        let n_channels = batch.n_channels;
327        let tc         = batch.tc;
328
329        // Keep a copy of tok_idx and chan_pos before the encoder consumes them.
330        let tok_idx_saved  = batch.tok_idx.clone();
331        let chan_pos_saved = batch.chan_pos.clone();
332
333        // Run encoder: [1, S, output_dim]
334        let enc_out = self.encoder.forward(
335            batch.encoder_input,
336            batch.tok_idx,
337            &self.rope,
338        );
339        let [_, s, output_dim] = enc_out.dims();
340
341        // Squeeze batch dim → [S, output_dim] and extract as Vec<f32>.
342        let embeddings = enc_out
343            .squeeze::<2>()
344            .into_data()
345            .to_vec::<f32>()
346            .map_err(|e| anyhow::anyhow!("embedding→vec: {e:?}"))?;
347
348        // tok_idx [S, 4] → Vec<i64>.
349        // NdArray backend stores Int as i64; wgpu backend stores Int as i32.
350        // Try i64 first, fall back to i32 and widen.
351        let tok_idx_data = tok_idx_saved.into_data();
352        let tok_idx: Vec<i64> = tok_idx_data
353            .to_vec::<i64>()
354            .or_else(|_| tok_idx_data.to_vec::<i32>()
355                .map(|v| v.into_iter().map(|x| x as i64).collect()))
356            .map_err(|e| anyhow::anyhow!("tok_idx→vec: {e:?}"))?;
357
358        // chan_pos [C, 3] → Vec<f32>.
359        let chan_pos = chan_pos_saved
360            .into_data()
361            .to_vec::<f32>()
362            .map_err(|e| anyhow::anyhow!("chan_pos→vec: {e:?}"))?;
363
364        Ok(EpochEmbedding {
365            embeddings,
366            shape: vec![s, output_dim],
367            tok_idx,
368            chan_pos,
369            n_channels,
370            tc,
371        })
372    }
373}