Skip to main content

zuna_rs/
decoder.rs

1//! Standalone ZUNA decoder — reconstruct EEG signals from latent embeddings.
2//!
3//! Use [`ZunaDecoder`] when you want to run only the decoder half of the model,
4//! for example to reconstruct signals from embeddings that were previously
5//! computed with [`crate::encoder::ZunaEncoder`] and saved to disk.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use zuna_rs::decoder::ZunaDecoder;
11//! use zuna_rs::encoder::EncodingResult;
12//!
13//! // Load stored embeddings from a previous encode pass.
14//! let embeddings = EncodingResult::load_safetensors("data/embeddings.safetensors")?;
15//!
16//! // Load decoder weights only.
17//! let (dec, ms) = ZunaDecoder::<B>::load(
18//!     Path::new("config.json"),
19//!     Path::new("model.safetensors"),
20//!     device,
21//! )?;
22//!
23//! // Decode: run the rectified-flow diffusion loop conditioned on embeddings.
24//! let result = dec.decode_embeddings(&embeddings, 50, 1.0, 10.0)?;
25//! result.save_safetensors("output.safetensors")?;
26//! ```
27
28use std::{path::Path, time::Instant};
29
30use anyhow::Context;
31use burn::{prelude::*, tensor::Distribution};
32
33use crate::{
34    config::{DataConfig, ModelConfig},
35    data::invert_reshape,
36    encoder::{EncodingResult, EpochEmbedding},
37    inference::{EpochOutput, InferenceResult},
38    model::{decoder::DecoderTransformer, rope::RotaryEmbedding},
39    weights::load_decoder_weights,
40};
41
42// ── ZunaDecoder ───────────────────────────────────────────────────────────────
43
44/// Standalone ZUNA decoder.
45///
46/// Reconstructs EEG signals from latent embeddings produced by
47/// [`crate::encoder::ZunaEncoder`] using a rectified-flow diffusion loop.
48///
49/// Load with [`ZunaDecoder::load`] (decoder weights only — saves ~50 % memory
50/// vs the full [`crate::ZunaInference`]).
51pub struct ZunaDecoder<B: Backend> {
52    decoder:       DecoderTransformer<B>,
53    rope:          RotaryEmbedding<B>,
54    /// Architecture hyperparameters (from config.json).
55    pub model_cfg: ModelConfig,
56    /// Preprocessing / tokenisation parameters.
57    pub data_cfg:  DataConfig,
58    /// Diffusion noise standard deviation (σ).
59    pub global_sigma: f32,
60    device:        B::Device,
61}
62
63impl<B: Backend> ZunaDecoder<B> {
64    // ── Construction ──────────────────────────────────────────────────────────
65
66    /// Load decoder weights from a HuggingFace `config.json` and
67    /// `model.safetensors`.  Encoder tensors are read from disk but not kept
68    /// in memory.
69    ///
70    /// Returns `(decoder, weight_load_ms)`.
71    pub fn load(
72        config_path:  &Path,
73        weights_path: &Path,
74        device:       B::Device,
75    ) -> anyhow::Result<(Self, f64)> {
76        let cfg_str = std::fs::read_to_string(config_path)
77            .with_context(|| format!("config: {}", config_path.display()))?;
78        let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
79        let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
80            .context("parsing model config")?;
81
82        let rope = RotaryEmbedding::<B>::new(
83            model_cfg.head_dim, model_cfg.rope_dim,
84            model_cfg.max_seqlen, model_cfg.rope_theta, &device,
85        );
86
87        let t = Instant::now();
88        let (decoder, n_heads) = load_decoder_weights::<B>(
89            &model_cfg,
90            weights_path.to_str().context("weights path not valid UTF-8")?,
91            &device,
92        )?;
93        let ms = t.elapsed().as_secs_f64() * 1000.0;
94
95        println!("Detected n_heads = {n_heads}");
96
97        let global_sigma = model_cfg.stft_global_sigma as f32;
98
99        Ok((Self { decoder, rope, model_cfg, data_cfg: DataConfig::default(), global_sigma, device }, ms))
100    }
101
102    // ── Accessors ─────────────────────────────────────────────────────────────
103
104    /// One-line description of the loaded decoder.
105    pub fn describe(&self) -> String {
106        let c = &self.model_cfg;
107        format!(
108            "ZUNA decoder  dim={}  layers={}  head_dim={}  t_dim={}  σ={}",
109            c.dim, c.n_layers, c.head_dim, c.t_dim, self.global_sigma,
110        )
111    }
112
113    // ── High-level decode API ─────────────────────────────────────────────────
114
115    /// Reconstruct EEG signals from pre-computed embeddings.
116    ///
117    /// This is the pure decode path — no FIF loading or preprocessing happens
118    /// here.  The [`EncodingResult`] must have been produced by
119    /// [`ZunaEncoder::encode_fif`] or [`ZunaEncoder::encode_batch`].
120    ///
121    /// # Arguments
122    /// - `embeddings`  — pre-computed encoder latents
123    /// - `steps`       — diffusion denoising steps (50 = full quality, 10 = fast)
124    /// - `cfg`         — classifier-free guidance scale (1.0 = disabled)
125    /// - `data_norm`   — divisor used during preprocessing; multiplied back into
126    ///                   the output to restore the original signal scale
127    pub fn decode_embeddings(
128        &self,
129        embeddings: &EncodingResult,
130        steps:      usize,
131        cfg:        f32,
132        data_norm:  f32,
133    ) -> anyhow::Result<InferenceResult> {
134        let t_dec = Instant::now();
135        let epochs = embeddings.epochs
136            .iter()
137            .map(|ep| self.decode_one(ep, steps, cfg, data_norm))
138            .collect::<anyhow::Result<Vec<_>>>()?;
139        let ms_infer = t_dec.elapsed().as_secs_f64() * 1000.0;
140
141        Ok(InferenceResult {
142            epochs,
143            fif_info:   None,
144            ms_preproc: 0.0,
145            ms_infer,
146        })
147    }
148
149    /// Decode a single epoch from a raw encoder output tensor `[1, S, output_dim]`.
150    ///
151    /// `tok_idx` is `[S, 4]`.  Returns the reconstructed token matrix
152    /// `[1, S, input_dim]` **before** inversion of the chop-and-reshape.
153    pub fn decode_tensor(
154        &self,
155        enc_out:   Tensor<B, 3>,
156        tok_idx:   Tensor<B, 2, Int>,
157        steps:     usize,
158        cfg:       f32,
159    ) -> Tensor<B, 3> {
160        let device = enc_out.device();
161        let [b, s, d] = enc_out.dims();
162        let dt = 1.0_f32 / steps as f32;
163
164        // Initial noise z ~ N(0, σ²)
165        let sigma = self.global_sigma as f64;
166        let mut z = Tensor::<B, 3>::random(
167            [b, s, d],
168            Distribution::Normal(0.0, sigma),
169            &device,
170        );
171
172        // Rectified-flow Euler sampling loop
173        for i in (1..=steps).rev() {
174            let t_val  = dt * i as f32;
175            let time_t = Tensor::<B, 3>::full([b, 1, 1], t_val, &device);
176
177            let vc = self.decoder.forward(
178                z.clone(), enc_out.clone(), time_t.clone(), tok_idx.clone(), &self.rope,
179            );
180
181            let vc = if (cfg - 1.0).abs() > 1e-4 {
182                // Classifier-free guidance: run unconditioned pass with zeros
183                let enc_zeros  = Tensor::zeros([b, s, d], &device);
184                let vc_uncond  = self.decoder.forward(
185                    z.clone(), enc_zeros, time_t, tok_idx.clone(), &self.rope,
186                );
187                vc_uncond.clone() + (vc - vc_uncond).mul_scalar(cfg)
188            } else {
189                vc
190            };
191
192            z = z - vc.mul_scalar(dt);
193        }
194
195        z
196    }
197
198    // ── Internal ──────────────────────────────────────────────────────────────
199
200    fn decode_one(
201        &self,
202        ep:        &EpochEmbedding,
203        steps:     usize,
204        cfg:       f32,
205        data_norm: f32,
206    ) -> anyhow::Result<EpochOutput> {
207        let n_tokens = ep.n_tokens();
208        let dc       = &self.data_cfg;
209
210        // Reconstruct enc_out [1, S, output_dim] from stored Vec<f32>.
211        let enc_out = Tensor::<B, 2>::from_data(
212            TensorData::new(ep.embeddings.clone(), ep.shape.clone()),
213            &self.device,
214        )
215        .unsqueeze_dim::<3>(0);  // [1, S, output_dim]
216
217        // Reconstruct tok_idx [S, 4] from stored Vec<i64>.
218        let tok_idx = Tensor::<B, 2, Int>::from_data(
219            TensorData::new(ep.tok_idx.clone(), vec![n_tokens, 4]),
220            &self.device,
221        );
222
223        // Run diffusion; returns [1, S, input_dim].
224        let z = self.decode_tensor(enc_out, tok_idx, steps, cfg);
225
226        // Invert chop-and-reshape: [1, S, tf] → [C, T].
227        let [_, s, tf] = z.dims();
228        let recon = invert_reshape(
229            z.reshape([s, tf]),
230            ep.n_channels,
231            ep.tc,
232            dc.num_fine_time_pts,
233        );
234        let recon = recon.mul_scalar(data_norm);
235
236        let shape         = recon.dims().to_vec();
237        let reconstructed = recon
238            .into_data()
239            .to_vec::<f32>()
240            .map_err(|e| anyhow::anyhow!("recon→vec: {e:?}"))?;
241        let chan_pos = ep.chan_pos.clone();
242
243        Ok(EpochOutput { reconstructed, shape, chan_pos, n_channels: ep.n_channels })
244    }
245}