1use burn::prelude::*;
9use safetensors::SafeTensors;
10use crate::config::DataConfig;
11
12pub fn discretize_chan_pos<B: Backend>(
19 chan_pos: Tensor<B, 2>,
20 cfg: &DataConfig,
21 device: &B::Device,
22) -> Tensor<B, 2, Int> {
23 let [_c, _] = chan_pos.dims();
24 let xyz_min = Tensor::<B, 2>::from_data(
25 TensorData::new(cfg.xyz_min.to_vec(), vec![1, 3]), device,
26 );
27 let xyz_max = Tensor::<B, 2>::from_data(
28 TensorData::new(cfg.xyz_max.to_vec(), vec![1, 3]), device,
29 );
30
31 let norm = (chan_pos - xyz_min.clone()) / (xyz_max - xyz_min); let bins = cfg.num_bins as f32;
33 norm.mul_scalar(bins)
34 .int()
35 .clamp(0i32, cfg.num_bins as i32 - 1)
36}
37
38pub fn chop_and_reshape<B: Backend>(
46 eeg: Tensor<B, 2>, chan_pos: Tensor<B, 2>, chan_pos_disc: Tensor<B, 2, Int>, tf: usize,
50) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
51 let [c, t_total] = eeg.dims();
52 assert_eq!(t_total % tf, 0, "T must be divisible by tf");
53 let tc = t_total / tf;
54 let s = c * tc;
55 let device = eeg.device();
56
57 let eeg_tokens = eeg.reshape([c, tc, tf]).reshape([s, tf]);
59
60 let pos = repeat_interleave_rows_f(chan_pos, tc);
62 let posd = repeat_interleave_rows_i(chan_pos_disc, tc);
63
64 let tc_vals: Vec<i32> = (0..tc as i32)
66 .cycle()
67 .take(s)
68 .collect();
69 let t_coarse = Tensor::<B, 1, Int>::from_data(
70 TensorData::new(tc_vals, vec![s]),
71 &device,
72 )
73 .reshape([s, 1]);
74
75 (eeg_tokens, pos, posd, t_coarse)
76}
77
78pub fn build_tok_idx<B: Backend>(
83 chan_pos_disc: Tensor<B, 2, Int>, t_coarse: Tensor<B, 2, Int>, ) -> Tensor<B, 2, Int> {
86 Tensor::cat(vec![chan_pos_disc, t_coarse], 1) }
88
89pub struct InputBatch<B: Backend> {
92 pub encoder_input: Tensor<B, 3>,
94 pub tok_idx: Tensor<B, 2, Int>,
96 pub chan_pos: Tensor<B, 2>,
98 pub n_channels: usize,
99 pub tc: usize,
100}
101
102pub fn load_batch<B: Backend>(
111 path: &str,
112 cfg: &DataConfig,
113 device: &B::Device,
114) -> anyhow::Result<Vec<InputBatch<B>>> {
115 let bytes = std::fs::read(path)?;
116 let st = SafeTensors::deserialize(&bytes)?;
117
118 let n_samples = {
119 let v = st.tensor("n_samples")?;
120 match v.dtype() {
121 safetensors::Dtype::I32 => {
123 let b: [u8; 4] = v.data().get(..4)
124 .and_then(|s| s.try_into().ok())
125 .ok_or_else(|| anyhow::anyhow!("n_samples I32 too short"))?;
126 i32::from_le_bytes(b) as usize
127 }
128 safetensors::Dtype::F32 => {
129 let b: [u8; 4] = v.data().get(..4)
130 .and_then(|s| s.try_into().ok())
131 .ok_or_else(|| anyhow::anyhow!("n_samples F32 too short"))?;
132 f32::from_le_bytes(b) as usize
133 }
134 other => anyhow::bail!("unexpected dtype for n_samples: {:?}", other),
135 }
136 };
137
138 let mut batches = Vec::with_capacity(n_samples);
139
140 for i in 0..n_samples {
141 let eeg_view = st.tensor(&format!("eeg_{i}"))?;
143 let [c, t]: [usize; 2] = eeg_view.shape().try_into()
144 .map_err(|_| anyhow::anyhow!("eeg_{i} must be 2-D"))?;
145 let eeg_f32 = bytes_to_f32(eeg_view.data(), eeg_view.dtype())?;
146 let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_f32, vec![c, t]), device);
147
148 let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
150 let pos_f32 = bytes_to_f32(pos_view.data(), pos_view.dtype())?;
151 let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_f32, vec![c, 3]), device);
152
153 let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), cfg, device);
154 let tc = t / cfg.num_fine_time_pts;
155
156 let (eeg_tokens, _, posd, t_coarse) =
157 chop_and_reshape(eeg.clone(), chan_pos.clone(), chan_pos_disc, cfg.num_fine_time_pts);
158
159 let tok_idx = build_tok_idx(posd, t_coarse);
160 let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
163 }
164
165 Ok(batches)
166}
167
168pub fn invert_reshape<B: Backend>(
172 tokens: Tensor<B, 2>,
173 n_channels: usize,
174 tc: usize,
175 tf: usize,
176) -> Tensor<B, 2> {
177 tokens.reshape([n_channels, tc, tf]).reshape([n_channels, tc * tf])
178}
179
180pub struct FifInfo {
184 pub ch_names: Vec<String>,
186 pub ch_pos_mm: Vec<[f32; 3]>,
188 pub sfreq: f32,
190 pub n_times_raw: usize,
192 pub duration_s: f32,
194 pub n_epochs: usize,
196 pub target_sfreq: f32,
198 pub epoch_dur_s: f32,
200}
201
202pub fn load_from_fif<B: Backend>(
213 path: &std::path::Path,
214 data_cfg: &DataConfig,
215 data_norm: f32,
216 device: &B::Device,
217) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
218 use exg::{
219 fiff::raw::open_raw,
220 PipelineConfig,
221 };
222 use ndarray::Array2;
223
224 let raw_fif = open_raw(path)?;
226 let src_sfreq = raw_fif.info.sfreq as f32;
227 let n_ch = raw_fif.info.n_chan;
228 let n_times_raw = raw_fif.n_times();
229 let duration_s = n_times_raw as f32 / src_sfreq;
230
231 let ch_names: Vec<String> = raw_fif.info.chs.iter()
233 .map(|ch| ch.name.clone())
234 .collect();
235 let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
236 .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
237 .collect();
238
239 let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
240 .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
241 .collect();
242 let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
243
244 let data_f64 = raw_fif.read_all_data()?;
246 let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
247
248 let preproc_cfg = PipelineConfig {
250 data_norm,
251 ..PipelineConfig::default()
252 };
253
254 let epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
255 let n_epochs = epochs.len();
256
257 let mut batches = Vec::with_capacity(n_epochs);
259
260 for (eeg_arr, pos_arr) in epochs {
261 let (c, t) = eeg_arr.dim();
262
263 let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
264 let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
265
266 let pos_data: Vec<f32> = pos_arr.iter().copied().collect();
267 let chan_pos_t = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
268
269 let chan_pos_disc = discretize_chan_pos(chan_pos_t.clone(), data_cfg, device);
270 let tc = t / data_cfg.num_fine_time_pts;
271
272 let (eeg_tokens, _, posd, t_coarse) = chop_and_reshape(
273 eeg,
274 chan_pos_t.clone(),
275 chan_pos_disc,
276 data_cfg.num_fine_time_pts,
277 );
278
279 let tok_idx = build_tok_idx(posd, t_coarse);
280 let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); batches.push(InputBatch {
283 encoder_input,
284 tok_idx,
285 chan_pos: chan_pos_t,
286 n_channels: c,
287 tc,
288 });
289 }
290
291 let info = FifInfo {
292 ch_names,
293 ch_pos_mm,
294 sfreq: src_sfreq,
295 n_times_raw,
296 duration_s,
297 n_epochs,
298 target_sfreq: preproc_cfg.target_sfreq,
299 epoch_dur_s: preproc_cfg.epoch_dur,
300 };
301
302 Ok((batches, info))
303}
304
305pub struct PreprocessedEpoch {
310 pub eeg_tokens: Vec<f32>,
312 pub tok_idx: Vec<i32>,
314 pub chan_pos: Vec<f32>,
316 pub s: usize,
318 pub tf: usize,
320 pub n_channels: usize,
322 pub tc: usize,
324}
325
326pub struct PreprocessedFif {
328 pub epochs: Vec<PreprocessedEpoch>,
329 pub info: FifInfo,
330}
331
332pub fn preprocess_fif_cpu(
338 path: &std::path::Path,
339 data_cfg: &DataConfig,
340 data_norm: f32,
341) -> anyhow::Result<PreprocessedFif> {
342 use exg::{fiff::raw::open_raw, PipelineConfig};
343 use ndarray::Array2;
344
345 let raw_fif = open_raw(path)?;
346 let src_sfreq = raw_fif.info.sfreq as f32;
347 let n_ch = raw_fif.info.n_chan;
348 let n_times_raw = raw_fif.n_times();
349 let duration_s = n_times_raw as f32 / src_sfreq;
350
351 let ch_names: Vec<String> = raw_fif.info.chs.iter().map(|ch| ch.name.clone()).collect();
352 let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
353 .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
354 .collect();
355 let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
356 .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
357 .collect();
358 let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
359
360 let data_f64 = raw_fif.read_all_data()?;
361 let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
362
363 let preproc_cfg = PipelineConfig { data_norm, ..PipelineConfig::default() };
364 let exg_epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
365 let n_epochs = exg_epochs.len();
366
367 let tf = data_cfg.num_fine_time_pts;
369 let mut epochs = Vec::with_capacity(n_epochs);
370
371 for (eeg_arr, pos_arr) in exg_epochs {
372 let (c, t) = eeg_arr.dim();
373 let tc = t / tf;
374
375 let bins = data_cfg.num_bins as f32;
377 let disc: Vec<i32> = pos_arr.iter().enumerate().map(|(i, &v)| {
378 let axis = i % 3;
379 let lo = data_cfg.xyz_min[axis];
380 let hi = data_cfg.xyz_max[axis];
381 let norm = (v - lo) / (hi - lo);
382 (norm * bins).min(bins - 1.0).max(0.0) as i32
383 }).collect();
384
385 let s = c * tc;
388 let mut eeg_tokens = vec![0f32; s * tf];
389 let mut tok_idx = vec![0i32; s * 4];
390
391 for ch in 0..c {
392 for ti in 0..tc {
393 let token = ch * tc + ti;
394 for f in 0..tf {
396 eeg_tokens[token * tf + f] = eeg_arr[[ch, ti * tf + f]];
397 }
398 tok_idx[token * 4] = disc[ch * 3];
400 tok_idx[token * 4 + 1] = disc[ch * 3 + 1];
401 tok_idx[token * 4 + 2] = disc[ch * 3 + 2];
402 tok_idx[token * 4 + 3] = ti as i32;
403 }
404 }
405
406 let chan_pos: Vec<f32> = pos_arr.iter().copied().collect();
407
408 epochs.push(PreprocessedEpoch { eeg_tokens, tok_idx, chan_pos, s, tf, n_channels: c, tc });
409 }
410
411 let info = FifInfo {
412 ch_names, ch_pos_mm, sfreq: src_sfreq, n_times_raw, duration_s,
413 n_epochs, target_sfreq: preproc_cfg.target_sfreq, epoch_dur_s: preproc_cfg.epoch_dur,
414 };
415
416 Ok(PreprocessedFif { epochs, info })
417}
418
419pub fn preprocessed_to_batch<B: Backend>(
421 ep: PreprocessedEpoch,
422 device: &B::Device,
423) -> InputBatch<B> {
424 let s = ep.s;
425 let tf = ep.tf;
426 let c = ep.n_channels;
427
428 let encoder_input = Tensor::<B, 2>::from_data(
429 TensorData::new(ep.eeg_tokens, vec![s, tf]), device,
430 ).unsqueeze_dim::<3>(0); let tok_idx = Tensor::<B, 2, Int>::from_data(
433 TensorData::new(ep.tok_idx, vec![s, 4]), device,
434 );
435
436 let chan_pos = Tensor::<B, 2>::from_data(
437 TensorData::new(ep.chan_pos, vec![c, 3]), device,
438 );
439
440 InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc: ep.tc }
441}
442
443fn repeat_interleave_rows_f<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
446 let [s, c] = t.dims();
447 t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
448}
449
450fn repeat_interleave_rows_i<B: Backend>(
451 t: Tensor<B, 2, Int>,
452 repeats: usize,
453) -> Tensor<B, 2, Int> {
454 let [s, c] = t.dims();
455 t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
456}
457
458fn bytes_to_f32(data: &[u8], dtype: safetensors::Dtype) -> anyhow::Result<Vec<f32>> {
459 match dtype {
460 safetensors::Dtype::F32 =>
461 Ok(data.chunks_exact(4)
462 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
463 .collect()),
464 safetensors::Dtype::BF16 =>
465 Ok(data.chunks_exact(2)
466 .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
467 .collect()),
468 other => anyhow::bail!("unsupported dtype {:?}", other),
469 }
470}