1use crate::bitreader::BitReader;
2use crate::error::AecError;
3use crate::params::{AecFlags, AecParams};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum Flush {
7 NoFlush,
9 Flush,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DecodeStatus {
15 NeedInput,
17 NeedOutput,
19 Finished,
21}
22
23pub struct Decoder {
34 params: AecParams,
35 bytes_per_sample: usize,
36 id_len: usize,
37 preprocess: bool,
38
39 output_samples: usize,
40 samples_written: usize,
41
42 predictor_x: Option<i64>,
44 sample_index_within_rsi: u64,
45 block_index_within_rsi: u32,
46
47 reader: StreamBitReader,
49
50 pending: Vec<u8>,
52 pending_pos: usize,
53
54 pending_repeat: Option<PendingRepeat>,
56
57 total_in: usize,
58 total_out: usize,
59}
60
61#[derive(Debug, Clone)]
62struct PendingRepeat {
63 coded_value: u32,
64 remaining: usize,
65}
66
67impl Decoder {
68 pub fn new(params: AecParams, output_samples: usize) -> Result<Self, AecError> {
69 validate_params(params)?;
70 let bytes_per_sample = bytes_per_sample(params)?;
71 let id_len = id_len(params)?;
72
73 Ok(Self {
74 params,
75 bytes_per_sample,
76 id_len,
77 preprocess: params.flags.contains(AecFlags::DATA_PREPROCESS),
78 output_samples,
79 samples_written: 0,
80 predictor_x: None,
81 sample_index_within_rsi: 0,
82 block_index_within_rsi: 0,
83 reader: StreamBitReader::new(),
84 pending: Vec::new(),
85 pending_pos: 0,
86 pending_repeat: None,
87 total_in: 0,
88 total_out: 0,
89 })
90 }
91
92 pub fn push_input(&mut self, input: &[u8]) {
94 self.reader.push(input);
95 }
96
97 pub fn total_in(&self) -> usize {
99 self.total_in
100 }
101
102 pub fn total_out(&self) -> usize {
104 self.total_out
105 }
106
107 pub fn avail_in(&self) -> usize {
109 self.reader.avail_bytes()
110 }
111
112 pub fn decode(&mut self, out: &mut [u8], flush: Flush) -> Result<(usize, DecodeStatus), AecError> {
114 if self.samples_written >= self.output_samples {
115 return Ok((0, DecodeStatus::Finished));
116 }
117
118 let mut written: usize = 0;
119
120 written += self.flush_pending(out, written);
122 if written >= out.len() {
123 self.total_out += written;
124 return Ok((written, DecodeStatus::NeedOutput));
125 }
126
127 if let Some(status) = self.flush_repeat(out, &mut written)? {
129 self.total_out += written;
130 return Ok((written, status));
131 }
132
133 while written < out.len() {
135 if self.samples_written >= self.output_samples {
136 self.total_out += written;
137 return Ok((written, DecodeStatus::Finished));
138 }
139
140 if self.preprocess && self.block_index_within_rsi == 0 {
142 self.predictor_x = None;
143 }
144
145 let snapshot = self.snapshot();
147 match self.decode_next_unit() {
148 Ok(()) => {
149 let consumed = self.reader.compact_consumed_bytes();
151 self.total_in += consumed;
152
153 written += self.flush_pending(out, written);
155 if written >= out.len() {
156 self.total_out += written;
157 return Ok((written, DecodeStatus::NeedOutput));
158 }
159
160 if let Some(status) = self.flush_repeat(out, &mut written)? {
161 self.total_out += written;
162 return Ok((written, status));
163 }
164
165 }
167 Err(AecError::UnexpectedEof { .. }) | Err(AecError::UnexpectedEofDuringDecode { .. }) => {
168 self.restore(snapshot);
170 self.total_out += written;
171 return match flush {
172 Flush::NoFlush => Ok((written, DecodeStatus::NeedInput)),
173 Flush::Flush => Err(AecError::UnexpectedEofDuringDecode {
174 bit_pos: self.reader.bits_read_total(),
175 samples_written: self.samples_written,
176 }),
177 };
178 }
179 Err(e) => {
180 self.restore(snapshot);
181 return Err(e);
182 }
183 }
184 }
185
186 self.total_out += written;
187 Ok((written, DecodeStatus::NeedOutput))
188 }
189
190 fn flush_pending(&mut self, out: &mut [u8], written: usize) -> usize {
191 if self.pending_pos >= self.pending.len() {
192 self.pending.clear();
193 self.pending_pos = 0;
194 return 0;
195 }
196
197 let available = out.len().saturating_sub(written);
198 let remaining = self.pending.len().saturating_sub(self.pending_pos);
199 let to_copy = available.min(remaining);
200
201 out[written..written + to_copy]
202 .copy_from_slice(&self.pending[self.pending_pos..self.pending_pos + to_copy]);
203 self.pending_pos += to_copy;
204 to_copy
205 }
206
207 fn flush_repeat(&mut self, out: &mut [u8], written: &mut usize) -> Result<Option<DecodeStatus>, AecError> {
208 let Some(rep) = self.pending_repeat.as_mut() else {
209 return Ok(None);
210 };
211
212 while *written < out.len() && rep.remaining > 0 {
213 if self.samples_written >= self.output_samples {
214 self.pending_repeat = None;
215 return Ok(Some(DecodeStatus::Finished));
216 }
217
218 let out_start = *written;
220 let out_end = out_start + self.bytes_per_sample;
221 if out_end > out.len() {
222 return Ok(Some(DecodeStatus::NeedOutput));
223 }
224
225 let mut tmp = OutBuf::new(&mut out[out_start..out_end], self.bytes_per_sample);
227 tmp.pos = 0;
228 emit_coded_value(
229 &mut tmp,
230 &mut self.predictor_x,
231 self.params,
232 self.bytes_per_sample,
233 rep.coded_value,
234 &mut self.sample_index_within_rsi,
235 usize::MAX,
236 )?;
237 *written += self.bytes_per_sample;
238 self.samples_written += 1;
239 rep.remaining -= 1;
240 }
241
242 if rep.remaining == 0 {
243 self.pending_repeat = None;
244 }
245
246 if *written >= out.len() {
247 return Ok(Some(DecodeStatus::NeedOutput));
248 }
249 Ok(None)
250 }
251
252 fn snapshot(&self) -> Snapshot {
253 Snapshot {
254 predictor_x: self.predictor_x,
255 sample_index_within_rsi: self.sample_index_within_rsi,
256 block_index_within_rsi: self.block_index_within_rsi,
257 samples_written: self.samples_written,
258 reader: self.reader.clone(),
259 pending: self.pending.clone(),
260 pending_pos: self.pending_pos,
261 pending_repeat: self.pending_repeat.clone(),
262 }
263 }
264
265 fn restore(&mut self, s: Snapshot) {
266 self.predictor_x = s.predictor_x;
267 self.sample_index_within_rsi = s.sample_index_within_rsi;
268 self.block_index_within_rsi = s.block_index_within_rsi;
269 self.samples_written = s.samples_written;
270 self.reader = s.reader;
271 self.pending = s.pending;
272 self.pending_pos = s.pending_pos;
273 self.pending_repeat = s.pending_repeat;
274 }
275
276 fn decode_next_unit(&mut self) -> Result<(), AecError> {
277 if self.pending_pos < self.pending.len() {
279 return Ok(());
280 }
281
282 let mut block_out: Vec<u8> = vec![0u8; self.bytes_per_sample * (self.params.block_size as usize)];
284 let mut out = OutBuf::new(&mut block_out, self.bytes_per_sample);
285
286 if self.preprocess && self.block_index_within_rsi == 0 {
288 self.predictor_x = None;
289 }
290
291 let at_rsi_start = self.preprocess && self.block_index_within_rsi == 0;
292 let ref_pending = at_rsi_start;
293 let mut reference_sample_consumed = false;
294
295 let id = self.reader.read_bits_u32(self.id_len)?;
297 let max_id = (1u32 << self.id_len) - 1;
298
299 let mut consume_reference = |this: &mut Self, out: &mut OutBuf<'_>| -> Result<(), AecError> {
301 let ref_raw = this.reader.read_bits_u32(this.params.bits_per_sample as usize)?;
302 let ref_val = if this.params.flags.contains(AecFlags::DATA_SIGNED) {
303 sign_extend(ref_raw, this.params.bits_per_sample)
304 } else {
305 ref_raw as i64
306 };
307 write_sample(out, ref_val, this.params)?;
308 this.predictor_x = Some(ref_val);
309 reference_sample_consumed = true;
310 this.sample_index_within_rsi += 1;
311 Ok(())
312 };
313
314 let remaining_total_samples = self.output_samples.saturating_sub(self.samples_written);
315 let max_samples_this_block = (self.params.block_size as usize).min(remaining_total_samples);
316
317 if id == 0 {
318 let selector = self.reader.read_bit()?;
320
321 if ref_pending {
323 consume_reference(self, &mut out)?;
324 self.samples_written += 1;
325 }
326
327 let remaining_total_samples = self.output_samples.saturating_sub(self.samples_written);
329
330 let mut remaining_in_block = self.params.block_size as usize;
331 if reference_sample_consumed {
332 remaining_in_block = remaining_in_block.saturating_sub(1);
333 }
334
335 if !selector {
336 let fs = read_unary_stream(&mut self.reader)?;
338 let mut z_blocks = fs + 1;
339 const ROS: u32 = 5;
340 if z_blocks == ROS {
341 let b = self.block_index_within_rsi;
342 let fill1 = self.params.rsi.saturating_sub(b);
343 let fill2 = 64u32.saturating_sub(b % 64);
344 z_blocks = fill1.min(fill2);
345 } else if z_blocks > ROS {
346 z_blocks = z_blocks.saturating_sub(1);
347 }
348
349 let mut zeros_samples = (z_blocks as usize)
350 .checked_mul(self.params.block_size as usize)
351 .ok_or(AecError::InvalidInput("zero-run overflow"))?;
352 if reference_sample_consumed {
353 zeros_samples = zeros_samples.saturating_sub(1);
354 }
355
356 zeros_samples = zeros_samples.min(remaining_total_samples);
358
359 let produced_len = out.len();
361 drop(out);
362 self.pending = block_out[..produced_len].to_vec();
363 self.pending_pos = 0;
364
365 if zeros_samples > 0 {
367 self.pending_repeat = Some(PendingRepeat { coded_value: 0, remaining: zeros_samples });
368 }
369
370 self.block_index_within_rsi = self.block_index_within_rsi.saturating_add(z_blocks);
372 if self.block_index_within_rsi >= self.params.rsi {
373 self.block_index_within_rsi %= self.params.rsi;
374 if self.params.flags.contains(AecFlags::PAD_RSI) {
375 self.reader.align_to_byte();
376 }
377 self.sample_index_within_rsi = 0;
378 }
379
380 return Ok(());
382 }
383
384 let mut produced_samples = 0usize;
386 while remaining_in_block > 0 && produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
387 let m = read_unary_stream(&mut self.reader)?;
388 if m > 90 {
389 return Err(AecError::InvalidInput("Second Extension unary symbol too large"));
390 }
391 let (a, b) = second_extension_pair(m);
392
393 if produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
395 emit_coded_value(
396 &mut out,
397 &mut self.predictor_x,
398 self.params,
399 self.bytes_per_sample,
400 a,
401 &mut self.sample_index_within_rsi,
402 usize::MAX,
403 )?;
404 produced_samples += 1;
405 self.samples_written += 1;
406 }
407
408 if remaining_in_block > 0 {
409 remaining_in_block = remaining_in_block.saturating_sub(1);
410 }
411 if produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
412 emit_coded_value(
413 &mut out,
414 &mut self.predictor_x,
415 self.params,
416 self.bytes_per_sample,
417 b,
418 &mut self.sample_index_within_rsi,
419 usize::MAX,
420 )?;
421 produced_samples += 1;
422 self.samples_written += 1;
423 }
424 if remaining_in_block > 0 {
425 remaining_in_block = remaining_in_block.saturating_sub(1);
426 }
427 }
428 } else if id == max_id {
429 if ref_pending {
431 consume_reference(self, &mut out)?;
432 self.samples_written += 1;
433 }
434
435 let mut remaining_in_block = self.params.block_size as usize;
436 if reference_sample_consumed {
437 remaining_in_block = remaining_in_block.saturating_sub(1);
438 }
439
440 for _ in 0..remaining_in_block {
441 if self.samples_written >= self.output_samples {
442 break;
443 }
444 let v = self.reader.read_bits_u32(self.params.bits_per_sample as usize)?;
445 emit_coded_value(
446 &mut out,
447 &mut self.predictor_x,
448 self.params,
449 self.bytes_per_sample,
450 v,
451 &mut self.sample_index_within_rsi,
452 usize::MAX,
453 )?;
454 self.samples_written += 1;
455 }
456 } else {
457 let k = (id - 1) as usize;
459 if ref_pending {
460 consume_reference(self, &mut out)?;
461 self.samples_written += 1;
462 }
463
464 let mut remaining_in_block = self.params.block_size as usize;
465 if reference_sample_consumed {
466 remaining_in_block = remaining_in_block.saturating_sub(1);
467 }
468 let n = remaining_in_block.min(self.output_samples.saturating_sub(self.samples_written));
469 let mut tmp: Vec<u32> = vec![0u32; n];
470
471 for i in 0..n {
472 let q = read_unary_stream(&mut self.reader)?;
473 tmp[i] = (q as u32)
474 .checked_shl(k as u32)
475 .ok_or(AecError::InvalidInput("rice shift overflow"))?;
476 }
477 if k > 0 {
478 for i in 0..n {
479 let rem = self.reader.read_bits_u32(k)?;
480 tmp[i] |= rem;
481 }
482 }
483 for v in tmp {
484 if self.samples_written >= self.output_samples {
485 break;
486 }
487 emit_coded_value(
488 &mut out,
489 &mut self.predictor_x,
490 self.params,
491 self.bytes_per_sample,
492 v,
493 &mut self.sample_index_within_rsi,
494 usize::MAX,
495 )?;
496 self.samples_written += 1;
497 }
498 }
499
500 let produced_len = out.len();
502 drop(out);
503 self.pending = block_out[..produced_len].to_vec();
504 self.pending_pos = 0;
505
506 self.block_index_within_rsi = self.block_index_within_rsi.saturating_add(1);
508 if self.preprocess && self.block_index_within_rsi >= self.params.rsi {
509 self.block_index_within_rsi = 0;
510 self.sample_index_within_rsi = 0;
511 if self.params.flags.contains(AecFlags::PAD_RSI) {
512 self.reader.align_to_byte();
513 }
514 }
515
516 Ok(())
517 }
518}
519
520#[derive(Clone)]
521struct Snapshot {
522 predictor_x: Option<i64>,
523 sample_index_within_rsi: u64,
524 block_index_within_rsi: u32,
525 samples_written: usize,
526 reader: StreamBitReader,
527 pending: Vec<u8>,
528 pending_pos: usize,
529 pending_repeat: Option<PendingRepeat>,
530}
531
532#[derive(Debug, Clone)]
536struct StreamBitReader {
537 buf: Vec<u8>,
538 bit_pos: usize,
539 total_bytes_dropped: usize,
540}
541
542impl StreamBitReader {
543 fn new() -> Self {
544 Self { buf: Vec::new(), bit_pos: 0, total_bytes_dropped: 0 }
545 }
546
547 fn push(&mut self, data: &[u8]) {
548 self.buf.extend_from_slice(data);
549 }
550
551 fn avail_bytes(&self) -> usize {
552 self.buf.len().saturating_sub(self.bit_pos / 8)
553 }
554
555 fn bits_read_total(&self) -> usize {
556 self.total_bytes_dropped * 8 + self.bit_pos
557 }
558
559 fn align_to_byte(&mut self) {
560 let rem = self.bit_pos % 8;
561 if rem != 0 {
562 self.bit_pos += 8 - rem;
563 }
564 }
565
566 fn read_bit(&mut self) -> Result<bool, AecError> {
567 Ok(self.read_bits_u32(1)? != 0)
568 }
569
570 fn read_bits_u32(&mut self, nbits: usize) -> Result<u32, AecError> {
571 if nbits == 0 {
572 return Ok(0);
573 }
574 if nbits > 32 {
575 return Err(AecError::InvalidInput("read_bits_u32 supports up to 32 bits"));
576 }
577
578 let mut out: u32 = 0;
579 for _ in 0..nbits {
580 let byte_idx = self.bit_pos / 8;
581 let bit_in_byte = self.bit_pos % 8;
582 let byte = *self
583 .buf
584 .get(byte_idx)
585 .ok_or(AecError::UnexpectedEof { bit_pos: self.bits_read_total() })?;
586 let bit = (byte >> (7 - bit_in_byte)) & 1;
587 out = (out << 1) | (bit as u32);
588 self.bit_pos += 1;
589 }
590 Ok(out)
591 }
592
593 fn compact_consumed_bytes(&mut self) -> usize {
594 let bytes = self.bit_pos / 8;
595 if bytes == 0 {
596 return 0;
597 }
598 self.buf.drain(0..bytes);
599 self.bit_pos -= bytes * 8;
600 self.total_bytes_dropped += bytes;
601 bytes
602 }
603}
604
605fn read_unary_stream(r: &mut StreamBitReader) -> Result<u32, AecError> {
606 let mut count: u32 = 0;
607 loop {
608 let bit = r.read_bit()?;
609 if bit {
610 return Ok(count);
611 }
612 count = count.saturating_add(1);
613 if count > 1_000_000 {
614 return Err(AecError::InvalidInput("unary run too long"));
615 }
616 }
617}
618
619struct OutBuf<'a> {
620 buf: &'a mut [u8],
621 pos: usize,
622 bytes_per_sample: usize,
623}
624
625impl<'a> OutBuf<'a> {
626 fn new(buf: &'a mut [u8], bytes_per_sample: usize) -> Self {
627 Self { buf, pos: 0, bytes_per_sample }
628 }
629
630 fn len(&self) -> usize {
631 self.pos
632 }
633
634 fn capacity(&self) -> usize {
635 self.buf.len()
636 }
637
638 fn samples_written(&self) -> usize {
639 self.pos / self.bytes_per_sample
640 }
641}
642
643pub fn decode(input: &[u8], params: AecParams, output_samples: usize) -> Result<Vec<u8>, AecError> {
644 validate_params(params)?;
645
646 let bytes_per_sample = bytes_per_sample(params)?;
647 let output_bytes = output_samples
648 .checked_mul(bytes_per_sample)
649 .ok_or(AecError::InvalidInput("output too large"))?;
650
651 let mut out = vec![0u8; output_bytes];
652 decode_into(input, params, output_samples, &mut out)?;
653 Ok(out)
654}
655
656pub fn decode_into(
657 input: &[u8],
658 params: AecParams,
659 output_samples: usize,
660 output: &mut [u8],
661) -> Result<(), AecError> {
662 validate_params(params)?;
663
664 let trace_sample: Option<usize> = std::env::var("RUST_AEC_TRACE_SAMPLE")
665 .ok()
666 .and_then(|v| v.parse::<usize>().ok());
667
668 let bytes_per_sample = bytes_per_sample(params)?;
669 let output_bytes = output_samples
670 .checked_mul(bytes_per_sample)
671 .ok_or(AecError::InvalidInput("output too large"))?;
672
673 if output.len() != output_bytes {
674 return Err(AecError::InvalidInput("output buffer has wrong length"));
675 }
676
677 let mut out = OutBuf::new(output, bytes_per_sample);
678 let mut r = BitReader::new(input);
679
680 let id_len = id_len(params)?;
681
682 let preprocess = params.flags.contains(AecFlags::DATA_PREPROCESS);
683
684 let mut sample_index_within_rsi: u64 = 0;
685 let mut block_index_within_rsi: u32 = 0;
686
687 let mut predictor_x: Option<i64> = None;
689
690 while out.len() < output_bytes {
691 if preprocess && block_index_within_rsi == 0 {
693 predictor_x = None;
694 }
695
696 let at_rsi_start = preprocess && block_index_within_rsi == 0;
697 let ref_pending = at_rsi_start;
698 let mut reference_sample_consumed = false;
699
700 let block_start_sample = out.samples_written();
701
702 let id = match r.read_bits_u32(id_len) {
704 Ok(v) => v,
705 Err(AecError::UnexpectedEof { bit_pos }) => {
706 return Err(AecError::UnexpectedEofDuringDecode {
707 bit_pos,
708 samples_written: out.samples_written(),
709 });
710 }
711 Err(e) => return Err(e),
712 };
713
714 let max_id = (1u32 << id_len) - 1;
715
716 let mut remaining_in_block: usize;
719
720 let mut consume_reference = |r: &mut BitReader, out: &mut OutBuf<'_>| -> Result<(), AecError> {
722 let ref_raw = match r.read_bits_u32(params.bits_per_sample as usize) {
723 Ok(v) => v,
724 Err(AecError::UnexpectedEof { bit_pos }) => {
725 return Err(AecError::UnexpectedEofDuringDecode {
726 bit_pos,
727 samples_written: out.samples_written(),
728 });
729 }
730 Err(e) => return Err(e),
731 };
732 let ref_val = if params.flags.contains(AecFlags::DATA_SIGNED) {
733 sign_extend(ref_raw, params.bits_per_sample)
734 } else {
735 ref_raw as i64
736 };
737
738 write_sample(out, ref_val, params)?;
739 predictor_x = Some(ref_val);
740 reference_sample_consumed = true;
741 sample_index_within_rsi += 1;
742 Ok(())
743 };
744
745 if id == 0 {
746 let selector = match r.read_bit() {
748 Ok(v) => v,
749 Err(AecError::UnexpectedEof { bit_pos }) => {
750 return Err(AecError::UnexpectedEofDuringDecode {
751 bit_pos,
752 samples_written: out.samples_written(),
753 });
754 }
755 Err(e) => return Err(e),
756 };
757
758 if let Some(ts) = trace_sample {
759 let block_end = block_start_sample + params.block_size as usize;
760 if (block_start_sample..block_end).contains(&ts) {
761 eprintln!(
762 "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id=0 mode=LE selector={} block_samples=[{}, {})",
763 r.bits_read(),
764 selector,
765 block_start_sample,
766 block_end
767 );
768 }
769 }
770
771 if ref_pending {
773 consume_reference(&mut r, &mut out)?;
774 if out.len() >= output_bytes {
775 break;
776 }
777 }
778
779 remaining_in_block = params.block_size as usize;
780 if reference_sample_consumed {
781 remaining_in_block = remaining_in_block.saturating_sub(1);
782 }
783
784 if !selector {
785 let fs = match read_unary(&mut r) {
787 Ok(v) => v,
788 Err(AecError::UnexpectedEof { bit_pos }) => {
789 return Err(AecError::UnexpectedEofDuringDecode {
790 bit_pos,
791 samples_written: out.samples_written(),
792 });
793 }
794 Err(e) => return Err(e),
795 };
796 let mut z_blocks = fs + 1;
797
798 const ROS: u32 = 5;
799
800 if z_blocks == ROS {
801 let b = block_index_within_rsi;
803 let fill1 = params.rsi.saturating_sub(b);
804 let fill2 = 64u32.saturating_sub(b % 64);
805 z_blocks = fill1.min(fill2);
806 } else if z_blocks > ROS {
807 z_blocks = z_blocks.saturating_sub(1);
808 }
809
810 let mut zeros_samples = z_blocks
811 .checked_mul(params.block_size)
812 .ok_or(AecError::InvalidInput("zero-run overflow"))? as usize;
813
814 if reference_sample_consumed {
817 zeros_samples = zeros_samples.saturating_sub(1);
818 }
819
820 if let Some(ts) = trace_sample {
821 let total_samples = (z_blocks as usize)
822 .checked_mul(params.block_size as usize)
823 .unwrap_or(usize::MAX);
824 let run_end = block_start_sample.saturating_add(total_samples);
825 if (block_start_sample..run_end).contains(&ts) {
826 eprintln!(
827 "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id=0 mode=ZRUN fs={} z_blocks={} run_samples=[{}, {})",
828 r.bits_read(),
829 fs,
830 z_blocks,
831 block_start_sample,
832 run_end
833 );
834 }
835 }
836
837 emit_repeated_value(
838 &mut out,
839 &mut predictor_x,
840 params,
841 bytes_per_sample,
842 0,
843 zeros_samples,
844 &mut sample_index_within_rsi,
845 output_bytes,
846 )?;
847
848 block_index_within_rsi = block_index_within_rsi.saturating_add(z_blocks);
851 if block_index_within_rsi >= params.rsi {
852 block_index_within_rsi %= params.rsi;
853 if params.flags.contains(AecFlags::PAD_RSI) {
854 r.align_to_byte();
855 }
856 sample_index_within_rsi = 0;
857 }
858
859 continue;
860 }
861
862 emit_second_extension(
864 &mut r,
865 &mut out,
866 &mut predictor_x,
867 params,
868 bytes_per_sample,
869 remaining_in_block,
870 reference_sample_consumed,
871 &mut sample_index_within_rsi,
872 output_bytes,
873 )?;
874 } else if id == max_id {
875 if let Some(ts) = trace_sample {
877 let block_end = block_start_sample + params.block_size as usize;
878 if (block_start_sample..block_end).contains(&ts) {
879 eprintln!(
880 "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id={} mode=UNCOMP block_samples=[{}, {})",
881 r.bits_read(),
882 id,
883 block_start_sample,
884 block_end
885 );
886 }
887 }
888 if ref_pending {
889 consume_reference(&mut r, &mut out)?;
891 if out.len() >= output_bytes {
892 break;
893 }
894 remaining_in_block = params.block_size as usize - 1;
895 } else {
896 remaining_in_block = params.block_size as usize;
897 }
898
899 for _ in 0..remaining_in_block {
900 let v = match r.read_bits_u32(params.bits_per_sample as usize) {
901 Ok(v) => v,
902 Err(AecError::UnexpectedEof { bit_pos }) => {
903 return Err(AecError::UnexpectedEofDuringDecode {
904 bit_pos,
905 samples_written: out.samples_written(),
906 });
907 }
908 Err(e) => return Err(e),
909 };
910 emit_coded_value(
911 &mut out,
912 &mut predictor_x,
913 params,
914 bytes_per_sample,
915 v,
916 &mut sample_index_within_rsi,
917 output_bytes,
918 )?;
919 if out.len() >= output_bytes {
920 break;
921 }
922 }
923 } else {
924 let k = (id - 1) as usize;
927
928 if let Some(ts) = trace_sample {
929 let block_end = block_start_sample + params.block_size as usize;
930 if (block_start_sample..block_end).contains(&ts) {
931 eprintln!(
932 "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id={} mode=SPLIT k={} block_samples=[{}, {})",
933 r.bits_read(),
934 id,
935 k,
936 block_start_sample,
937 block_end
938 );
939 }
940 }
941
942 if ref_pending {
943 consume_reference(&mut r, &mut out)?;
944 if out.len() >= output_bytes {
945 break;
946 }
947 }
948
949 remaining_in_block = params.block_size as usize;
950 if reference_sample_consumed {
951 remaining_in_block = remaining_in_block.saturating_sub(1);
952 }
953
954 let n = remaining_in_block;
955 let mut tmp: Vec<u32> = vec![0u32; n];
956
957 let trace_offset_in_block: Option<usize> = trace_sample.and_then(|ts| {
960 let coded_start = out.samples_written();
961 if ts >= coded_start && ts < coded_start + n {
962 Some(ts - coded_start)
963 } else {
964 None
965 }
966 });
967 let mut trace_q: Option<u32> = None;
968 let mut trace_rem: Option<u32> = None;
969
970 for i in 0..n {
971 let q = match read_unary(&mut r) {
972 Ok(v) => v,
973 Err(AecError::UnexpectedEof { bit_pos }) => {
974 return Err(AecError::UnexpectedEofDuringDecode {
975 bit_pos,
976 samples_written: out.samples_written(),
977 });
978 }
979 Err(e) => return Err(e),
980 };
981 if trace_offset_in_block == Some(i) {
982 trace_q = Some(q);
983 }
984 tmp[i] = (q as u32)
985 .checked_shl(k as u32)
986 .ok_or(AecError::InvalidInput("rice shift overflow"))?;
987 }
988
989 if k > 0 {
990 for i in 0..n {
991 let rem_bitpos_before = if trace_offset_in_block
992 .map(|off| i + 2 >= off && i <= off + 2)
993 .unwrap_or(false)
994 {
995 Some(r.bits_read())
996 } else {
997 None
998 };
999
1000 let rem = match r.read_bits_u32(k) {
1001 Ok(v) => v,
1002 Err(AecError::UnexpectedEof { bit_pos }) => {
1003 return Err(AecError::UnexpectedEofDuringDecode {
1004 bit_pos,
1005 samples_written: out.samples_written(),
1006 });
1007 }
1008 Err(e) => return Err(e),
1009 };
1010
1011 if let (Some(off), Some(bitpos)) = (trace_offset_in_block, rem_bitpos_before) {
1012 if i + 2 >= off && i <= off + 2 {
1013 eprintln!(
1014 "TRACE rem i={} (off={}) bitpos={} bits={:0width$b} rem={}",
1015 i,
1016 off,
1017 bitpos,
1018 rem,
1019 rem,
1020 width = k
1021 );
1022 }
1023 }
1024
1025 if trace_offset_in_block == Some(i) {
1026 trace_rem = Some(rem);
1027 }
1028 tmp[i] |= rem;
1029 }
1030 }
1031
1032 if let Some(off) = trace_offset_in_block {
1033 let d = tmp[off];
1034 let w_start = off.saturating_sub(2);
1035 let w_end = (off + 3).min(n);
1036 let window = tmp[w_start..w_end].to_vec();
1037 eprintln!(
1038 "TRACE split-detail sample={} rsi_block={} id={} k={} off={} q={:?} rem={:?} d={} window[{}..{}]={:?}",
1039 trace_sample.unwrap_or(0),
1040 block_index_within_rsi,
1041 id,
1042 k,
1043 off,
1044 trace_q,
1045 trace_rem,
1046 d
1047 ,
1048 w_start,
1049 w_end,
1050 window
1051 );
1052 }
1053
1054 for v in tmp {
1055 emit_coded_value(
1056 &mut out,
1057 &mut predictor_x,
1058 params,
1059 bytes_per_sample,
1060 v,
1061 &mut sample_index_within_rsi,
1062 output_bytes,
1063 )?;
1064 if out.len() >= output_bytes {
1065 break;
1066 }
1067 }
1068 }
1069
1070 block_index_within_rsi = block_index_within_rsi.saturating_add(1);
1072 if preprocess && block_index_within_rsi >= params.rsi {
1073 block_index_within_rsi = 0;
1074 sample_index_within_rsi = 0;
1075 if params.flags.contains(AecFlags::PAD_RSI) {
1076 r.align_to_byte();
1077 }
1078 }
1079 }
1080
1081 Ok(())
1082}
1083
1084fn validate_params(params: AecParams) -> Result<(), AecError> {
1085 if !(1..=32).contains(¶ms.bits_per_sample) {
1086 return Err(AecError::InvalidInput("bits_per_sample must be 1..=32"));
1087 }
1088 if params.block_size == 0 {
1089 return Err(AecError::InvalidInput("block_size must be > 0"));
1090 }
1091 if params.rsi == 0 {
1092 return Err(AecError::InvalidInput("rsi must be > 0"));
1093 }
1094
1095 if ![8u32, 16, 32, 64].contains(¶ms.block_size) {
1097 return Err(AecError::Unsupported("block_size must be one of 8,16,32,64"));
1098 }
1099
1100 Ok(())
1101}
1102
1103fn bytes_per_sample(params: AecParams) -> Result<usize, AecError> {
1104 let bps = params.bits_per_sample;
1105
1106 let b = match bps {
1107 1..=8 => 1,
1108 9..=16 => 2,
1109 17..=24 => {
1110 if params.flags.contains(AecFlags::DATA_3BYTE) {
1111 3
1112 } else {
1113 4
1114 }
1115 }
1116 25..=32 => 4,
1117 _ => return Err(AecError::InvalidInput("invalid bits_per_sample")),
1118 };
1119
1120 Ok(b)
1121}
1122
1123fn id_len(params: AecParams) -> Result<usize, AecError> {
1124 let bps = params.bits_per_sample;
1125
1126 let mut id_len = if bps > 16 { 5 } else if bps > 8 { 4 } else { 3 };
1127
1128 if params.flags.contains(AecFlags::RESTRICTED) && bps <= 4 {
1129 id_len = if bps <= 2 { 1 } else { 2 };
1130 }
1131
1132 Ok(id_len)
1133}
1134
1135fn read_unary(r: &mut BitReader<'_>) -> Result<u32, AecError> {
1136 let mut count: u32 = 0;
1137 loop {
1138 let bit = r.read_bit()?;
1139 if bit {
1140 return Ok(count);
1141 }
1142 count = count.saturating_add(1);
1143 if count > 1_000_000 {
1147 return Err(AecError::InvalidInput("unary run too long"));
1148 }
1149 }
1150}
1151
1152fn emit_coded_value(
1153 out: &mut OutBuf<'_>,
1154 predictor_x: &mut Option<i64>,
1155 params: AecParams,
1156 _bytes_per_sample: usize,
1157 v: u32,
1158 sample_index_within_rsi: &mut u64,
1159 output_bytes: usize,
1160) -> Result<(), AecError> {
1161 if out.len() >= output_bytes {
1162 return Ok(());
1163 }
1164
1165 if params.flags.contains(AecFlags::DATA_PREPROCESS) {
1166 let x_prev = predictor_x.ok_or(AecError::InvalidInput("missing reference sample"))?;
1167 let x_next = inverse_preprocess_step(x_prev, v, params);
1168 write_sample(out, x_next, params)?;
1169 *predictor_x = Some(x_next);
1170 *sample_index_within_rsi += 1;
1171 return Ok(());
1172 }
1173
1174 write_sample(out, v as i64, params)?;
1176 *sample_index_within_rsi += 1;
1177 Ok(())
1178}
1179
1180fn emit_repeated_value(
1181 out: &mut OutBuf<'_>,
1182 predictor_x: &mut Option<i64>,
1183 params: AecParams,
1184 bytes_per_sample: usize,
1185 v: u32,
1186 count: usize,
1187 sample_index_within_rsi: &mut u64,
1188 output_bytes: usize,
1189) -> Result<(), AecError> {
1190 for _ in 0..count {
1191 if out.len() >= output_bytes {
1192 break;
1193 }
1194 emit_coded_value(
1195 out,
1196 predictor_x,
1197 params,
1198 bytes_per_sample,
1199 v,
1200 sample_index_within_rsi,
1201 output_bytes,
1202 )?;
1203 }
1204 Ok(())
1205}
1206
1207fn emit_second_extension(
1208 r: &mut BitReader<'_>,
1209 out: &mut OutBuf<'_>,
1210 predictor_x: &mut Option<i64>,
1211 params: AecParams,
1212 bytes_per_sample: usize,
1213 mut remaining_in_block: usize,
1214 reference_sample_consumed: bool,
1215 sample_index_within_rsi: &mut u64,
1216 output_bytes: usize,
1217) -> Result<(), AecError> {
1218 let mut need_odd_first = reference_sample_consumed;
1222
1223 while remaining_in_block > 0 && out.len() < output_bytes {
1224 let m = read_unary(r)?;
1225 if m > 90 {
1226 return Err(AecError::InvalidInput("Second Extension unary symbol too large"));
1227 }
1228
1229 let (a, b) = second_extension_pair(m);
1230
1231 if need_odd_first {
1232 emit_coded_value(
1234 out,
1235 predictor_x,
1236 params,
1237 bytes_per_sample,
1238 b,
1239 sample_index_within_rsi,
1240 output_bytes,
1241 )?;
1242 remaining_in_block = remaining_in_block.saturating_sub(1);
1243 need_odd_first = false;
1244 continue;
1245 }
1246
1247 emit_coded_value(
1249 out,
1250 predictor_x,
1251 params,
1252 bytes_per_sample,
1253 a,
1254 sample_index_within_rsi,
1255 output_bytes,
1256 )?;
1257 remaining_in_block = remaining_in_block.saturating_sub(1);
1258 if remaining_in_block == 0 || out.len() >= output_bytes {
1259 break;
1260 }
1261
1262 emit_coded_value(
1264 out,
1265 predictor_x,
1266 params,
1267 bytes_per_sample,
1268 b,
1269 sample_index_within_rsi,
1270 output_bytes,
1271 )?;
1272 remaining_in_block = remaining_in_block.saturating_sub(1);
1273 }
1274
1275 Ok(())
1276}
1277
1278fn second_extension_pair(m: u32) -> (u32, u32) {
1279 let mut idx: u32 = 0;
1281 for s in 0u32..=12 {
1282 for k in 0u32..=s {
1283 if idx == m {
1284 return (s - k, k);
1285 }
1286 idx += 1;
1287 }
1288 }
1289
1290 (0, 0)
1292}
1293
1294fn inverse_preprocess_step(x_prev: i64, d: u32, params: AecParams) -> i64 {
1295 let n = params.bits_per_sample;
1296
1297 let delta: i64 = ((d >> 1) as i64) ^ (!(((d & 1) as i64) - 1));
1302 let half_d: i64 = ((d >> 1) + (d & 1)) as i64;
1303
1304 if params.flags.contains(AecFlags::DATA_SIGNED) {
1305 let signed_max: i64 = (1i64 << (n - 1)) - 1;
1307 let data = x_prev;
1308
1309 if data < 0 {
1310 if half_d <= signed_max + data + 1 {
1311 data + delta
1312 } else {
1313 (d as i64) - signed_max - 1
1314 }
1315 } else {
1316 if half_d <= signed_max - data {
1317 data + delta
1318 } else {
1319 signed_max - (d as i64)
1320 }
1321 }
1322 } else {
1323 let unsigned_max: u64 = (1u64 << n) - 1;
1324 let data_u: u64 = x_prev as u64;
1325
1326 let med: u64 = unsigned_max / 2 + 1;
1328 let mask: u64 = if (data_u & med) != 0 { unsigned_max } else { 0 };
1329
1330 if (half_d as u64) <= (mask ^ data_u) {
1331 (x_prev + delta) as i64
1332 } else {
1333 (mask ^ (d as u64)) as i64
1334 }
1335 }
1336}
1337
1338fn write_sample(out: &mut OutBuf<'_>, value: i64, params: AecParams) -> Result<(), AecError> {
1339 let n = params.bits_per_sample as u32;
1340 let mask: u64 = if n == 32 { u64::MAX } else { (1u64 << n) - 1 };
1341
1342 let raw_u = if params.flags.contains(AecFlags::DATA_SIGNED) {
1343 (value as i64 as u64) & mask
1344 } else {
1345 (value.max(0) as u64) & mask
1346 };
1347
1348 let bytes_per_sample = out.bytes_per_sample;
1349 if out.pos.checked_add(bytes_per_sample).ok_or(AecError::InvalidInput("output too large"))? > out.capacity() {
1350 return Err(AecError::InvalidInput("output buffer too small"));
1351 }
1352
1353 let msb = params.flags.contains(AecFlags::MSB);
1354 if msb {
1355 for i in (0..bytes_per_sample).rev() {
1356 out.buf[out.pos] = ((raw_u >> (i * 8)) & 0xff) as u8;
1357 out.pos += 1;
1358 }
1359 } else {
1360 for i in 0..bytes_per_sample {
1361 out.buf[out.pos] = ((raw_u >> (i * 8)) & 0xff) as u8;
1362 out.pos += 1;
1363 }
1364 }
1365
1366 Ok(())
1367}
1368
1369fn sign_extend(raw: u32, bits: u8) -> i64 {
1370 if bits == 32 {
1371 return (raw as i32) as i64;
1372 }
1373 let shift = 32 - bits as u32;
1374 (((raw << shift) as i32) >> shift) as i64
1375}