1use burn::prelude::*;
9use safetensors::SafeTensors;
10use crate::config::DataConfig;
11
12pub fn discretize_chan_pos<B: Backend>(
19 chan_pos: Tensor<B, 2>,
20 cfg: &DataConfig,
21 device: &B::Device,
22) -> Tensor<B, 2, Int> {
23 let [_c, _] = chan_pos.dims();
24 let xyz_min = Tensor::<B, 2>::from_data(
25 TensorData::new(cfg.xyz_min.to_vec(), vec![1, 3]), device,
26 );
27 let xyz_max = Tensor::<B, 2>::from_data(
28 TensorData::new(cfg.xyz_max.to_vec(), vec![1, 3]), device,
29 );
30
31 let norm = (chan_pos - xyz_min.clone()) / (xyz_max - xyz_min); let bins = cfg.num_bins as f32;
33 norm.mul_scalar(bins)
34 .int()
35 .clamp(0i32, cfg.num_bins as i32 - 1)
36}
37
38pub fn chop_and_reshape<B: Backend>(
46 eeg: Tensor<B, 2>, chan_pos: Tensor<B, 2>, chan_pos_disc: Tensor<B, 2, Int>, tf: usize,
50) -> (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
51 let [c, t_total] = eeg.dims();
52 assert_eq!(t_total % tf, 0, "T must be divisible by tf");
53 let tc = t_total / tf;
54 let s = c * tc;
55 let device = eeg.device();
56
57 let eeg_tokens = eeg.reshape([c, tc, tf]).reshape([s, tf]);
59
60 let pos = repeat_interleave_rows_f(chan_pos, tc);
62 let posd = repeat_interleave_rows_i(chan_pos_disc, tc);
63
64 let tc_vals: Vec<i32> = (0..tc as i32)
66 .cycle()
67 .take(s)
68 .collect();
69 let t_coarse = Tensor::<B, 1, Int>::from_data(
70 TensorData::new(tc_vals, vec![s]),
71 &device,
72 )
73 .reshape([s, 1]);
74
75 (eeg_tokens, pos, posd, t_coarse)
76}
77
78pub fn build_tok_idx<B: Backend>(
83 chan_pos_disc: Tensor<B, 2, Int>, t_coarse: Tensor<B, 2, Int>, ) -> Tensor<B, 2, Int> {
86 Tensor::cat(vec![chan_pos_disc, t_coarse], 1) }
88
89pub struct InputBatch<B: Backend> {
92 pub encoder_input: Tensor<B, 3>,
94 pub tok_idx: Tensor<B, 2, Int>,
96 pub chan_pos: Tensor<B, 2>,
98 pub n_channels: usize,
99 pub tc: usize,
100}
101
102pub fn load_batch<B: Backend>(
111 path: &str,
112 cfg: &DataConfig,
113 device: &B::Device,
114) -> anyhow::Result<Vec<InputBatch<B>>> {
115 let bytes = std::fs::read(path)?;
116 let st = SafeTensors::deserialize(&bytes)?;
117
118 let n_samples = {
119 let v = st.tensor("n_samples")?;
120 match v.dtype() {
121 safetensors::Dtype::I32 =>
123 i32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
124 safetensors::Dtype::F32 =>
125 f32::from_le_bytes(v.data()[..4].try_into().unwrap()) as usize,
126 other => anyhow::bail!("unexpected dtype for n_samples: {:?}", other),
127 }
128 };
129
130 let mut batches = Vec::with_capacity(n_samples);
131
132 for i in 0..n_samples {
133 let eeg_view = st.tensor(&format!("eeg_{i}"))?;
135 let [c, t]: [usize; 2] = eeg_view.shape().try_into()
136 .map_err(|_| anyhow::anyhow!("eeg_{i} must be 2-D"))?;
137 let eeg_f32 = bytes_to_f32(eeg_view.data(), eeg_view.dtype())?;
138 let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_f32, vec![c, t]), device);
139
140 let pos_view = st.tensor(&format!("chan_pos_{i}"))?;
142 let pos_f32 = bytes_to_f32(pos_view.data(), pos_view.dtype())?;
143 let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_f32, vec![c, 3]), device);
144
145 let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), cfg, device);
146 let tc = t / cfg.num_fine_time_pts;
147
148 let (eeg_tokens, _, posd, t_coarse) =
149 chop_and_reshape(eeg.clone(), chan_pos.clone(), chan_pos_disc, cfg.num_fine_time_pts);
150
151 let tok_idx = build_tok_idx(posd, t_coarse);
152 let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
155 }
156
157 Ok(batches)
158}
159
160pub fn invert_reshape<B: Backend>(
164 tokens: Tensor<B, 2>,
165 n_channels: usize,
166 tc: usize,
167 tf: usize,
168) -> Tensor<B, 2> {
169 tokens.reshape([n_channels, tc, tf]).reshape([n_channels, tc * tf])
170}
171
172pub struct FifInfo {
176 pub ch_names: Vec<String>,
178 pub ch_pos_mm: Vec<[f32; 3]>,
180 pub sfreq: f32,
182 pub n_times_raw: usize,
184 pub duration_s: f32,
186 pub n_epochs: usize,
188 pub target_sfreq: f32,
190 pub epoch_dur_s: f32,
192}
193
194pub fn load_from_fif<B: Backend>(
205 path: &std::path::Path,
206 data_cfg: &DataConfig,
207 data_norm: f32,
208 device: &B::Device,
209) -> anyhow::Result<(Vec<InputBatch<B>>, FifInfo)> {
210 use exg::{
211 fiff::raw::open_raw,
212 PipelineConfig,
213 };
214 use ndarray::Array2;
215
216 let raw_fif = open_raw(path)?;
218 let src_sfreq = raw_fif.info.sfreq as f32;
219 let n_ch = raw_fif.info.n_chan;
220 let n_times_raw = raw_fif.n_times();
221 let duration_s = n_times_raw as f32 / src_sfreq;
222
223 let ch_names: Vec<String> = raw_fif.info.chs.iter()
225 .map(|ch| ch.name.clone())
226 .collect();
227 let ch_pos_mm: Vec<[f32; 3]> = raw_fif.info.chs.iter()
228 .map(|ch| [ch.loc[0] * 1000.0, ch.loc[1] * 1000.0, ch.loc[2] * 1000.0])
229 .collect();
230
231 let pos_flat: Vec<f32> = raw_fif.info.chs.iter()
232 .flat_map(|ch| [ch.loc[0], ch.loc[1], ch.loc[2]])
233 .collect();
234 let chan_pos_arr = Array2::from_shape_vec((n_ch, 3), pos_flat)?;
235
236 let data_f64 = raw_fif.read_all_data()?;
238 let data_f32: Array2<f32> = data_f64.mapv(|v| v as f32);
239
240 let preproc_cfg = PipelineConfig {
242 data_norm,
243 ..PipelineConfig::default()
244 };
245
246 let epochs = exg::preprocess(data_f32, chan_pos_arr, src_sfreq, &preproc_cfg)?;
247 let n_epochs = epochs.len();
248
249 let mut batches = Vec::with_capacity(n_epochs);
251
252 for (eeg_arr, pos_arr) in epochs {
253 let (c, t) = eeg_arr.dim();
254
255 let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
256 let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
257
258 let pos_data: Vec<f32> = pos_arr.iter().copied().collect();
259 let chan_pos_t = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
260
261 let chan_pos_disc = discretize_chan_pos(chan_pos_t.clone(), data_cfg, device);
262 let tc = t / data_cfg.num_fine_time_pts;
263
264 let (eeg_tokens, _, posd, t_coarse) = chop_and_reshape(
265 eeg,
266 chan_pos_t.clone(),
267 chan_pos_disc,
268 data_cfg.num_fine_time_pts,
269 );
270
271 let tok_idx = build_tok_idx(posd, t_coarse);
272 let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0); batches.push(InputBatch {
275 encoder_input,
276 tok_idx,
277 chan_pos: chan_pos_t,
278 n_channels: c,
279 tc,
280 });
281 }
282
283 let info = FifInfo {
284 ch_names,
285 ch_pos_mm,
286 sfreq: src_sfreq,
287 n_times_raw,
288 duration_s,
289 n_epochs,
290 target_sfreq: preproc_cfg.target_sfreq,
291 epoch_dur_s: preproc_cfg.epoch_dur,
292 };
293
294 Ok((batches, info))
295}
296
297fn repeat_interleave_rows_f<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
300 let [s, c] = t.dims();
301 t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
302}
303
304fn repeat_interleave_rows_i<B: Backend>(
305 t: Tensor<B, 2, Int>,
306 repeats: usize,
307) -> Tensor<B, 2, Int> {
308 let [s, c] = t.dims();
309 t.unsqueeze_dim::<3>(1).expand([s, repeats, c]).reshape([s * repeats, c])
310}
311
312fn bytes_to_f32(data: &[u8], dtype: safetensors::Dtype) -> anyhow::Result<Vec<f32>> {
313 match dtype {
314 safetensors::Dtype::F32 =>
315 Ok(data.chunks_exact(4)
316 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
317 .collect()),
318 safetensors::Dtype::BF16 =>
319 Ok(data.chunks_exact(2)
320 .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
321 .collect()),
322 other => anyhow::bail!("unsupported dtype {:?}", other),
323 }
324}