Skip to main content

signinum_transcode/
htj2k97_codeblock_oracle.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared scalar oracle: float 9/7 bands into prequantized HTJ2K code-blocks.
4//!
5//! This module uses the native encoder's public irreversible 9/7 quantization
6//! step helper plus the native code-block layout rules, so both GPU backends can
7//! compare their fused code-block kernels against one authoritative CPU
8//! reference instead of each re-deriving the math.
9//!
10//! The re-derivation is anchored to native truth by a codestream pin test (see
11//! the module tests): encoding the oracle's prequantized output reproduces the
12//! native precomputed-DWT codestream byte-for-byte.
13
14use crate::accelerator::Htj2k97CodeBlockOptions;
15use crate::dct97_2d::Dwt97TwoDimensional;
16use signinum_j2k::{
17    J2kSubBandType, PrequantizedHtj2k97CodeBlock, PrequantizedHtj2k97Component,
18    PrequantizedHtj2k97Resolution, PrequantizedHtj2k97Subband,
19};
20use signinum_j2k_native::irreversible_quantization_step_for_subband;
21
22/// Quantize one level of float 9/7 bands into a prequantized HTJ2K component.
23///
24/// Resolution nesting matches the native encoder for a single decomposition
25/// level: resolution 0 holds `[LL]`, resolution 1 holds `[HL, LH, HH]`.
26#[must_use]
27pub fn prequantized_component_from_dwt97(
28    dwt: &Dwt97TwoDimensional<f64>,
29    options: Htj2k97CodeBlockOptions,
30    x_rsiz: u8,
31    y_rsiz: u8,
32) -> PrequantizedHtj2k97Component {
33    PrequantizedHtj2k97Component {
34        x_rsiz,
35        y_rsiz,
36        resolutions: vec![
37            PrequantizedHtj2k97Resolution {
38                subbands: vec![quantize_codeblock_subband(
39                    &dwt.ll,
40                    dwt.low_width,
41                    dwt.low_height,
42                    J2kSubBandType::LowLow,
43                    options,
44                )],
45            },
46            PrequantizedHtj2k97Resolution {
47                subbands: vec![
48                    quantize_codeblock_subband(
49                        &dwt.hl,
50                        dwt.high_width,
51                        dwt.low_height,
52                        J2kSubBandType::HighLow,
53                        options,
54                    ),
55                    quantize_codeblock_subband(
56                        &dwt.lh,
57                        dwt.low_width,
58                        dwt.high_height,
59                        J2kSubBandType::LowHigh,
60                        options,
61                    ),
62                    quantize_codeblock_subband(
63                        &dwt.hh,
64                        dwt.high_width,
65                        dwt.high_height,
66                        J2kSubBandType::HighHigh,
67                        options,
68                    ),
69                ],
70            },
71        ],
72    }
73}
74
75/// Quantize a single float subband and slice it into code-block-major layout.
76///
77/// Code-blocks are emitted outer `cby`, inner `cbx`; each block's coefficients
78/// are row-major, matching the native encoder's `copy_code_block_coefficients`.
79#[must_use]
80pub fn quantize_codeblock_subband(
81    coefficients: &[f64],
82    width: usize,
83    height: usize,
84    sub_band_type: J2kSubBandType,
85    options: Htj2k97CodeBlockOptions,
86) -> PrequantizedHtj2k97Subband {
87    let quantized = quantize_subband_coefficients(coefficients, sub_band_type, options);
88    let cb_width = htj2k97_code_block_dim(options.code_block_width_exp);
89    let cb_height = htj2k97_code_block_dim(options.code_block_height_exp);
90    let num_cbs_x = width.div_ceil(cb_width);
91    let num_cbs_y = height.div_ceil(cb_height);
92    let mut code_blocks = Vec::with_capacity(num_cbs_x * num_cbs_y);
93
94    for cby in 0..num_cbs_y {
95        for cbx in 0..num_cbs_x {
96            let x0 = cbx * cb_width;
97            let y0 = cby * cb_height;
98            let block_width = (width - x0).min(cb_width);
99            let block_height = (height - y0).min(cb_height);
100            let mut block_coefficients = Vec::with_capacity(block_width * block_height);
101            for y in 0..block_height {
102                let row_start = (y0 + y) * width + x0;
103                block_coefficients
104                    .extend_from_slice(&quantized[row_start..row_start + block_width]);
105            }
106            code_blocks.push(PrequantizedHtj2k97CodeBlock {
107                coefficients: block_coefficients,
108                width: block_width as u32,
109                height: block_height as u32,
110            });
111        }
112    }
113
114    PrequantizedHtj2k97Subband {
115        sub_band_type,
116        num_cbs_x: num_cbs_x as u32,
117        num_cbs_y: num_cbs_y as u32,
118        total_bitplanes: htj2k97_subband_total_bitplanes(options, sub_band_type),
119        code_blocks,
120    }
121}
122
123/// Deadzone quantization step size `Δ` for a subband.
124///
125/// `Δ = 2^(range_bits − exponent) · (1 + mantissa/2048)`, with
126/// `range_bits = bit_depth + {LL:0, HL:1, LH:1, HH:2}` and the shared
127/// `(exponent, mantissa)` derived by this module's quantizer.
128#[must_use]
129pub fn htj2k97_subband_delta(
130    options: Htj2k97CodeBlockOptions,
131    sub_band_type: J2kSubBandType,
132) -> f64 {
133    let log_gain = match sub_band_type {
134        J2kSubBandType::LowLow => 0,
135        J2kSubBandType::HighLow | J2kSubBandType::LowHigh => 1,
136        J2kSubBandType::HighHigh => 2,
137    };
138    let range_bits = i32::from(options.bit_depth) + log_gain;
139    let (exponent, mantissa) = htj2k97_step(options, sub_band_type);
140    pow2i_f64(range_bits - i32::from(exponent)) * (1.0 + f64::from(mantissa) / 2048.0)
141}
142
143/// Total declared bitplanes for every code-block in a subband.
144///
145/// `saturating(guard_bits + exponent - 1)`. The exponent is derived from the
146/// effective global plus per-subband quantization profile, so callers must pass
147/// the actual subband kind.
148#[must_use]
149pub fn htj2k97_subband_total_bitplanes(
150    options: Htj2k97CodeBlockOptions,
151    sub_band_type: J2kSubBandType,
152) -> u8 {
153    let (exponent, _) = htj2k97_step(options, sub_band_type);
154    options
155        .guard_bits
156        .saturating_add(exponent)
157        .saturating_sub(1)
158}
159
160/// Validate 9/7 code-block options against the numeric limits both GPU
161/// backends must agree on, returning the decoded `(cb_width, cb_height)`.
162///
163/// One shared implementation keeps Metal and CUDA from drifting: the same
164/// options must be accepted or rejected identically by every backend. Errors
165/// are backend-neutral static strings for the caller's unsupported-job error.
166///
167/// # Errors
168/// Rejects zero/oversized bit depths and guard bits, non-finite or
169/// non-positive quantization scales, code-block dimensions beyond the HTJ2K
170/// limits (sides ≤ 1024, area ≤ 4096), and subband deltas or total bitplane
171/// counts outside the supported range.
172pub fn validate_htj2k97_codeblock_options(
173    options: Htj2k97CodeBlockOptions,
174) -> Result<(usize, usize), &'static str> {
175    if options.bit_depth == 0
176        || options.bit_depth > 30
177        || options.guard_bits > 30
178        || !options.irreversible_quantization_scale.is_finite()
179        || options.irreversible_quantization_scale <= 0.0
180    {
181        return Err("9/7 code-block options are outside supported numeric range");
182    }
183    let subband_scales = options.irreversible_quantization_subband_scales;
184    if [
185        subband_scales.low_low,
186        subband_scales.high_low,
187        subband_scales.low_high,
188        subband_scales.high_high,
189    ]
190    .iter()
191    .any(|scale| !scale.is_finite() || *scale <= 0.0)
192    {
193        return Err("9/7 code-block quantization options are outside supported range");
194    }
195
196    let cb_width = checked_code_block_dim(options.code_block_width_exp)?;
197    let cb_height = checked_code_block_dim(options.code_block_height_exp)?;
198    if cb_width > 1024
199        || cb_height > 1024
200        || cb_width
201            .checked_mul(cb_height)
202            .is_none_or(|area| area > 4096)
203    {
204        return Err("9/7 code-block dimensions exceed HTJ2K limits");
205    }
206
207    for subband in [
208        J2kSubBandType::LowLow,
209        J2kSubBandType::HighLow,
210        J2kSubBandType::LowHigh,
211        J2kSubBandType::HighHigh,
212    ] {
213        let delta = htj2k97_subband_delta(options, subband);
214        if !delta.is_finite()
215            || delta <= 0.0
216            || htj2k97_subband_total_bitplanes(options, subband) > 30
217        {
218            return Err("9/7 code-block quantization options are outside supported range");
219        }
220    }
221
222    Ok((cb_width, cb_height))
223}
224
225fn checked_code_block_dim(exp_minus_two: u8) -> Result<usize, &'static str> {
226    1usize
227        .checked_shl(u32::from(exp_minus_two) + 2)
228        .ok_or("9/7 code-block dimension exponent is unsupported")
229}
230
231fn quantize_subband_coefficients(
232    coefficients: &[f64],
233    sub_band_type: J2kSubBandType,
234    options: Htj2k97CodeBlockOptions,
235) -> Vec<i32> {
236    let delta = htj2k97_subband_delta(options, sub_band_type);
237    let inv_delta = 1.0 / delta;
238
239    coefficients
240        .iter()
241        .map(|&coefficient| {
242            // Deadzone quantization: q = sign(c) · floor(|c| · (1/Δ)), sign(0) = +1.
243            let sign = if coefficient < 0.0 { -1 } else { 1 };
244            sign * (coefficient.abs() * inv_delta).floor() as i32
245        })
246        .collect()
247}
248
249/// Shared `(exponent, mantissa)` for the irreversible 9/7 quantizer.
250fn htj2k97_step(options: Htj2k97CodeBlockOptions, sub_band_type: J2kSubBandType) -> (u8, u16) {
251    let step = irreversible_quantization_step_for_subband(
252        options.bit_depth,
253        options.guard_bits,
254        options.irreversible_quantization_scale,
255        options.irreversible_quantization_subband_scales,
256        sub_band_type,
257    );
258    (step.exponent, step.mantissa)
259}
260
261fn pow2i_f64(exp: i32) -> f64 {
262    2.0f64.powi(exp)
263}
264
265fn htj2k97_code_block_dim(exp_minus_two: u8) -> usize {
266    1usize
267        .checked_shl(u32::from(exp_minus_two) + 2)
268        .unwrap_or(usize::MAX)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use signinum_j2k::{IrreversibleQuantizationSubbandScales, PrequantizedHtj2k97Image};
275    use signinum_j2k_native::{
276        encode_precomputed_htj2k_97, encode_prequantized_htj2k_97, EncodeOptions,
277        J2kForwardDwt97Level, J2kForwardDwt97Output, PrecomputedHtj2k97Component,
278        PrecomputedHtj2k97Image,
279    };
280
281    // Boundary-free coefficients on a 0.25 grid: exact in both f32 and f64, and
282    // every product with the scale-1.0 inverse deltas (4, 2, 1) lands on an exact
283    // integer/half-integer. So the f64 oracle and native's f32 quantizer agree
284    // bit-for-bit here and the codestream pin is exact, not merely close.
285    fn sample_band(len: usize, offset: f64) -> Vec<f64> {
286        (0..len)
287            .map(|idx| ((idx % 17) as f64 - 8.0) * 0.5 + offset)
288            .collect()
289    }
290
291    #[test]
292    fn oracle_prequantized_component_matches_native_precomputed_codestream() {
293        let width = 17u32;
294        let height = 13u32;
295        let low_width = width.div_ceil(2) as usize;
296        let low_height = height.div_ceil(2) as usize;
297        let high_width = (width / 2) as usize;
298        let high_height = (height / 2) as usize;
299
300        let ll = sample_band(low_width * low_height, 0.25);
301        let hl = sample_band(high_width * low_height, -0.75);
302        let lh = sample_band(low_width * high_height, 1.25);
303        let hh = sample_band(high_width * high_height, -1.5);
304
305        let options = EncodeOptions {
306            num_decomposition_levels: 1,
307            reversible: false,
308            guard_bits: 2,
309            code_block_width_exp: 2,
310            code_block_height_exp: 2,
311            ..EncodeOptions::default()
312        };
313
314        // Native precomputed-DWT path quantizes the f32 bands internally.
315        let precomputed_image = PrecomputedHtj2k97Image {
316            width,
317            height,
318            bit_depth: 8,
319            signed: false,
320            components: vec![PrecomputedHtj2k97Component {
321                x_rsiz: 1,
322                y_rsiz: 1,
323                dwt: J2kForwardDwt97Output {
324                    ll: ll.iter().map(|&v| v as f32).collect(),
325                    ll_width: low_width as u32,
326                    ll_height: low_height as u32,
327                    levels: vec![J2kForwardDwt97Level {
328                        hl: hl.iter().map(|&v| v as f32).collect(),
329                        lh: lh.iter().map(|&v| v as f32).collect(),
330                        hh: hh.iter().map(|&v| v as f32).collect(),
331                        width,
332                        height,
333                        low_width: low_width as u32,
334                        low_height: low_height as u32,
335                        high_width: high_width as u32,
336                        high_height: high_height as u32,
337                    }],
338                },
339            }],
340        };
341
342        // Oracle prequantized path (f64) over the same bands.
343        let dwt = Dwt97TwoDimensional {
344            ll,
345            hl,
346            lh,
347            hh,
348            low_width,
349            low_height,
350            high_width,
351            high_height,
352        };
353        let codeblock_options = Htj2k97CodeBlockOptions {
354            bit_depth: 8,
355            guard_bits: 2,
356            code_block_width_exp: 2,
357            code_block_height_exp: 2,
358            irreversible_quantization_scale: 1.0,
359            irreversible_quantization_subband_scales:
360                IrreversibleQuantizationSubbandScales::default(),
361        };
362        let component = prequantized_component_from_dwt97(&dwt, codeblock_options, 1, 1);
363        let prequantized_image = PrequantizedHtj2k97Image {
364            width,
365            height,
366            bit_depth: 8,
367            signed: false,
368            components: vec![component],
369        };
370
371        let expected = encode_precomputed_htj2k_97(&precomputed_image, &options)
372            .expect("native precomputed 9/7 encode");
373        let native_prequantized_image = native_prequantized_image(prequantized_image);
374        let actual = encode_prequantized_htj2k_97(&native_prequantized_image, &options)
375            .expect("oracle prequantized 9/7 encode");
376
377        assert_eq!(
378            actual, expected,
379            "oracle prequantized component must reproduce the native precomputed-DWT codestream"
380        );
381    }
382
383    #[test]
384    fn shared_validator_accepts_standard_options_and_returns_dims() {
385        let options = Htj2k97CodeBlockOptions {
386            bit_depth: 8,
387            guard_bits: 2,
388            code_block_width_exp: 4,
389            code_block_height_exp: 4,
390            irreversible_quantization_scale: 1.0,
391            irreversible_quantization_subband_scales:
392                IrreversibleQuantizationSubbandScales::default(),
393        };
394        assert_eq!(validate_htj2k97_codeblock_options(options), Ok((64, 64)));
395    }
396
397    #[test]
398    fn shared_validator_rejects_out_of_spec_options_on_every_backend() {
399        let valid = Htj2k97CodeBlockOptions {
400            bit_depth: 8,
401            guard_bits: 2,
402            code_block_width_exp: 4,
403            code_block_height_exp: 4,
404            irreversible_quantization_scale: 1.0,
405            irreversible_quantization_subband_scales:
406                IrreversibleQuantizationSubbandScales::default(),
407        };
408
409        // Each case was accepted by the old Metal-only validator.
410        let oversized_bit_depth = Htj2k97CodeBlockOptions {
411            bit_depth: 31,
412            ..valid
413        };
414        let oversized_guard_bits = Htj2k97CodeBlockOptions {
415            guard_bits: 31,
416            ..valid
417        };
418        // 1024x1024: each side passes the per-side cap, area breaks the
419        // HTJ2K 4096 limit.
420        let oversized_area = Htj2k97CodeBlockOptions {
421            code_block_width_exp: 8,
422            code_block_height_exp: 8,
423            ..valid
424        };
425        for options in [oversized_bit_depth, oversized_guard_bits, oversized_area] {
426            assert!(
427                validate_htj2k97_codeblock_options(options).is_err(),
428                "options must be rejected: {options:?}"
429            );
430        }
431
432        // guard_bits == 0 stays accepted (the old Metal validator rejected it,
433        // CUDA and the native encoder accept it).
434        let zero_guard_bits = Htj2k97CodeBlockOptions {
435            guard_bits: 0,
436            ..valid
437        };
438        assert!(validate_htj2k97_codeblock_options(zero_guard_bits).is_ok());
439    }
440
441    #[test]
442    fn oracle_subband_profile_changes_only_selected_delta_and_bitplanes() {
443        let mut options = Htj2k97CodeBlockOptions {
444            bit_depth: 8,
445            guard_bits: 2,
446            code_block_width_exp: 2,
447            code_block_height_exp: 2,
448            irreversible_quantization_scale: 1.9,
449            irreversible_quantization_subband_scales:
450                IrreversibleQuantizationSubbandScales::default(),
451        };
452        let high_low_delta = htj2k97_subband_delta(options, J2kSubBandType::HighLow);
453        let high_high_delta = htj2k97_subband_delta(options, J2kSubBandType::HighHigh);
454        let default_hh_bitplanes =
455            htj2k97_subband_total_bitplanes(options, J2kSubBandType::HighHigh);
456
457        options.irreversible_quantization_subband_scales.high_high = 1.5;
458
459        assert_eq!(
460            htj2k97_subband_delta(options, J2kSubBandType::HighLow).to_bits(),
461            high_low_delta.to_bits()
462        );
463        assert!(htj2k97_subband_delta(options, J2kSubBandType::HighHigh) > high_high_delta);
464        assert_ne!(
465            htj2k97_subband_total_bitplanes(options, J2kSubBandType::HighHigh),
466            default_hh_bitplanes
467        );
468    }
469
470    fn native_prequantized_image(
471        image: PrequantizedHtj2k97Image,
472    ) -> signinum_j2k_native::PrequantizedHtj2k97Image {
473        signinum_j2k_native::PrequantizedHtj2k97Image {
474            width: image.width,
475            height: image.height,
476            bit_depth: image.bit_depth,
477            signed: image.signed,
478            components: image
479                .components
480                .into_iter()
481                .map(
482                    |component| signinum_j2k_native::PrequantizedHtj2k97Component {
483                        x_rsiz: component.x_rsiz,
484                        y_rsiz: component.y_rsiz,
485                        resolutions: component
486                            .resolutions
487                            .into_iter()
488                            .map(
489                                |resolution| {
490                                    signinum_j2k_native::PrequantizedHtj2k97Resolution {
491                                        subbands: resolution
492                                            .subbands
493                                            .into_iter()
494                                            .map(|subband| {
495                                                signinum_j2k_native::PrequantizedHtj2k97Subband {
496                                                    sub_band_type: subband.sub_band_type,
497                                                    num_cbs_x: subband.num_cbs_x,
498                                                    num_cbs_y: subband.num_cbs_y,
499                                                    total_bitplanes: subband.total_bitplanes,
500                                                    code_blocks: subband
501                                                        .code_blocks
502                                                        .into_iter()
503                                                        .map(|block| {
504                                                            signinum_j2k_native::PrequantizedHtj2k97CodeBlock {
505                                                                coefficients: block.coefficients,
506                                                                width: block.width,
507                                                                height: block.height,
508                                                            }
509                                                        })
510                                                        .collect(),
511                                                }
512                                            })
513                                            .collect(),
514                                    }
515                                },
516                            )
517                            .collect(),
518                    },
519                )
520                .collect(),
521        }
522    }
523}