zuna_rs/encoder.rs
1//! Standalone ZUNA encoder — produce latent EEG embeddings.
2//!
3//! # Regularisation
4//!
5//! The ZUNA encoder uses an **MMD bottleneck** (Maximum Mean Discrepancy).
6//! During training, an MMD loss constrains the encoder's output distribution
7//! to be close to **N(0, I)**. At inference the bottleneck is a pure
8//! passthrough — no reparameterisation or additional normalisation is applied.
9//! The weights therefore carry the regularisation implicitly: the encoder
10//! output is already approximately normally distributed with zero mean and
11//! unit variance per dimension.
12//!
13//! # Loading — encoder only
14//!
15//! ```rust,ignore
16//! use zuna_rs::encoder::ZunaEncoder;
17//!
18//! let (enc, ms) = ZunaEncoder::<B>::load(
19//! Path::new("config.json"),
20//! Path::new("model.safetensors"),
21//! device,
22//! )?;
23//!
24//! // Encode a FIF recording
25//! let result = enc.encode_fif(Path::new("recording.fif"), 10.0)?;
26//! result.save_safetensors("data/embeddings.safetensors")?;
27//! ```
28
29use std::{path::Path, time::Instant};
30
31use anyhow::Context;
32use burn::prelude::*;
33
34use crate::{
35 config::{DataConfig, ModelConfig},
36 data::{load_batch, load_from_fif, FifInfo, InputBatch},
37 model::{encoder::EncoderTransformer, rope::RotaryEmbedding},
38 weights::load_encoder_weights,
39};
40
41// ── Output types ──────────────────────────────────────────────────────────────
42
43/// Per-epoch encoder embedding produced by [`ZunaEncoder`].
44///
45/// Represents one 5-second EEG epoch in the model's latent space.
46///
47/// ## Shape
48/// `embeddings` is a flat row-major `f32` buffer of shape
49/// `[n_tokens, output_dim]` where:
50/// - `n_tokens = n_channels × tc` (S in model notation)
51/// - `output_dim = encoder_output_dim` from config (32 by default)
52/// - `tc = T / num_fine_time_pts` (coarse time steps)
53///
54/// ## Regularisation
55/// Per the MMD training objective, each dimension of the embedding is
56/// approximately N(0, 1) at the population level. Individual samples will
57/// deviate; no further normalisation is needed before downstream use.
58pub struct EpochEmbedding {
59 /// Latent tokens: row-major f32, shape `[n_tokens, output_dim]`.
60 pub embeddings: Vec<f32>,
61 /// Shape `[n_tokens, output_dim]`.
62 pub shape: Vec<usize>,
63 /// Discrete token indices needed to re-decode this embedding.
64 /// Row-major i64, shape `[n_tokens, 4]` (x_bin, y_bin, z_bin, t_coarse).
65 pub tok_idx: Vec<i64>,
66 /// Channel positions in metres, row-major f32, shape `[n_channels, 3]`.
67 pub chan_pos: Vec<f32>,
68 /// Number of EEG channels (C).
69 pub n_channels: usize,
70 /// Coarse time steps per epoch (tc = T / num_fine_time_pts).
71 pub tc: usize,
72}
73
74impl EpochEmbedding {
75 /// Total number of tokens S = n_channels × tc.
76 #[inline] pub fn n_tokens(&self) -> usize { self.n_channels * self.tc }
77 /// Output dimension of the encoder bottleneck (32 by default).
78 #[inline] pub fn output_dim(&self) -> usize { self.shape.get(1).copied().unwrap_or(0) }
79}
80
81/// Collection of per-epoch embeddings returned by [`ZunaEncoder`].
82pub struct EncodingResult {
83 /// One entry per 5-second EEG epoch.
84 pub epochs: Vec<EpochEmbedding>,
85 /// Metadata extracted from the FIF file; `None` for safetensors batch input.
86 pub fif_info: Option<FifInfo>,
87 /// Preprocessing time in milliseconds.
88 pub ms_preproc: f64,
89 /// Encoder forward-pass time in milliseconds (all epochs combined).
90 pub ms_encode: f64,
91}
92
93impl EncodingResult {
94 /// Persist embeddings to a safetensors file.
95 ///
96 /// Keys written per epoch `N`:
97 /// - `embeddings_N` — `[n_tokens, output_dim]` float32
98 /// - `tok_idx_N` — `[n_tokens, 4]` int32 (needed for decoding)
99 /// - `chan_pos_N` — `[n_channels, 3]` float32
100 ///
101 /// Plus:
102 /// - `n_samples` — scalar float32 = number of epochs
103 pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
104 use safetensors::{Dtype, View};
105 use std::borrow::Cow;
106
107 struct RawTensor { data: Vec<u8>, shape: Vec<usize>, dtype: Dtype }
108 impl View for RawTensor {
109 fn dtype(&self) -> Dtype { self.dtype }
110 fn shape(&self) -> &[usize] { &self.shape }
111 fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
112 fn data_len(&self) -> usize { self.data.len() }
113 }
114
115 let f32_bytes = |v: &[f32]| -> Vec<u8> { v.iter().flat_map(|f| f.to_le_bytes()).collect() };
116 let i64_bytes = |v: &[i64]| -> Vec<u8> { v.iter().flat_map(|i| i.to_le_bytes()).collect() };
117
118 let mut keys: Vec<String> = Vec::new();
119 let mut tensors: Vec<RawTensor> = Vec::new();
120
121 for (i, ep) in self.epochs.iter().enumerate() {
122 let n_tok = ep.n_tokens();
123
124 keys.push(format!("embeddings_{i}"));
125 tensors.push(RawTensor {
126 data: f32_bytes(&ep.embeddings),
127 shape: ep.shape.clone(),
128 dtype: Dtype::F32,
129 });
130
131 keys.push(format!("tok_idx_{i}"));
132 tensors.push(RawTensor {
133 data: i64_bytes(&ep.tok_idx),
134 shape: vec![n_tok, 4],
135 dtype: Dtype::I64,
136 });
137
138 keys.push(format!("chan_pos_{i}"));
139 tensors.push(RawTensor {
140 data: f32_bytes(&ep.chan_pos),
141 shape: vec![ep.n_channels, 3],
142 dtype: Dtype::F32,
143 });
144 }
145
146 let n = self.epochs.len() as f32;
147 keys.push("n_samples".into());
148 tensors.push(RawTensor { data: f32_bytes(&[n]), shape: vec![1], dtype: Dtype::F32 });
149
150 let pairs: Vec<(&str, RawTensor)> = keys.iter().map(|s| s.as_str()).zip(tensors).collect();
151 let bytes = safetensors::serialize(pairs, None)?;
152 std::fs::write(path, bytes)?;
153 Ok(())
154 }
155}
156
157// ── ZunaEncoder ───────────────────────────────────────────────────────────────
158
159/// Standalone ZUNA encoder.
160///
161/// Loads only the encoder half of the pretrained weights — useful when you only
162/// need latent embeddings and want to save memory and startup time compared to
163/// loading the full [`crate::ZunaInference`].
164///
165/// # Backend
166/// Compile-time choice (same as the full model):
167/// - CPU (default): `--features ndarray`
168/// - GPU: `--no-default-features --features wgpu`
169pub struct ZunaEncoder<B: Backend> {
170 encoder: EncoderTransformer<B>,
171 rope: RotaryEmbedding<B>,
172 /// Architecture hyperparameters (from config.json).
173 pub model_cfg: ModelConfig,
174 /// Preprocessing / tokenisation parameters.
175 pub data_cfg: DataConfig,
176 device: B::Device,
177}
178
179impl<B: Backend> ZunaEncoder<B> {
180 // ── Construction ──────────────────────────────────────────────────────────
181
182 /// Load encoder weights from a HuggingFace `config.json` and
183 /// `model.safetensors`. Decoder tensors are read from disk but not kept
184 /// in memory (the full file is parsed once for key extraction).
185 ///
186 /// Returns `(encoder, weight_load_ms)`.
187 pub fn load(
188 config_path: &Path,
189 weights_path: &Path,
190 device: B::Device,
191 ) -> anyhow::Result<(Self, f64)> {
192 let cfg_str = std::fs::read_to_string(config_path)
193 .with_context(|| format!("config: {}", config_path.display()))?;
194 let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
195 let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
196 .context("parsing model config")?;
197
198 let rope = RotaryEmbedding::<B>::new(
199 model_cfg.head_dim, model_cfg.rope_dim,
200 model_cfg.max_seqlen, model_cfg.rope_theta, &device,
201 );
202
203 let t = Instant::now();
204 let (encoder, n_heads) = load_encoder_weights::<B>(
205 &model_cfg,
206 weights_path.to_str().context("weights path not valid UTF-8")?,
207 &device,
208 )?;
209 let ms = t.elapsed().as_secs_f64() * 1000.0;
210
211 println!("Detected n_heads = {n_heads}");
212
213 Ok((Self { encoder, rope, model_cfg, data_cfg: DataConfig::default(), device }, ms))
214 }
215
216 // ── Accessors ─────────────────────────────────────────────────────────────
217
218 /// One-line description of the loaded encoder.
219 pub fn describe(&self) -> String {
220 let c = &self.model_cfg;
221 format!(
222 "ZUNA encoder dim={} layers={} head_dim={} out_dim={}",
223 c.dim, c.n_layers, c.head_dim, c.encoder_output_dim,
224 )
225 }
226
227 // ── High-level encode API ─────────────────────────────────────────────────
228
229 /// Preprocess a `.fif` recording and encode it into latent embeddings.
230 ///
231 /// `data_norm` is the same divisor used to train ZUNA (default: 10.0).
232 /// It is applied during preprocessing; the encoder output is **not**
233 /// re-scaled — it reflects the MMD-regularised latent space directly.
234 pub fn encode_fif(
235 &self,
236 fif_path: &Path,
237 data_norm: f32,
238 ) -> anyhow::Result<EncodingResult> {
239 let t_pp = Instant::now();
240 let (batches, fif_info) = load_from_fif::<B>(
241 fif_path, &self.data_cfg, data_norm, &self.device,
242 ).with_context(|| format!("exg on {}", fif_path.display()))?;
243 let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
244
245 let t_enc = Instant::now();
246 let epochs = self.encode_inputs(batches)?;
247 let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
248
249 Ok(EncodingResult { epochs, fif_info: Some(fif_info), ms_preproc, ms_encode })
250 }
251
252 /// Encode a pre-processed safetensors batch (Python / legacy input path).
253 ///
254 /// The batch is assumed to already be normalised (÷ data_norm); the
255 /// `data_norm` argument is **not** applied again here — it exists only to
256 /// document the convention used when the file was created.
257 pub fn encode_batch(
258 &self,
259 batch_path: &Path,
260 ) -> anyhow::Result<EncodingResult> {
261 let t_pp = Instant::now();
262 let batches = load_batch::<B>(
263 batch_path.to_str().context("batch path not valid UTF-8")?,
264 &self.data_cfg,
265 &self.device,
266 )?;
267 let ms_preproc = t_pp.elapsed().as_secs_f64() * 1000.0;
268
269 let t_enc = Instant::now();
270 let epochs = self.encode_inputs(batches)?;
271 let ms_encode = t_enc.elapsed().as_secs_f64() * 1000.0;
272
273 Ok(EncodingResult { epochs, fif_info: None, ms_preproc, ms_encode })
274 }
275
276 /// Encode a single prepared [`InputBatch`], returning the raw encoder
277 /// output tensor `[1, S, output_dim]`.
278 ///
279 /// This is the **MMD-regularised embedding**: training constrains the
280 /// distribution to N(0, I); at inference the bottleneck is a passthrough.
281 /// No further normalisation is applied here.
282 pub fn encode_tensor(&self, batch: &InputBatch<B>) -> Tensor<B, 3> {
283 self.encoder.forward(
284 batch.encoder_input.clone(),
285 batch.tok_idx.clone(),
286 &self.rope,
287 )
288 }
289
290 // ── Lower-level API (benchmark / export) ─────────────────────────────────
291
292 /// Run the FIF preprocessing pipeline and return raw [`InputBatch`]es
293 /// **without** running the encoder.
294 ///
295 /// Use together with [`Self::encode_batches`] to time encode separately,
296 /// or to export the pre-tokenised tensors for external comparison.
297 pub fn preprocess_fif(
298 &self,
299 fif_path: &Path,
300 data_norm: f32,
301 ) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
302 load_from_fif(fif_path, &self.data_cfg, data_norm, &self.device)
303 }
304
305 /// Encode a list of [`InputBatch`]es produced by [`Self::preprocess_fif`].
306 pub fn encode_batches(
307 &self,
308 batches: Vec<InputBatch<B>>,
309 ) -> anyhow::Result<Vec<EpochEmbedding>> {
310 self.encode_inputs(batches)
311 }
312
313 /// Reference to the Burn device this encoder was loaded on.
314 pub fn device(&self) -> &B::Device { &self.device }
315
316 // ── Internal ──────────────────────────────────────────────────────────────
317
318 pub(crate) fn encode_inputs(
319 &self,
320 batches: Vec<InputBatch<B>>,
321 ) -> anyhow::Result<Vec<EpochEmbedding>> {
322 batches.into_iter().map(|b| self.encode_one(b)).collect()
323 }
324
325 fn encode_one(&self, batch: InputBatch<B>) -> anyhow::Result<EpochEmbedding> {
326 let n_channels = batch.n_channels;
327 let tc = batch.tc;
328
329 // Keep a copy of tok_idx and chan_pos before the encoder consumes them.
330 let tok_idx_saved = batch.tok_idx.clone();
331 let chan_pos_saved = batch.chan_pos.clone();
332
333 // Run encoder: [1, S, output_dim]
334 let enc_out = self.encoder.forward(
335 batch.encoder_input,
336 batch.tok_idx,
337 &self.rope,
338 );
339 let [_, s, output_dim] = enc_out.dims();
340
341 // Squeeze batch dim → [S, output_dim] and extract as Vec<f32>.
342 let embeddings = enc_out
343 .squeeze::<2>()
344 .into_data()
345 .to_vec::<f32>()
346 .map_err(|e| anyhow::anyhow!("embedding→vec: {e:?}"))?;
347
348 // tok_idx [S, 4] → Vec<i64>.
349 // NdArray backend stores Int as i64; wgpu backend stores Int as i32.
350 // Try i64 first, fall back to i32 and widen.
351 let tok_idx_data = tok_idx_saved.into_data();
352 let tok_idx: Vec<i64> = tok_idx_data
353 .to_vec::<i64>()
354 .or_else(|_| tok_idx_data.to_vec::<i32>()
355 .map(|v| v.into_iter().map(|x| x as i64).collect()))
356 .map_err(|e| anyhow::anyhow!("tok_idx→vec: {e:?}"))?;
357
358 // chan_pos [C, 3] → Vec<f32>.
359 let chan_pos = chan_pos_saved
360 .into_data()
361 .to_vec::<f32>()
362 .map_err(|e| anyhow::anyhow!("chan_pos→vec: {e:?}"))?;
363
364 Ok(EpochEmbedding {
365 embeddings,
366 shape: vec![s, output_dim],
367 tok_idx,
368 chan_pos,
369 n_channels,
370 tc,
371 })
372 }
373}