Skip to main content

signinum_transcode/
dct97_2d.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Constrained 2D DCT to irreversible 9/7 wavelet transforms.
4//!
5//! The production float path performs a separable 8x8 IDCT into a reusable
6//! spatial plane, then applies the separable single-level 9/7 transform.
7
8use rayon::prelude::*;
9
10use crate::dct_grid::{high_len, idct8_basis, idct8_basis_table, low_len, validate_dct_block_grid};
11pub use crate::DctGridError as Dct97GridError;
12
13const ALPHA: f64 = -1.586_134_342_059_924;
14const BETA: f64 = -0.052_980_118_572_961;
15const GAMMA: f64 = 0.882_911_075_530_934;
16const DELTA: f64 = 0.443_506_852_043_971;
17const KAPPA: f64 = 1.230_174_104_914_001;
18const INV_KAPPA: f64 = 1.0 / KAPPA;
19const PARALLEL_IDCT_MIN_SAMPLES: usize = 64 * 64;
20
21/// One separable single-level 2D 9/7 transform result.
22#[derive(Debug, Clone, PartialEq)]
23pub struct Dwt97TwoDimensional<T> {
24    /// Low-horizontal, low-vertical band.
25    pub ll: Vec<T>,
26    /// High-horizontal, low-vertical band.
27    pub hl: Vec<T>,
28    /// Low-horizontal, high-vertical band.
29    pub lh: Vec<T>,
30    /// High-horizontal, high-vertical band.
31    pub hh: Vec<T>,
32    /// Width of horizontally low-pass bands.
33    pub low_width: usize,
34    /// Height of vertically low-pass bands.
35    pub low_height: usize,
36    /// Width of horizontally high-pass bands.
37    pub high_width: usize,
38    /// Height of vertically high-pass bands.
39    pub high_height: usize,
40}
41
42/// Scratch storage for repeated DCT-grid to 9/7 transform calls.
43#[derive(Debug, Default)]
44pub struct Dct97GridScratch {
45    spatial_samples: Vec<f64>,
46    plane: Dwt97PlaneScratch,
47}
48
49#[derive(Debug, Default)]
50struct Dwt97PlaneScratch {
51    row_low: Vec<f64>,
52    row_high: Vec<f64>,
53    lift_workspace: Vec<f64>,
54}
55
56impl Dct97GridScratch {
57    /// Capacity of the reusable spatial-sample buffer used by the IDCT-then
58    /// 9/7 path.
59    #[must_use]
60    pub fn spatial_sample_capacity(&self) -> usize {
61        self.spatial_samples.capacity()
62    }
63}
64
65/// Reference path for a DCT block grid:
66/// DCT coefficients -> float IDCT samples -> separable linearized 9/7.
67pub fn dct8x8_blocks_then_dwt97_float(
68    blocks: &[[[f64; 8]; 8]],
69    block_cols: usize,
70    block_rows: usize,
71    width: usize,
72    height: usize,
73) -> Result<Dwt97TwoDimensional<f64>, Dct97GridError> {
74    validate_grid(blocks.len(), block_cols, block_rows, width, height)?;
75
76    let mut samples = Vec::with_capacity(width * height);
77    for y in 0..height {
78        let block_y = y / 8;
79        let local_y = y % 8;
80        for x in 0..width {
81            let block_x = x / 8;
82            let local_x = x % 8;
83            let block = &blocks[block_y * block_cols + block_x];
84            samples.push(idct8x8_sample(block, local_x, local_y));
85        }
86    }
87
88    Ok(linearized_97_2d_from_plane(&samples, width, height))
89}
90
91/// Reference 9/7 path with caller-owned spatial-sample scratch:
92/// DCT coefficients -> float IDCT samples -> separable linearized 9/7.
93pub fn dct8x8_blocks_then_dwt97_float_with_scratch(
94    blocks: &[[[f64; 8]; 8]],
95    block_cols: usize,
96    block_rows: usize,
97    width: usize,
98    height: usize,
99    scratch: &mut Dct97GridScratch,
100) -> Result<Dwt97TwoDimensional<f64>, Dct97GridError> {
101    validate_grid(blocks.len(), block_cols, block_rows, width, height)?;
102
103    let sample_count = width.saturating_mul(height);
104    scratch.spatial_samples.clear();
105    scratch.spatial_samples.resize(sample_count, 0.0);
106    idct8x8_blocks_to_samples(
107        blocks,
108        block_cols,
109        width,
110        height,
111        &mut scratch.spatial_samples,
112    );
113
114    Ok(linearized_97_2d_from_plane_with_plane_scratch(
115        &scratch.spatial_samples,
116        width,
117        height,
118        &mut scratch.plane,
119    ))
120}
121
122pub(crate) fn linearized_97_2d_from_plane(
123    samples: &[f64],
124    width: usize,
125    height: usize,
126) -> Dwt97TwoDimensional<f64> {
127    let mut scratch = Dct97GridScratch::default();
128    linearized_97_2d_from_plane_with_scratch(samples, width, height, &mut scratch)
129}
130
131pub(crate) fn linearized_97_2d_from_plane_with_scratch(
132    samples: &[f64],
133    width: usize,
134    height: usize,
135    scratch: &mut Dct97GridScratch,
136) -> Dwt97TwoDimensional<f64> {
137    linearized_97_2d_from_plane_with_plane_scratch(samples, width, height, &mut scratch.plane)
138}
139
140fn linearized_97_2d_from_plane_with_plane_scratch(
141    samples: &[f64],
142    width: usize,
143    height: usize,
144    scratch: &mut Dwt97PlaneScratch,
145) -> Dwt97TwoDimensional<f64> {
146    debug_assert_eq!(samples.len(), width * height);
147
148    let low_width = low_len(width);
149    let low_height = low_len(height);
150    let high_width = high_len(width);
151    let high_height = high_len(height);
152
153    scratch.row_low.clear();
154    scratch.row_low.resize(height * low_width, 0.0);
155    scratch.row_high.clear();
156    scratch.row_high.resize(height * high_width, 0.0);
157
158    for y in 0..height {
159        let start = y * width;
160        let row = &samples[start..start + width];
161        let low_start = y * low_width;
162        let high_start = y * high_width;
163        linearized_97_split_contiguous_into(
164            row,
165            &mut scratch.row_low[low_start..low_start + low_width],
166            &mut scratch.row_high[high_start..high_start + high_width],
167            &mut scratch.lift_workspace,
168        );
169    }
170
171    let mut ll = vec![0.0; low_width * low_height];
172    let mut lh = vec![0.0; low_width * high_height];
173    for x in 0..low_width {
174        linearized_97_split_strided_into(
175            &scratch.row_low,
176            low_width,
177            x,
178            height,
179            &mut ll,
180            &mut lh,
181            low_width,
182            &mut scratch.lift_workspace,
183        );
184    }
185
186    let mut hl = vec![0.0; high_width * low_height];
187    let mut hh = vec![0.0; high_width * high_height];
188    for x in 0..high_width {
189        linearized_97_split_strided_into(
190            &scratch.row_high,
191            high_width,
192            x,
193            height,
194            &mut hl,
195            &mut hh,
196            high_width,
197            &mut scratch.lift_workspace,
198        );
199    }
200
201    Dwt97TwoDimensional {
202        ll,
203        hl,
204        lh,
205        hh,
206        low_width,
207        low_height,
208        high_width,
209        high_height,
210    }
211}
212
213fn idct8x8_sample(block: &[[f64; 8]; 8], x: usize, y: usize) -> f64 {
214    let mut sample = 0.0;
215    for (freq_y, row) in block.iter().enumerate() {
216        let y_basis = idct8_basis(y, freq_y);
217        for (freq_x, coefficient) in row.iter().copied().enumerate() {
218            sample += coefficient * y_basis * idct8_basis(x, freq_x);
219        }
220    }
221    sample
222}
223
224fn idct8x8_blocks_to_samples(
225    blocks: &[[[f64; 8]; 8]],
226    block_cols: usize,
227    width: usize,
228    height: usize,
229    samples: &mut [f64],
230) {
231    debug_assert_eq!(samples.len(), width * height);
232    let basis = idct8_basis_table();
233    let active_block_cols = width.div_ceil(8);
234    let active_block_rows = height.div_ceil(8);
235
236    if width * height >= PARALLEL_IDCT_MIN_SAMPLES {
237        samples
238            .par_chunks_mut(width * 8)
239            .enumerate()
240            .take(active_block_rows)
241            .for_each(|(block_y, sample_rows)| {
242                idct8x8_block_row_to_samples(
243                    blocks,
244                    block_cols,
245                    width,
246                    height,
247                    basis,
248                    active_block_cols,
249                    block_y,
250                    sample_rows,
251                );
252            });
253    } else {
254        for block_y in 0..active_block_rows {
255            let block_sample_y = block_y * 8;
256            let output_rows = (height - block_sample_y).min(8);
257            let row_start = block_sample_y * width;
258            let row_end = row_start + output_rows * width;
259            idct8x8_block_row_to_samples(
260                blocks,
261                block_cols,
262                width,
263                height,
264                basis,
265                active_block_cols,
266                block_y,
267                &mut samples[row_start..row_end],
268            );
269        }
270    }
271}
272
273#[allow(clippy::too_many_arguments)]
274fn idct8x8_block_row_to_samples(
275    blocks: &[[[f64; 8]; 8]],
276    block_cols: usize,
277    width: usize,
278    height: usize,
279    basis: &[[f64; 8]; 8],
280    active_block_cols: usize,
281    block_y: usize,
282    sample_rows: &mut [f64],
283) {
284    let block_sample_y = block_y * 8;
285    let output_rows = (height - block_sample_y).min(8);
286    for block_x in 0..active_block_cols {
287        let block_sample_x = block_x * 8;
288        let output_cols = (width - block_sample_x).min(8);
289        let block = &blocks[block_y * block_cols + block_x];
290        let mut vertical = [[0.0; 8]; 8];
291
292        for (local_y, basis_row) in basis.iter().enumerate() {
293            for freq_x in 0..8 {
294                let mut sum = 0.0;
295                for (freq_y, block_row) in block.iter().enumerate() {
296                    sum += basis_row[freq_y] * block_row[freq_x];
297                }
298                vertical[local_y][freq_x] = sum;
299            }
300        }
301
302        for (local_y, vertical_row) in vertical.iter().enumerate().take(output_rows) {
303            let row_offset = local_y * width + block_sample_x;
304            for local_x in 0..output_cols {
305                let mut sample = 0.0;
306                for (freq_x, vertical_value) in vertical_row.iter().enumerate() {
307                    sample += *vertical_value * basis[local_x][freq_x];
308                }
309                sample_rows[row_offset + local_x] = sample;
310            }
311        }
312    }
313}
314
315#[cfg(test)]
316fn linearized_97_from_sample_slice(samples: &[f64]) -> Dwt97OneDimensional {
317    let mut lifted = samples.to_vec();
318    forward_lift_97(&mut lifted);
319
320    Dwt97OneDimensional {
321        low: lifted.iter().step_by(2).copied().collect(),
322        high: lifted.iter().skip(1).step_by(2).copied().collect(),
323    }
324}
325
326fn forward_lift_97(data: &mut [f64]) {
327    let n = data.len();
328    if n < 2 {
329        return;
330    }
331
332    let last_even = if n.is_multiple_of(2) { n - 2 } else { n - 1 };
333
334    for i in (1..n).step_by(2) {
335        let left = data[i - 1];
336        let right = if i + 1 < n {
337            data[i + 1]
338        } else {
339            data[last_even]
340        };
341        data[i] += ALPHA * (left + right);
342    }
343
344    for i in (0..n).step_by(2) {
345        let left = if i > 0 { data[i - 1] } else { data[1] };
346        let right = if i + 1 < n { data[i + 1] } else { left };
347        data[i] += BETA * (left + right);
348    }
349
350    for i in (1..n).step_by(2) {
351        let left = data[i - 1];
352        let right = if i + 1 < n {
353            data[i + 1]
354        } else {
355            data[last_even]
356        };
357        data[i] += GAMMA * (left + right);
358    }
359
360    for i in (0..n).step_by(2) {
361        let left = if i > 0 { data[i - 1] } else { data[1] };
362        let right = if i + 1 < n { data[i + 1] } else { left };
363        data[i] += DELTA * (left + right);
364    }
365
366    for i in (0..n).step_by(2) {
367        data[i] *= INV_KAPPA;
368    }
369    for i in (1..n).step_by(2) {
370        data[i] *= KAPPA;
371    }
372}
373
374fn linearized_97_split_contiguous_into(
375    samples: &[f64],
376    low: &mut [f64],
377    high: &mut [f64],
378    workspace: &mut Vec<f64>,
379) {
380    debug_assert_eq!(low.len(), low_len(samples.len()));
381    debug_assert_eq!(high.len(), high_len(samples.len()));
382
383    workspace.clear();
384    workspace.extend_from_slice(samples);
385    forward_lift_97(workspace);
386
387    for (target, value) in low.iter_mut().zip(workspace.iter().step_by(2)) {
388        *target = *value;
389    }
390    for (target, value) in high.iter_mut().zip(workspace.iter().skip(1).step_by(2)) {
391        *target = *value;
392    }
393}
394
395#[allow(clippy::too_many_arguments)]
396fn linearized_97_split_strided_into(
397    samples: &[f64],
398    stride: usize,
399    x: usize,
400    height: usize,
401    low: &mut [f64],
402    high: &mut [f64],
403    band_width: usize,
404    workspace: &mut Vec<f64>,
405) {
406    debug_assert_eq!(low.len(), band_width * low_len(height));
407    debug_assert_eq!(high.len(), band_width * high_len(height));
408
409    workspace.clear();
410    workspace.extend((0..height).map(|y| samples[y * stride + x]));
411    forward_lift_97(workspace);
412
413    for (low_y, value) in workspace.iter().step_by(2).enumerate() {
414        low[low_y * band_width + x] = *value;
415    }
416    for (high_y, value) in workspace.iter().skip(1).step_by(2).enumerate() {
417        high[high_y * band_width + x] = *value;
418    }
419}
420
421fn validate_grid(
422    block_count: usize,
423    block_cols: usize,
424    block_rows: usize,
425    width: usize,
426    height: usize,
427) -> Result<(), Dct97GridError> {
428    validate_dct_block_grid(block_count, block_cols, block_rows, width, height)
429}
430
431#[cfg(test)]
432struct Dwt97OneDimensional {
433    low: Vec<f64>,
434    high: Vec<f64>,
435}
436
437#[cfg(test)]
438mod tests {
439    use core::f64::consts::PI;
440
441    use super::*;
442
443    fn assert_all_close(values: &[f64], expected: f64, epsilon: f64) {
444        for &value in values {
445            assert!(
446                (value - expected).abs() < epsilon,
447                "value={value} expected={expected} values={values:?}"
448            );
449        }
450    }
451
452    #[test]
453    fn linearized_97_from_constant_signal_places_dc_in_low_pass() {
454        for len in [2usize, 3, 8, 9, 64, 65] {
455            let samples = vec![50.0; len];
456
457            let transformed = linearized_97_from_sample_slice(&samples);
458
459            assert_all_close(&transformed.low, 50.0, 0.001);
460            assert_all_close(&transformed.high, 0.0, 0.001);
461        }
462    }
463
464    #[test]
465    fn linearized_97_2d_from_constant_plane_places_dc_in_ll() {
466        for (width, height) in [(8usize, 8usize), (9, 7)] {
467            let samples = vec![50.0; width * height];
468
469            let transformed = linearized_97_2d_from_plane(&samples, width, height);
470
471            assert_all_close(&transformed.ll, 50.0, 0.001);
472            assert_all_close(&transformed.hl, 0.0, 0.001);
473            assert_all_close(&transformed.lh, 0.0, 0.001);
474            assert_all_close(&transformed.hh, 0.0, 0.001);
475        }
476    }
477
478    // -------------------------------------------------------------------------
479    // Independent CDF 9/7 ground truth.
480    //
481    // The CUDA 9/7 kernel is parity-tested against `forward_lift_97` /
482    // `linearized_97_2d_from_plane`, so a bug in the lifting would be faithfully
483    // reproduced by the kernel and pass that parity test unnoticed. These tests
484    // close that gap by validating the lifting against an *independent*
485    // implementation: a direct FIR filter bank using the canonical, fully
486    // normalized CDF 9/7 analysis taps and JPEG2000 whole-sample symmetric
487    // extension. Different arithmetic, same transform.
488    //
489    // The taps themselves are checked against their defining mathematical
490    // properties (DC gains and high-pass vanishing moments) so they cannot
491    // silently drift to "match" a buggy lifting.
492
493    /// Canonical CDF 9/7 analysis low-pass filter (9 taps, even-symmetric).
494    /// Fully normalized so its DC gain is 1 (a constant maps unchanged into the
495    /// low band, matching the lifting's `INV_KAPPA` scaling).
496    const REF_LP: [f64; 9] = [
497        0.026_748_757_410_810,
498        -0.016_864_118_442_875,
499        -0.078_223_266_528_990,
500        0.266_864_118_442_875,
501        0.602_949_018_236_360,
502        0.266_864_118_442_875,
503        -0.078_223_266_528_990,
504        -0.016_864_118_442_875,
505        0.026_748_757_410_810,
506    ];
507
508    /// Canonical CDF 9/7 analysis high-pass filter (7 taps, even-symmetric).
509    /// Fully normalized so its DC gain is 0 (matching the lifting's `KAPPA`
510    /// scaling); it has four vanishing moments.
511    const REF_HP: [f64; 7] = [
512        0.091_271_763_114_250,
513        -0.057_543_526_228_500,
514        -0.591_271_763_114_247,
515        1.115_087_052_456_994,
516        -0.591_271_763_114_247,
517        -0.057_543_526_228_500,
518        0.091_271_763_114_250,
519    ];
520
521    /// Whole-sample symmetric reflection: mirror about index 0 and `n - 1`
522    /// without repeating the endpoints. This is the boundary extension
523    /// `forward_lift_97` implements at the array edges.
524    fn ws_reflect(i: isize, n: usize) -> usize {
525        debug_assert!(n >= 1);
526        if n == 1 {
527            return 0;
528        }
529        let n = isize::try_from(n).expect("signal length fits in isize");
530        let period = 2 * (n - 1);
531        let mut k = i.rem_euclid(period);
532        if k >= n {
533            k = period - k;
534        }
535        usize::try_from(k).expect("reflected index is non-negative")
536    }
537
538    /// Independent single-level forward 9/7 analysis via direct convolution.
539    /// Returns `(low, high)` interleaved-position bands matching `forward_lift_97`
540    /// (`low[m]` centered at sample `2m`, `high[m]` centered at sample `2m + 1`).
541    fn ref_analysis_1d(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
542        let n = signal.len();
543        if n < 2 {
544            // The lifting leaves <2-length signals unchanged (low = the sample).
545            return (signal.to_vec(), Vec::new());
546        }
547        let mut low = vec![0.0; low_len(n)];
548        let mut high = vec![0.0; high_len(n)];
549        for (m, out) in low.iter_mut().enumerate() {
550            let center = 2 * isize::try_from(m).unwrap();
551            *out = REF_LP
552                .iter()
553                .enumerate()
554                .map(|(t, &tap)| {
555                    tap * signal[ws_reflect(center + isize::try_from(t).unwrap() - 4, n)]
556                })
557                .sum();
558        }
559        for (m, out) in high.iter_mut().enumerate() {
560            let center = 2 * isize::try_from(m).unwrap() + 1;
561            *out = REF_HP
562                .iter()
563                .enumerate()
564                .map(|(t, &tap)| {
565                    tap * signal[ws_reflect(center + isize::try_from(t).unwrap() - 3, n)]
566                })
567                .sum();
568        }
569        (low, high)
570    }
571
572    /// Independent separable 2D forward 9/7 (rows then columns) producing the
573    /// same four-band layout as `linearized_97_2d_from_plane`.
574    fn ref_analysis_2d(samples: &[f64], width: usize, height: usize) -> Dwt97TwoDimensional<f64> {
575        let low_width = low_len(width);
576        let high_width = high_len(width);
577        let low_height = low_len(height);
578        let high_height = high_len(height);
579
580        let mut row_low = vec![0.0; height * low_width];
581        let mut row_high = vec![0.0; height * high_width];
582        for y in 0..height {
583            let (lo, hi) = ref_analysis_1d(&samples[y * width..y * width + width]);
584            row_low[y * low_width..y * low_width + low_width].copy_from_slice(&lo);
585            row_high[y * high_width..y * high_width + high_width].copy_from_slice(&hi);
586        }
587
588        let vertical_split = |source: &[f64], band_width: usize| -> (Vec<f64>, Vec<f64>) {
589            let mut low = vec![0.0; band_width * low_height];
590            let mut high = vec![0.0; band_width * high_height];
591            for x in 0..band_width {
592                let column: Vec<f64> = (0..height).map(|y| source[y * band_width + x]).collect();
593                let (lo, hi) = ref_analysis_1d(&column);
594                for (vy, &value) in lo.iter().enumerate() {
595                    low[vy * band_width + x] = value;
596                }
597                for (vy, &value) in hi.iter().enumerate() {
598                    high[vy * band_width + x] = value;
599                }
600            }
601            (low, high)
602        };
603
604        let (ll, lh) = vertical_split(&row_low, low_width);
605        let (hl, hh) = vertical_split(&row_high, high_width);
606
607        Dwt97TwoDimensional {
608            ll,
609            hl,
610            lh,
611            hh,
612            low_width,
613            low_height,
614            high_width,
615            high_height,
616        }
617    }
618
619    /// Small deterministic PRNG (LCG) for reproducible test signals in [-1, 1).
620    fn next_unit(state: &mut u64) -> f64 {
621        *state = state
622            .wrapping_mul(6_364_136_223_846_793_005)
623            .wrapping_add(1_442_695_040_888_963_407);
624        ((*state >> 11) as f64 / (1u64 << 53) as f64).mul_add(2.0, -1.0)
625    }
626
627    fn assert_bands_close(actual: &[f64], expected: &[f64], label: &str, epsilon: f64) {
628        assert_eq!(actual.len(), expected.len(), "{label} band length");
629        for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
630            assert!(
631                (a - b).abs() <= epsilon,
632                "{label}[{i}] diverged: lifting={a} reference={b} (diff {})",
633                (a - b).abs()
634            );
635        }
636    }
637
638    #[test]
639    fn reference_cdf97_taps_satisfy_their_defining_properties() {
640        // Low-pass DC gain 1, high-pass DC gain 0 — the normalization the
641        // lifting's KAPPA scaling targets.
642        let lp_dc: f64 = REF_LP.iter().sum();
643        assert!((lp_dc - 1.0).abs() < 1e-9, "low-pass DC gain = {lp_dc}");
644        let hp_dc: f64 = REF_HP.iter().sum();
645        assert!(hp_dc.abs() < 1e-9, "high-pass DC gain = {hp_dc}");
646
647        // Even symmetry.
648        for k in 0..4 {
649            assert!(
650                (REF_LP[k] - REF_LP[8 - k]).abs() < 1e-15,
651                "low-pass asymmetric at {k}"
652            );
653        }
654        for k in 0..3 {
655            assert!(
656                (REF_HP[k] - REF_HP[6 - k]).abs() < 1e-15,
657                "high-pass asymmetric at {k}"
658            );
659        }
660
661        // Four vanishing moments: the high-pass annihilates polynomials of
662        // degree <= 3 (so a wrong predict coefficient or sign cannot pass).
663        for m in 1..=3 {
664            let moment: f64 = REF_HP
665                .iter()
666                .enumerate()
667                .map(|(k, &tap)| (k as f64 - 3.0).powi(m) * tap)
668                .sum();
669            assert!(moment.abs() < 1e-9, "high-pass moment {m} = {moment}");
670        }
671    }
672
673    #[test]
674    fn forward_lift_97_matches_independent_filter_bank_1d() {
675        let mut state = 0x1234_5678_9abc_def0u64;
676        for n in [2usize, 3, 4, 5, 8, 9, 12, 15, 16, 23, 32, 33, 64, 65] {
677            let signal: Vec<f64> = (0..n).map(|_| next_unit(&mut state) * 100.0).collect();
678            let lifted = linearized_97_from_sample_slice(&signal);
679            let (low, high) = ref_analysis_1d(&signal);
680            assert_bands_close(&lifted.low, &low, &format!("n={n} low"), 1e-9);
681            assert_bands_close(&lifted.high, &high, &format!("n={n} high"), 1e-9);
682        }
683    }
684
685    #[test]
686    fn forward_lift_97_annihilates_low_degree_polynomials() {
687        // Independent of the filter bank: a correct 9/7 high-pass kills cubics in
688        // the interior (boundary coefficients use symmetric extension). This pins
689        // the predict-step coefficients and signs directly from wavelet theory.
690        let n = 40usize;
691        let polynomials: [[f64; 4]; 4] = [
692            [5.0, 0.0, 0.0, 0.0],
693            [0.0, 2.5, 0.0, 0.0],
694            [1.0, -0.7, 0.3, 0.0],
695            [0.0, 0.0, 0.0, 0.05],
696        ];
697        for coeffs in polynomials {
698            let signal: Vec<f64> = (0..n)
699                .map(|i| {
700                    let x = i as f64;
701                    coeffs[3].mul_add(
702                        x * x * x,
703                        coeffs[2].mul_add(x * x, coeffs[1].mul_add(x, coeffs[0])),
704                    )
705                })
706                .collect();
707            let lifted = linearized_97_from_sample_slice(&signal);
708            // Skip the first/last high-pass coefficients (boundary support).
709            let interior = &lifted.high[3..lifted.high.len() - 3];
710            assert_all_close(interior, 0.0, 1e-6);
711        }
712    }
713
714    #[test]
715    fn linearized_97_2d_matches_independent_separable_filter_bank() {
716        let mut state = 0xfeed_face_dead_beefu64;
717        for (width, height) in [
718            (8usize, 8usize),
719            (16, 16),
720            (24, 16),
721            (15, 13),
722            (16, 23),
723            (9, 7),
724            (32, 32),
725        ] {
726            let samples: Vec<f64> = (0..width * height)
727                .map(|_| next_unit(&mut state) * 100.0)
728                .collect();
729            let got = linearized_97_2d_from_plane(&samples, width, height);
730            let want = ref_analysis_2d(&samples, width, height);
731            assert_eq!(
732                (
733                    got.low_width,
734                    got.low_height,
735                    got.high_width,
736                    got.high_height
737                ),
738                (
739                    want.low_width,
740                    want.low_height,
741                    want.high_width,
742                    want.high_height
743                ),
744                "band dimensions for {width}x{height}"
745            );
746            assert_bands_close(&got.ll, &want.ll, &format!("{width}x{height} ll"), 1e-9);
747            assert_bands_close(&got.hl, &want.hl, &format!("{width}x{height} hl"), 1e-9);
748            assert_bands_close(&got.lh, &want.lh, &format!("{width}x{height} lh"), 1e-9);
749            assert_bands_close(&got.hh, &want.hh, &format!("{width}x{height} hh"), 1e-9);
750        }
751    }
752
753    #[test]
754    fn linearized_97_2d_separates_horizontal_and_vertical_detail() {
755        // Catches an HL/LH swap or a row/column transpose independently of the
756        // filter bank: a plane that varies only along x has no vertical detail
757        // (LH and HH must vanish), and vice versa.
758        let (width, height) = (16usize, 16usize);
759
760        let varies_in_x: Vec<f64> = (0..width * height)
761            .map(|i| ((i % width) as f64).sin().mul_add(30.0, 5.0))
762            .collect();
763        let t = linearized_97_2d_from_plane(&varies_in_x, width, height);
764        assert_all_close(&t.lh, 0.0, 1e-9);
765        assert_all_close(&t.hh, 0.0, 1e-9);
766
767        let varies_in_y: Vec<f64> = (0..width * height)
768            .map(|i| ((i / width) as f64).cos().mul_add(30.0, 5.0))
769            .collect();
770        let t = linearized_97_2d_from_plane(&varies_in_y, width, height);
771        assert_all_close(&t.hl, 0.0, 1e-9);
772        assert_all_close(&t.hh, 0.0, 1e-9);
773    }
774
775    // -------------------------------------------------------------------------
776    // Ground truth: exact mathematical inverse DCT for the float 9/7 path.
777    //
778    // The 9/7 transcode oracle (`dct8x8_blocks_then_dwt97_float`) feeds
779    // `idct8x8_sample` into the wavelet. Validate that IDCT against the defining
780    // DCT-III cosine sum so a basis/normalization/transpose bug cannot hide
781    // inside both the oracle and its CUDA port.
782    fn exact_idct_sample(block: &[[f64; 8]; 8], x: usize, y: usize) -> f64 {
783        let alpha = |k: usize| {
784            if k == 0 {
785                (1.0_f64 / 8.0).sqrt()
786            } else {
787                (2.0_f64 / 8.0).sqrt()
788            }
789        };
790        let cos_term = |sample: usize, freq: usize| {
791            (((2 * sample + 1) as f64) * freq as f64 * PI / 16.0).cos()
792        };
793        let mut acc = 0.0;
794        for (v, row) in block.iter().enumerate() {
795            for (u, &coeff) in row.iter().enumerate() {
796                acc += alpha(u) * alpha(v) * coeff * cos_term(x, u) * cos_term(y, v);
797            }
798        }
799        acc
800    }
801
802    #[test]
803    fn idct8x8_sample_matches_exact_cosine_sum() {
804        let mut state = 0x5151_aaaa_bbbb_ccccu64;
805        for _ in 0..64 {
806            let mut block = [[0.0f64; 8]; 8];
807            for row in &mut block {
808                for coeff in row {
809                    *coeff = next_unit(&mut state) * 64.0;
810                }
811            }
812            for y in 0..8 {
813                for x in 0..8 {
814                    let got = idct8x8_sample(&block, x, y);
815                    let want = exact_idct_sample(&block, x, y);
816                    assert!(
817                        (got - want).abs() < 1e-9,
818                        "idct8x8_sample({x},{y})={got} exact={want}"
819                    );
820                }
821            }
822        }
823    }
824
825    #[test]
826    fn idct8x8_sample_dc_only_is_uniform() {
827        // DC-only block -> uniform plane equal to F(0,0) / 8.
828        let mut block = [[0.0f64; 8]; 8];
829        block[0][0] = 320.0;
830        for y in 0..8 {
831            for x in 0..8 {
832                assert!((idct8x8_sample(&block, x, y) - 40.0).abs() < 1e-9);
833            }
834        }
835    }
836}