1use std::path::Path;
16
17use serde::Deserialize;
18use snafu::ResultExt;
19use svod_arch::ctc::{CtcDecoder, GreedyDecoder};
20
21use super::error::{ConfigIoSnafu, ConfigSnafu, Error, Result};
22
23#[derive(Clone, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum SubsamplingMode {
26 Conv1d,
27 Conv2d,
28}
29
30#[derive(Clone, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ConvNormType {
33 LayerNorm,
34 BatchNorm,
35}
36
37#[derive(Clone)]
38pub struct GigaAmConfig {
39 pub max_batch_size: usize,
40 pub n_mels: usize,
41 pub d_model: usize,
42 pub n_heads: usize,
43 pub n_layers: usize,
44 pub d_ff: usize,
45 pub conv_kernel: usize,
46 pub subsampling_factor: usize,
47 pub subsampling_mode: SubsamplingMode,
48 pub subs_kernel_size: usize,
49 pub conv_norm_type: ConvNormType,
50 pub vocab_size: usize,
51 pub sample_rate: usize,
52 pub n_fft: usize,
53 pub hop_length: usize,
54 pub win_length: usize,
55 pub mel_center: bool,
56 pub max_mel_frames: usize,
57 pub max_encoder_frames: usize,
58 pub decoder: CtcDecoder,
62 pub transducer: Option<TransducerConfig>,
66}
67
68#[derive(Clone, Debug)]
72pub struct TransducerConfig {
73 pub pred_hidden: usize,
74 pub pred_rnn_layers: usize,
75 pub joint_hidden: usize,
76 pub num_classes: usize,
78 pub max_symbols_per_step: usize,
79 pub vocabulary: Vec<String>,
80 pub sentencepiece: bool,
83}
84
85impl GigaAmConfig {
86 pub fn from_json(path: &Path) -> Result<Self> {
87 let data = std::fs::read_to_string(path).context(ConfigIoSnafu)?;
88 let root: serde_json::Value = serde_json::from_str(&data).context(ConfigSnafu)?;
89 let leaf = root.pointer("/cfg/model/cfg").ok_or_else(|| Error::DecoderConfig {
90 message: "config.json missing required path /cfg/model/cfg".into(),
91 })?;
92 let raw: RawModelCfg = serde_json::from_value(leaf.clone()).context(ConfigSnafu)?;
93 Self::from_raw(raw)
94 }
95
96 fn from_raw(raw: RawModelCfg) -> Result<Self> {
97 validate_preprocessor(&raw.preprocessor)?;
98 validate_encoder(&raw.encoder)?;
99
100 let max_encoder_frames = raw.encoder.pos_emb_max_len;
105 let max_mel_frames = raw
106 .encoder
107 .max_mel_frames
108 .or(raw.encoder.max_seq_len)
109 .unwrap_or(max_encoder_frames * raw.encoder.subsampling_factor);
110 let subs_kernel = match &raw.encoder.subsampling {
111 SubsamplingMode::Conv1d => raw.encoder.subs_kernel_size,
112 SubsamplingMode::Conv2d => 3,
113 };
114 let max_sub_frames = subsampled_len(subs_kernel, max_mel_frames);
115 if max_sub_frames > max_encoder_frames {
116 return Err(Error::DecoderConfig {
117 message: format!(
118 "max_mel_frames ({max_mel_frames}) subsamples to {max_sub_frames} encoder frames, exceeding pos_emb_max_len ({max_encoder_frames})"
119 ),
120 });
121 }
122 let vocab_size = raw
125 .head
126 .num_classes
127 .or_else(|| raw.head.decoder.as_ref().and_then(|d| d.num_classes))
128 .or_else(|| raw.head.joint.as_ref().and_then(|j| j.num_classes))
129 .ok_or_else(|| Error::DecoderConfig {
130 message: "missing num_classes (head.num_classes or head.{decoder,joint}.num_classes)".into(),
131 })?;
132 let decoder = raw_to_decoder(raw.decoding.as_ref(), vocab_size)?;
133 let transducer = raw_to_transducer(&raw.head, raw.decoding.as_ref(), vocab_size)?;
134 Ok(Self {
135 max_batch_size: raw.encoder.max_batch_size,
136 n_mels: raw.preprocessor.features,
137 d_model: raw.encoder.d_model,
138 n_heads: raw.encoder.n_heads,
139 n_layers: raw.encoder.n_layers,
140 d_ff: raw.encoder.d_model * raw.encoder.ff_expansion_factor,
141 conv_kernel: raw.encoder.conv_kernel_size,
142 subsampling_factor: raw.encoder.subsampling_factor,
143 subsampling_mode: raw.encoder.subsampling,
144 subs_kernel_size: raw.encoder.subs_kernel_size,
145 conv_norm_type: raw.encoder.conv_norm_type,
146 vocab_size,
147 sample_rate: raw.preprocessor.sample_rate,
148 n_fft: raw.preprocessor.n_fft,
149 hop_length: raw.preprocessor.hop_length,
150 win_length: raw.preprocessor.win_length,
151 mel_center: raw.preprocessor.center,
152 max_mel_frames,
153 max_encoder_frames,
154 decoder,
155 transducer,
156 })
157 }
158}
159
160#[derive(Deserialize)]
167struct RawModelCfg {
168 preprocessor: RawPreprocessor,
169 encoder: RawEncoder,
170 head: RawHead,
171 #[serde(default)]
172 decoding: Option<serde_json::Value>,
173}
174
175#[derive(Deserialize)]
176struct RawPreprocessor {
177 features: usize,
178 sample_rate: usize,
179 n_fft: usize,
180 hop_length: usize,
181 win_length: usize,
182 #[serde(default = "default_true")]
183 center: bool,
184 #[serde(default)]
185 mel_scale: Option<String>,
186 #[serde(default)]
187 mel_norm: Option<String>,
188}
189
190#[derive(Deserialize)]
191struct RawEncoder {
192 d_model: usize,
193 ff_expansion_factor: usize,
194 n_heads: usize,
195 n_layers: usize,
196 conv_kernel_size: usize,
197 subsampling_factor: usize,
198 #[serde(default = "default_self_attention_model")]
199 self_attention_model: String,
200 #[serde(default = "default_subs_kernel_size")]
201 subs_kernel_size: usize,
202 #[serde(default = "default_subsampling_mode")]
203 subsampling: SubsamplingMode,
204 #[serde(default = "default_conv_norm_type")]
205 conv_norm_type: ConvNormType,
206 #[serde(default = "default_pos_emb_max_len")]
207 pos_emb_max_len: usize,
208 #[serde(default)]
209 max_mel_frames: Option<usize>,
210 #[serde(default)]
211 max_seq_len: Option<usize>,
212 #[serde(default = "default_max_batch_size")]
213 max_batch_size: usize,
214}
215
216#[derive(Deserialize)]
217struct RawHead {
218 #[serde(default)]
219 num_classes: Option<usize>,
220 #[serde(default)]
221 decoder: Option<RawHeadDecoder>,
222 #[serde(default)]
223 joint: Option<RawHeadJoint>,
224}
225
226#[derive(Deserialize)]
227struct RawHeadDecoder {
228 pred_hidden: usize,
229 pred_rnn_layers: usize,
230 #[serde(default)]
231 num_classes: Option<usize>,
232}
233
234#[derive(Deserialize)]
235struct RawHeadJoint {
236 joint_hidden: usize,
237 #[serde(default)]
238 num_classes: Option<usize>,
239}
240
241fn default_true() -> bool {
242 true
243}
244fn default_subs_kernel_size() -> usize {
245 3
246}
247fn default_subsampling_mode() -> SubsamplingMode {
248 SubsamplingMode::Conv2d
249}
250fn default_conv_norm_type() -> ConvNormType {
251 ConvNormType::BatchNorm
252}
253fn default_pos_emb_max_len() -> usize {
254 5000
255}
256fn default_self_attention_model() -> String {
257 "rotary".into()
258}
259fn default_max_batch_size() -> usize {
260 32
261}
262
263fn validate_preprocessor(pre: &RawPreprocessor) -> Result<()> {
264 if let Some(scale) = pre.mel_scale.as_deref()
265 && scale != "htk"
266 {
267 return Err(Error::DecoderConfig {
268 message: format!(
269 "unsupported mel_scale {scale:?}; Svod GigaAM currently matches torchaudio's HTK mel frontend"
270 ),
271 });
272 }
273 if let Some(norm) = pre.mel_norm.as_deref() {
274 return Err(Error::DecoderConfig {
275 message: format!(
276 "unsupported mel_norm {norm:?}; Svod GigaAM currently supports only null/no mel normalization"
277 ),
278 });
279 }
280 if pre.n_fft != pre.win_length {
281 return Err(Error::DecoderConfig {
282 message: format!(
283 "unsupported mel frontend n_fft ({}) != win_length ({}); current GigaAM parity path requires equal FFT/window lengths",
284 pre.n_fft, pre.win_length
285 ),
286 });
287 }
288 Ok(())
289}
290
291fn validate_encoder(encoder: &RawEncoder) -> Result<()> {
292 if encoder.self_attention_model != "rotary" {
293 return Err(Error::DecoderConfig {
294 message: format!(
295 "unsupported self_attention_model {:?}; Svod GigaAM currently implements rotary attention only",
296 encoder.self_attention_model
297 ),
298 });
299 }
300 if encoder.subsampling_factor != 4 {
301 return Err(Error::DecoderConfig {
302 message: format!(
303 "unsupported subsampling_factor {}; Svod GigaAM currently implements exactly two stride-2 subsampling layers",
304 encoder.subsampling_factor
305 ),
306 });
307 }
308 Ok(())
309}
310
311fn subsampled_len(kernel_size: usize, mel_frames: usize) -> usize {
312 let pad = (kernel_size - 1) / 2;
313 let mut len = mel_frames;
314 for _ in 0..2 {
315 len = len.saturating_add(2 * pad).saturating_sub(kernel_size) / 2 + 1;
316 }
317 len
318}
319
320fn raw_to_decoder(decoding: Option<&serde_json::Value>, vocab_size: usize) -> Result<CtcDecoder> {
323 let Some(decoding) = decoding else {
324 return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
325 };
326 if decoding.is_null() {
327 return Ok(CtcDecoder::Greedy(GreedyDecoder::new(Vec::new())));
328 }
329 let target = decoding["_target_"].as_str().unwrap_or("");
330 let decoder: CtcDecoder = if target.contains("CTCGreedyDecoding") {
331 let g: GreedyDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
332 CtcDecoder::Greedy(g)
333 } else if target.contains("CTCBeamDecoding") {
334 let b: svod_arch::ctc::BeamDecoder = serde_json::from_value(decoding.clone()).context(ConfigSnafu)?;
335 CtcDecoder::Beam(Box::new(b))
336 } else {
337 let vocab: Vec<String> = decoding["vocabulary"]
340 .as_array()
341 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
342 .unwrap_or_default();
343 CtcDecoder::Greedy(GreedyDecoder::new(vocab))
344 };
345 if !decoder.vocabulary().is_empty() && decoder.total_vocab() != vocab_size {
346 return Err(Error::DecoderConfig {
347 message: format!(
348 "decoder vocabulary length + 1 ({}) != head.num_classes ({}); \
349 CTC convention is one blank token appended after the vocabulary",
350 decoder.total_vocab(),
351 vocab_size
352 ),
353 });
354 }
355 Ok(decoder)
356}
357
358fn raw_to_transducer(
359 head: &RawHead,
360 decoding: Option<&serde_json::Value>,
361 vocab_size: usize,
362) -> Result<Option<TransducerConfig>> {
363 let target = decoding.and_then(|d| d["_target_"].as_str()).unwrap_or("");
364 let has_decoder = head.decoder.is_some();
365 let has_joint = head.joint.is_some();
366 if !(target.contains("RNNT") || (has_decoder && has_joint)) {
367 return Ok(None);
368 }
369 let dec = head
370 .decoder
371 .as_ref()
372 .ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.decoder block".into() })?;
373 let joint = head
374 .joint
375 .as_ref()
376 .ok_or_else(|| Error::DecoderConfig { message: "RNN-T config: missing head.joint block".into() })?;
377 let max_symbols_per_step = decoding.and_then(|d| d["max_symbols_per_step"].as_u64()).unwrap_or(10) as usize;
378 let vocabulary: Vec<String> = decoding
384 .and_then(|d| d["vocabulary"].as_array())
385 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
386 .unwrap_or_default();
387 let sentencepiece =
390 decoding.and_then(|d| d.get("model_path")).and_then(|v| v.as_str()).map(|s| !s.is_empty()).unwrap_or(false);
391 Ok(Some(TransducerConfig {
392 pred_hidden: dec.pred_hidden,
393 pred_rnn_layers: dec.pred_rnn_layers,
394 joint_hidden: joint.joint_hidden,
395 num_classes: vocab_size,
396 max_symbols_per_step,
397 vocabulary,
398 sentencepiece,
399 }))
400}