Skip to main content

zuna_rs/
csv_loader.rs

1//! CSV and raw-tensor loading for ZUNA inference.
2//!
3//! Three entry points, all producing the same `Vec<InputBatch<B>>` that
4//! [`ZunaEncoder`](crate::encoder::ZunaEncoder) consumes:
5//!
6//! | Function | Input |
7//! |---|---|
8//! | [`load_from_csv`] | CSV file: timestamp column + channel columns |
9//! | [`load_from_raw_tensor`] | `ndarray::Array2<f32>` + explicit `[f32;3]` positions |
10//! | [`load_from_named_tensor`] | `ndarray::Array2<f32>` + channel names (auto-lookup) |
11//!
12//! ## CSV format
13//!
14//! ```text
15//! timestamp,Fp1,Fp2,F3,F4,C3,C4
16//! 0.000000000e0,2.0721e-05,8.38e-07,...
17//! 3.906250000e-3,...
18//! ```
19//!
20//! - First column must be timestamps in **seconds** (column name is ignored;
21//!   any leading column whose name contains "time" or is index 0 is treated as
22//!   the timestamp).
23//! - Remaining columns are EEG channel values in **volts**.
24//! - Lines starting with `#` are ignored.
25//! - Scientific notation (`1.23e-5`) and plain decimals both accepted.
26//!
27//! ## Padding
28//!
29//! When `target_channels` is set in [`CsvLoadOptions`], channels present in
30//! the target list but absent from the CSV are synthesised:
31//!
32//! | [`PaddingStrategy`] | Data | Position |
33//! |---|---|---|
34//! | `Zero` | all-zero row | overrides → database → centroid |
35//! | `CloneChannel(src)` | copy of the named channel's row | overrides → database → src's pos |
36//! | `CloneNearest` | copy of nearest loaded channel by xyz | overrides → database → centroid |
37//! | `InterpWeighted { k }` | inverse-distance–weighted mean of k nearest real channels | same as CloneNearest |
38//! | `Mirror` | copy of nearest real channel on the opposite hemisphere (X flipped) | database → centroid |
39//! | `MeanRef` | per-sample mean of all real channels (common average reference) | database → centroid |
40//! | `NoPadding` | missing channels are **dropped** — output has fewer channels than the target list | n/a |
41
42use std::collections::HashMap;
43use std::path::Path;
44
45use anyhow::{bail, Context};
46use burn::prelude::*;
47use ndarray::Array2;
48
49use crate::channel_positions::{channel_xyz, nearest_channel, normalise};
50use crate::config::DataConfig;
51use crate::data::{build_tok_idx, chop_and_reshape, discretize_chan_pos, InputBatch};
52
53// ─────────────────────────────────────────────────────────────────────────────
54// Public types
55// ─────────────────────────────────────────────────────────────────────────────
56
57/// How to synthesise EEG channels that are missing from the CSV.
58#[derive(Debug, Clone)]
59pub enum PaddingStrategy {
60    /// Fill the missing channel with zeros.
61    /// Its scalp position is taken from `position_overrides`, then the
62    /// channel-position database, then the centroid of existing channels.
63    Zero,
64
65    /// Clone the data from a specific named channel.
66    /// Position of the new channel: `position_overrides[missing]` →
67    /// database lookup of the *missing* channel name → centroid.
68    CloneChannel(String),
69
70    /// Clone the data from whichever loaded channel is nearest (by Euclidean
71    /// distance) to the missing channel's known position.
72    /// Position of the new channel: `position_overrides[missing]` →
73    /// database lookup of the *missing* channel name → centroid.
74    CloneNearest,
75
76    /// Synthesise by inverse-distance–weighted averaging of the `k` nearest
77    /// real channels.  Uses all real channels when `k` ≥ number of real
78    /// channels.  This is a simple form of scalp-surface interpolation.
79    /// Position: same as [`CloneNearest`](Self::CloneNearest).
80    InterpWeighted { k: usize },
81
82    /// Copy the signal of the nearest real channel on the **opposite**
83    /// hemisphere (the target channel's X coordinate is negated to find
84    /// the "mirror" point, then the closest real channel to that point is
85    /// used).  Useful for symmetric montages where the contralateral
86    /// homologue is the best available substitute.
87    /// Position: database → centroid.
88    Mirror,
89
90    /// Fill with the per-sample mean across **all** real channels.
91    /// This is equivalent to injecting the common-average-reference (CAR)
92    /// signal, which is the least-informative but spectrally neutral choice.
93    /// Position: database → centroid.
94    MeanRef,
95
96    /// **No padding** — channels that are absent from the CSV are silently
97    /// dropped from the output instead of being synthesised.
98    ///
99    /// The returned data will have fewer channels than `target_channels` when
100    /// any targets are missing.  The encoder handles variable-length inputs
101    /// natively, so the resulting [`InputBatch`](crate::data::InputBatch) is
102    /// fully valid.
103    NoPadding,
104}
105
106impl Default for PaddingStrategy {
107    fn default() -> Self { Self::Zero }
108}
109
110/// Options for [`load_from_csv`].
111#[derive(Debug, Clone)]
112pub struct CsvLoadOptions {
113    /// Sampling rate of the CSV data in Hz.  Default: `256.0`.
114    pub sample_rate: f32,
115
116    /// Signal normalisation divisor applied after z-scoring.  Default: `10.0`.
117    pub data_norm: f32,
118
119    /// If set, the output channels are reordered / padded to match this list.
120    /// Channels in the CSV but *not* in this list are discarded.
121    /// Channels in the list but *not* in the CSV are synthesised with [`padding`](Self::padding).
122    pub target_channels: Option<Vec<String>>,
123
124    /// Strategy for synthesising missing channels.  Default: [`PaddingStrategy::Zero`].
125    pub padding: PaddingStrategy,
126
127    /// Per-channel XYZ position overrides (metres).
128    ///
129    /// Keys are matched case-insensitively.  Use this to supply
130    /// *fuzzy coordinates* for channels not in the standard montage database,
131    /// or to override database positions for `CloneNearest` distance queries.
132    pub position_overrides: HashMap<String, [f32; 3]>,
133
134    /// If set, only CSV columns whose normalised name appears in this list are
135    /// treated as **present**.  Other CSV columns are silently ignored — they
136    /// will be synthesised as missing channels if they appear in
137    /// `target_channels`.
138    ///
139    /// Use this to simulate recordings with fewer channels without modifying
140    /// the CSV file (e.g. `--n-channels 6` in the `csv_embed` example).
141    pub channel_whitelist: Option<Vec<String>>,
142}
143
144impl Default for CsvLoadOptions {
145    fn default() -> Self {
146        Self {
147            sample_rate: 256.0,
148            data_norm:   10.0,
149            target_channels:    None,
150            padding:            PaddingStrategy::Zero,
151            position_overrides: HashMap::new(),
152            channel_whitelist:  None,
153        }
154    }
155}
156
157/// Metadata returned alongside the batches by [`load_from_csv`].
158#[derive(Debug)]
159pub struct CsvInfo {
160    /// Final channel names after reordering and padding.
161    pub ch_names: Vec<String>,
162    /// Scalp positions in metres `[C, 3]` after reordering and padding.
163    pub ch_pos_m: Vec<[f32; 3]>,
164    /// Sample rate used (from [`CsvLoadOptions::sample_rate`]).
165    pub sample_rate: f32,
166    /// Number of raw time-samples read from the CSV.
167    pub n_samples_raw: usize,
168    /// Recording duration in seconds.
169    pub duration_s: f32,
170    /// Number of 5-second epochs produced.
171    pub n_epochs: usize,
172    /// Number of channels added by padding.
173    pub n_padded: usize,
174}
175
176// ─────────────────────────────────────────────────────────────────────────────
177// Entry point 1 — CSV file
178// ─────────────────────────────────────────────────────────────────────────────
179
180/// Load EEG data from a CSV file and run the full ZUNA preprocessing pipeline.
181///
182/// The pipeline is identical to [`load_from_fif`](crate::data::load_from_fif):
183/// resample (if needed) → 0.5 Hz highpass FIR → average reference →
184/// global z-score → epoch (5 s) → baseline correction → ÷ data_norm.
185pub fn load_from_csv<B: Backend>(
186    path:     &Path,
187    opts:     &CsvLoadOptions,
188    data_cfg: &DataConfig,
189    device:   &B::Device,
190) -> anyhow::Result<(Vec<InputBatch<B>>, CsvInfo)> {
191    // ── Parse CSV ─────────────────────────────────────────────────────────────
192    let (csv_names, raw_data) = parse_csv(path)
193        .with_context(|| format!("parsing CSV {}", path.display()))?;
194    let (_n_ch_raw, n_t) = raw_data.dim();
195
196    // ── Look up positions for loaded channels ─────────────────────────────────
197    let raw_positions = resolve_positions(&csv_names, &opts.position_overrides);
198
199    // ── Apply target-channel reordering / padding ─────────────────────────────
200    let (padded_data, padded_names, padded_positions, n_padded) =
201        if let Some(ref targets) = opts.target_channels {
202            apply_padding(
203                &raw_data,
204                &csv_names,
205                &raw_positions,
206                targets,
207                &opts.padding,
208                &opts.position_overrides,
209                opts.channel_whitelist.as_deref(),
210            )?
211        } else if let Some(ref wl) = opts.channel_whitelist {
212            // No explicit target — whitelist acts as the target list itself
213            apply_padding(
214                &raw_data,
215                &csv_names,
216                &raw_positions,
217                wl,
218                &opts.padding,
219                &opts.position_overrides,
220                Some(wl),
221            )?
222        } else {
223            (raw_data, csv_names.clone(), raw_positions, 0)
224        };
225
226    let n_ch_final = padded_data.nrows();
227    let duration_s = n_t as f32 / opts.sample_rate;
228
229    // ── Minimum epoch size guard ──────────────────────────────────────────────
230    let min_dur = 5.0_f32;
231    if duration_s < min_dur {
232        bail!(
233            "CSV recording is {duration_s:.2} s, shorter than the minimum \
234             epoch duration of {min_dur} s"
235        );
236    }
237
238    // ── Run exg preprocessing pipeline ───────────────────────────────────────
239    let pos_arr = positions_to_array(&padded_positions, n_ch_final);
240    let batches = run_pipeline(
241        padded_data, pos_arr, opts.sample_rate, opts.data_norm, data_cfg, device,
242    )?;
243    let n_epochs = batches.len();
244
245    let info = CsvInfo {
246        ch_names:      padded_names,
247        ch_pos_m:      padded_positions,
248        sample_rate:   opts.sample_rate,
249        n_samples_raw: n_t,
250        duration_s,
251        n_epochs,
252        n_padded,
253    };
254
255    Ok((batches, info))
256}
257
258// ─────────────────────────────────────────────────────────────────────────────
259// Entry point 2 — raw tensor with explicit XYZ positions
260// ─────────────────────────────────────────────────────────────────────────────
261
262/// Load from a pre-assembled `Array2<f32>` with one **explicit** `[x,y,z]`
263/// position per channel row.
264///
265/// The data must be raw (unprocessed) EEG in volts; the full exg pipeline is
266/// applied internally.  The shape is `[n_channels, n_samples]`.
267pub fn load_from_raw_tensor<B: Backend>(
268    data:      Array2<f32>,
269    positions: &[[f32; 3]],
270    sample_rate: f32,
271    data_norm:   f32,
272    data_cfg:    &DataConfig,
273    device:      &B::Device,
274) -> anyhow::Result<Vec<InputBatch<B>>> {
275    let n_ch = data.nrows();
276    anyhow::ensure!(
277        positions.len() == n_ch,
278        "positions.len() = {} must equal data.nrows() = {}", positions.len(), n_ch
279    );
280
281    let duration_s = data.ncols() as f32 / sample_rate;
282    if duration_s < 5.0 {
283        bail!("recording is {duration_s:.2} s, shorter than the 5 s minimum epoch");
284    }
285
286    let pos_arr = positions_to_array(positions, n_ch);
287    run_pipeline(data, pos_arr, sample_rate, data_norm, data_cfg, device)
288}
289
290// ─────────────────────────────────────────────────────────────────────────────
291// Entry point 3 — raw tensor with channel names (auto position lookup)
292// ─────────────────────────────────────────────────────────────────────────────
293
294/// Load from a pre-assembled `Array2<f32>` using **channel names** to look up
295/// scalp positions from the bundled montage database.
296///
297/// Channels not found in any montage (e.g. custom names) get the centroid of
298/// the remaining channels as their position, which keeps them encodable.
299/// Pass explicit XYZ via `position_overrides` to override any channel.
300pub fn load_from_named_tensor<B: Backend>(
301    data:               Array2<f32>,
302    channel_names:      &[&str],
303    sample_rate:        f32,
304    data_norm:          f32,
305    position_overrides: &HashMap<String, [f32; 3]>,
306    data_cfg:           &DataConfig,
307    device:             &B::Device,
308) -> anyhow::Result<Vec<InputBatch<B>>> {
309    let n_ch = data.nrows();
310    anyhow::ensure!(
311        channel_names.len() == n_ch,
312        "channel_names.len() = {} must equal data.nrows() = {}",
313        channel_names.len(), n_ch
314    );
315
316    let duration_s = data.ncols() as f32 / sample_rate;
317    if duration_s < 5.0 {
318        bail!("recording is {duration_s:.2} s, shorter than the 5 s minimum epoch");
319    }
320
321    let names: Vec<String> = channel_names.iter().map(|s| s.to_string()).collect();
322    let positions = resolve_positions(&names, position_overrides);
323    let pos_arr   = positions_to_array(&positions, n_ch);
324
325    run_pipeline(data, pos_arr, sample_rate, data_norm, data_cfg, device)
326}
327
328// ─────────────────────────────────────────────────────────────────────────────
329// CSV parser (no external dependencies)
330// ─────────────────────────────────────────────────────────────────────────────
331
332/// Parse a CSV file into `(channel_names, data [C, T])`.
333///
334/// Rules:
335/// - Lines starting with `#` are skipped.
336/// - First non-blank, non-comment line is the header.
337/// - The first column is the timestamp column (identified by the header name
338///   containing "time" case-insensitively, or simply by being column index 0).
339/// - All remaining columns are EEG channels.
340fn parse_csv(path: &Path) -> anyhow::Result<(Vec<String>, Array2<f32>)> {
341    let content = std::fs::read_to_string(path)
342        .with_context(|| format!("reading {}", path.display()))?;
343
344    let mut lines = content.lines()
345        .filter(|l| { let t = l.trim(); !t.is_empty() && !t.starts_with('#') });
346
347    // ── Header ────────────────────────────────────────────────────────────────
348    let header_line = lines.next()
349        .ok_or_else(|| anyhow::anyhow!("CSV file is empty"))?;
350    let header: Vec<&str> = header_line.split(',').collect();
351    anyhow::ensure!(header.len() >= 2, "CSV must have at least a timestamp and one channel column");
352
353    // Identify timestamp column (first column, OR first whose name ≈ "time")
354    let ts_col = header.iter().position(|h| {
355        let n = h.trim().to_ascii_lowercase();
356        n.contains("time") || n == "t" || n == "ts"
357    }).unwrap_or(0);
358
359    // Channel names: all columns except the timestamp column
360    let ch_names: Vec<String> = header.iter().enumerate()
361        .filter(|&(i, _)| i != ts_col)
362        .map(|(_, h)| h.trim().to_string())
363        .collect();
364    let n_ch = ch_names.len();
365    anyhow::ensure!(n_ch >= 1, "CSV has no channel columns after timestamp");
366
367    // ── Data rows ─────────────────────────────────────────────────────────────
368    let mut rows: Vec<Vec<f32>> = Vec::new();
369    for (row_idx, line) in lines.enumerate() {
370        let parts: Vec<&str> = line.split(',').collect();
371        anyhow::ensure!(
372            parts.len() == header.len(),
373            "row {row_idx}: expected {} columns, got {}", header.len(), parts.len()
374        );
375        let eeg: Vec<f32> = parts.iter().enumerate()
376            .filter(|&(i, _)| i != ts_col)
377            .map(|(_, s)| {
378                s.trim().parse::<f32>()
379                    .with_context(|| format!("row {row_idx}: cannot parse '{}'", s.trim()))
380            })
381            .collect::<anyhow::Result<Vec<f32>>>()?;
382        rows.push(eeg);
383    }
384
385    let n_t = rows.len();
386    anyhow::ensure!(n_t >= 1, "CSV has no data rows");
387
388    // ── Assemble [C, T] array ─────────────────────────────────────────────────
389    // rows is currently [T, C]; transpose to [C, T]
390    let mut flat = vec![0f32; n_ch * n_t];
391    for (t, row) in rows.iter().enumerate() {
392        for (c, &v) in row.iter().enumerate() {
393            flat[c * n_t + t] = v;
394        }
395    }
396    let data = Array2::from_shape_vec((n_ch, n_t), flat)
397        .context("assembling data array")?;
398
399    Ok((ch_names, data))
400}
401
402// ─────────────────────────────────────────────────────────────────────────────
403// Position helpers
404// ─────────────────────────────────────────────────────────────────────────────
405
406/// Resolve XYZ positions for a list of channel names.
407///
408/// Priority per channel:
409/// 1. `overrides` map (case-insensitive normalised key)
410/// 2. [`channel_xyz`] database
411/// 3. `[0.0, 0.0, 0.0]` placeholder — will be replaced by centroid after all
412///    known channels are resolved.
413fn resolve_positions(
414    names:     &[String],
415    overrides: &HashMap<String, [f32; 3]>,
416) -> Vec<[f32; 3]> {
417    let mut positions: Vec<[f32; 3]> = names.iter().map(|name| {
418        // 1. override map
419        let key = normalise(name);
420        if let Some(&xyz) = overrides.iter().find(|(k, _)| normalise(k) == key).map(|(_, v)| v) {
421            return xyz;
422        }
423        // 2. database
424        if let Some(xyz) = channel_xyz(name) {
425            return xyz;
426        }
427        // 3. placeholder
428        [f32::NAN, f32::NAN, f32::NAN]
429    }).collect();
430
431    // Replace NaN placeholders with centroid of known positions
432    let centroid = centroid_of(&positions);
433    for p in &mut positions {
434        if p[0].is_nan() { *p = centroid; }
435    }
436
437    positions
438}
439
440/// Euclidean distance between two 3-D points.
441#[inline]
442fn dist3(a: [f32; 3], b: [f32; 3]) -> f32 {
443    let dx = a[0] - b[0];
444    let dy = a[1] - b[1];
445    let dz = a[2] - b[2];
446    (dx * dx + dy * dy + dz * dz).sqrt()
447}
448
449/// Compute centroid of non-NaN positions; returns `[0,0,0]` if none.
450fn centroid_of(positions: &[[f32; 3]]) -> [f32; 3] {
451    let valid: Vec<_> = positions.iter().filter(|p| !p[0].is_nan()).collect();
452    if valid.is_empty() { return [0.0, 0.0, 0.0]; }
453    let n = valid.len() as f32;
454    let x = valid.iter().map(|p| p[0]).sum::<f32>() / n;
455    let y = valid.iter().map(|p| p[1]).sum::<f32>() / n;
456    let z = valid.iter().map(|p| p[2]).sum::<f32>() / n;
457    [x, y, z]
458}
459
460fn positions_to_array(positions: &[[f32; 3]], n_ch: usize) -> Array2<f32> {
461    let flat: Vec<f32> = positions.iter().flat_map(|p| p.iter().copied()).collect();
462    Array2::from_shape_vec((n_ch, 3), flat).expect("positions_to_array shape mismatch")
463}
464
465// ─────────────────────────────────────────────────────────────────────────────
466// Padding
467// ─────────────────────────────────────────────────────────────────────────────
468
469/// Reorder and pad channels to match `target_channels`.
470///
471/// If `whitelist` is `Some`, only CSV channels whose normalised name appears
472/// in the whitelist are considered "present"; others are ignored.
473///
474/// Returns `(padded_data [C_out, T], padded_names, padded_positions, n_padded)`.
475fn apply_padding(
476    data:      &Array2<f32>,
477    names:     &[String],
478    positions: &[[f32; 3]],
479    targets:   &[String],
480    strategy:  &PaddingStrategy,
481    overrides: &HashMap<String, [f32; 3]>,
482    whitelist: Option<&[String]>,
483) -> anyhow::Result<(Array2<f32>, Vec<String>, Vec<[f32; 3]>, usize)> {
484    let n_t = data.ncols();
485    let mut out_rows:  Vec<Vec<f32>>   = Vec::with_capacity(targets.len());
486    let mut out_names: Vec<String>     = Vec::with_capacity(targets.len());
487    let mut out_pos:   Vec<[f32; 3]>   = Vec::with_capacity(targets.len());
488    let mut n_padded = 0usize;
489
490    // Build a normalised-name → source-index map for loaded channels.
491    // If a whitelist is provided, only whitelisted channels count as "present".
492    let wl_keys: Option<std::collections::HashSet<String>> = whitelist.map(|wl| {
493        wl.iter().map(|n| normalise(n)).collect()
494    });
495    let src_index: HashMap<String, usize> = names.iter().enumerate()
496        .filter(|(_, n)| {
497            wl_keys.as_ref().map_or(true, |wl| wl.contains(&normalise(n)))
498        })
499        .map(|(i, n)| (normalise(n), i))
500        .collect();
501
502    // Positions of loaded channels, useful for CloneNearest.
503    // Restricted to whitelisted channels when whitelist is active.
504    let loaded_xyz_with_idx: Vec<([f32; 3], usize)> = positions.iter().copied()
505        .enumerate()
506        .filter(|(i, _)| src_index.values().any(|&si| si == *i))
507        .map(|(i, xyz)| (xyz, i))
508        .collect();
509
510    for target in targets {
511        let key = normalise(target);
512        if let Some(&src) = src_index.get(&key) {
513            // Channel present in CSV — use it as-is
514            out_rows.push(data.row(src).to_vec());
515            out_names.push(target.clone());
516            out_pos.push(positions[src]);
517        } else if matches!(strategy, PaddingStrategy::NoPadding) {
518            // Drop the missing channel entirely — no synthesis, no row added.
519            n_padded += 1;
520            continue;
521        } else {
522            // Channel missing — synthesise
523            n_padded += 1;
524
525            // Position for the new channel
526            let new_pos = position_for_missing(target, overrides, positions);
527
528            let new_row = match strategy {
529                PaddingStrategy::Zero => {
530                    vec![0f32; n_t]
531                }
532                PaddingStrategy::CloneChannel(src_name) => {
533                    let src_key = normalise(src_name);
534                    let src_idx = src_index.get(&src_key).copied()
535                        .ok_or_else(|| anyhow::anyhow!(
536                            "CloneChannel source '{}' not found in CSV", src_name
537                        ))?;
538                    data.row(src_idx).to_vec()
539                }
540                PaddingStrategy::CloneNearest => {
541                    // Find loaded channel whose position is closest to `new_pos`
542                    let nearest_idx = nearest_channel(new_pos, &loaded_xyz_with_idx)
543                        .unwrap_or(0);
544                    data.row(nearest_idx).to_vec()
545                }
546
547                PaddingStrategy::InterpWeighted { k } => {
548                    // Sort real channels by L2 distance, keep k nearest, then
549                    // form an inverse-distance–weighted average.
550                    let mut dists: Vec<(f32, usize)> = loaded_xyz_with_idx.iter()
551                        .map(|&(xyz, idx)| (dist3(xyz, new_pos), idx))
552                        .collect();
553                    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
554                    let k_actual = (*k).min(dists.len()).max(1);
555                    let k_slice  = &dists[..k_actual];
556                    // weight_i = 1/d_i  (replace exact-zero distance with large weight)
557                    let weights: Vec<f32> = k_slice.iter()
558                        .map(|(d, _)| if *d < 1e-6 { 1e6_f32 } else { 1.0 / d })
559                        .collect();
560                    let w_sum: f32 = weights.iter().sum();
561                    let mut interp = vec![0f32; n_t];
562                    for ((_, idx), w) in k_slice.iter().zip(weights.iter()) {
563                        let wn = w / w_sum;
564                        for (o, &v) in interp.iter_mut().zip(data.row(*idx).iter()) {
565                            *o += wn * v;
566                        }
567                    }
568                    interp
569                }
570
571                PaddingStrategy::Mirror => {
572                    // Flip the target's X coordinate to the opposite hemisphere,
573                    // then find the nearest real channel to that mirror position.
574                    let mirror_pos = [-new_pos[0], new_pos[1], new_pos[2]];
575                    let nearest_idx = nearest_channel(mirror_pos, &loaded_xyz_with_idx)
576                        .unwrap_or_else(|| loaded_xyz_with_idx.first().map(|&(_, i)| i).unwrap_or(0));
577                    data.row(nearest_idx).to_vec()
578                }
579
580                PaddingStrategy::MeanRef => {
581                    // Per-sample mean of all real channels.
582                    let n_real = loaded_xyz_with_idx.len().max(1);
583                    let mut mean_sig = vec![0f32; n_t];
584                    for &(_, idx) in &loaded_xyz_with_idx {
585                        for (m, &v) in mean_sig.iter_mut().zip(data.row(idx).iter()) {
586                            *m += v;
587                        }
588                    }
589                    for m in &mut mean_sig { *m /= n_real as f32; }
590                    mean_sig
591                }
592
593                // Handled by the early `continue` branch above.
594                PaddingStrategy::NoPadding => unreachable!(),
595            };
596
597            out_rows.push(new_row);
598            out_names.push(target.clone());
599            out_pos.push(new_pos);
600        }
601    }
602
603    let n_out = out_rows.len();
604    let flat: Vec<f32> = out_rows.into_iter().flatten().collect();
605    let padded = Array2::from_shape_vec((n_out, n_t), flat)
606        .context("assembling padded data array")?;
607
608    Ok((padded, out_names, out_pos, n_padded))
609}
610
611/// Determine the XYZ position for a missing channel.
612///
613/// Priority: position_overrides → database lookup → centroid of existing.
614fn position_for_missing(
615    name:      &str,
616    overrides: &HashMap<String, [f32; 3]>,
617    existing:  &[[f32; 3]],
618) -> [f32; 3] {
619    let key = normalise(name);
620    if let Some(&xyz) = overrides.iter().find(|(k, _)| normalise(k) == key).map(|(_, v)| v) {
621        return xyz;
622    }
623    if let Some(xyz) = channel_xyz(name) {
624        return xyz;
625    }
626    centroid_of(existing)
627}
628
629// ─────────────────────────────────────────────────────────────────────────────
630// Shared preprocessing pipeline
631// ─────────────────────────────────────────────────────────────────────────────
632
633/// Run the full exg preprocessing pipeline and assemble `InputBatch` structs.
634///
635/// Pipeline (identical to [`load_from_fif`](crate::data::load_from_fif)):
636/// resample → 0.5 Hz HP FIR → average reference → global z-score →
637/// epoch (5 s) → baseline correction → ÷ data_norm
638fn run_pipeline<B: Backend>(
639    data:        Array2<f32>,    // [C, T] raw EEG in volts
640    pos_arr:     Array2<f32>,    // [C, 3] metres
641    sample_rate: f32,
642    data_norm:   f32,
643    data_cfg:    &DataConfig,
644    device:      &B::Device,
645) -> anyhow::Result<Vec<InputBatch<B>>> {
646    use exg::PipelineConfig;
647
648    let cfg = PipelineConfig { data_norm, ..PipelineConfig::default() };
649    let epochs = exg::preprocess(data, pos_arr, sample_rate, &cfg)?;
650
651    if epochs.is_empty() {
652        bail!("recording produced zero epochs (likely shorter than the 5 s minimum epoch)");
653    }
654
655    let mut batches = Vec::with_capacity(epochs.len());
656    for (eeg_arr, pos_out) in epochs {
657        let (c, t) = eeg_arr.dim();
658        let eeg_data: Vec<f32> = eeg_arr.iter().copied().collect();
659        let eeg = Tensor::<B, 2>::from_data(TensorData::new(eeg_data, vec![c, t]), device);
660
661        let pos_data: Vec<f32> = pos_out.iter().copied().collect();
662        let chan_pos = Tensor::<B, 2>::from_data(TensorData::new(pos_data, vec![c, 3]), device);
663
664        let chan_pos_disc = discretize_chan_pos(chan_pos.clone(), data_cfg, device);
665        let tc = t / data_cfg.num_fine_time_pts;
666
667        let (eeg_tokens, _, posd, t_coarse) =
668            chop_and_reshape(eeg, chan_pos.clone(), chan_pos_disc, data_cfg.num_fine_time_pts);
669
670        let tok_idx       = build_tok_idx(posd, t_coarse);
671        let encoder_input = eeg_tokens.unsqueeze_dim::<3>(0);
672
673        batches.push(InputBatch { encoder_input, tok_idx, chan_pos, n_channels: c, tc });
674    }
675
676    Ok(batches)
677}
678
679// ─────────────────────────────────────────────────────────────────────────────
680// Unit tests
681// ─────────────────────────────────────────────────────────────────────────────
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686
687    /// Write a minimal CSV to a temp file and verify it round-trips.
688    #[test]
689    fn parse_csv_basic() {
690        let content = "timestamp,Fp1,Fp2\n0.0,1e-5,2e-5\n0.004,3e-5,4e-5\n";
691        let path = std::env::temp_dir().join("zuna_test_basic.csv");
692        std::fs::write(&path, content).unwrap();
693        let (names, data) = parse_csv(&path).unwrap();
694        assert_eq!(names, ["Fp1", "Fp2"]);
695        assert_eq!(data.dim(), (2, 2));
696        assert!((data[[0, 0]] - 1e-5_f32).abs() < 1e-10);
697        assert!((data[[1, 1]] - 4e-5_f32).abs() < 1e-10);
698    }
699
700    #[test]
701    fn parse_csv_skips_comments() {
702        let content = "# comment\ntimestamp,C3\n0.0,0.5\n0.004,-0.3\n";
703        let path = std::env::temp_dir().join("zuna_test_comments.csv");
704        std::fs::write(&path, content).unwrap();
705        let (names, data) = parse_csv(&path).unwrap();
706        assert_eq!(names, ["C3"]);
707        assert_eq!(data.dim(), (1, 2));
708    }
709
710    #[test]
711    fn resolve_positions_uses_database() {
712        let pos = resolve_positions(&["Cz".to_string()], &HashMap::new());
713        assert_eq!(pos.len(), 1);
714        let [x, y, z] = pos[0];
715        assert!(x.abs() < 0.12 && y.abs() < 0.12 && z.abs() < 0.12);
716    }
717
718    #[test]
719    fn resolve_positions_override_wins() {
720        let mut ov = HashMap::new();
721        ov.insert("CZ".to_string(), [0.01, 0.02, 0.09]);
722        let pos = resolve_positions(&["Cz".to_string()], &ov);
723        assert_eq!(pos[0], [0.01, 0.02, 0.09]);
724    }
725
726    #[test]
727    fn resolve_positions_unknown_gets_centroid() {
728        let names = vec!["UNKNOWN_XYZ".to_string(), "Cz".to_string()];
729        let pos = resolve_positions(&names, &HashMap::new());
730        // Unknown channel should get centroid of known channels, which is Cz
731        let cz = channel_xyz("Cz").unwrap();
732        let centroid = pos[0]; // unknown channel
733        // centroid of [unknown_placeholder, cz] → when unknown is NaN, centroid = cz
734        assert!((centroid[0] - cz[0]).abs() < 1e-5);
735    }
736
737    #[test]
738    fn padding_zero_adds_zero_rows() {
739        let data = Array2::from_shape_vec((2, 4), vec![1f32; 8]).unwrap();
740        let names = vec!["Fp1".to_string(), "Fp2".to_string()];
741        let pos = resolve_positions(&names, &HashMap::new());
742        let targets = vec!["Fp1".to_string(), "Fp2".to_string(), "Fz".to_string()];
743        let (out, out_names, out_pos, n_padded) = apply_padding(
744            &data, &names, &pos, &targets, &PaddingStrategy::Zero, &HashMap::new(), None
745        ).unwrap();
746        assert_eq!(out.dim(), (3, 4));
747        assert_eq!(n_padded, 1);
748        assert_eq!(out_names[2], "Fz");
749        // Fz row must be all zeros
750        assert!(out.row(2).iter().all(|&v| v == 0.0));
751        // Fz must have a known position (from database)
752        let [x, y, z] = out_pos[2];
753        assert!(x.abs() < 0.12 && y.abs() < 0.12 && z.abs() < 0.12);
754    }
755
756    #[test]
757    fn padding_clone_channel() {
758        let data = Array2::from_shape_vec((2, 4), (0..8).map(|i| i as f32).collect()).unwrap();
759        let names = vec!["Fp1".to_string(), "Fp2".to_string()];
760        let pos = resolve_positions(&names, &HashMap::new());
761        let targets = vec!["Fp1".to_string(), "Cz".to_string()];  // Cz missing
762        let (out, _, _, n_padded) = apply_padding(
763            &data, &names, &pos, &targets,
764            &PaddingStrategy::CloneChannel("Fp1".to_string()), &HashMap::new(), None
765        ).unwrap();
766        assert_eq!(n_padded, 1);
767        // Cz row should equal Fp1 row
768        assert_eq!(out.row(0).to_vec(), out.row(1).to_vec());
769    }
770
771    #[test]
772    fn padding_clone_nearest() {
773        // Fp1 and Fp2 are close together; Fz is between them and Cz
774        let data = Array2::from_shape_vec((2, 4), (0..8).map(|i| i as f32 * 0.1).collect()).unwrap();
775        let names = vec!["Fp1".to_string(), "Fp2".to_string()];
776        let pos = resolve_positions(&names, &HashMap::new());
777        let targets = vec!["Fp1".to_string(), "Fp2".to_string(), "AF7".to_string()];
778        let (out, _, _, n_padded) = apply_padding(
779            &data, &names, &pos, &targets,
780            &PaddingStrategy::CloneNearest, &HashMap::new(), None
781        ).unwrap();
782        assert_eq!(n_padded, 1);
783        // AF7 is near Fp1/Fp2 front — cloned from one of them, must be nonzero
784        assert!(out.row(2).iter().any(|&v| v != 0.0));
785    }
786}