Skip to main content

zuna_rs/
data.rs

1/// Data preparation for ZUNA inference (burn 0.20.1)
2///
3/// Two entry points:
4///   • `load_from_fif`  — read a raw .fif file through the exg
5///                        pipeline and return `InputBatch` structs directly.
6///   • `load_batch`     — read a pre-exported safetensors batch
7///                        (legacy / Python-compatible path).
8use burn::prelude::*;
9use safetensors::SafeTensors;
10use crate::config::DataConfig;
11
12// ── 1. Discretise channel positions ─────────────────────────────────────────
13
14/// Map continuous scalp xyz positions to integer bin indices [0, num_bins-1].
15/// Python equivalent: `discretize_chan_pos` in eeg_data.py.
16///
17/// chan_pos: [C, 3], cfg.xyz_min/max in metres, cfg.num_bins = 50.
18pub fn discretize_chan_pos<B: Backend>(
19    chan_pos: Tensor<B, 2>,
20    cfg:      &DataConfig,
21    device:   &B::Device,
22) -> Tensor<B, 2, Int> {
23    let [_c, _] = chan_pos.dims();
24    let xyz_min = Tensor::<B, 2>::from_data(
25        TensorData::new(cfg.xyz_min.to_vec(), vec![1, 3]), device,
26    );
27    let xyz_max = Tensor::<B, 2>::from_data(
28        TensorData::new(cfg.xyz_max.to_vec(), vec![1, 3]), device,
29    );
30
31    let norm = (chan_pos - xyz_min.clone()) / (xyz_max - xyz_min); // [C, 3] in [0,1]
32    let bins = cfg.num_bins as f32;
33    norm.mul_scalar(bins)
34        .int()
35        .clamp(0i32, cfg.num_bins as i32 - 1)
36}
37
38// ── 2. Chop-and-reshape (mode "B") ──────────────────────────────────────────
39
40/// Reshape [C, T] → [C*tc, tf] token matrix.
41/// Python: `chop_and_reshape_signals(..., use_coarse_time="B")`.
42///
43/// Returns (eeg_tokens [C*tc, tf], chan_pos_rep [C*tc, 3],
44///          chan_pos_disc_rep [C*tc, 3], t_coarse [C*tc, 1])
45pub fn chop_and_reshape<B: Backend>(
46    eeg:          Tensor<B, 2>,       // [C, T]
47    chan_pos:     Tensor<B, 2>,       // [C, 3]
48    chan_pos_disc: Tensor<B, 2, Int>, // [C, 3]
49    tf:           usize,
50) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
51    let [c, t_total] = eeg.dims();
52    assert_eq!(t_total % tf, 0, "T must be divisible by tf");
53    let tc = t_total / tf;
54    let s  = c * tc;
55    let device = eeg.device();
56
57    // [C, T] → [C, tc, tf] → [C*tc, tf]
58    let eeg_tokens = eeg.reshape([c, tc, tf]).reshape([s, tf]);
59
60    // Repeat each channel position tc times: [C, 3] → [C*tc, 3]
61    let pos  = repeat_interleave_rows_f(chan_pos,       tc);
62    let posd = repeat_interleave_rows_i(chan_pos_disc,  tc);
63
64    // t_coarse: [0,1,...,tc-1] repeated C times → [C*tc, 1]
65    let tc_vals: Vec<i32> = (0..tc as i32)
66        .cycle()
67        .take(s)
68        .collect();
69    let t_coarse = Tensor::<B, 1, Int>::from_data(
70        TensorData::new(tc_vals, vec![s]),
71        &device,
72    )
73    .reshape([s, 1]);
74
75    (eeg_tokens, pos, posd, t_coarse)
76}
77
78// ── 3. Build token index [S, 4] ──────────────────────────────────────────────
79
80/// Concatenate discrete (x,y,z) and t_coarse → [S, 4].
81/// Python: `cat((chan_pos_discrete, t_coarse), dim=2)` (we drop the batch dim).
82pub fn build_tok_idx<B: Backend>(
83    chan_pos_disc: Tensor<B, 2, Int>,  // [S, 3]
84    t_coarse:     Tensor<B, 2, Int>,  // [S, 1]
85) -> Tensor<B, 2, Int> {
86    Tensor::cat(vec![chan_pos_disc, t_coarse], 1)  // [S, 4]
87}
88
89// ── 4. InputBatch ─────────────────────────────────────────────────────────────
90
91pub struct InputBatch<B: Backend> {
92    /// [1, S, tf]  — encoder input (normalised, zeroed = dropped channel)
93    pub encoder_input: Tensor<B, 3>,
94    /// [S, 4]  — 4-D RoPE token indices
95    pub tok_idx: Tensor<B, 2, Int>,
96    /// [C, 3]  — continuous channel positions (metres)
97    pub chan_pos: Tensor<B, 2>,
98    pub n_channels: usize,
99    pub tc: usize,
100}
101
102// ── 5. Load a safetensors batch file ─────────────────────────────────────────
103
104/// Load a safetensors file created by `scripts/export_batch.py`.
105///
106/// Expected keys:
107///   `n_samples`       int32 scalar
108///   `eeg_{i}`         float32 [C, T]  (already /data_norm)
109///   `chan_pos_{i}`    float32 [C, 3]
110pub fn load_batch<B: Backend>(
111    path:   &str,
112    cfg:    &DataConfig,
113    device: &B::Device,
114) -> anyhow::Result<Vec<InputBatch<B>>> {
115    let bytes = std::fs::read(path)?;
116    let st    = SafeTensors::deserialize(&bytes)?;
117
118    let n_samples = {
119        let v = st.tensor("n_samples")?;
120        match v.dtype() {
121            // preprocess_fif.py writes I32; infer binary writes F32
122            safetensors::Dtype::I32 => {
123                let b: [u8; 4] = v.data().get(..4)
124                    .and_then(|s| s.try_into().ok())
125                    .ok_or_else(|| anyhow::anyhow!("n_samples I32 too short"))?;
126                i32::from_le_bytes(b) as usize
127            }
128            safetensors::Dtype::F32 => {
129                let b: [u8; 4] = v.data().get(..4)
130                    .and_then(|s| s.try_into().ok())
131                    .ok_or_else(|| anyhow::anyhow!("n_samples F32 too short"))?;
132                f32::from_le_bytes(b) as usize
133            }
134            other => anyhow::bail!("unexpected dtype for n_samples: {:?}", other),
135        }
136    };
137
138    let mut batches = Vec::with_capacity(n_samples);
139
140    for i in 0..n_samples {
141        // EEG signal [C, T]
142        let eeg_view = st.tensor(&format!("eeg_{i}"))?;
143        let [c, t]: [usize; 2] = eeg_view.shape().try_into()
144            .map_err(|_| anyhow::anyhow!("eeg_{i} must be 2-D"))?;
145        let eeg_f32 = bytes_to_f32(eeg_view.data(), eeg_view.dtype())?;
146        let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_f32, vec![c, t]), device);
147
148        // Channel positions [C, 3]
149        let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
150        let pos_f32  = bytes_to_f32(pos_view.data(), pos_view.dtype())?;
151        let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_f32, vec![c, 3]), device);
152
153        let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), cfg, device);
154        let tc = t / cfg.num_fine_time_pts;
155
156        let (eeg_tokens, _, posd, t_coarse) =
157            chop_and_reshape(eeg.clone(), chan_pos.clone(), chan_pos_disc, cfg.num_fine_time_pts);
158
159        let tok_idx     = build_tok_idx(posd, t_coarse);
160        let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); // [1, S, tf]
161
162        batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
163    }
164
165    Ok(batches)
166}
167
168// ── 6. Invert reshape ─────────────────────────────────────────────────────────
169
170/// [C*tc, tf] → [C, T]  (inverse of `chop_and_reshape` mode "B")
171pub fn invert_reshape<B: Backend>(
172    tokens:     Tensor<B, 2>,
173    n_channels: usize,
174    tc:         usize,
175    tf:         usize,
176) -> Tensor<B, 2> {
177    tokens.reshape([n_channels, tc, tf]).reshape([n_channels, tc * tf])
178}
179
180// ── 7. FIF metadata (for verbose printing) ───────────────────────────────────
181
182/// Metadata extracted from a FIF file header — returned alongside batches.
183pub struct FifInfo {
184    /// Channel names in order.
185    pub ch_names: Vec<String>,
186    /// Scalp positions in **millimetres** `[C, 3]` (x=right, y=anterior, z=superior).
187    pub ch_pos_mm: Vec<[f32; 3]>,
188    /// Original sampling rate (Hz).
189    pub sfreq: f32,
190    /// Number of time points in the raw file (before resampling).
191    pub n_times_raw: usize,
192    /// Duration in seconds.
193    pub duration_s: f32,
194    /// Number of epochs produced by the pipeline.
195    pub n_epochs: usize,
196    /// Target sfreq used by the pipeline.
197    pub target_sfreq: f32,
198    /// Epoch duration (s).
199    pub epoch_dur_s: f32,
200}
201
202// ── 8. Load directly from a FIF file ─────────────────────────────────────────
203
204/// Preprocess a `.fif` file through the exg pipeline and return
205/// ready-to-run `InputBatch` structs plus metadata — no Python required.
206///
207/// Pipeline applied (matches `preprocess_fif.py`):
208///   resample → 0.5 Hz highpass FIR → average reference → global z-score
209///   → epoch (5 s @ 256 Hz = 1280 pts) → ÷ data_norm
210///
211/// Channel positions are read from the FIF `ch_info.loc[0..3]` (metres).
212pub fn load_from_fif<B: Backend>(
213    path:      &std::path::Path,
214    data_cfg:  &DataConfig,
215    data_norm: f32,
216    device:    &B::Device,
217) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
218    use exg::{
219        fiff::raw::open_raw,
220        PipelineConfig,
221    };
222    use ndarray::Array2;
223
224    // ── 1. Open FIF ─────────────────────────────────────────────────────────
225    let raw_fif      = open_raw(path)?;
226    let src_sfreq    = raw_fif.info.sfreq as f32;
227    let n_ch         = raw_fif.info.n_chan;
228    let n_times_raw  = raw_fif.n_times();
229    let duration_s   = n_times_raw as f32 / src_sfreq;
230
231    // ── 2. Channel names & positions ────────────────────────────────────────
232    let ch_names: Vec<String> = raw_fif.info.chs.iter()
233        .map(|ch| ch.name.clone())
234        .collect();
235    let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
236        .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
237        .collect();
238
239    let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
240        .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
241        .collect();
242    let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
243
244    // ── 3. Read raw data [C, T] ─────────────────────────────────────────────
245    let data_f64  = raw_fif.read_all_data()?;
246    let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
247
248    // ── 4. Preprocessing pipeline ───────────────────────────────────────────
249    let preproc_cfg = PipelineConfig {
250        data_norm,
251        ..PipelineConfig::default()
252    };
253
254    let epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
255    let n_epochs = epochs.len();
256
257    // ── 5. Convert each epoch to InputBatch<B> ──────────────────────────────
258    let mut batches = Vec::with_capacity(n_epochs);
259
260    for (eeg_arr, pos_arr) in epochs {
261        let (c, t) = eeg_arr.dim();
262
263        let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
264        let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
265
266        let pos_data: Vec<f32> = pos_arr.iter().copied().collect();
267        let chan_pos_t = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
268
269        let chan_pos_disc = discretize_chan_pos(chan_pos_t.clone(), data_cfg, device);
270        let tc = t / data_cfg.num_fine_time_pts;
271
272        let (eeg_tokens, _, posd, t_coarse) = chop_and_reshape(
273            eeg,
274            chan_pos_t.clone(),
275            chan_pos_disc,
276            data_cfg.num_fine_time_pts,
277        );
278
279        let tok_idx       = build_tok_idx(posd, t_coarse);
280        let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); // [1, S, tf]
281
282        batches.push(InputBatch {
283            encoder_input,
284            tok_idx,
285            chan_pos: chan_pos_t,
286            n_channels: c,
287            tc,
288        });
289    }
290
291    let info = FifInfo {
292        ch_names,
293        ch_pos_mm,
294        sfreq: src_sfreq,
295        n_times_raw,
296        duration_s,
297        n_epochs,
298        target_sfreq: preproc_cfg.target_sfreq,
299        epoch_dur_s:  preproc_cfg.epoch_dur,
300    };
301
302    Ok((batches, info))
303}
304
305// ── 9. CPU-only FIF preprocessing (no burn tensors — Send-safe) ──────────────
306
307/// Preprocessed epoch data in plain Vecs (no burn tensors).
308/// Safe to produce on any thread and convert to tensors later.
309pub struct PreprocessedEpoch {
310    /// EEG token data, row-major `[S, tf]` float32.
311    pub eeg_tokens: Vec<f32>,
312    /// Token indices, row-major `[S, 4]` int32.
313    pub tok_idx: Vec<i32>,
314    /// Channel positions in metres, row-major `[C, 3]` float32.
315    pub chan_pos: Vec<f32>,
316    /// Number of tokens S = n_channels × tc.
317    pub s: usize,
318    /// Fine time points per token.
319    pub tf: usize,
320    /// Number of EEG channels.
321    pub n_channels: usize,
322    /// Coarse time steps per epoch.
323    pub tc: usize,
324}
325
326/// Result of CPU-only FIF preprocessing.
327pub struct PreprocessedFif {
328    pub epochs: Vec<PreprocessedEpoch>,
329    pub info: FifInfo,
330}
331
332/// Preprocess a FIF file entirely on the CPU without creating burn tensors.
333///
334/// This is the parallel-safe counterpart of [`load_from_fif`]: it returns
335/// plain `Vec<f32>` / `Vec<i32>` buffers that can be produced on a rayon
336/// worker and later converted to tensors on the main thread.
337pub fn preprocess_fif_cpu(
338    path:      &std::path::Path,
339    data_cfg:  &DataConfig,
340    data_norm: f32,
341) -> anyhow::Result<PreprocessedFif> {
342    use exg::{fiff::raw::open_raw, PipelineConfig};
343    use ndarray::Array2;
344
345    let raw_fif     = open_raw(path)?;
346    let src_sfreq   = raw_fif.info.sfreq as f32;
347    let n_ch        = raw_fif.info.n_chan;
348    let n_times_raw = raw_fif.n_times();
349    let duration_s  = n_times_raw as f32 / src_sfreq;
350
351    let ch_names: Vec<String> = raw_fif.info.chs.iter().map(|ch| ch.name.clone()).collect();
352    let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
353        .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
354        .collect();
355    let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
356        .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
357        .collect();
358    let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
359
360    let data_f64 = raw_fif.read_all_data()?;
361    let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
362
363    let preproc_cfg = PipelineConfig { data_norm, ..PipelineConfig::default() };
364    let exg_epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
365    let n_epochs = exg_epochs.len();
366
367    // Discretize + chop using temporary ndarray math (no burn tensors).
368    let tf = data_cfg.num_fine_time_pts;
369    let mut epochs = Vec::with_capacity(n_epochs);
370
371    for (eeg_arr, pos_arr) in exg_epochs {
372        let (c, t) = eeg_arr.dim();
373        let tc = t / tf;
374
375        // Discretize channel positions to bins [0, num_bins-1].
376        let bins = data_cfg.num_bins as f32;
377        let disc: Vec<i32> = pos_arr.iter().enumerate().map(|(i, &v)| {
378            let axis = i % 3;
379            let lo = data_cfg.xyz_min[axis];
380            let hi = data_cfg.xyz_max[axis];
381            let norm = (v - lo) / (hi - lo);
382            (norm * bins).min(bins - 1.0).max(0.0) as i32
383        }).collect();
384
385        // Chop and reshape: [C, T] → [C*tc, tf]
386        // Also build tok_idx [S, 4] = [disc_x, disc_y, disc_z, t_coarse]
387        let s = c * tc;
388        let mut eeg_tokens = vec![0f32; s * tf];
389        let mut tok_idx = vec![0i32; s * 4];
390
391        for ch in 0..c {
392            for ti in 0..tc {
393                let token = ch * tc + ti;
394                // Copy tf samples from eeg_arr[ch, ti*tf .. (ti+1)*tf]
395                for f in 0..tf {
396                    eeg_tokens[token * tf + f] = eeg_arr[[ch, ti * tf + f]];
397                }
398                // tok_idx = [x_bin, y_bin, z_bin, t_coarse]
399                tok_idx[token * 4]     = disc[ch * 3];
400                tok_idx[token * 4 + 1] = disc[ch * 3 + 1];
401                tok_idx[token * 4 + 2] = disc[ch * 3 + 2];
402                tok_idx[token * 4 + 3] = ti as i32;
403            }
404        }
405
406        let chan_pos: Vec<f32> = pos_arr.iter().copied().collect();
407
408        epochs.push(PreprocessedEpoch { eeg_tokens, tok_idx, chan_pos, s, tf, n_channels: c, tc });
409    }
410
411    let info = FifInfo {
412        ch_names, ch_pos_mm, sfreq: src_sfreq, n_times_raw, duration_s,
413        n_epochs, target_sfreq: preproc_cfg.target_sfreq, epoch_dur_s: preproc_cfg.epoch_dur,
414    };
415
416    Ok(PreprocessedFif { epochs, info })
417}
418
419/// Convert a [`PreprocessedEpoch`] to a burn [`InputBatch`] on the given device.
420pub fn preprocessed_to_batch<B: Backend>(
421    ep:     PreprocessedEpoch,
422    device: &B::Device,
423) -> InputBatch<B> {
424    let s  = ep.s;
425    let tf = ep.tf;
426    let c  = ep.n_channels;
427
428    let encoder_input = Tensor::<B, 2>::from_data(
429        TensorData::new(ep.eeg_tokens, vec![s, tf]), device,
430    ).unsqueeze_dim::<3>(0); // [1, S, tf]
431
432    let tok_idx = Tensor::<B, 2, Int>::from_data(
433        TensorData::new(ep.tok_idx, vec![s, 4]), device,
434    );
435
436    let chan_pos = Tensor::<B, 2>::from_data(
437        TensorData::new(ep.chan_pos, vec![c, 3]), device,
438    );
439
440    InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc: ep.tc }
441}
442
443// ── Helpers ───────────────────────────────────────────────────────────────────
444
445fn repeat_interleave_rows_f<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
446    let [s, c] = t.dims();
447    t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
448}
449
450fn repeat_interleave_rows_i<B: Backend>(
451    t: Tensor<B, 2, Int>,
452    repeats: usize,
453) -> Tensor<B, 2, Int> {
454    let [s, c] = t.dims();
455    t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
456}
457
458fn bytes_to_f32(data: &[u8], dtype: safetensors::Dtype) -> anyhow::Result<Vec<f32>> {
459    match dtype {
460        safetensors::Dtype::F32 =>
461            Ok(data.chunks_exact(4)
462                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
463                .collect()),
464        safetensors::Dtype::BF16 =>
465            Ok(data.chunks_exact(2)
466                .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
467                .collect()),
468        other => anyhow::bail!("unsupported dtype {:?}", other),
469    }
470}