Skip to main content

zuna_rs/
data.rs

1/// Data preparation for ZUNA inference (burn 0.20.1)
2///
3/// Two entry points:
4///   • `load_from_fif`  — read a raw .fif file through the exg
5///                        pipeline and return `InputBatch` structs directly.
6///   • `load_batch`     — read a pre-exported safetensors batch
7///                        (legacy / Python-compatible path).
8use burn::prelude::*;
9use safetensors::SafeTensors;
10use crate::config::DataConfig;
11
12// ── 1. Discretise channel positions ─────────────────────────────────────────
13
14/// Map continuous scalp xyz positions to integer bin indices [0, num_bins-1].
15/// Python equivalent: `discretize_chan_pos` in eeg_data.py.
16///
17/// chan_pos: [C, 3], cfg.xyz_min/max in metres, cfg.num_bins = 50.
18pub fn discretize_chan_pos<B: Backend>(
19    chan_pos: Tensor<B, 2>,
20    cfg:      &DataConfig,
21    device:   &B::Device,
22) -> Tensor<B, 2, Int> {
23    let [_c, _] = chan_pos.dims();
24    let xyz_min = Tensor::<B, 2>::from_data(
25        TensorData::new(cfg.xyz_min.to_vec(), vec![1, 3]), device,
26    );
27    let xyz_max = Tensor::<B, 2>::from_data(
28        TensorData::new(cfg.xyz_max.to_vec(), vec![1, 3]), device,
29    );
30
31    let norm = (chan_pos - xyz_min.clone()) / (xyz_max - xyz_min); // [C, 3] in [0,1]
32    let bins = cfg.num_bins as f32;
33    norm.mul_scalar(bins)
34        .int()
35        .clamp(0i32, cfg.num_bins as i32 - 1)
36}
37
38// ── 2. Chop-and-reshape (mode "B") ──────────────────────────────────────────
39
40/// Reshape [C, T] → [C*tc, tf] token matrix.
41/// Python: `chop_and_reshape_signals(..., use_coarse_time="B")`.
42///
43/// Returns (eeg_tokens [C*tc, tf], chan_pos_rep [C*tc, 3],
44///          chan_pos_disc_rep [C*tc, 3], t_coarse [C*tc, 1])
45pub fn chop_and_reshape<B: Backend>(
46    eeg:          Tensor<B, 2>,       // [C, T]
47    chan_pos:     Tensor<B, 2>,       // [C, 3]
48    chan_pos_disc: Tensor<B, 2, Int>, // [C, 3]
49    tf:           usize,
50) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
51    let [c, t_total] = eeg.dims();
52    assert_eq!(t_total % tf, 0, "T must be divisible by tf");
53    let tc = t_total / tf;
54    let s  = c * tc;
55    let device = eeg.device();
56
57    // [C, T] → [C, tc, tf] → [C*tc, tf]
58    let eeg_tokens = eeg.reshape([c, tc, tf]).reshape([s, tf]);
59
60    // Repeat each channel position tc times: [C, 3] → [C*tc, 3]
61    let pos  = repeat_interleave_rows_f(chan_pos,       tc);
62    let posd = repeat_interleave_rows_i(chan_pos_disc,  tc);
63
64    // t_coarse: [0,1,...,tc-1] repeated C times → [C*tc, 1]
65    let tc_vals: Vec<i32> = (0..tc as i32)
66        .cycle()
67        .take(s)
68        .collect();
69    let t_coarse = Tensor::<B, 1, Int>::from_data(
70        TensorData::new(tc_vals, vec![s]),
71        &device,
72    )
73    .reshape([s, 1]);
74
75    (eeg_tokens, pos, posd, t_coarse)
76}
77
78// ── 3. Build token index [S, 4] ──────────────────────────────────────────────
79
80/// Concatenate discrete (x,y,z) and t_coarse → [S, 4].
81/// Python: `cat((chan_pos_discrete, t_coarse), dim=2)` (we drop the batch dim).
82pub fn build_tok_idx<B: Backend>(
83    chan_pos_disc: Tensor<B, 2, Int>,  // [S, 3]
84    t_coarse:     Tensor<B, 2, Int>,  // [S, 1]
85) -> Tensor<B, 2, Int> {
86    Tensor::cat(vec![chan_pos_disc, t_coarse], 1)  // [S, 4]
87}
88
89// ── 4. InputBatch ─────────────────────────────────────────────────────────────
90
91pub struct InputBatch<B: Backend> {
92    /// [1, S, tf]  — encoder input (normalised, zeroed = dropped channel)
93    pub encoder_input: Tensor<B, 3>,
94    /// [S, 4]  — 4-D RoPE token indices
95    pub tok_idx: Tensor<B, 2, Int>,
96    /// [C, 3]  — continuous channel positions (metres)
97    pub chan_pos: Tensor<B, 2>,
98    pub n_channels: usize,
99    pub tc: usize,
100}
101
102// ── 5. Load a safetensors batch file ─────────────────────────────────────────
103
104/// Load a safetensors file created by `scripts/export_batch.py`.
105///
106/// Expected keys:
107///   `n_samples`       int32 scalar
108///   `eeg_{i}`         float32 [C, T]  (already /data_norm)
109///   `chan_pos_{i}`    float32 [C, 3]
110pub fn load_batch<B: Backend>(
111    path:   &str,
112    cfg:    &DataConfig,
113    device: &B::Device,
114) -> anyhow::Result<Vec<InputBatch<B>>> {
115    let bytes = std::fs::read(path)?;
116    let st    = SafeTensors::deserialize(&bytes)?;
117
118    let n_samples = {
119        let v = st.tensor("n_samples")?;
120        match v.dtype() {
121            // preprocess_fif.py writes I32; infer binary writes F32
122            safetensors::Dtype::I32 =>
123                i32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
124            safetensors::Dtype::F32 =>
125                f32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
126            other => anyhow::bail!("unexpected dtype for n_samples: {:?}", other),
127        }
128    };
129
130    let mut batches = Vec::with_capacity(n_samples);
131
132    for i in 0..n_samples {
133        // EEG signal [C, T]
134        let eeg_view = st.tensor(&format!("eeg_{i}"))?;
135        let [c, t]: [usize; 2] = eeg_view.shape().try_into()
136            .map_err(|_| anyhow::anyhow!("eeg_{i} must be 2-D"))?;
137        let eeg_f32 = bytes_to_f32(eeg_view.data(), eeg_view.dtype())?;
138        let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_f32, vec![c, t]), device);
139
140        // Channel positions [C, 3]
141        let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
142        let pos_f32  = bytes_to_f32(pos_view.data(), pos_view.dtype())?;
143        let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_f32, vec![c, 3]), device);
144
145        let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), cfg, device);
146        let tc = t / cfg.num_fine_time_pts;
147
148        let (eeg_tokens, _, posd, t_coarse) =
149            chop_and_reshape(eeg.clone(), chan_pos.clone(), chan_pos_disc, cfg.num_fine_time_pts);
150
151        let tok_idx     = build_tok_idx(posd, t_coarse);
152        let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); // [1, S, tf]
153
154        batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
155    }
156
157    Ok(batches)
158}
159
160// ── 6. Invert reshape ─────────────────────────────────────────────────────────
161
162/// [C*tc, tf] → [C, T]  (inverse of `chop_and_reshape` mode "B")
163pub fn invert_reshape<B: Backend>(
164    tokens:     Tensor<B, 2>,
165    n_channels: usize,
166    tc:         usize,
167    tf:         usize,
168) -> Tensor<B, 2> {
169    tokens.reshape([n_channels, tc, tf]).reshape([n_channels, tc * tf])
170}
171
172// ── 7. FIF metadata (for verbose printing) ───────────────────────────────────
173
174/// Metadata extracted from a FIF file header — returned alongside batches.
175pub struct FifInfo {
176    /// Channel names in order.
177    pub ch_names: Vec<String>,
178    /// Scalp positions in **millimetres** `[C, 3]` (x=right, y=anterior, z=superior).
179    pub ch_pos_mm: Vec<[f32; 3]>,
180    /// Original sampling rate (Hz).
181    pub sfreq: f32,
182    /// Number of time points in the raw file (before resampling).
183    pub n_times_raw: usize,
184    /// Duration in seconds.
185    pub duration_s: f32,
186    /// Number of epochs produced by the pipeline.
187    pub n_epochs: usize,
188    /// Target sfreq used by the pipeline.
189    pub target_sfreq: f32,
190    /// Epoch duration (s).
191    pub epoch_dur_s: f32,
192}
193
194// ── 8. Load directly from a FIF file ─────────────────────────────────────────
195
196/// Preprocess a `.fif` file through the exg pipeline and return
197/// ready-to-run `InputBatch` structs plus metadata — no Python required.
198///
199/// Pipeline applied (matches `preprocess_fif.py`):
200///   resample → 0.5 Hz highpass FIR → average reference → global z-score
201///   → epoch (5 s @ 256 Hz = 1280 pts) → ÷ data_norm
202///
203/// Channel positions are read from the FIF `ch_info.loc[0..3]` (metres).
204pub fn load_from_fif<B: Backend>(
205    path:      &std::path::Path,
206    data_cfg:  &DataConfig,
207    data_norm: f32,
208    device:    &B::Device,
209) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
210    use exg::{
211        fiff::raw::open_raw,
212        PipelineConfig,
213    };
214    use ndarray::Array2;
215
216    // ── 1. Open FIF ─────────────────────────────────────────────────────────
217    let raw_fif      = open_raw(path)?;
218    let src_sfreq    = raw_fif.info.sfreq as f32;
219    let n_ch         = raw_fif.info.n_chan;
220    let n_times_raw  = raw_fif.n_times();
221    let duration_s   = n_times_raw as f32 / src_sfreq;
222
223    // ── 2. Channel names & positions ────────────────────────────────────────
224    let ch_names: Vec<String> = raw_fif.info.chs.iter()
225        .map(|ch| ch.name.clone())
226        .collect();
227    let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
228        .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
229        .collect();
230
231    let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
232        .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
233        .collect();
234    let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
235
236    // ── 3. Read raw data [C, T] ─────────────────────────────────────────────
237    let data_f64  = raw_fif.read_all_data()?;
238    let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
239
240    // ── 4. Preprocessing pipeline ───────────────────────────────────────────
241    let preproc_cfg = PipelineConfig {
242        data_norm,
243        ..PipelineConfig::default()
244    };
245
246    let epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
247    let n_epochs = epochs.len();
248
249    // ── 5. Convert each epoch to InputBatch<B> ──────────────────────────────
250    let mut batches = Vec::with_capacity(n_epochs);
251
252    for (eeg_arr, pos_arr) in epochs {
253        let (c, t) = eeg_arr.dim();
254
255        let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
256        let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
257
258        let pos_data: Vec<f32> = pos_arr.iter().copied().collect();
259        let chan_pos_t = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
260
261        let chan_pos_disc = discretize_chan_pos(chan_pos_t.clone(), data_cfg, device);
262        let tc = t / data_cfg.num_fine_time_pts;
263
264        let (eeg_tokens, _, posd, t_coarse) = chop_and_reshape(
265            eeg,
266            chan_pos_t.clone(),
267            chan_pos_disc,
268            data_cfg.num_fine_time_pts,
269        );
270
271        let tok_idx       = build_tok_idx(posd, t_coarse);
272        let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); // [1, S, tf]
273
274        batches.push(InputBatch {
275            encoder_input,
276            tok_idx,
277            chan_pos: chan_pos_t,
278            n_channels: c,
279            tc,
280        });
281    }
282
283    let info = FifInfo {
284        ch_names,
285        ch_pos_mm,
286        sfreq: src_sfreq,
287        n_times_raw,
288        duration_s,
289        n_epochs,
290        target_sfreq: preproc_cfg.target_sfreq,
291        epoch_dur_s:  preproc_cfg.epoch_dur,
292    };
293
294    Ok((batches, info))
295}
296
297// ── Helpers ───────────────────────────────────────────────────────────────────
298
299fn repeat_interleave_rows_f<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
300    let [s, c] = t.dims();
301    t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
302}
303
304fn repeat_interleave_rows_i<B: Backend>(
305    t: Tensor<B, 2, Int>,
306    repeats: usize,
307) -> Tensor<B, 2, Int> {
308    let [s, c] = t.dims();
309    t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
310}
311
312fn bytes_to_f32(data: &[u8], dtype: safetensors::Dtype) -> anyhow::Result<Vec<f32>> {
313    match dtype {
314        safetensors::Dtype::F32 =>
315            Ok(data.chunks_exact(4)
316                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
317                .collect()),
318        safetensors::Dtype::BF16 =>
319            Ok(data.chunks_exact(2)
320                .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
321                .collect()),
322        other => anyhow::bail!("unsupported dtype {:?}", other),
323    }
324}