1use std::collections::HashMap;
43use std::path::Path;
44
45use anyhow::{bail, Context};
46use burn::prelude::*;
47use ndarray::Array2;
48
49use crate::channel_positions::{channel_xyz, nearest_channel, normalise};
50use crate::config::DataConfig;
51use crate::data::{build_tok_idx, chop_and_reshape, discretize_chan_pos, InputBatch};
52
53#[derive(Debug, Clone)]
59pub enum PaddingStrategy {
60 Zero,
64
65 CloneChannel(String),
69
70 CloneNearest,
75
76 InterpWeighted { k: usize },
81
82 Mirror,
89
90 MeanRef,
95
96 NoPadding,
104}
105
106impl Default for PaddingStrategy {
107 fn default() -> Self { Self::Zero }
108}
109
110#[derive(Debug, Clone)]
112pub struct CsvLoadOptions {
113 pub sample_rate: f32,
115
116 pub data_norm: f32,
118
119 pub target_channels: Option<Vec<String>>,
123
124 pub padding: PaddingStrategy,
126
127 pub position_overrides: HashMap<String, [f32; 3]>,
133
134 pub channel_whitelist: Option<Vec<String>>,
142}
143
144impl Default for CsvLoadOptions {
145 fn default() -> Self {
146 Self {
147 sample_rate: 256.0,
148 data_norm: 10.0,
149 target_channels: None,
150 padding: PaddingStrategy::Zero,
151 position_overrides: HashMap::new(),
152 channel_whitelist: None,
153 }
154 }
155}
156
157#[derive(Debug)]
159pub struct CsvInfo {
160 pub ch_names: Vec<String>,
162 pub ch_pos_m: Vec<[f32; 3]>,
164 pub sample_rate: f32,
166 pub n_samples_raw: usize,
168 pub duration_s: f32,
170 pub n_epochs: usize,
172 pub n_padded: usize,
174}
175
176pub fn load_from_csv<B: Backend>(
186 path: &Path,
187 opts: &CsvLoadOptions,
188 data_cfg: &DataConfig,
189 device: &B::Device,
190) -> anyhow::Result<(Vec<InputBatch<B>>, CsvInfo)> {
191 let (csv_names, raw_data) = parse_csv(path)
193 .with_context(|| format!("parsing CSV {}", path.display()))?;
194 let (_n_ch_raw, n_t) = raw_data.dim();
195
196 let raw_positions = resolve_positions(&csv_names, &opts.position_overrides);
198
199 let (padded_data, padded_names, padded_positions, n_padded) =
201 if let Some(ref targets) = opts.target_channels {
202 apply_padding(
203 &raw_data,
204 &csv_names,
205 &raw_positions,
206 targets,
207 &opts.padding,
208 &opts.position_overrides,
209 opts.channel_whitelist.as_deref(),
210 )?
211 } else if let Some(ref wl) = opts.channel_whitelist {
212 apply_padding(
214 &raw_data,
215 &csv_names,
216 &raw_positions,
217 wl,
218 &opts.padding,
219 &opts.position_overrides,
220 Some(wl),
221 )?
222 } else {
223 (raw_data, csv_names.clone(), raw_positions, 0)
224 };
225
226 let n_ch_final = padded_data.nrows();
227 let duration_s = n_t as f32 / opts.sample_rate;
228
229 let min_dur = 5.0_f32;
231 if duration_s < min_dur {
232 bail!(
233 "CSV recording is {duration_s:.2} s, shorter than the minimum \
234 epoch duration of {min_dur} s"
235 );
236 }
237
238 let pos_arr = positions_to_array(&padded_positions, n_ch_final);
240 let batches = run_pipeline(
241 padded_data, pos_arr, opts.sample_rate, opts.data_norm, data_cfg, device,
242 )?;
243 let n_epochs = batches.len();
244
245 let info = CsvInfo {
246 ch_names: padded_names,
247 ch_pos_m: padded_positions,
248 sample_rate: opts.sample_rate,
249 n_samples_raw: n_t,
250 duration_s,
251 n_epochs,
252 n_padded,
253 };
254
255 Ok((batches, info))
256}
257
258pub fn load_from_raw_tensor<B: Backend>(
268 data: Array2<f32>,
269 positions: &[[f32; 3]],
270 sample_rate: f32,
271 data_norm: f32,
272 data_cfg: &DataConfig,
273 device: &B::Device,
274) -> anyhow::Result<Vec<InputBatch<B>>> {
275 let n_ch = data.nrows();
276 anyhow::ensure!(
277 positions.len() == n_ch,
278 "positions.len() = {} must equal data.nrows() = {}", positions.len(), n_ch
279 );
280
281 let duration_s = data.ncols() as f32 / sample_rate;
282 if duration_s < 5.0 {
283 bail!("recording is {duration_s:.2} s, shorter than the 5 s minimum epoch");
284 }
285
286 let pos_arr = positions_to_array(positions, n_ch);
287 run_pipeline(data, pos_arr, sample_rate, data_norm, data_cfg, device)
288}
289
290pub fn load_from_named_tensor<B: Backend>(
301 data: Array2<f32>,
302 channel_names: &[&str],
303 sample_rate: f32,
304 data_norm: f32,
305 position_overrides: &HashMap<String, [f32; 3]>,
306 data_cfg: &DataConfig,
307 device: &B::Device,
308) -> anyhow::Result<Vec<InputBatch<B>>> {
309 let n_ch = data.nrows();
310 anyhow::ensure!(
311 channel_names.len() == n_ch,
312 "channel_names.len() = {} must equal data.nrows() = {}",
313 channel_names.len(), n_ch
314 );
315
316 let duration_s = data.ncols() as f32 / sample_rate;
317 if duration_s < 5.0 {
318 bail!("recording is {duration_s:.2} s, shorter than the 5 s minimum epoch");
319 }
320
321 let names: Vec<String> = channel_names.iter().map(|s| s.to_string()).collect();
322 let positions = resolve_positions(&names, position_overrides);
323 let pos_arr = positions_to_array(&positions, n_ch);
324
325 run_pipeline(data, pos_arr, sample_rate, data_norm, data_cfg, device)
326}
327
328fn parse_csv(path: &Path) -> anyhow::Result<(Vec<String>, Array2<f32>)> {
341 let content = std::fs::read_to_string(path)
342 .with_context(|| format!("reading {}", path.display()))?;
343
344 let mut lines = content.lines()
345 .filter(|l| { let t = l.trim(); !t.is_empty() && !t.starts_with('#') });
346
347 let header_line = lines.next()
349 .ok_or_else(|| anyhow::anyhow!("CSV file is empty"))?;
350 let header: Vec<&str> = header_line.split(',').collect();
351 anyhow::ensure!(header.len() >= 2, "CSV must have at least a timestamp and one channel column");
352
353 let ts_col = header.iter().position(|h| {
355 let n = h.trim().to_ascii_lowercase();
356 n.contains("time") || n == "t" || n == "ts"
357 }).unwrap_or(0);
358
359 let ch_names: Vec<String> = header.iter().enumerate()
361 .filter(|&(i, _)| i != ts_col)
362 .map(|(_, h)| h.trim().to_string())
363 .collect();
364 let n_ch = ch_names.len();
365 anyhow::ensure!(n_ch >= 1, "CSV has no channel columns after timestamp");
366
367 let mut rows: Vec<Vec<f32>> = Vec::new();
369 for (row_idx, line) in lines.enumerate() {
370 let parts: Vec<&str> = line.split(',').collect();
371 anyhow::ensure!(
372 parts.len() == header.len(),
373 "row {row_idx}: expected {} columns, got {}", header.len(), parts.len()
374 );
375 let eeg: Vec<f32> = parts.iter().enumerate()
376 .filter(|&(i, _)| i != ts_col)
377 .map(|(_, s)| {
378 s.trim().parse::<f32>()
379 .with_context(|| format!("row {row_idx}: cannot parse '{}'", s.trim()))
380 })
381 .collect::<anyhow::Result<Vec<f32>>>()?;
382 rows.push(eeg);
383 }
384
385 let n_t = rows.len();
386 anyhow::ensure!(n_t >= 1, "CSV has no data rows");
387
388 let mut flat = vec![0f32; n_ch * n_t];
391 for (t, row) in rows.iter().enumerate() {
392 for (c, &v) in row.iter().enumerate() {
393 flat[c * n_t + t] = v;
394 }
395 }
396 let data = Array2::from_shape_vec((n_ch, n_t), flat)
397 .context("assembling data array")?;
398
399 Ok((ch_names, data))
400}
401
402fn resolve_positions(
414 names: &[String],
415 overrides: &HashMap<String, [f32; 3]>,
416) -> Vec<[f32; 3]> {
417 let mut positions: Vec<[f32; 3]> = names.iter().map(|name| {
418 let key = normalise(name);
420 if let Some(&xyz) = overrides.iter().find(|(k, _)| normalise(k) == key).map(|(_, v)| v) {
421 return xyz;
422 }
423 if let Some(xyz) = channel_xyz(name) {
425 return xyz;
426 }
427 [f32::NAN, f32::NAN, f32::NAN]
429 }).collect();
430
431 let centroid = centroid_of(&positions);
433 for p in &mut positions {
434 if p[0].is_nan() { *p = centroid; }
435 }
436
437 positions
438}
439
440#[inline]
442fn dist3(a: [f32; 3], b: [f32; 3]) -> f32 {
443 let dx = a[0] - b[0];
444 let dy = a[1] - b[1];
445 let dz = a[2] - b[2];
446 (dx * dx + dy * dy + dz * dz).sqrt()
447}
448
449fn centroid_of(positions: &[[f32; 3]]) -> [f32; 3] {
451 let valid: Vec<_> = positions.iter().filter(|p| !p[0].is_nan()).collect();
452 if valid.is_empty() { return [0.0, 0.0, 0.0]; }
453 let n = valid.len() as f32;
454 let x = valid.iter().map(|p| p[0]).sum::<f32>() / n;
455 let y = valid.iter().map(|p| p[1]).sum::<f32>() / n;
456 let z = valid.iter().map(|p| p[2]).sum::<f32>() / n;
457 [x, y, z]
458}
459
460fn positions_to_array(positions: &[[f32; 3]], n_ch: usize) -> Array2<f32> {
461 let flat: Vec<f32> = positions.iter().flat_map(|p| p.iter().copied()).collect();
462 Array2::from_shape_vec((n_ch, 3), flat).expect("positions_to_array shape mismatch")
463}
464
465fn apply_padding(
476 data: &Array2<f32>,
477 names: &[String],
478 positions: &[[f32; 3]],
479 targets: &[String],
480 strategy: &PaddingStrategy,
481 overrides: &HashMap<String, [f32; 3]>,
482 whitelist: Option<&[String]>,
483) -> anyhow::Result<(Array2<f32>, Vec<String>, Vec<[f32; 3]>, usize)> {
484 let n_t = data.ncols();
485 let mut out_rows: Vec<Vec<f32>> = Vec::with_capacity(targets.len());
486 let mut out_names: Vec<String> = Vec::with_capacity(targets.len());
487 let mut out_pos: Vec<[f32; 3]> = Vec::with_capacity(targets.len());
488 let mut n_padded = 0usize;
489
490 let wl_keys: Option<std::collections::HashSet<String>> = whitelist.map(|wl| {
493 wl.iter().map(|n| normalise(n)).collect()
494 });
495 let src_index: HashMap<String, usize> = names.iter().enumerate()
496 .filter(|(_, n)| {
497 wl_keys.as_ref().map_or(true, |wl| wl.contains(&normalise(n)))
498 })
499 .map(|(i, n)| (normalise(n), i))
500 .collect();
501
502 let loaded_xyz_with_idx: Vec<([f32; 3], usize)> = positions.iter().copied()
505 .enumerate()
506 .filter(|(i, _)| src_index.values().any(|&si| si == *i))
507 .map(|(i, xyz)| (xyz, i))
508 .collect();
509
510 for target in targets {
511 let key = normalise(target);
512 if let Some(&src) = src_index.get(&key) {
513 out_rows.push(data.row(src).to_vec());
515 out_names.push(target.clone());
516 out_pos.push(positions[src]);
517 } else if matches!(strategy, PaddingStrategy::NoPadding) {
518 n_padded += 1;
520 continue;
521 } else {
522 n_padded += 1;
524
525 let new_pos = position_for_missing(target, overrides, positions);
527
528 let new_row = match strategy {
529 PaddingStrategy::Zero => {
530 vec![0f32; n_t]
531 }
532 PaddingStrategy::CloneChannel(src_name) => {
533 let src_key = normalise(src_name);
534 let src_idx = src_index.get(&src_key).copied()
535 .ok_or_else(|| anyhow::anyhow!(
536 "CloneChannel source '{}' not found in CSV", src_name
537 ))?;
538 data.row(src_idx).to_vec()
539 }
540 PaddingStrategy::CloneNearest => {
541 let nearest_idx = nearest_channel(new_pos, &loaded_xyz_with_idx)
543 .unwrap_or(0);
544 data.row(nearest_idx).to_vec()
545 }
546
547 PaddingStrategy::InterpWeighted { k } => {
548 let mut dists: Vec<(f32, usize)> = loaded_xyz_with_idx.iter()
551 .map(|&(xyz, idx)| (dist3(xyz, new_pos), idx))
552 .collect();
553 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
554 let k_actual = (*k).min(dists.len()).max(1);
555 let k_slice = &dists[..k_actual];
556 let weights: Vec<f32> = k_slice.iter()
558 .map(|(d, _)| if *d < 1e-6 { 1e6_f32 } else { 1.0 / d })
559 .collect();
560 let w_sum: f32 = weights.iter().sum();
561 let mut interp = vec![0f32; n_t];
562 for ((_, idx), w) in k_slice.iter().zip(weights.iter()) {
563 let wn = w / w_sum;
564 for (o, &v) in interp.iter_mut().zip(data.row(*idx).iter()) {
565 *o += wn * v;
566 }
567 }
568 interp
569 }
570
571 PaddingStrategy::Mirror => {
572 let mirror_pos = [-new_pos[0], new_pos[1], new_pos[2]];
575 let nearest_idx = nearest_channel(mirror_pos, &loaded_xyz_with_idx)
576 .unwrap_or_else(|| loaded_xyz_with_idx.first().map(|&(_, i)| i).unwrap_or(0));
577 data.row(nearest_idx).to_vec()
578 }
579
580 PaddingStrategy::MeanRef => {
581 let n_real = loaded_xyz_with_idx.len().max(1);
583 let mut mean_sig = vec![0f32; n_t];
584 for &(_, idx) in &loaded_xyz_with_idx {
585 for (m, &v) in mean_sig.iter_mut().zip(data.row(idx).iter()) {
586 *m += v;
587 }
588 }
589 for m in &mut mean_sig { *m /= n_real as f32; }
590 mean_sig
591 }
592
593 PaddingStrategy::NoPadding => unreachable!(),
595 };
596
597 out_rows.push(new_row);
598 out_names.push(target.clone());
599 out_pos.push(new_pos);
600 }
601 }
602
603 let n_out = out_rows.len();
604 let flat: Vec<f32> = out_rows.into_iter().flatten().collect();
605 let padded = Array2::from_shape_vec((n_out, n_t), flat)
606 .context("assembling padded data array")?;
607
608 Ok((padded, out_names, out_pos, n_padded))
609}
610
611fn position_for_missing(
615 name: &str,
616 overrides: &HashMap<String, [f32; 3]>,
617 existing: &[[f32; 3]],
618) -> [f32; 3] {
619 let key = normalise(name);
620 if let Some(&xyz) = overrides.iter().find(|(k, _)| normalise(k) == key).map(|(_, v)| v) {
621 return xyz;
622 }
623 if let Some(xyz) = channel_xyz(name) {
624 return xyz;
625 }
626 centroid_of(existing)
627}
628
629fn run_pipeline<B: Backend>(
639 data: Array2<f32>, pos_arr: Array2<f32>, sample_rate: f32,
642 data_norm: f32,
643 data_cfg: &DataConfig,
644 device: &B::Device,
645) -> anyhow::Result<Vec<InputBatch<B>>> {
646 use exg::PipelineConfig;
647
648 let cfg = PipelineConfig { data_norm, ..PipelineConfig::default() };
649 let epochs = exg::preprocess(data, pos_arr, sample_rate, &cfg)?;
650
651 if epochs.is_empty() {
652 bail!("recording produced zero epochs (likely shorter than the 5 s minimum epoch)");
653 }
654
655 let mut batches = Vec::with_capacity(epochs.len());
656 for (eeg_arr, pos_out) in epochs {
657 let (c, t) = eeg_arr.dim();
658 let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
659 let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
660
661 let pos_data: Vec<f32> = pos_out.iter().copied().collect();
662 let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
663
664 let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), data_cfg, device);
665 let tc = t / data_cfg.num_fine_time_pts;
666
667 let (eeg_tokens, _, posd, t_coarse) =
668 chop_and_reshape(eeg, chan_pos.clone(), chan_pos_disc, data_cfg.num_fine_time_pts);
669
670 let tok_idx = build_tok_idx(posd, t_coarse);
671 let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0);
672
673 batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
674 }
675
676 Ok(batches)
677}
678
679#[cfg(test)]
684mod tests {
685 use super::*;
686
687 #[test]
689 fn parse_csv_basic() {
690 let content = "timestamp,Fp1,Fp2\n0.0,1e-5,2e-5\n0.004,3e-5,4e-5\n";
691 let path = std::env::temp_dir().join("zuna_test_basic.csv");
692 std::fs::write(&path, content).unwrap();
693 let (names, data) = parse_csv(&path).unwrap();
694 assert_eq!(names, ["Fp1", "Fp2"]);
695 assert_eq!(data.dim(), (2, 2));
696 assert!((data[[0, 0]] - 1e-5_f32).abs() < 1e-10);
697 assert!((data[[1, 1]] - 4e-5_f32).abs() < 1e-10);
698 }
699
700 #[test]
701 fn parse_csv_skips_comments() {
702 let content = "# comment\ntimestamp,C3\n0.0,0.5\n0.004,-0.3\n";
703 let path = std::env::temp_dir().join("zuna_test_comments.csv");
704 std::fs::write(&path, content).unwrap();
705 let (names, data) = parse_csv(&path).unwrap();
706 assert_eq!(names, ["C3"]);
707 assert_eq!(data.dim(), (1, 2));
708 }
709
710 #[test]
711 fn resolve_positions_uses_database() {
712 let pos = resolve_positions(&["Cz".to_string()], &HashMap::new());
713 assert_eq!(pos.len(), 1);
714 let [x, y, z] = pos[0];
715 assert!(x.abs() < 0.12 && y.abs() < 0.12 && z.abs() < 0.12);
716 }
717
718 #[test]
719 fn resolve_positions_override_wins() {
720 let mut ov = HashMap::new();
721 ov.insert("CZ".to_string(), [0.01, 0.02, 0.09]);
722 let pos = resolve_positions(&["Cz".to_string()], &ov);
723 assert_eq!(pos[0], [0.01, 0.02, 0.09]);
724 }
725
726 #[test]
727 fn resolve_positions_unknown_gets_centroid() {
728 let names = vec!["UNKNOWN_XYZ".to_string(), "Cz".to_string()];
729 let pos = resolve_positions(&names, &HashMap::new());
730 let cz = channel_xyz("Cz").unwrap();
732 let centroid = pos[0]; assert!((centroid[0] - cz[0]).abs() < 1e-5);
735 }
736
737 #[test]
738 fn padding_zero_adds_zero_rows() {
739 let data = Array2::from_shape_vec((2, 4), vec![1f32; 8]).unwrap();
740 let names = vec!["Fp1".to_string(), "Fp2".to_string()];
741 let pos = resolve_positions(&names, &HashMap::new());
742 let targets = vec!["Fp1".to_string(), "Fp2".to_string(), "Fz".to_string()];
743 let (out, out_names, out_pos, n_padded) = apply_padding(
744 &data, &names, &pos, &targets, &PaddingStrategy::Zero, &HashMap::new(), None
745 ).unwrap();
746 assert_eq!(out.dim(), (3, 4));
747 assert_eq!(n_padded, 1);
748 assert_eq!(out_names[2], "Fz");
749 assert!(out.row(2).iter().all(|&v| v == 0.0));
751 let [x, y, z] = out_pos[2];
753 assert!(x.abs() < 0.12 && y.abs() < 0.12 && z.abs() < 0.12);
754 }
755
756 #[test]
757 fn padding_clone_channel() {
758 let data = Array2::from_shape_vec((2, 4), (0..8).map(|i| i as f32).collect()).unwrap();
759 let names = vec!["Fp1".to_string(), "Fp2".to_string()];
760 let pos = resolve_positions(&names, &HashMap::new());
761 let targets = vec!["Fp1".to_string(), "Cz".to_string()]; let (out, _, _, n_padded) = apply_padding(
763 &data, &names, &pos, &targets,
764 &PaddingStrategy::CloneChannel("Fp1".to_string()), &HashMap::new(), None
765 ).unwrap();
766 assert_eq!(n_padded, 1);
767 assert_eq!(out.row(0).to_vec(), out.row(1).to_vec());
769 }
770
771 #[test]
772 fn padding_clone_nearest() {
773 let data = Array2::from_shape_vec((2, 4), (0..8).map(|i| i as f32 * 0.1).collect()).unwrap();
775 let names = vec!["Fp1".to_string(), "Fp2".to_string()];
776 let pos = resolve_positions(&names, &HashMap::new());
777 let targets = vec!["Fp1".to_string(), "Fp2".to_string(), "AF7".to_string()];
778 let (out, _, _, n_padded) = apply_padding(
779 &data, &names, &pos, &targets,
780 &PaddingStrategy::CloneNearest, &HashMap::new(), None
781 ).unwrap();
782 assert_eq!(n_padded, 1);
783 assert!(out.row(2).iter().any(|&v| v != 0.0));
785 }
786}