Skip to main content

zuna_rs/
lib.rs

1//! # zuna-rs — ZUNA EEG Foundation Model inference in Rust
2//!
3//! Pure-Rust inference for the [ZUNA](https://huggingface.co/Zyphra/ZUNA)
4//! EEG foundation model, built on [Burn 0.20](https://burn.dev) and
5//! [exg](https://github.com/eugenehp/exg) for FIF preprocessing.
6//!
7//! ## Three entry points
8//!
9//! | Type | Loads | Use case |
10//! |---|---|---|
11//! | [`ZunaInference`] | encoder + decoder | full encode → diffuse → decode pipeline |
12//! | [`ZunaEncoder`]   | encoder only      | produce latent embeddings, save memory |
13//! | [`ZunaDecoder`]   | decoder only      | reconstruct from stored embeddings |
14//!
15//! ## Quick start — full pipeline
16//!
17//! ```rust,ignore
18//! use zuna_rs::{ZunaInference, InferenceResult};
19//!
20//! let (model, _ms) = ZunaInference::<B>::load(
21//!     Path::new("config.json"),
22//!     Path::new("model.safetensors"),
23//!     device,
24//! )?;
25//! let result: InferenceResult = model.run_fif(Path::new("recording.fif"), 50, 1.0, 10.0)?;
26//! result.save_safetensors("output.safetensors")?;
27//! ```
28//!
29//! ## Quick start — encode only
30//!
31//! ```rust,ignore
32//! use zuna_rs::{ZunaEncoder, EncodingResult};
33//!
34//! let (enc, _ms) = ZunaEncoder::<B>::load(
35//!     Path::new("config.json"),
36//!     Path::new("model.safetensors"),
37//!     device,
38//! )?;
39//! let result: EncodingResult = enc.encode_fif(Path::new("recording.fif"), 10.0)?;
40//! result.save_safetensors("data/embeddings.safetensors")?;
41//! ```
42//!
43//! ## Quick start — decode from stored embeddings
44//!
45//! ```rust,ignore
46//! use zuna_rs::{ZunaDecoder, encoder::EncodingResult};
47//!
48//! let embeddings = EncodingResult::load_safetensors("data/embeddings.safetensors")?;
49//! let (dec, _ms) = ZunaDecoder::<B>::load(
50//!     Path::new("config.json"),
51//!     Path::new("model.safetensors"),
52//!     device,
53//! )?;
54//! let result = dec.decode_embeddings(&embeddings, 50, 1.0, 10.0)?;
55//! result.save_safetensors("output.safetensors")?;
56//! ```
57//!
58//! ## Embedding regularisation
59//!
60//! The encoder uses an **MMD (Maximum Mean Discrepancy) bottleneck**: during
61//! training an MMD loss constrains the embedding distribution toward **N(0, I)**.
62//! At inference the bottleneck is a pure passthrough — no reparameterisation is
63//! applied.  Embeddings from [`ZunaEncoder`] or [`ZunaInference::encode_fif`]
64//! are therefore already in the regularised latent space and can be used
65//! directly for downstream tasks.
66
67// ── Internal modules ─────────────────────────────────────────────────────────
68
69pub mod channel_positions;
70pub mod config;
71pub mod csv_export;
72pub mod csv_loader;
73pub mod data;
74pub mod encoder;
75pub mod decoder;
76pub mod inference;
77pub mod model;
78pub mod weights;
79
80// ── Flat re-exports ───────────────────────────────────────────────────────────
81//
82// Everything a downstream user needs is available as `zuna_rs::Foo` without
83// knowing the internal module layout.
84
85// Full pipeline
86pub use inference::{ZunaInference, EpochOutput, InferenceResult};
87
88// Encoder-only
89pub use encoder::{ZunaEncoder, EpochEmbedding, EncodingResult};
90
91// Decoder-only
92pub use decoder::ZunaDecoder;
93
94// Configs
95pub use config::{ModelConfig, DataConfig, InferConfig};
96
97// Data types needed for the lower-level API
98pub use data::{InputBatch, FifInfo};
99
100// Channel position lookup
101pub use channel_positions::{channel_xyz, MontageLayout, montage_channels, nearest_channel, normalise};
102
103// CSV / tensor data loading
104pub use csv_loader::{
105    load_from_csv, load_from_raw_tensor, load_from_named_tensor,
106    PaddingStrategy, CsvLoadOptions, CsvInfo,
107};
108
109// CSV export from FIF
110pub use csv_export::fif_to_csv;