zuna_rs/decoder.rs
1//! Standalone ZUNA decoder — reconstruct EEG signals from latent embeddings.
2//!
3//! Use [`ZunaDecoder`] when you want to run only the decoder half of the model,
4//! for example to reconstruct signals from embeddings that were previously
5//! computed with [`crate::encoder::ZunaEncoder`] and saved to disk.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use zuna_rs::decoder::ZunaDecoder;
11//! use zuna_rs::encoder::EncodingResult;
12//!
13//! // Load stored embeddings from a previous encode pass.
14//! let embeddings = EncodingResult::load_safetensors("data/embeddings.safetensors")?;
15//!
16//! // Load decoder weights only.
17//! let (dec, ms) = ZunaDecoder::<B>::load(
18//! Path::new("config.json"),
19//! Path::new("model.safetensors"),
20//! device,
21//! )?;
22//!
23//! // Decode: run the rectified-flow diffusion loop conditioned on embeddings.
24//! let result = dec.decode_embeddings(&embeddings, 50, 1.0, 10.0)?;
25//! result.save_safetensors("output.safetensors")?;
26//! ```
27
28use std::{path::Path, time::Instant};
29
30use anyhow::Context;
31use burn::{prelude::*, tensor::Distribution};
32
33use crate::{
34 config::{DataConfig, ModelConfig},
35 data::invert_reshape,
36 encoder::{EncodingResult, EpochEmbedding},
37 inference::{EpochOutput, InferenceResult},
38 model::{decoder::DecoderTransformer, rope::RotaryEmbedding},
39 weights::load_decoder_weights,
40};
41
42// ── ZunaDecoder ───────────────────────────────────────────────────────────────
43
44/// Standalone ZUNA decoder.
45///
46/// Reconstructs EEG signals from latent embeddings produced by
47/// [`crate::encoder::ZunaEncoder`] using a rectified-flow diffusion loop.
48///
49/// Load with [`ZunaDecoder::load`] (decoder weights only — saves ~50 % memory
50/// vs the full [`crate::ZunaInference`]).
51pub struct ZunaDecoder<B: Backend> {
52 decoder: DecoderTransformer<B>,
53 rope: RotaryEmbedding<B>,
54 /// Architecture hyperparameters (from config.json).
55 pub model_cfg: ModelConfig,
56 /// Preprocessing / tokenisation parameters.
57 pub data_cfg: DataConfig,
58 /// Diffusion noise standard deviation (σ).
59 pub global_sigma: f32,
60 device: B::Device,
61}
62
63impl<B: Backend> ZunaDecoder<B> {
64 // ── Construction ──────────────────────────────────────────────────────────
65
66 /// Load decoder weights from a HuggingFace `config.json` and
67 /// `model.safetensors`. Encoder tensors are read from disk but not kept
68 /// in memory.
69 ///
70 /// Returns `(decoder, weight_load_ms)`.
71 pub fn load(
72 config_path: &Path,
73 weights_path: &Path,
74 device: B::Device,
75 ) -> anyhow::Result<(Self, f64)> {
76 let cfg_str = std::fs::read_to_string(config_path)
77 .with_context(|| format!("config: {}", config_path.display()))?;
78 let hf_val: serde_json::Value = serde_json::from_str(&cfg_str)?;
79 let model_cfg: ModelConfig = serde_json::from_value(hf_val["model"].clone())
80 .context("parsing model config")?;
81
82 let rope = RotaryEmbedding::<B>::new(
83 model_cfg.head_dim, model_cfg.rope_dim,
84 model_cfg.max_seqlen, model_cfg.rope_theta, &device,
85 );
86
87 let t = Instant::now();
88 let (decoder, n_heads) = load_decoder_weights::<B>(
89 &model_cfg,
90 weights_path.to_str().context("weights path not valid UTF-8")?,
91 &device,
92 )?;
93 let ms = t.elapsed().as_secs_f64() * 1000.0;
94
95 println!("Detected n_heads = {n_heads}");
96
97 let global_sigma = model_cfg.stft_global_sigma as f32;
98
99 Ok((Self { decoder, rope, model_cfg, data_cfg: DataConfig::default(), global_sigma, device }, ms))
100 }
101
102 // ── Accessors ─────────────────────────────────────────────────────────────
103
104 /// One-line description of the loaded decoder.
105 pub fn describe(&self) -> String {
106 let c = &self.model_cfg;
107 format!(
108 "ZUNA decoder dim={} layers={} head_dim={} t_dim={} σ={}",
109 c.dim, c.n_layers, c.head_dim, c.t_dim, self.global_sigma,
110 )
111 }
112
113 // ── High-level decode API ─────────────────────────────────────────────────
114
115 /// Reconstruct EEG signals from pre-computed embeddings.
116 ///
117 /// This is the pure decode path — no FIF loading or preprocessing happens
118 /// here. The [`EncodingResult`] must have been produced by
119 /// [`ZunaEncoder::encode_fif`] or [`ZunaEncoder::encode_batch`].
120 ///
121 /// # Arguments
122 /// - `embeddings` — pre-computed encoder latents
123 /// - `steps` — diffusion denoising steps (50 = full quality, 10 = fast)
124 /// - `cfg` — classifier-free guidance scale (1.0 = disabled)
125 /// - `data_norm` — divisor used during preprocessing; multiplied back into
126 /// the output to restore the original signal scale
127 pub fn decode_embeddings(
128 &self,
129 embeddings: &EncodingResult,
130 steps: usize,
131 cfg: f32,
132 data_norm: f32,
133 ) -> anyhow::Result<InferenceResult> {
134 let t_dec = Instant::now();
135 let epochs = embeddings.epochs
136 .iter()
137 .map(|ep| self.decode_one(ep, steps, cfg, data_norm))
138 .collect::<anyhow::Result<Vec<_>>>()?;
139 let ms_infer = t_dec.elapsed().as_secs_f64() * 1000.0;
140
141 Ok(InferenceResult {
142 epochs,
143 fif_info: None,
144 ms_preproc: 0.0,
145 ms_infer,
146 })
147 }
148
149 /// Decode a single epoch from a raw encoder output tensor `[1, S, output_dim]`.
150 ///
151 /// `tok_idx` is `[S, 4]`. Returns the reconstructed token matrix
152 /// `[1, S, input_dim]` **before** inversion of the chop-and-reshape.
153 pub fn decode_tensor(
154 &self,
155 enc_out: Tensor<B, 3>,
156 tok_idx: Tensor<B, 2, Int>,
157 steps: usize,
158 cfg: f32,
159 ) -> Tensor<B, 3> {
160 let device = enc_out.device();
161 let [b, s, d] = enc_out.dims();
162 let dt = 1.0_f32 / steps as f32;
163
164 // Initial noise z ~ N(0, σ²)
165 let sigma = self.global_sigma as f64;
166 let mut z = Tensor::<B, 3>::random(
167 [b, s, d],
168 Distribution::Normal(0.0, sigma),
169 &device,
170 );
171
172 // Rectified-flow Euler sampling loop
173 for i in (1..=steps).rev() {
174 let t_val = dt * i as f32;
175 let time_t = Tensor::<B, 3>::full([b, 1, 1], t_val, &device);
176
177 let vc = self.decoder.forward(
178 z.clone(), enc_out.clone(), time_t.clone(), tok_idx.clone(), &self.rope,
179 );
180
181 let vc = if (cfg - 1.0).abs() > 1e-4 {
182 // Classifier-free guidance: run unconditioned pass with zeros
183 let enc_zeros = Tensor::zeros([b, s, d], &device);
184 let vc_uncond = self.decoder.forward(
185 z.clone(), enc_zeros, time_t, tok_idx.clone(), &self.rope,
186 );
187 vc_uncond.clone() + (vc - vc_uncond).mul_scalar(cfg)
188 } else {
189 vc
190 };
191
192 z = z - vc.mul_scalar(dt);
193 }
194
195 z
196 }
197
198 // ── Internal ──────────────────────────────────────────────────────────────
199
200 fn decode_one(
201 &self,
202 ep: &EpochEmbedding,
203 steps: usize,
204 cfg: f32,
205 data_norm: f32,
206 ) -> anyhow::Result<EpochOutput> {
207 let n_tokens = ep.n_tokens();
208 let dc = &self.data_cfg;
209
210 // Reconstruct enc_out [1, S, output_dim] from stored Vec<f32>.
211 let enc_out = Tensor::<B, 2>::from_data(
212 TensorData::new(ep.embeddings.clone(), ep.shape.clone()),
213 &self.device,
214 )
215 .unsqueeze_dim::<3>(0); // [1, S, output_dim]
216
217 // Reconstruct tok_idx [S, 4] from stored Vec<i64>.
218 let tok_idx = Tensor::<B, 2, Int>::from_data(
219 TensorData::new(ep.tok_idx.clone(), vec![n_tokens, 4]),
220 &self.device,
221 );
222
223 // Run diffusion; returns [1, S, input_dim].
224 let z = self.decode_tensor(enc_out, tok_idx, steps, cfg);
225
226 // Invert chop-and-reshape: [1, S, tf] → [C, T].
227 let [_, s, tf] = z.dims();
228 let recon = invert_reshape(
229 z.reshape([s, tf]),
230 ep.n_channels,
231 ep.tc,
232 dc.num_fine_time_pts,
233 );
234 let recon = recon.mul_scalar(data_norm);
235
236 let shape = recon.dims().to_vec();
237 let reconstructed = recon
238 .into_data()
239 .to_vec::<f32>()
240 .map_err(|e| anyhow::anyhow!("recon→vec: {e:?}"))?;
241 let chan_pos = ep.chan_pos.clone();
242
243 Ok(EpochOutput { reconstructed, shape, chan_pos, n_channels: ep.n_channels })
244 }
245}