1use crate::{Error, Result};
8use candle_core::{DType, Device, Tensor};
9use candle_nn::{conv1d, conv_transpose1d, Conv1d, Conv1dConfig, Module, VarBuilder, VarMap};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::sync::RwLock;
15use tracing::{debug, error, info, trace, warn};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct NeuralCodecConfig {
20 pub sample_rate: u32,
22 pub channels: usize,
24 pub encoder_dim: usize,
26 pub decoder_dim: usize,
28 pub num_quantizers: usize,
30 pub codebook_size: usize,
32 pub codebook_dim: usize,
34 pub target_bitrate: f32,
36 pub compression_ratio: usize,
38 pub encoder_layers: usize,
40 pub decoder_layers: usize,
42 pub kernel_sizes: Vec<usize>,
44 pub strides: Vec<usize>,
46 pub dilations: Vec<usize>,
48 pub use_residual: bool,
50 pub use_skip_connections: bool,
52 pub dropout_rate: f32,
54 pub perceptual_loss_weight: f32,
56 pub adversarial_loss_weight: f32,
58 pub reconstruction_loss_weight: f32,
60 pub quantization_loss_weight: f32,
62}
63
64impl Default for NeuralCodecConfig {
65 fn default() -> Self {
66 Self {
67 sample_rate: 24000,
68 channels: 1,
69 encoder_dim: 512,
70 decoder_dim: 512,
71 num_quantizers: 8,
72 codebook_size: 1024,
73 codebook_dim: 256,
74 target_bitrate: 6.0, compression_ratio: 32,
76 encoder_layers: 5,
77 decoder_layers: 5,
78 kernel_sizes: vec![7, 7, 7, 7, 7],
79 strides: vec![1, 2, 2, 4, 4],
80 dilations: vec![1, 1, 1, 1, 1],
81 use_residual: true,
82 use_skip_connections: true,
83 dropout_rate: 0.1,
84 perceptual_loss_weight: 1.0,
85 adversarial_loss_weight: 1.0,
86 reconstruction_loss_weight: 45.0,
87 quantization_loss_weight: 1.0,
88 }
89 }
90}
91
92impl NeuralCodecConfig {
93 pub fn high_quality() -> Self {
95 Self {
96 encoder_dim: 768,
97 decoder_dim: 768,
98 num_quantizers: 12,
99 codebook_size: 2048,
100 codebook_dim: 384,
101 target_bitrate: 12.0,
102 compression_ratio: 16,
103 encoder_layers: 8,
104 decoder_layers: 8,
105 kernel_sizes: vec![7, 7, 7, 7, 7, 7, 7, 7],
106 strides: vec![1, 2, 2, 2, 2, 2, 2, 2],
107 ..Default::default()
108 }
109 }
110
111 pub fn low_bitrate() -> Self {
113 Self {
114 encoder_dim: 256,
115 decoder_dim: 256,
116 num_quantizers: 4,
117 codebook_size: 512,
118 codebook_dim: 128,
119 target_bitrate: 2.0,
120 compression_ratio: 64,
121 encoder_layers: 4,
122 decoder_layers: 4,
123 ..Default::default()
124 }
125 }
126
127 pub fn realtime_optimized() -> Self {
129 Self {
130 encoder_dim: 384,
131 decoder_dim: 384,
132 num_quantizers: 6,
133 codebook_size: 1024,
134 codebook_dim: 192,
135 target_bitrate: 8.0,
136 compression_ratio: 24,
137 encoder_layers: 4,
138 decoder_layers: 4,
139 kernel_sizes: vec![3, 3, 3, 3],
140 strides: vec![1, 2, 3, 4],
141 ..Default::default()
142 }
143 }
144
145 pub fn validate(&self) -> Result<()> {
147 if self.sample_rate == 0 {
148 return Err(Error::Config(
149 "sample_rate must be greater than 0".to_string(),
150 ));
151 }
152 if self.channels == 0 {
153 return Err(Error::Config("channels must be greater than 0".to_string()));
154 }
155 if self.encoder_dim == 0 || self.decoder_dim == 0 {
156 return Err(Error::Config(
157 "encoder_dim and decoder_dim must be greater than 0".to_string(),
158 ));
159 }
160 if self.num_quantizers == 0 {
161 return Err(Error::Config(
162 "num_quantizers must be greater than 0".to_string(),
163 ));
164 }
165 if self.codebook_size == 0 || (self.codebook_size & (self.codebook_size - 1)) != 0 {
166 return Err(Error::Config(
167 "codebook_size must be a power of 2".to_string(),
168 ));
169 }
170 if self.compression_ratio == 0 {
171 return Err(Error::Config(
172 "compression_ratio must be greater than 0".to_string(),
173 ));
174 }
175 if self.kernel_sizes.len() != self.encoder_layers {
176 return Err(Error::Config(
177 "kernel_sizes length must match encoder_layers".to_string(),
178 ));
179 }
180 if self.strides.len() != self.encoder_layers {
181 return Err(Error::Config(
182 "strides length must match encoder_layers".to_string(),
183 ));
184 }
185 Ok(())
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct CodecCompressionRequest {
192 pub audio: Vec<f32>,
194 pub sample_rate: u32,
196 pub target_bitrate: Option<f32>,
198 pub quality_level: f32,
200 pub perceptual_optimization: bool,
202 pub temporal_consistency: bool,
204 pub variable_bitrate: bool,
206}
207
208impl Default for CodecCompressionRequest {
209 fn default() -> Self {
210 Self {
211 audio: Vec::new(),
212 sample_rate: 24000,
213 target_bitrate: None,
214 quality_level: 0.8,
215 perceptual_optimization: true,
216 temporal_consistency: true,
217 variable_bitrate: false,
218 }
219 }
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct CodecCompressionResult {
225 pub codes: Vec<Vec<u32>>,
227 pub quantizer_indices: Vec<usize>,
229 pub compression_ratio: f32,
231 pub actual_bitrate: f32,
233 pub compression_time_ms: u64,
235 pub quality_metrics: CodecQualityMetrics,
237 pub metadata: CodecMetadata,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct CodecDecompressionResult {
244 pub audio: Vec<f32>,
246 pub sample_rate: u32,
248 pub duration: f32,
250 pub decompression_time_ms: u64,
252 pub quality_metrics: Option<CodecQualityMetrics>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct CodecQualityMetrics {
259 pub snr_db: f32,
261 pub pesq_score: f32,
263 pub stoi_score: f32,
265 pub spectral_distortion_db: f32,
267 pub bitrate_efficiency: f32,
269 pub perceptual_quality: f32,
271 pub temporal_consistency: f32,
273 pub artifacts_score: f32,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct CodecMetadata {
280 pub original_length: usize,
282 pub compressed_size: usize,
284 pub codec_version: String,
286 pub encoding_params: HashMap<String, f32>,
288 pub timestamp: u64,
290}
291
292#[derive(Debug)]
294pub struct NeuralEncoder {
295 config: NeuralCodecConfig,
296 conv_layers: Vec<Conv1d>,
297 residual_layers: Vec<ResidualBlock>,
298 output_projection: candle_nn::Linear,
299}
300
301impl NeuralEncoder {
302 pub fn new(config: &NeuralCodecConfig, vb: VarBuilder) -> Result<Self> {
303 let mut conv_layers = Vec::new();
304 let mut in_channels = config.channels;
305
306 for (i, (&kernel_size, &stride)) in config
308 .kernel_sizes
309 .iter()
310 .zip(config.strides.iter())
311 .enumerate()
312 {
313 let out_channels = if i == 0 {
314 config.encoder_dim / 4
315 } else if i == config.encoder_layers - 1 {
316 config.encoder_dim
317 } else {
318 config.encoder_dim / 2
319 };
320
321 let conv_config = Conv1dConfig {
322 padding: kernel_size / 2,
323 stride,
324 dilation: config.dilations.get(i).copied().unwrap_or(1),
325 groups: 1,
326 cudnn_fwd_algo: None,
327 };
328
329 let conv = conv1d(
330 in_channels,
331 out_channels,
332 kernel_size,
333 conv_config,
334 vb.pp(format!("conv_layers.{}", i)),
335 )?;
336
337 conv_layers.push(conv);
338 in_channels = out_channels;
339 }
340
341 let mut residual_layers = Vec::new();
343 for i in 0..config.encoder_layers {
344 residual_layers.push(ResidualBlock::new(
345 config.encoder_dim,
346 config.encoder_dim,
347 config.dropout_rate,
348 vb.pp(format!("residual.{}", i)),
349 )?);
350 }
351
352 let output_projection = candle_nn::linear(
353 config.encoder_dim,
354 config.codebook_dim,
355 vb.pp("output_projection"),
356 )?;
357
358 Ok(Self {
359 config: config.clone(),
360 conv_layers,
361 residual_layers,
362 output_projection,
363 })
364 }
365
366 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
367 let mut hidden = x.clone();
368
369 for (i, conv) in self.conv_layers.iter().enumerate() {
371 hidden = conv
372 .forward(&hidden)
373 .map_err(|e| Error::Processing(e.to_string()))?;
374
375 hidden = hidden
377 .elu(1.0)
378 .map_err(|e| Error::Processing(e.to_string()))?;
379 }
380
381 if self.config.use_residual {
383 for residual in &self.residual_layers {
384 hidden = residual.forward(&hidden)?;
385 }
386 }
387
388 self.output_projection
390 .forward(&hidden)
391 .map_err(|e| Error::Processing(e.to_string()))
392 }
393}
394
395#[derive(Debug)]
397pub struct NeuralDecoder {
398 config: NeuralCodecConfig,
399 input_projection: candle_nn::Linear,
400 conv_transpose_layers: Vec<Conv1d>,
401 residual_layers: Vec<ResidualBlock>,
402 output_conv: Conv1d,
403}
404
405impl NeuralDecoder {
406 pub fn new(config: &NeuralCodecConfig, vb: VarBuilder) -> Result<Self> {
407 let input_projection = candle_nn::linear(
408 config.codebook_dim,
409 config.decoder_dim,
410 vb.pp("input_projection"),
411 )?;
412
413 let mut conv_transpose_layers = Vec::new();
414 let mut in_channels = config.decoder_dim;
415
416 for (i, (&kernel_size, &stride)) in config
418 .kernel_sizes
419 .iter()
420 .zip(config.strides.iter())
421 .enumerate()
422 .rev()
423 {
424 let out_channels = if i == 0 {
425 config.channels
426 } else if i == config.decoder_layers - 1 {
427 config.decoder_dim
428 } else {
429 config.decoder_dim / 2
430 };
431
432 let conv_config = Conv1dConfig {
433 padding: kernel_size / 2,
434 stride,
435 dilation: config.dilations.get(i).copied().unwrap_or(1),
436 groups: 1,
437 cudnn_fwd_algo: None,
438 };
439
440 let conv = conv1d(
441 in_channels,
442 out_channels,
443 kernel_size,
444 conv_config,
445 vb.pp(format!("conv_transpose.{}", i)),
446 )?;
447
448 conv_transpose_layers.push(conv);
449 in_channels = out_channels;
450 }
451
452 let mut residual_layers = Vec::new();
454 for i in 0..config.decoder_layers {
455 residual_layers.push(ResidualBlock::new(
456 config.decoder_dim,
457 config.decoder_dim,
458 config.dropout_rate,
459 vb.pp(format!("residual.{}", i)),
460 )?);
461 }
462
463 let output_conv = conv1d(
465 config.decoder_dim,
466 config.channels,
467 7, Conv1dConfig {
469 padding: 3,
470 stride: 1,
471 dilation: 1,
472 groups: 1,
473 cudnn_fwd_algo: None,
474 },
475 vb.pp("output_conv"),
476 )?;
477
478 Ok(Self {
479 config: config.clone(),
480 input_projection,
481 conv_transpose_layers,
482 residual_layers,
483 output_conv,
484 })
485 }
486
487 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
488 let mut hidden = self
490 .input_projection
491 .forward(x)
492 .map_err(|e| Error::Processing(e.to_string()))?;
493
494 if self.config.use_residual {
496 for residual in &self.residual_layers {
497 hidden = residual.forward(&hidden)?;
498 }
499 }
500
501 for (i, conv) in self.conv_transpose_layers.iter().enumerate() {
503 let stride = self.config.strides[self.config.strides.len() - 1 - i];
505 if stride > 1 {
506 hidden = self.upsample(&hidden, stride)?;
507 }
508
509 hidden = conv
510 .forward(&hidden)
511 .map_err(|e| Error::Processing(e.to_string()))?;
512
513 if i < self.conv_transpose_layers.len() - 1 {
515 hidden = hidden
516 .elu(1.0)
517 .map_err(|e| Error::Processing(e.to_string()))?;
518 }
519 }
520
521 let output = self
523 .output_conv
524 .forward(&hidden)
525 .map_err(|e| Error::Processing(e.to_string()))?;
526
527 output.tanh().map_err(|e| Error::Processing(e.to_string()))
528 }
529
530 fn upsample(&self, x: &Tensor, factor: usize) -> Result<Tensor> {
531 let (batch_size, channels, length) = x.dims3()?;
533 let new_length = length * factor;
534
535 let mut upsampled_data = Vec::new();
537 let data = x
538 .flatten_all()?
539 .to_vec1::<f32>()
540 .map_err(|e| Error::Processing(e.to_string()))?;
541
542 for b in 0..batch_size {
543 for c in 0..channels {
544 for t in 0..length {
545 let idx = b * channels * length + c * length + t;
546 let value = data[idx];
547 for _ in 0..factor {
548 upsampled_data.push(value);
549 }
550 }
551 }
552 }
553
554 Tensor::from_vec(
555 upsampled_data,
556 (batch_size, channels, new_length),
557 x.device(),
558 )
559 .map_err(|e| Error::Processing(e.to_string()))
560 }
561}
562
563#[derive(Debug)]
565pub struct VectorQuantizer {
566 config: NeuralCodecConfig,
567 codebooks: Vec<candle_nn::Embedding>,
568 commitment_weight: f32,
569}
570
571impl VectorQuantizer {
572 pub fn new(config: &NeuralCodecConfig, vb: VarBuilder) -> Result<Self> {
573 let mut codebooks = Vec::new();
574
575 for i in 0..config.num_quantizers {
576 let codebook = candle_nn::embedding(
577 config.codebook_size,
578 config.codebook_dim,
579 vb.pp(format!("codebook.{}", i)),
580 )?;
581 codebooks.push(codebook);
582 }
583
584 Ok(Self {
585 config: config.clone(),
586 codebooks,
587 commitment_weight: 0.25,
588 })
589 }
590
591 pub fn forward(&self, x: &Tensor) -> Result<QuantizationResult> {
592 let (batch_size, dim, seq_len) = x.dims3()?;
593 let device = x.device();
594
595 let mut quantized_layers = Vec::new();
596 let mut indices_layers = Vec::new();
597 let mut losses = Vec::new();
598
599 let mut residual = x.clone();
600
601 for (i, codebook) in self.codebooks.iter().enumerate() {
603 let (quantized, indices, loss) = self.quantize_layer(&residual, codebook)?;
604
605 quantized_layers.push(quantized.clone());
606 indices_layers.push(indices);
607 losses.push(loss);
608
609 residual = (&residual - &quantized)?;
611
612 if i >= (self.config.num_quantizers as f32 * 0.6) as usize
614 && residual
615 .sqr()?
616 .mean_all()?
617 .to_scalar::<f32>()
618 .map_err(|e| Error::Processing(e.to_string()))?
619 < 0.01
620 {
621 break;
622 }
623 }
624
625 let mut quantized_sum = quantized_layers[0].clone();
627 for quantized in &quantized_layers[1..] {
628 quantized_sum = (&quantized_sum + quantized)?;
629 }
630
631 let avg_loss = losses.iter().sum::<f32>() / losses.len() as f32;
633
634 let perplexity = self.calculate_perplexity(&indices_layers)?;
635
636 Ok(QuantizationResult {
637 quantized: quantized_sum,
638 indices: indices_layers,
639 quantization_loss: avg_loss,
640 perplexity,
641 })
642 }
643
644 fn quantize_layer(
645 &self,
646 x: &Tensor,
647 codebook: &candle_nn::Embedding,
648 ) -> Result<(Tensor, Vec<u32>, f32)> {
649 let (batch_size, dim, seq_len) = x.dims3()?;
650
651 let x_flat = x.transpose(1, 2)?.reshape((batch_size * seq_len, dim))?;
653
654 let codebook_indices = Tensor::arange(0u32, self.config.codebook_size as u32, x.device())?;
656 let codebook_vectors = codebook.forward(&codebook_indices)?;
657
658 let distances = self.compute_distances(&x_flat, &codebook_vectors)?;
660
661 let indices = distances.argmin(1)?;
663 let indices_data = indices
664 .to_vec1::<u32>()
665 .map_err(|e| Error::Processing(e.to_string()))?;
666
667 let quantized_flat = codebook.forward(&indices)?;
669 let quantized = quantized_flat
670 .reshape((batch_size, seq_len, dim))?
671 .transpose(1, 2)?;
672
673 let commitment_loss = (&x_flat - &quantized_flat.detach())?
675 .sqr()?
676 .mean_all()?
677 .to_scalar::<f32>()
678 .map_err(|e| Error::Processing(e.to_string()))?;
679
680 Ok((
681 quantized,
682 indices_data,
683 commitment_loss * self.commitment_weight,
684 ))
685 }
686
687 fn compute_distances(&self, x: &Tensor, codebook: &Tensor) -> Result<Tensor> {
688 let x_norm = x.sqr()?.sum_keepdim(1)?;
690 let codebook_norm = codebook.sqr()?.sum_keepdim(1)?.t()?;
691 let dot_product = x.matmul(&codebook.t()?)?;
692
693 let distances = (x_norm + codebook_norm - &(&dot_product * 2.0)?)?;
695
696 Ok(distances)
697 }
698
699 fn calculate_perplexity(&self, indices_layers: &[Vec<u32>]) -> Result<f32> {
700 let mut total_entropy = 0.0;
701 let mut total_count = 0;
702
703 for indices in indices_layers {
704 if indices.is_empty() {
705 continue;
706 }
707
708 let mut counts = vec![0; self.config.codebook_size];
709 for &idx in indices {
710 if (idx as usize) < counts.len() {
711 counts[idx as usize] += 1;
712 }
713 }
714
715 let total = indices.len() as f32;
716 let mut entropy = 0.0;
717
718 for count in counts {
719 if count > 0 {
720 let prob = count as f32 / total;
721 entropy -= prob * prob.ln();
722 }
723 }
724
725 total_entropy += entropy;
726 total_count += 1;
727 }
728
729 let avg_entropy = if total_count > 0 {
730 total_entropy / total_count as f32
731 } else {
732 0.0
733 };
734 Ok(avg_entropy.exp())
735 }
736
737 pub fn decode(&self, indices: &[Vec<u32>]) -> Result<Tensor> {
738 if indices.is_empty() {
739 return Err(Error::Processing("Empty indices for decoding".to_string()));
740 }
741
742 let seq_len = indices[0].len();
743 let device = Device::Cpu; let mut decoded_sum: Option<Tensor> = None;
746
747 for (layer_idx, layer_indices) in indices.iter().enumerate() {
748 if layer_idx >= self.codebooks.len() {
749 break;
750 }
751
752 let indices_tensor = Tensor::from_vec(
753 layer_indices
754 .clone()
755 .into_iter()
756 .map(|x| x as i64)
757 .collect(),
758 (1, seq_len),
759 &device,
760 )?;
761
762 let decoded_layer = self.codebooks[layer_idx].forward(&indices_tensor)?;
763 let decoded_reshaped = decoded_layer.transpose(1, 2)?;
764
765 match decoded_sum {
766 None => decoded_sum = Some(decoded_reshaped),
767 Some(ref sum) => {
768 decoded_sum = Some((sum + &decoded_reshaped)?);
769 }
770 }
771 }
772
773 decoded_sum.ok_or_else(|| Error::Processing("Failed to decode any layers".to_string()))
774 }
775}
776
777#[derive(Debug)]
779pub struct QuantizationResult {
780 pub quantized: Tensor,
781 pub indices: Vec<Vec<u32>>,
782 pub quantization_loss: f32,
783 pub perplexity: f32,
784}
785
786#[derive(Debug)]
788pub struct ResidualBlock {
789 conv1: Conv1d,
790 conv2: Conv1d,
791 skip_conv: Option<Conv1d>,
792 dropout: candle_nn::Dropout,
793}
794
795impl ResidualBlock {
796 pub fn new(
797 in_channels: usize,
798 out_channels: usize,
799 dropout_rate: f32,
800 vb: VarBuilder,
801 ) -> Result<Self> {
802 let conv1 = conv1d(
803 in_channels,
804 out_channels,
805 3,
806 Conv1dConfig {
807 padding: 1,
808 stride: 1,
809 dilation: 1,
810 groups: 1,
811 cudnn_fwd_algo: None,
812 },
813 vb.pp("conv1"),
814 )?;
815
816 let conv2 = conv1d(
817 out_channels,
818 out_channels,
819 3,
820 Conv1dConfig {
821 padding: 1,
822 stride: 1,
823 dilation: 1,
824 groups: 1,
825 cudnn_fwd_algo: None,
826 },
827 vb.pp("conv2"),
828 )?;
829
830 let skip_conv = if in_channels != out_channels {
831 Some(conv1d(
832 in_channels,
833 out_channels,
834 1,
835 Conv1dConfig {
836 padding: 0,
837 stride: 1,
838 dilation: 1,
839 groups: 1,
840 cudnn_fwd_algo: None,
841 },
842 vb.pp("skip"),
843 )?)
844 } else {
845 None
846 };
847
848 let dropout = candle_nn::Dropout::new(dropout_rate);
849
850 Ok(Self {
851 conv1,
852 conv2,
853 skip_conv,
854 dropout,
855 })
856 }
857
858 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
859 let mut residual = x.clone();
860
861 let mut out = self
863 .conv1
864 .forward(x)
865 .map_err(|e| Error::Processing(e.to_string()))?;
866 out = out.elu(1.0).map_err(|e| Error::Processing(e.to_string()))?;
867 out = self.dropout.forward(&out, false)?;
868
869 out = self
871 .conv2
872 .forward(&out)
873 .map_err(|e| Error::Processing(e.to_string()))?;
874
875 if let Some(ref skip) = self.skip_conv {
877 residual = skip
878 .forward(&residual)
879 .map_err(|e| Error::Processing(e.to_string()))?;
880 }
881
882 let result = (&out + &residual)?;
884
885 result
887 .elu(1.0)
888 .map_err(|e| Error::Processing(e.to_string()))
889 }
890}
891
892#[derive(Debug)]
894pub struct NeuralCodec {
895 config: NeuralCodecConfig,
896 encoder: NeuralEncoder,
897 decoder: NeuralDecoder,
898 quantizer: VectorQuantizer,
899 device: Device,
900}
901
902impl NeuralCodec {
903 pub fn new(config: NeuralCodecConfig, vb: VarBuilder, device: Device) -> Result<Self> {
904 config.validate()?;
905
906 let encoder = NeuralEncoder::new(&config, vb.pp("encoder"))?;
907 let decoder = NeuralDecoder::new(&config, vb.pp("decoder"))?;
908 let quantizer = VectorQuantizer::new(&config, vb.pp("quantizer"))?;
909
910 Ok(Self {
911 config,
912 encoder,
913 decoder,
914 quantizer,
915 device,
916 })
917 }
918
919 pub fn compress(&self, request: &CodecCompressionRequest) -> Result<CodecCompressionResult> {
921 let start_time = Instant::now();
922
923 if request.audio.is_empty() {
925 return Err(Error::InvalidInput("Empty audio input".to_string()));
926 }
927
928 let audio_tensor = self.prepare_audio_tensor(&request.audio)?;
930
931 let encoded = self.encoder.forward(&audio_tensor)?;
933
934 let quantization_result = self.quantizer.forward(&encoded)?;
936
937 let compression_time = start_time.elapsed();
939 let original_size = request.audio.len() * 4; let compressed_size = self.calculate_compressed_size(&quantization_result.indices);
941 let compression_ratio = original_size as f32 / compressed_size as f32;
942 let actual_bitrate =
943 self.calculate_bitrate(compressed_size, request.audio.len(), request.sample_rate);
944
945 let reconstructed = self.decoder.forward(&quantization_result.quantized)?;
947 let quality_metrics = self.evaluate_quality(&audio_tensor, &reconstructed)?;
948
949 Ok(CodecCompressionResult {
950 codes: quantization_result.indices,
951 quantizer_indices: (0..self.config.num_quantizers).collect(),
952 compression_ratio,
953 actual_bitrate,
954 compression_time_ms: compression_time.as_millis() as u64,
955 quality_metrics,
956 metadata: CodecMetadata {
957 original_length: request.audio.len(),
958 compressed_size,
959 codec_version: "NeuralCodec-1.0".to_string(),
960 encoding_params: {
961 let mut params = HashMap::new();
962 params.insert("quality_level".to_string(), request.quality_level);
963 params.insert(
964 "target_bitrate".to_string(),
965 request.target_bitrate.unwrap_or(self.config.target_bitrate),
966 );
967 params
968 },
969 timestamp: std::time::SystemTime::now()
970 .duration_since(std::time::UNIX_EPOCH)
971 .expect("SystemTime should be after UNIX_EPOCH")
972 .as_secs(),
973 },
974 })
975 }
976
977 pub fn decompress(
979 &self,
980 codes: &[Vec<u32>],
981 metadata: &CodecMetadata,
982 ) -> Result<CodecDecompressionResult> {
983 let start_time = Instant::now();
984
985 if codes.is_empty() {
986 return Err(Error::InvalidInput("Empty codes input".to_string()));
987 }
988
989 let quantized = self.quantizer.decode(codes)?;
991
992 let reconstructed = self.decoder.forward(&quantized)?;
994
995 let audio = self.tensor_to_audio(&reconstructed)?;
997
998 let decompression_time = start_time.elapsed();
999 let duration = audio.len() as f32 / self.config.sample_rate as f32;
1000
1001 Ok(CodecDecompressionResult {
1002 audio,
1003 sample_rate: self.config.sample_rate,
1004 duration,
1005 decompression_time_ms: decompression_time.as_millis() as u64,
1006 quality_metrics: None, })
1008 }
1009
1010 fn prepare_audio_tensor(&self, audio: &[f32]) -> Result<Tensor> {
1011 let batch_size = 1;
1012 let channels = self.config.channels;
1013 let length = audio.len() / channels;
1014
1015 Tensor::from_vec(audio.to_vec(), (batch_size, channels, length), &self.device)
1016 .map_err(|e| Error::Processing(e.to_string()))
1017 }
1018
1019 fn tensor_to_audio(&self, tensor: &Tensor) -> Result<Vec<f32>> {
1020 let data = tensor
1021 .flatten_all()?
1022 .to_vec1::<f32>()
1023 .map_err(|e| Error::Processing(e.to_string()))?;
1024 Ok(data)
1025 }
1026
1027 fn calculate_compressed_size(&self, indices: &[Vec<u32>]) -> usize {
1028 let bits_per_code = (self.config.codebook_size as f32).log2().ceil() as usize;
1029 let total_codes: usize = indices.iter().map(|layer| layer.len()).sum();
1030 total_codes * bits_per_code / 8 }
1032
1033 fn calculate_bitrate(
1034 &self,
1035 compressed_size: usize,
1036 audio_length: usize,
1037 sample_rate: u32,
1038 ) -> f32 {
1039 let duration_seconds = audio_length as f32 / sample_rate as f32;
1040 (compressed_size * 8) as f32 / duration_seconds / 1000.0 }
1042
1043 fn evaluate_quality(
1044 &self,
1045 original: &Tensor,
1046 reconstructed: &Tensor,
1047 ) -> Result<CodecQualityMetrics> {
1048 let mse = (original - reconstructed)?
1050 .sqr()?
1051 .mean_all()?
1052 .to_scalar::<f32>()
1053 .map_err(|e| Error::Processing(e.to_string()))?;
1054
1055 let snr_db = -10.0 * mse.log10();
1056
1057 Ok(CodecQualityMetrics {
1058 snr_db,
1059 pesq_score: 3.5 + (snr_db / 30.0).min(1.0), stoi_score: 0.8 + (snr_db / 50.0).min(0.2), spectral_distortion_db: mse.sqrt() * 20.0,
1062 bitrate_efficiency: snr_db / self.config.target_bitrate,
1063 perceptual_quality: (snr_db / 30.0).clamp(0.0, 1.0),
1064 temporal_consistency: 0.9, artifacts_score: mse.sqrt(),
1066 })
1067 }
1068}
1069
1070#[derive(Debug)]
1072pub struct NeuralCodecManager {
1073 codec: Arc<RwLock<NeuralCodec>>,
1074 config: NeuralCodecConfig,
1075 device: Device,
1076 compression_cache: Arc<RwLock<HashMap<String, CodecCompressionResult>>>,
1077 performance_stats: Arc<RwLock<CodecPerformanceStats>>,
1078}
1079
1080#[derive(Debug, Default, Clone)]
1081pub struct CodecPerformanceStats {
1082 pub total_compressions: u64,
1083 pub total_decompressions: u64,
1084 pub total_compression_time: Duration,
1085 pub total_decompression_time: Duration,
1086 pub average_compression_ratio: f32,
1087 pub average_quality_score: f32,
1088 pub cache_hits: u64,
1089 pub cache_misses: u64,
1090}
1091
1092impl NeuralCodecManager {
1093 pub fn new(config: NeuralCodecConfig) -> Result<Self> {
1094 config.validate()?;
1095
1096 let cuda_available =
1097 std::panic::catch_unwind(candle_core::utils::cuda_is_available).unwrap_or(false);
1098 let device = if cuda_available {
1099 std::panic::catch_unwind(|| Device::new_cuda(0))
1100 .unwrap_or(Ok(Device::Cpu))
1101 .unwrap_or(Device::Cpu)
1102 } else {
1103 Device::Cpu
1104 };
1105
1106 info!("Initializing Neural Codec on device: {:?}", device);
1107
1108 let varmap = VarMap::new();
1110 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
1111
1112 let codec = NeuralCodec::new(config.clone(), vb, device.clone())?;
1113
1114 Ok(Self {
1115 codec: Arc::new(RwLock::new(codec)),
1116 config,
1117 device,
1118 compression_cache: Arc::new(RwLock::new(HashMap::new())),
1119 performance_stats: Arc::new(RwLock::new(CodecPerformanceStats::default())),
1120 })
1121 }
1122
1123 pub fn high_quality() -> Result<Self> {
1125 Self::new(NeuralCodecConfig::high_quality())
1126 }
1127
1128 pub fn low_bitrate() -> Result<Self> {
1130 Self::new(NeuralCodecConfig::low_bitrate())
1131 }
1132
1133 pub fn realtime_optimized() -> Result<Self> {
1135 Self::new(NeuralCodecConfig::realtime_optimized())
1136 }
1137
1138 pub async fn compress(
1140 &self,
1141 request: CodecCompressionRequest,
1142 ) -> Result<CodecCompressionResult> {
1143 let cache_key = self.compute_cache_key(&request);
1144
1145 {
1147 let cache = self.compression_cache.read().await;
1148 if let Some(cached_result) = cache.get(&cache_key) {
1149 let mut stats = self.performance_stats.write().await;
1150 stats.cache_hits += 1;
1151 debug!("Cache hit for compression request");
1152 return Ok(cached_result.clone());
1153 }
1154 }
1155
1156 let codec = self.codec.read().await;
1158 let result = codec.compress(&request)?;
1159 drop(codec);
1160
1161 {
1163 let mut stats = self.performance_stats.write().await;
1164 stats.total_compressions += 1;
1165 stats.total_compression_time += Duration::from_millis(result.compression_time_ms);
1166 stats.average_compression_ratio = (stats.average_compression_ratio
1167 * (stats.total_compressions - 1) as f32
1168 + result.compression_ratio)
1169 / stats.total_compressions as f32;
1170 stats.average_quality_score = (stats.average_quality_score
1171 * (stats.total_compressions - 1) as f32
1172 + result.quality_metrics.snr_db)
1173 / stats.total_compressions as f32;
1174 stats.cache_misses += 1;
1175 }
1176
1177 {
1179 let mut cache = self.compression_cache.write().await;
1180 cache.insert(cache_key, result.clone());
1181
1182 if cache.len() > 1000 {
1184 let oldest_key = cache
1185 .keys()
1186 .next()
1187 .expect("cache is non-empty (len > 1000)")
1188 .clone();
1189 cache.remove(&oldest_key);
1190 }
1191 }
1192
1193 info!(
1194 "Neural codec compression completed: {:.2}x compression, {:.2} kbps, {:.2} dB SNR",
1195 result.compression_ratio, result.actual_bitrate, result.quality_metrics.snr_db
1196 );
1197
1198 Ok(result)
1199 }
1200
1201 pub async fn decompress(
1203 &self,
1204 codes: &[Vec<u32>],
1205 metadata: &CodecMetadata,
1206 ) -> Result<CodecDecompressionResult> {
1207 let codec = self.codec.read().await;
1208 let result = codec.decompress(codes, metadata)?;
1209 drop(codec);
1210
1211 {
1213 let mut stats = self.performance_stats.write().await;
1214 stats.total_decompressions += 1;
1215 stats.total_decompression_time += Duration::from_millis(result.decompression_time_ms);
1216 }
1217
1218 info!(
1219 "Neural codec decompression completed: {:.2}s audio in {}ms",
1220 result.duration, result.decompression_time_ms
1221 );
1222
1223 Ok(result)
1224 }
1225
1226 fn compute_cache_key(&self, request: &CodecCompressionRequest) -> String {
1227 use std::collections::hash_map::DefaultHasher;
1228 use std::hash::{Hash, Hasher};
1229
1230 let mut hasher = DefaultHasher::new();
1231
1232 if !request.audio.is_empty() {
1234 request.audio[0].to_bits().hash(&mut hasher);
1235 if request.audio.len() > 1 {
1236 request.audio[request.audio.len() - 1]
1237 .to_bits()
1238 .hash(&mut hasher);
1239 }
1240 request.audio.len().hash(&mut hasher);
1241 }
1242
1243 request.sample_rate.hash(&mut hasher);
1244 ((request.quality_level * 1000.0) as u32).hash(&mut hasher);
1245 request.perceptual_optimization.hash(&mut hasher);
1246 request.temporal_consistency.hash(&mut hasher);
1247
1248 format!("neural_codec_{:x}", hasher.finish())
1249 }
1250
1251 pub async fn get_performance_stats(&self) -> CodecPerformanceStats {
1253 (*self.performance_stats.read().await).clone()
1254 }
1255
1256 pub async fn clear_cache(&self) {
1258 self.compression_cache.write().await.clear();
1259 info!("Neural codec compression cache cleared");
1260 }
1261
1262 pub fn config(&self) -> &NeuralCodecConfig {
1264 &self.config
1265 }
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270 use super::*;
1271
1272 #[test]
1273 fn test_neural_codec_config_validation() {
1274 let config = NeuralCodecConfig::default();
1275 assert!(config.validate().is_ok());
1276
1277 let mut invalid_config = config.clone();
1278 invalid_config.sample_rate = 0;
1279 assert!(invalid_config.validate().is_err());
1280
1281 invalid_config = config.clone();
1282 invalid_config.codebook_size = 1023; assert!(invalid_config.validate().is_err());
1284 }
1285
1286 #[test]
1287 fn test_codec_compression_request_default() {
1288 let request = CodecCompressionRequest::default();
1289 assert_eq!(request.sample_rate, 24000);
1290 assert_eq!(request.quality_level, 0.8);
1291 assert!(request.perceptual_optimization);
1292 }
1293
1294 #[tokio::test]
1295 async fn test_neural_codec_manager_creation() {
1296 let config = NeuralCodecConfig::low_bitrate();
1297 let manager = NeuralCodecManager::new(config);
1298 match manager {
1301 Ok(_) => {
1302 }
1304 Err(e) => {
1305 eprintln!("Expected failure creating manager without weights: {}", e);
1307 }
1308 }
1309 }
1310
1311 #[tokio::test]
1312 async fn test_neural_codec_compression() {
1313 let manager = match NeuralCodecManager::low_bitrate() {
1314 Ok(m) => m,
1315 Err(_) => return, };
1317
1318 let request = CodecCompressionRequest {
1319 audio: {
1320 let mut audio = Vec::new();
1321 let pattern = [0.1, -0.1, 0.2, -0.2, 0.0];
1322 for _ in 0..(1024 / pattern.len()) {
1323 audio.extend_from_slice(&pattern);
1324 }
1325 audio.extend_from_slice(&pattern[0..(1024 % pattern.len())]);
1326 audio
1327 },
1328 sample_rate: 24000,
1329 quality_level: 0.8,
1330 ..Default::default()
1331 };
1332
1333 let result = manager.compress(request).await;
1334 assert!(result.is_ok());
1335
1336 let compression_result = result.unwrap();
1337 assert!(!compression_result.codes.is_empty());
1338 assert!(compression_result.compression_ratio > 1.0);
1339 assert!(compression_result.actual_bitrate > 0.0);
1340 }
1341}