Skip to main content

ultrahdr_core/gainmap/
apply.rs

1//! Gain map application for HDR reconstruction.
2
3use alloc::boxed::Box;
4use alloc::vec;
5
6use crate::color::transfer::{srgb_eotf, srgb_oetf};
7use crate::types::{
8    GainMap, GainMapMetadata, PixelBuffer, PixelFormat, PixelSlice, Result, TransferFunction,
9    new_pixel_buffer,
10};
11use enough::Stop;
12
13/// Precomputed lookup table for gain map decoding.
14///
15/// This LUT eliminates expensive `powf()` and `exp()` calls per pixel by
16/// precomputing the mapping from 8-bit gain map values to linear gain multipliers.
17/// Provides ~10x speedup for `apply_gainmap`.
18#[derive(Debug)]
19pub struct GainMapLut {
20    /// 256 entries per channel (R, G, B), mapping byte value to linear gain.
21    /// Layout: [R0..R255, G0..G255, B0..B255]
22    table: Box<[f32; 256 * 3]>,
23}
24
25impl GainMapLut {
26    /// Create a new gain map LUT for the given metadata and display boost.
27    ///
28    /// The `weight` parameter is typically calculated from `display_boost` and
29    /// the metadata's `base_hdr_headroom`/`alternate_hdr_headroom`.
30    pub fn new(metadata: &GainMapMetadata, weight: f32) -> Self {
31        let mut table = Box::new([0.0f32; 256 * 3]);
32
33        for channel in 0..3 {
34            let gamma = metadata.channels[channel].gamma as f32;
35            // Convert log2 domain to natural log for exp() math
36            let ln2 = core::f64::consts::LN_2;
37            let log_min = (metadata.channels[channel].min * ln2) as f32;
38            let log_max = (metadata.channels[channel].max * ln2) as f32;
39            let log_range = log_max - log_min;
40
41            for i in 0..256 {
42                // Convert byte to normalized [0,1]
43                let normalized = i as f32 / 255.0;
44
45                // Undo gamma
46                let linear = if gamma != 1.0 && gamma > 0.0 {
47                    normalized.powf(1.0 / gamma)
48                } else {
49                    normalized
50                };
51
52                // Convert from normalized to log gain, apply weight, convert to linear
53                let log_gain = log_min + linear * log_range;
54                let gain = (log_gain * weight).exp();
55
56                table[channel * 256 + i] = gain;
57            }
58        }
59
60        Self { table }
61    }
62
63    /// Look up the gain multiplier for a single channel.
64    #[inline(always)]
65    pub fn lookup(&self, byte_value: u8, channel: usize) -> f32 {
66        // Safety: channel is always 0..3 and byte_value is u8 (0..255)
67        debug_assert!(channel < 3);
68        self.table[channel * 256 + byte_value as usize]
69    }
70
71    /// Look up gain multipliers for all 3 channels from a single byte (luminance mode).
72    #[inline(always)]
73    pub fn lookup_luminance(&self, byte_value: u8) -> [f32; 3] {
74        let g = self.table[byte_value as usize]; // Channel 0
75        [g, g, g]
76    }
77
78    /// Look up gain multipliers for RGB from 3 bytes.
79    #[inline(always)]
80    pub fn lookup_rgb(&self, r: u8, g: u8, b: u8) -> [f32; 3] {
81        [
82            self.table[r as usize],
83            self.table[256 + g as usize],
84            self.table[512 + b as usize],
85        ]
86    }
87}
88
89/// Output format for HDR reconstruction.
90///
91/// Mirrors libultrahdr's three supported decode outputs:
92/// - [`LinearFloat`](Self::LinearFloat) ↔ `UHDR_IMG_FMT_64bppRGBAHalfFloat`
93///   semantically (same linear-light content), but at f32 precision instead
94///   of f16. Use when downstream wants float math.
95/// - [`LinearF16`](Self::LinearF16) ↔ `UHDR_IMG_FMT_64bppRGBAHalfFloat` exactly.
96///   Use for direct compositor / GPU-texture handoff.
97/// - [`Srgb8`](Self::Srgb8) ↔ `UHDR_IMG_FMT_32bppRGBA8888` with sRGB transfer.
98///   Use when downstream wants SDR (display_boost = 1.0 typical).
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[non_exhaustive]
101pub enum HdrOutputFormat {
102    /// Linear f32 RGBA where 1.0 = SDR white (203 nits). Range `[0, ~display_boost]`.
103    /// 16 bytes/pixel (`RgbaF32`).
104    LinearFloat,
105    /// Linear f16 (IEEE 754 half-precision) RGBA where 1.0 = SDR white.
106    /// 8 bytes/pixel (`RgbaF16`). Mirrors libultrahdr's
107    /// `UHDR_IMG_FMT_64bppRGBAHalfFloat`.
108    LinearF16,
109    /// sRGB 8-bit (SDR output, no HDR boost). 4 bytes/pixel (`Rgba8`).
110    Srgb8,
111}
112
113/// Apply a gain map to an SDR image to reconstruct HDR.
114///
115/// The `display_boost` parameter controls how much HDR effect to apply:
116/// - 1.0 = SDR output (no boost)
117/// - 2.0 = 2x brightness capability
118/// - 4.0 = 4x brightness capability (typical HDR display)
119///
120/// The `stop` parameter enables cooperative cancellation. Pass `Unstoppable`
121/// when cancellation is not needed.
122pub fn apply_gainmap(
123    sdr: &PixelBuffer,
124    gainmap: &GainMap,
125    metadata: &GainMapMetadata,
126    display_boost: f32,
127    output_format: HdrOutputFormat,
128    stop: impl Stop,
129) -> Result<PixelBuffer> {
130    let sdr_slice = sdr.as_slice();
131    apply_gainmap_slice(
132        sdr_slice,
133        gainmap,
134        metadata,
135        display_boost,
136        output_format,
137        stop,
138    )
139}
140
141/// [`apply_gainmap`] variant that takes a borrowed [`PixelSlice`] directly.
142///
143/// Useful when the caller already has a slice view over pixel bytes (e.g.
144/// from a cropped region or a foreign allocation) and doesn't want to copy
145/// into a [`PixelBuffer`] first.
146pub fn apply_gainmap_slice(
147    sdr: PixelSlice<'_>,
148    gainmap: &GainMap,
149    metadata: &GainMapMetadata,
150    display_boost: f32,
151    output_format: HdrOutputFormat,
152    stop: impl Stop,
153) -> Result<PixelBuffer> {
154    crate::types::validate_ultrahdr_slice(&sdr)?;
155
156    let width = sdr.width();
157    let height = sdr.rows();
158    let sdr_primaries = sdr.descriptor().primaries;
159
160    // Calculate weight factor based on display capability
161    let weight = calculate_weight(display_boost, metadata);
162
163    // Create precomputed LUT for fast gain decoding
164    let lut = GainMapLut::new(metadata, weight);
165
166    // Build the Shepard's weight table if image-to-gainmap ratio is integer
167    // (the common case — ISO 21496-1 maps are typically 1/2, 1/4, 1/8, 1/16).
168    // `None` means we fall back to per-pixel sqrt with row-hoisted constants.
169    let shepards = ShepardsLut::try_new(width, height, gainmap.width, gainmap.height);
170
171    // Create output image
172    let mut output = match output_format {
173        HdrOutputFormat::LinearFloat => new_pixel_buffer(
174            width,
175            height,
176            PixelFormat::RgbaF32,
177            sdr_primaries,
178            TransferFunction::Linear,
179        )?,
180        HdrOutputFormat::LinearF16 => new_pixel_buffer(
181            width,
182            height,
183            PixelFormat::RgbaF16,
184            sdr_primaries,
185            TransferFunction::Linear,
186        )?,
187        HdrOutputFormat::Srgb8 => new_pixel_buffer(
188            width,
189            height,
190            PixelFormat::Rgba8,
191            sdr_primaries,
192            TransferFunction::Srgb,
193        )?,
194    };
195
196    // Row-reusable scratch buffers
197    let row_pixels = width as usize;
198    let mut sdr_row = vec![[0.0f32; 3]; row_pixels];
199    let mut gains_row = vec![[0.0f32; 3]; row_pixels];
200    let mut hdr_row = vec![[0.0f32; 3]; row_pixels];
201
202    // Pre-broadcast metadata offsets into `[f32; 3]` arrays once per image.
203    let base_offset = [
204        metadata.channels[0].base_offset as f32,
205        metadata.channels[1].base_offset as f32,
206        metadata.channels[2].base_offset as f32,
207    ];
208    let alternate_offset = [
209        metadata.channels[0].alternate_offset as f32,
210        metadata.channels[1].alternate_offset as f32,
211        metadata.channels[2].alternate_offset as f32,
212    ];
213
214    let out_stride = output.stride();
215    let out_format = output.descriptor().pixel_format();
216    let mut out_slice = output.as_slice_mut();
217    let out_data = out_slice.as_strided_bytes_mut();
218
219    // Process each row, checking for cancellation periodically
220    for y in 0..height {
221        // Check for cancellation once per row (not per pixel for performance)
222        stop.check()?;
223
224        read_sdr_row_linear(&sdr, y, &mut sdr_row);
225        sample_gainmap_row_lut(
226            gainmap,
227            &lut,
228            shepards.as_ref(),
229            y,
230            width,
231            height,
232            &mut gains_row,
233        );
234        super::apply_simd::apply_gain_row_presampled(
235            &sdr_row,
236            &gains_row,
237            base_offset,
238            alternate_offset,
239            &mut hdr_row,
240        );
241        write_hdr_row(out_data, out_stride, out_format, y, &hdr_row, output_format);
242    }
243
244    drop(out_slice);
245    Ok(output)
246}
247
248/// Read a row of the SDR image into linear f32 RGB.
249///
250/// `out.len()` must equal `sdr.width() as usize`. Supports the pixel formats
251/// that the per-pixel `get_sdr_linear` supports — other formats yield
252/// `[0, 0, 0]` per-pixel as a fallback.
253fn read_sdr_row_linear(sdr: &PixelSlice<'_>, y: u32, out: &mut [[f32; 3]]) {
254    debug_assert_eq!(out.len(), sdr.width() as usize);
255    for (x, pixel) in out.iter_mut().enumerate() {
256        *pixel = get_sdr_linear(sdr, x as u32, y);
257    }
258}
259
260/// Sample a full row of gains from the gain map (Shepard's IDW, LUT-accelerated).
261///
262/// `out.len()` must equal `img_width as usize`. For single-channel gain maps,
263/// the same gain is broadcast to R/G/B.
264///
265/// `shepards` MUST be `Some(_)` when `img_width % gainmap.width == 0` and
266/// `img_height % gainmap.height == 0`; the caller builds it once via
267/// [`ShepardsLut::try_new`] and passes it to every row. When `None`, the
268/// per-pixel sqrt fallback runs (still computes weights once per pixel and
269/// shares them across channels).
270pub(crate) fn sample_gainmap_row_lut(
271    gainmap: &GainMap,
272    lut: &GainMapLut,
273    shepards: Option<&ShepardsLut>,
274    y: u32,
275    img_width: u32,
276    img_height: u32,
277    out: &mut [[f32; 3]],
278) {
279    debug_assert_eq!(out.len(), img_width as usize);
280    match shepards {
281        Some(shep) => sample_row_lut_int(gainmap, lut, shep, y, out),
282        None => sample_row_lut_float(gainmap, lut, y, img_width, img_height, out),
283    }
284}
285
286/// Write a row of HDR pixels to the output image in the requested format.
287///
288/// `out_data` must be the strided-byte buffer of the output image; `out_stride`
289/// is its row stride; `out_format` is the format (Rgba8 or RgbaF32).
290fn write_hdr_row(
291    out_data: &mut [u8],
292    out_stride: usize,
293    out_format: PixelFormat,
294    y: u32,
295    hdr_row: &[[f32; 3]],
296    format: HdrOutputFormat,
297) {
298    let row_start = (y as usize) * out_stride;
299    match format {
300        HdrOutputFormat::LinearFloat => {
301            debug_assert_eq!(out_format, PixelFormat::RgbaF32);
302            for (x, &hdr) in hdr_row.iter().enumerate() {
303                let idx = row_start + x * 16;
304                out_data[idx..idx + 4].copy_from_slice(&hdr[0].to_le_bytes());
305                out_data[idx + 4..idx + 8].copy_from_slice(&hdr[1].to_le_bytes());
306                out_data[idx + 8..idx + 12].copy_from_slice(&hdr[2].to_le_bytes());
307                out_data[idx + 12..idx + 16].copy_from_slice(&1.0f32.to_le_bytes());
308            }
309        }
310        HdrOutputFormat::LinearF16 => {
311            debug_assert_eq!(out_format, PixelFormat::RgbaF16);
312            // 8 bytes/pixel: 4 channels × f16 (2 bytes). Alpha is 1.0 (constant).
313            const F16_ONE: u16 = 0x3C00; // half::f16::ONE.to_bits()
314            for (x, &hdr) in hdr_row.iter().enumerate() {
315                let idx = row_start + x * 8;
316                let r = half::f16::from_f32(hdr[0]).to_bits().to_le_bytes();
317                let g = half::f16::from_f32(hdr[1]).to_bits().to_le_bytes();
318                let b = half::f16::from_f32(hdr[2]).to_bits().to_le_bytes();
319                let a = F16_ONE.to_le_bytes();
320                out_data[idx..idx + 2].copy_from_slice(&r);
321                out_data[idx + 2..idx + 4].copy_from_slice(&g);
322                out_data[idx + 4..idx + 6].copy_from_slice(&b);
323                out_data[idx + 6..idx + 8].copy_from_slice(&a);
324            }
325        }
326        HdrOutputFormat::Srgb8 => {
327            debug_assert_eq!(out_format, PixelFormat::Rgba8);
328            for (x, &hdr) in hdr_row.iter().enumerate() {
329                let r = srgb_oetf(hdr[0].clamp(0.0, 1.0));
330                let g = srgb_oetf(hdr[1].clamp(0.0, 1.0));
331                let b = srgb_oetf(hdr[2].clamp(0.0, 1.0));
332                let idx = row_start + x * 4;
333                out_data[idx] = (r * 255.0).round() as u8;
334                out_data[idx + 1] = (g * 255.0).round() as u8;
335                out_data[idx + 2] = (b * 255.0).round() as u8;
336                out_data[idx + 3] = 255;
337            }
338        }
339    }
340}
341
342/// Calculate the weight factor for gain map application.
343///
344/// Headroom values are in log2 domain. `display_boost` is linear.
345///
346/// Mirrors `avifGetGainMapWeight` in libavif and the equivalent in
347/// libultrahdr. The output is `clamp((log2(display_boost) - base) /
348/// (alt - base), 0, 1)` where `base` and `alt` are the metadata's
349/// HDR-headroom log2 values.
350pub fn calculate_weight(display_boost: f32, metadata: &GainMapMetadata) -> f32 {
351    let log_display = display_boost.max(1.0).log2() as f64;
352    let log_min = metadata.base_hdr_headroom.max(0.0);
353    let log_max = metadata.alternate_hdr_headroom.max(0.0);
354
355    if log_max <= log_min {
356        return 1.0;
357    }
358
359    ((log_display - log_min) / (log_max - log_min)).clamp(0.0, 1.0) as f32
360}
361
362/// Get linear RGB from SDR image.
363fn get_sdr_linear(sdr: &PixelSlice<'_>, x: u32, y: u32) -> [f32; 3] {
364    let format = sdr.descriptor().pixel_format();
365    let stride = sdr.stride();
366    let data = sdr.as_strided_bytes();
367    match format {
368        PixelFormat::Rgba8 | PixelFormat::Rgb8 => {
369            let bpp = if format == PixelFormat::Rgba8 { 4 } else { 3 };
370            let idx = (y as usize) * stride + (x as usize) * bpp;
371            let r = data[idx] as f32 / 255.0;
372            let g = data[idx + 1] as f32 / 255.0;
373            let b = data[idx + 2] as f32 / 255.0;
374            [srgb_eotf(r), srgb_eotf(g), srgb_eotf(b)]
375        }
376        PixelFormat::RgbaF32 => {
377            let idx = (y as usize) * stride + (x as usize) * 16;
378            let r = f32::from_le_bytes([data[idx], data[idx + 1], data[idx + 2], data[idx + 3]]);
379            let g =
380                f32::from_le_bytes([data[idx + 4], data[idx + 5], data[idx + 6], data[idx + 7]]);
381            let b =
382                f32::from_le_bytes([data[idx + 8], data[idx + 9], data[idx + 10], data[idx + 11]]);
383            [r, g, b]
384        }
385        PixelFormat::RgbaF16 | PixelFormat::RgbF16 => {
386            let bpp = if format == PixelFormat::RgbaF16 { 8 } else { 6 };
387            let idx = (y as usize) * stride + (x as usize) * bpp;
388            let r = half::f16::from_le_bytes([data[idx], data[idx + 1]]).to_f32();
389            let g = half::f16::from_le_bytes([data[idx + 2], data[idx + 3]]).to_f32();
390            let b = half::f16::from_le_bytes([data[idx + 4], data[idx + 5]]).to_f32();
391            [r, g, b]
392        }
393        PixelFormat::Gray8 => {
394            let idx = (y as usize) * stride + (x as usize);
395            let v = data[idx] as f32 / 255.0;
396            [v, v, v]
397        }
398        _ => [0.0, 0.0, 0.0],
399    }
400}
401
402/// Precomputed Shepard's IDW weight tables for integer-scale gain map upsample.
403///
404/// Mirrors libultrahdr's `ShepardsIDW` struct (see `gainmapmath.h:228`,
405/// `gainmapmath.cpp:49`). Per-pixel sample-time work drops from "4 sqrt
406/// plus 4 div" to "4 mul plus 3 add" by precomputing weights for every
407/// distinct sub-pixel position in the unit cell. Holds four tables
408/// (interior, no-right edge, no-bottom edge, corner) to handle gain-map
409/// boundary clamping without per-pixel branches on bounds.
410///
411/// Only valid when image dimensions are an integer multiple of gain-map
412/// dimensions. Storage is `4 * scale_x * scale_y * 4` floats (≤ 16 KB
413/// for any sane scale).
414#[derive(Debug)]
415pub struct ShepardsLut {
416    scale_x: u32,
417    scale_y: u32,
418    /// Indexed `[oy * scale_x + ox] * 4 + corner` where corner is
419    /// 0=TL, 1=BL, 2=TR, 3=BR. Same memory layout in all four tables.
420    full: Box<[f32]>,
421    no_right: Box<[f32]>,
422    no_bottom: Box<[f32]>,
423    corner: Box<[f32]>,
424}
425
426impl ShepardsLut {
427    /// Build weight tables for an arbitrary integer scale (`scale_x`, `scale_y`).
428    /// Most callers should use [`Self::try_new`] which infers the scale from
429    /// image and gain-map dimensions.
430    pub fn new(scale_x: u32, scale_y: u32) -> Self {
431        debug_assert!(scale_x >= 1 && scale_y >= 1);
432        let n = (scale_x * scale_y * 4) as usize;
433        let mut full = vec![0.0f32; n].into_boxed_slice();
434        let mut no_right = vec![0.0f32; n].into_boxed_slice();
435        let mut no_bottom = vec![0.0f32; n].into_boxed_slice();
436        let mut corner = vec![0.0f32; n].into_boxed_slice();
437        fill_shepards(&mut full, scale_x, scale_y, 1, 1);
438        fill_shepards(&mut no_right, scale_x, scale_y, 0, 1);
439        fill_shepards(&mut no_bottom, scale_x, scale_y, 1, 0);
440        fill_shepards(&mut corner, scale_x, scale_y, 0, 0);
441        Self {
442            scale_x,
443            scale_y,
444            full,
445            no_right,
446            no_bottom,
447            corner,
448        }
449    }
450
451    /// Build a LUT iff image dims are an exact integer multiple of gain-map
452    /// dims. `None` means the caller must take the per-pixel sqrt fallback.
453    pub fn try_new(img_width: u32, img_height: u32, gm_width: u32, gm_height: u32) -> Option<Self> {
454        if gm_width == 0 || gm_height == 0 {
455            return None;
456        }
457        if !img_width.is_multiple_of(gm_width) || !img_height.is_multiple_of(gm_height) {
458            return None;
459        }
460        let sx = img_width / gm_width;
461        let sy = img_height / gm_height;
462        if sx == 0 || sy == 0 {
463            return None;
464        }
465        Some(Self::new(sx, sy))
466    }
467
468    #[inline(always)]
469    fn pick(&self, no_right: bool, no_bottom: bool) -> &[f32] {
470        match (no_right, no_bottom) {
471            (false, false) => &self.full,
472            (true, false) => &self.no_right,
473            (false, true) => &self.no_bottom,
474            (true, true) => &self.corner,
475        }
476    }
477}
478
479fn fill_shepards(weights: &mut [f32], sx: u32, sy: u32, inc_r: u32, inc_b: u32) {
480    let sx_f = sx as f32;
481    let sy_f = sy as f32;
482    for y in 0..sy {
483        for x in 0..sx {
484            let pos_x = x as f32 / sx_f;
485            let pos_y = y as f32 / sy_f;
486            let next_x = inc_r as f32;
487            let next_y = inc_b as f32;
488            let idx = ((y * sx + x) * 4) as usize;
489            let d_tl = (pos_x * pos_x + pos_y * pos_y).sqrt();
490            if d_tl == 0.0 {
491                weights[idx] = 1.0;
492                weights[idx + 1] = 0.0;
493                weights[idx + 2] = 0.0;
494                weights[idx + 3] = 0.0;
495                continue;
496            }
497            let dy_b = pos_y - next_y;
498            let dx_r = pos_x - next_x;
499            let d_bl = (pos_x * pos_x + dy_b * dy_b).sqrt();
500            let d_tr = (dx_r * dx_r + pos_y * pos_y).sqrt();
501            let d_br = (dx_r * dx_r + dy_b * dy_b).sqrt();
502            let w_tl = 1.0 / d_tl;
503            let w_bl = 1.0 / d_bl;
504            let w_tr = 1.0 / d_tr;
505            let w_br = 1.0 / d_br;
506            let inv_total = 1.0 / (w_tl + w_bl + w_tr + w_br);
507            weights[idx] = w_tl * inv_total;
508            weights[idx + 1] = w_bl * inv_total;
509            weights[idx + 2] = w_tr * inv_total;
510            weights[idx + 3] = w_br * inv_total;
511        }
512    }
513}
514
515/// Shepard's IDW on 4 corners with weights computed in-place.
516///
517/// Used as the per-pixel fallback when the image-to-gain-map ratio is
518/// non-integer (so the precomputed LUT can't be used). Weights are
519/// computed once and reused across channels by callers.
520#[inline(always)]
521fn shepards_weights(fx: f32, fy: f32) -> [f32; 4] {
522    let dx_r = 1.0 - fx;
523    let dy_b = 1.0 - fy;
524    let d_tl = (fx * fx + fy * fy).sqrt();
525    if d_tl == 0.0 {
526        return [1.0, 0.0, 0.0, 0.0];
527    }
528    let d_bl = (fx * fx + dy_b * dy_b).sqrt();
529    if d_bl == 0.0 {
530        return [0.0, 1.0, 0.0, 0.0];
531    }
532    let d_tr = (dx_r * dx_r + fy * fy).sqrt();
533    if d_tr == 0.0 {
534        return [0.0, 0.0, 1.0, 0.0];
535    }
536    let d_br = (dx_r * dx_r + dy_b * dy_b).sqrt();
537    if d_br == 0.0 {
538        return [0.0, 0.0, 0.0, 1.0];
539    }
540    let w_tl = 1.0 / d_tl;
541    let w_bl = 1.0 / d_bl;
542    let w_tr = 1.0 / d_tr;
543    let w_br = 1.0 / d_br;
544    let inv_total = 1.0 / (w_tl + w_bl + w_tr + w_br);
545    [
546        w_tl * inv_total,
547        w_bl * inv_total,
548        w_tr * inv_total,
549        w_br * inv_total,
550    ]
551}
552
553#[inline(always)]
554fn dot4(c: [f32; 4], w: [f32; 4]) -> f32 {
555    c[0] * w[0] + c[1] * w[1] + c[2] * w[2] + c[3] * w[3]
556}
557
558/// Fast integer-scale row sampler. Picks weights from `shepards`
559/// (no per-pixel sqrt/div) and shares them across channels.
560fn sample_row_lut_int(
561    gainmap: &GainMap,
562    lut: &GainMapLut,
563    shepards: &ShepardsLut,
564    y: u32,
565    out: &mut [[f32; 3]],
566) {
567    let sx = shepards.scale_x;
568    let sy = shepards.scale_y;
569    let gw = gainmap.width;
570    let gh = gainmap.height;
571    debug_assert!(gw > 0 && gh > 0);
572
573    // Row-constant pieces: y0/y1 = enclosing gainmap rows; oy = sub-pixel
574    // offset; no_bottom = row sits at the gainmap's bottom edge so y1 was
575    // clamped back to y0.
576    let y0 = (y / sy).min(gh - 1);
577    let y1 = (y0 + 1).min(gh - 1);
578    let oy = y % sy;
579    let no_bottom = y0 == y1;
580
581    let row0_off = (y0 * gw) as usize;
582    let row1_off = (y1 * gw) as usize;
583
584    if gainmap.channels == 1 {
585        for (x_out, gain) in out.iter_mut().enumerate() {
586            let x = x_out as u32;
587            let x0 = (x / sx).min(gw - 1);
588            let x1 = (x0 + 1).min(gw - 1);
589            let ox = x % sx;
590            let no_right = x0 == x1;
591
592            let table = shepards.pick(no_right, no_bottom);
593            let base = ((oy * sx + ox) * 4) as usize;
594            let w = [
595                table[base],
596                table[base + 1],
597                table[base + 2],
598                table[base + 3],
599            ];
600
601            let g_tl = lut.lookup(gainmap.data[row0_off + x0 as usize], 0);
602            let g_bl = lut.lookup(gainmap.data[row1_off + x0 as usize], 0);
603            let g_tr = lut.lookup(gainmap.data[row0_off + x1 as usize], 0);
604            let g_br = lut.lookup(gainmap.data[row1_off + x1 as usize], 0);
605            let g = dot4([g_tl, g_bl, g_tr, g_br], w);
606            *gain = [g, g, g];
607        }
608    } else {
609        for (x_out, gain) in out.iter_mut().enumerate() {
610            let x = x_out as u32;
611            let x0 = (x / sx).min(gw - 1);
612            let x1 = (x0 + 1).min(gw - 1);
613            let ox = x % sx;
614            let no_right = x0 == x1;
615
616            let table = shepards.pick(no_right, no_bottom);
617            let base = ((oy * sx + ox) * 4) as usize;
618            let w = [
619                table[base],
620                table[base + 1],
621                table[base + 2],
622                table[base + 3],
623            ];
624
625            let tl = (row0_off + x0 as usize) * 3;
626            let bl = (row1_off + x0 as usize) * 3;
627            let tr = (row0_off + x1 as usize) * 3;
628            let br = (row1_off + x1 as usize) * 3;
629            for (c, dst) in gain.iter_mut().enumerate() {
630                let corners = [
631                    lut.lookup(gainmap.data[tl + c], c),
632                    lut.lookup(gainmap.data[bl + c], c),
633                    lut.lookup(gainmap.data[tr + c], c),
634                    lut.lookup(gainmap.data[br + c], c),
635                ];
636                *dst = dot4(corners, w);
637            }
638        }
639    }
640}
641
642/// Non-integer scale fallback. Computes weights once per pixel (4 sqrt
643/// plus 1 reciprocal-divide) and shares them across channels. Hoists
644/// `gm_y`/`y0`/`y1`/`fy` to row constants.
645fn sample_row_lut_float(
646    gainmap: &GainMap,
647    lut: &GainMapLut,
648    y: u32,
649    img_width: u32,
650    img_height: u32,
651    out: &mut [[f32; 3]],
652) {
653    let gw = gainmap.width;
654    let gh = gainmap.height;
655    debug_assert!(gw > 0 && gh > 0);
656    debug_assert!(img_width > 0 && img_height > 0);
657
658    let inv_iw = 1.0 / img_width as f32;
659    let inv_ih = 1.0 / img_height as f32;
660    let gw_f = gw as f32;
661    let gh_f = gh as f32;
662
663    let gm_y = (y as f32 * inv_ih) * gh_f;
664    let gm_y_floor = gm_y.floor();
665    let y0 = (gm_y_floor as u32).min(gh - 1);
666    let y1 = (y0 + 1).min(gh - 1);
667    let fy = gm_y - gm_y_floor;
668    let row0_off = (y0 * gw) as usize;
669    let row1_off = (y1 * gw) as usize;
670
671    if gainmap.channels == 1 {
672        for (x_out, gain) in out.iter_mut().enumerate() {
673            let gm_x = (x_out as f32 * inv_iw) * gw_f;
674            let gm_x_floor = gm_x.floor();
675            let x0 = (gm_x_floor as u32).min(gw - 1);
676            let x1 = (x0 + 1).min(gw - 1);
677            let fx = gm_x - gm_x_floor;
678            let w = shepards_weights(fx, fy);
679
680            let g_tl = lut.lookup(gainmap.data[row0_off + x0 as usize], 0);
681            let g_bl = lut.lookup(gainmap.data[row1_off + x0 as usize], 0);
682            let g_tr = lut.lookup(gainmap.data[row0_off + x1 as usize], 0);
683            let g_br = lut.lookup(gainmap.data[row1_off + x1 as usize], 0);
684            let g = dot4([g_tl, g_bl, g_tr, g_br], w);
685            *gain = [g, g, g];
686        }
687    } else {
688        for (x_out, gain) in out.iter_mut().enumerate() {
689            let gm_x = (x_out as f32 * inv_iw) * gw_f;
690            let gm_x_floor = gm_x.floor();
691            let x0 = (gm_x_floor as u32).min(gw - 1);
692            let x1 = (x0 + 1).min(gw - 1);
693            let fx = gm_x - gm_x_floor;
694            let w = shepards_weights(fx, fy);
695
696            let tl = (row0_off + x0 as usize) * 3;
697            let bl = (row1_off + x0 as usize) * 3;
698            let tr = (row0_off + x1 as usize) * 3;
699            let br = (row1_off + x1 as usize) * 3;
700            for (c, dst) in gain.iter_mut().enumerate() {
701                let corners = [
702                    lut.lookup(gainmap.data[tl + c], c),
703                    lut.lookup(gainmap.data[bl + c], c),
704                    lut.lookup(gainmap.data[tr + c], c),
705                    lut.lookup(gainmap.data[br + c], c),
706                ];
707                *dst = dot4(corners, w);
708            }
709        }
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716
717    /// Single-pixel convenience wrapper over `apply_gain_row_presampled` for tests.
718    fn apply_gain_one(metadata: &GainMapMetadata, sdr: [f32; 3], gain: [f32; 3]) -> [f32; 3] {
719        let base = [
720            metadata.channels[0].base_offset as f32,
721            metadata.channels[1].base_offset as f32,
722            metadata.channels[2].base_offset as f32,
723        ];
724        let alt = [
725            metadata.channels[0].alternate_offset as f32,
726            metadata.channels[1].alternate_offset as f32,
727            metadata.channels[2].alternate_offset as f32,
728        ];
729        let sdr_row = [sdr];
730        let gains_row = [gain];
731        let mut out_row = [[0.0f32; 3]];
732        super::super::apply_simd::apply_gain_row_presampled(
733            &sdr_row,
734            &gains_row,
735            base,
736            alt,
737            &mut out_row,
738        );
739        out_row[0]
740    }
741    use crate::types::ColorPrimaries;
742
743    #[test]
744    fn test_calculate_weight() {
745        let mut metadata = GainMapMetadata::default();
746        metadata.base_hdr_headroom = 0.0;
747        metadata.alternate_hdr_headroom = 2.0;
748
749        // No boost
750        let w = calculate_weight(1.0, &metadata);
751        assert!((w - 0.0).abs() < 0.01);
752
753        // Full boost
754        let w = calculate_weight(4.0, &metadata);
755        assert!((w - 1.0).abs() < 0.01);
756
757        // Half boost (log scale)
758        let w = calculate_weight(2.0, &metadata);
759        assert!(w > 0.4 && w < 0.6);
760    }
761
762    #[test]
763    fn test_gain_map_lut() {
764        let mut metadata = GainMapMetadata::default();
765        for ch in &mut metadata.channels {
766            ch.max = 2.0;
767        }
768
769        let lut = GainMapLut::new(&metadata, 1.0);
770
771        // Min gain (byte 0 = normalized 0.0)
772        let gain = lut.lookup(0, 0);
773        assert!((gain - 1.0).abs() < 0.01, "min gain: {}", gain);
774
775        // Max gain (byte 255 = normalized 1.0)
776        let gain = lut.lookup(255, 0);
777        assert!((gain - 4.0).abs() < 0.1, "max gain: {}", gain);
778
779        // Mid gain should be between min and max
780        let gain = lut.lookup(128, 0);
781        assert!(gain > 1.5 && gain < 2.5, "mid gain: {}", gain);
782    }
783
784    #[test]
785    fn test_apply_gainmap_basic() {
786        // Create SDR image
787        let mut sdr = crate::types::new_pixel_buffer(
788            4,
789            4,
790            PixelFormat::Rgba8,
791            ColorPrimaries::Bt709,
792            TransferFunction::Srgb,
793        )
794        .unwrap();
795        {
796            let mut slice = sdr.as_slice_mut();
797            let bytes = slice.as_strided_bytes_mut();
798            for i in 0..bytes.len() / 4 {
799                bytes[i * 4] = 128;
800                bytes[i * 4 + 1] = 128;
801                bytes[i * 4 + 2] = 128;
802                bytes[i * 4 + 3] = 255;
803            }
804        }
805
806        // Create gain map (all same boost)
807        let mut gainmap = GainMap::new(2, 2).unwrap();
808        for v in &mut gainmap.data {
809            *v = 200; // High gain
810        }
811
812        let metadata = crate::types::metadata_from_arrays(
813            [0.0; 3],
814            [2.0; 3],
815            [1.0; 3],
816            [0.015625; 3],
817            [0.015625; 3],
818            0.0,
819            2.0,
820            true,
821            false,
822        );
823
824        let result = apply_gainmap(
825            &sdr,
826            &gainmap,
827            &metadata,
828            4.0,
829            HdrOutputFormat::Srgb8,
830            enough::Unstoppable,
831        )
832        .unwrap();
833
834        assert_eq!(result.width(), 4);
835        assert_eq!(result.height(), 4);
836        assert_eq!(result.descriptor().pixel_format(), PixelFormat::Rgba8);
837    }
838
839    // ========================================================================
840    // Gain application reference values (C++ libultrahdr parity)
841    //
842    // Tests the LUT-based gain application against known-correct values.
843    // The LUT maps byte values to linear gain multipliers:
844    //   normalized = byte / 255.0
845    //   linear = normalized^(1/gamma)  [undo gamma]
846    //   log_gain = ln(min_boost) + linear * (ln(max_boost) - ln(min_boost))
847    //   gain = exp(log_gain * weight)
848    //
849    // Then HDR = (sdr + offset_sdr) * gain - offset_hdr
850    // ========================================================================
851
852    /// Test gain application at 5 weight levels for white pixel.
853    ///
854    /// White (sdr=1.0) at gain map value 255 (max boost),
855    /// with weight from 0.0 to 1.0 in steps of 0.25.
856    #[test]
857    fn test_gain_application_weight_levels() {
858        let metadata = crate::types::metadata_from_arrays(
859            [0.0; 3],
860            [2.0; 3],
861            [1.0; 3],
862            [1.0 / 64.0; 3],
863            [1.0 / 64.0; 3],
864            0.0,
865            2.0,
866            true,
867            false,
868        );
869
870        let sdr_val = 1.0_f32; // White pixel (linear)
871        let offset = 1.0_f32 / 64.0;
872        let log_min = 1.0_f32.ln(); // 0.0
873        let log_max = 4.0_f32.ln(); // ~1.386
874
875        // At byte=255 (normalized=1.0, gamma=1.0 → linear=1.0):
876        //   log_gain = 0.0 + 1.0 * (ln(4) - ln(1)) = ln(4) ≈ 1.386
877        //   gain = exp(log_gain * weight)
878        //   hdr = (sdr + offset) * gain - offset
879
880        let weights: [(f32, &str); 5] = [
881            (0.0, "SDR (no boost)"),
882            (0.25, "25% boost"),
883            (0.5, "50% boost"),
884            (0.75, "75% boost"),
885            (1.0, "full boost"),
886        ];
887
888        for &(weight, desc) in &weights {
889            let lut = GainMapLut::new(&metadata, weight);
890            let gain = lut.lookup(255, 0);
891
892            let log_gain = log_min + 1.0 * (log_max - log_min);
893            let expected_gain = (log_gain * weight).exp();
894            let expected_hdr = (sdr_val + offset) * expected_gain - offset;
895
896            // Verify LUT gain matches formula
897            assert!(
898                (gain - expected_gain).abs() < 0.01,
899                "{}: LUT gain={}, expected={}",
900                desc,
901                gain,
902                expected_gain
903            );
904
905            // Verify HDR output
906            let hdr = apply_gain_one(&metadata, [sdr_val; 3], [gain; 3]);
907            assert!(
908                (hdr[0] - expected_hdr).abs() < 0.02,
909                "{}: hdr={}, expected={}",
910                desc,
911                hdr[0],
912                expected_hdr
913            );
914        }
915    }
916
917    /// Test gain application for black pixel (sdr=0.0).
918    ///
919    /// Black pixels should remain close to black regardless of gain,
920    /// because the offset dominates: hdr = (0 + 1/64) * gain - 1/64
921    #[test]
922    fn test_gain_application_black_pixel() {
923        let metadata = crate::types::metadata_from_arrays(
924            [0.0; 3],
925            [2.0; 3],
926            [1.0; 3],
927            [1.0 / 64.0; 3],
928            [1.0 / 64.0; 3],
929            0.0,
930            2.0,
931            true,
932            false,
933        );
934
935        let offset = 1.0_f32 / 64.0;
936
937        // At full weight with max gain byte
938        let lut = GainMapLut::new(&metadata, 1.0);
939        let gain = lut.lookup(255, 0);
940
941        // hdr = (0 + 1/64) * 4.0 - 1/64 = 4/64 - 1/64 = 3/64 ≈ 0.047
942        let expected_hdr = offset * gain - offset;
943        let hdr = apply_gain_one(&metadata, [0.0; 3], [gain; 3]);
944
945        assert!(
946            (hdr[0] - expected_hdr).abs() < 0.01,
947            "Black pixel HDR: {} vs expected {}",
948            hdr[0],
949            expected_hdr
950        );
951
952        // Black with zero gain (byte=0) should stay near zero
953        let gain_min = lut.lookup(0, 0);
954        let hdr_min = apply_gain_one(&metadata, [0.0; 3], [gain_min; 3]);
955        // gain_min = exp(0 * 1.0) = 1.0 for weight=1.0 and min_boost=1.0
956        // hdr = (0 + 1/64) * 1.0 - 1/64 = 0
957        assert!(
958            hdr_min[0].abs() < 0.01,
959            "Black at min gain should be ~0, got {}",
960            hdr_min[0]
961        );
962    }
963
964    /// Verify gain LUT covers the full [min_boost, max_boost] range.
965    #[test]
966    fn test_gain_lut_range_coverage() {
967        let metadata = crate::types::metadata_from_arrays(
968            [-1.0; 3],
969            [3.0; 3],
970            [1.0; 3],
971            [1.0 / 64.0; 3],
972            [1.0 / 64.0; 3],
973            0.0,
974            3.0,
975            true,
976            false,
977        );
978
979        let lut = GainMapLut::new(&metadata, 1.0);
980
981        // Byte 0 → min gain = exp(ln(0.5)) = 0.5
982        let gain_0 = lut.lookup(0, 0);
983        assert!(
984            (gain_0 - 0.5).abs() < 0.01,
985            "Byte 0 should give min gain 0.5, got {}",
986            gain_0
987        );
988
989        // Byte 255 → max gain = exp(ln(8)) = 8.0
990        let gain_255 = lut.lookup(255, 0);
991        assert!(
992            (gain_255 - 8.0).abs() < 0.1,
993            "Byte 255 should give max gain 8.0, got {}",
994            gain_255
995        );
996
997        // Monotonically increasing
998        for i in 1..=255u8 {
999            let prev = lut.lookup(i - 1, 0);
1000            let curr = lut.lookup(i, 0);
1001            assert!(
1002                curr >= prev,
1003                "LUT not monotonic at byte {}: {} < {}",
1004                i,
1005                curr,
1006                prev
1007            );
1008        }
1009    }
1010
1011    /// Helper: create a 4x4 SDR image (Rgba8, Srgb, BT.709) filled with a uniform color.
1012    fn make_sdr_4x4(r: u8, g: u8, b: u8) -> PixelBuffer {
1013        let mut data = vec![0u8; 4 * 4 * 4];
1014        for i in 0..16 {
1015            data[i * 4] = r;
1016            data[i * 4 + 1] = g;
1017            data[i * 4 + 2] = b;
1018            data[i * 4 + 3] = 255;
1019        }
1020        crate::types::pixel_buffer_from_vec(
1021            data,
1022            4,
1023            4,
1024            PixelFormat::Rgba8,
1025            ColorPrimaries::Bt709,
1026            TransferFunction::Srgb,
1027        )
1028        .unwrap()
1029    }
1030
1031    /// Helper: create a 2x2 single-channel gain map filled with a uniform value.
1032    fn make_gainmap_2x2(value: u8) -> GainMap {
1033        let mut gm = GainMap::new(2, 2).unwrap();
1034        for v in &mut gm.data {
1035            *v = value;
1036        }
1037        gm
1038    }
1039
1040    /// Helper: create standard test metadata.
1041    fn test_metadata() -> GainMapMetadata {
1042        // log2(1.0)=0.0, log2(4.0)=2.0
1043        crate::types::metadata_from_arrays(
1044            [0.0; 3],
1045            [2.0; 3],
1046            [1.0; 3],
1047            [1.0 / 64.0; 3],
1048            [1.0 / 64.0; 3],
1049            0.0,
1050            2.0,
1051            true,
1052            false,
1053        )
1054    }
1055
1056    #[test]
1057    fn test_apply_gainmap_linear_float_format() {
1058        let sdr = make_sdr_4x4(128, 128, 128);
1059        let gainmap = make_gainmap_2x2(128);
1060        let metadata = test_metadata();
1061
1062        let result = apply_gainmap(
1063            &sdr,
1064            &gainmap,
1065            &metadata,
1066            4.0,
1067            HdrOutputFormat::LinearFloat,
1068            enough::Unstoppable,
1069        )
1070        .unwrap();
1071
1072        assert_eq!(result.descriptor().pixel_format(), PixelFormat::RgbaF32);
1073        assert_eq!(result.width(), 4);
1074        assert_eq!(result.height(), 4);
1075        // RgbaF32: 16 bytes per pixel (4 f32 channels)
1076        assert_eq!(result.as_slice().as_strided_bytes().len(), 4 * 4 * 16);
1077    }
1078
1079    #[test]
1080    fn test_apply_gainmap_linear_f16_format() {
1081        let sdr = make_sdr_4x4(128, 128, 128);
1082        let gainmap = make_gainmap_2x2(128);
1083        let metadata = test_metadata();
1084
1085        let f32_out = apply_gainmap(
1086            &sdr,
1087            &gainmap,
1088            &metadata,
1089            4.0,
1090            HdrOutputFormat::LinearFloat,
1091            enough::Unstoppable,
1092        )
1093        .unwrap();
1094        let f16_out = apply_gainmap(
1095            &sdr,
1096            &gainmap,
1097            &metadata,
1098            4.0,
1099            HdrOutputFormat::LinearF16,
1100            enough::Unstoppable,
1101        )
1102        .unwrap();
1103
1104        assert_eq!(f16_out.descriptor().pixel_format(), PixelFormat::RgbaF16);
1105        assert_eq!(f16_out.width(), 4);
1106        assert_eq!(f16_out.height(), 4);
1107        // RgbaF16: 8 bytes per pixel (4 f16 channels).
1108        assert_eq!(f16_out.as_slice().as_strided_bytes().len(), 4 * 4 * 8);
1109
1110        // f32 vs f16 must agree within f16 rounding (~1e-3 for values near 1).
1111        let f32_bytes = f32_out.as_slice();
1112        let f32_data = f32_bytes.as_strided_bytes();
1113        let f16_bytes = f16_out.as_slice();
1114        let f16_data = f16_bytes.as_strided_bytes();
1115        for px in 0..16 {
1116            let f32_idx = px * 16;
1117            let f16_idx = px * 8;
1118            for ch in 0..3 {
1119                let want = f32::from_le_bytes(
1120                    f32_data[f32_idx + ch * 4..f32_idx + ch * 4 + 4]
1121                        .try_into()
1122                        .unwrap(),
1123                );
1124                let got = half::f16::from_le_bytes(
1125                    f16_data[f16_idx + ch * 2..f16_idx + ch * 2 + 2]
1126                        .try_into()
1127                        .unwrap(),
1128                )
1129                .to_f32();
1130                let err = (want - got).abs();
1131                // f16 has ~3-4 sig figs near 1.0; allow generous tolerance for
1132                // values up to ~50 (full HDR boost range).
1133                let tol = (want.abs() * 5e-4).max(5e-4);
1134                assert!(
1135                    err < tol,
1136                    "px {px} ch {ch}: f32={want} f16={got} err={err} tol={tol}",
1137                );
1138            }
1139        }
1140    }
1141
1142    #[test]
1143    fn test_apply_gainmap_srgb8_format() {
1144        let sdr = make_sdr_4x4(128, 128, 128);
1145        let gainmap = make_gainmap_2x2(128);
1146        let metadata = test_metadata();
1147
1148        let result = apply_gainmap(
1149            &sdr,
1150            &gainmap,
1151            &metadata,
1152            4.0,
1153            HdrOutputFormat::Srgb8,
1154            enough::Unstoppable,
1155        )
1156        .unwrap();
1157
1158        assert_eq!(result.descriptor().pixel_format(), PixelFormat::Rgba8);
1159        assert_eq!(result.width(), 4);
1160        assert_eq!(result.height(), 4);
1161    }
1162
1163    #[test]
1164    fn test_apply_gainmap_boost_1() {
1165        // display_boost=1.0 → weight=0.0 → gain=1.0 everywhere → output ≈ SDR
1166        let sdr = make_sdr_4x4(128, 128, 128);
1167        let gainmap = make_gainmap_2x2(200); // High gain value, but weight=0 should negate it
1168        let metadata = test_metadata();
1169
1170        let result = apply_gainmap(
1171            &sdr,
1172            &gainmap,
1173            &metadata,
1174            1.0,
1175            HdrOutputFormat::Srgb8,
1176            enough::Unstoppable,
1177        )
1178        .unwrap();
1179
1180        // With boost=1.0, weight=0.0, gain=exp(0)=1.0 for all LUT entries.
1181        // HDR = (sdr_linear + offset) * 1.0 - offset = sdr_linear
1182        // So output should be very close to the input SDR values.
1183        let result_bytes = result.as_slice().as_strided_bytes();
1184        for i in 0..16 {
1185            let r = result_bytes[i * 4];
1186            let g = result_bytes[i * 4 + 1];
1187            let b = result_bytes[i * 4 + 2];
1188            assert!(
1189                (r as i16 - 128).unsigned_abs() <= 2,
1190                "boost=1 R should be ~128, got {}",
1191                r
1192            );
1193            assert!(
1194                (g as i16 - 128).unsigned_abs() <= 2,
1195                "boost=1 G should be ~128, got {}",
1196                g
1197            );
1198            assert!(
1199                (b as i16 - 128).unsigned_abs() <= 2,
1200                "boost=1 B should be ~128, got {}",
1201                b
1202            );
1203        }
1204    }
1205
1206    #[test]
1207    fn test_apply_gainmap_boost_max() {
1208        // display_boost = hdr_capacity_max → weight=1.0 → full HDR enhancement
1209        let sdr = make_sdr_4x4(128, 128, 128);
1210        let gainmap = make_gainmap_2x2(255); // Max gain
1211        let metadata = test_metadata();
1212
1213        let result_max = apply_gainmap(
1214            &sdr,
1215            &gainmap,
1216            &metadata,
1217            2.0f32.powf(metadata.alternate_hdr_headroom as f32), // linear display boost
1218            HdrOutputFormat::LinearFloat,
1219            enough::Unstoppable,
1220        )
1221        .unwrap();
1222
1223        // Also compute with boost=1.0 for comparison
1224        let result_sdr = apply_gainmap(
1225            &sdr,
1226            &gainmap,
1227            &metadata,
1228            1.0,
1229            HdrOutputFormat::LinearFloat,
1230            enough::Unstoppable,
1231        )
1232        .unwrap();
1233
1234        // Read first pixel from each
1235        let max_bytes = result_max.as_slice().as_strided_bytes();
1236        let sdr_bytes = result_sdr.as_slice().as_strided_bytes();
1237        let hdr_r = f32::from_le_bytes([max_bytes[0], max_bytes[1], max_bytes[2], max_bytes[3]]);
1238        let sdr_r = f32::from_le_bytes([sdr_bytes[0], sdr_bytes[1], sdr_bytes[2], sdr_bytes[3]]);
1239
1240        // Full boost should produce significantly brighter output than no boost
1241        assert!(
1242            hdr_r > sdr_r * 1.5,
1243            "max boost ({}) should be much brighter than sdr ({})",
1244            hdr_r,
1245            sdr_r
1246        );
1247    }
1248
1249    #[test]
1250    fn test_gain_map_lut_monotonic() {
1251        let metadata = test_metadata();
1252        let lut = GainMapLut::new(&metadata, 1.0);
1253
1254        // LUT values should be monotonically non-decreasing from byte 0 to 255
1255        for channel in 0..3 {
1256            for i in 1..=255u8 {
1257                let prev = lut.lookup(i - 1, channel);
1258                let curr = lut.lookup(i, channel);
1259                assert!(
1260                    curr >= prev,
1261                    "LUT not monotonic at byte {} channel {}: {} < {}",
1262                    i,
1263                    channel,
1264                    curr,
1265                    prev
1266                );
1267            }
1268        }
1269    }
1270
1271    #[test]
1272    fn test_gain_map_lut_endpoints() {
1273        let metadata = test_metadata();
1274        let lut = GainMapLut::new(&metadata, 1.0);
1275
1276        // At weight=1.0:
1277        // Byte 0 → normalized=0.0 → gain = 2^gain_map_min = 2^0 = 1.0
1278        let gain_0 = lut.lookup(0, 0);
1279        let expected_min = 2.0f32.powf(metadata.channels[0].min as f32);
1280        assert!(
1281            (gain_0 - expected_min).abs() < 0.01,
1282            "byte 0 should give 2^gain_map_min={}, got {}",
1283            expected_min,
1284            gain_0
1285        );
1286
1287        // Byte 255 → normalized=1.0 → gain = 2^gain_map_max = 2^2 = 4.0
1288        let gain_255 = lut.lookup(255, 0);
1289        let expected_max = 2.0f32.powf(metadata.channels[0].max as f32);
1290        assert!(
1291            (gain_255 - expected_max).abs() < 0.1,
1292            "byte 255 should give 2^gain_map_max={}, got {}",
1293            expected_max,
1294            gain_255
1295        );
1296    }
1297
1298    #[test]
1299    fn test_apply_gainmap_multichannel() {
1300        let sdr = make_sdr_4x4(128, 128, 128);
1301
1302        // Create a 2x2 multichannel (3-channel) gain map
1303        let mut gainmap = GainMap::new_multichannel(2, 2).unwrap();
1304        assert_eq!(gainmap.channels, 3);
1305        // Fill with different values per channel
1306        for i in 0..(2 * 2) {
1307            gainmap.data[i * 3] = 200; // R channel - high gain
1308            gainmap.data[i * 3 + 1] = 128; // G channel - mid gain
1309            gainmap.data[i * 3 + 2] = 50; // B channel - low gain
1310        }
1311
1312        let metadata = test_metadata();
1313
1314        let result = apply_gainmap(
1315            &sdr,
1316            &gainmap,
1317            &metadata,
1318            4.0,
1319            HdrOutputFormat::LinearFloat,
1320            enough::Unstoppable,
1321        )
1322        .unwrap();
1323
1324        assert_eq!(result.width(), 4);
1325        assert_eq!(result.height(), 4);
1326        assert_eq!(result.descriptor().pixel_format(), PixelFormat::RgbaF32);
1327        assert_eq!(result.as_slice().as_strided_bytes().len(), 4 * 4 * 16);
1328    }
1329
1330    #[test]
1331    fn test_apply_gainmap_invalid_boost() {
1332        // display_boost=0.5 (< 1.0) is clamped to 1.0 internally, not an error.
1333        // Verify it behaves exactly like boost=1.0.
1334        let sdr = make_sdr_4x4(128, 128, 128);
1335        let gainmap = make_gainmap_2x2(200);
1336        let metadata = test_metadata();
1337
1338        let result_low = apply_gainmap(
1339            &sdr,
1340            &gainmap,
1341            &metadata,
1342            0.5,
1343            HdrOutputFormat::Srgb8,
1344            enough::Unstoppable,
1345        )
1346        .unwrap();
1347
1348        let result_one = apply_gainmap(
1349            &sdr,
1350            &gainmap,
1351            &metadata,
1352            1.0,
1353            HdrOutputFormat::Srgb8,
1354            enough::Unstoppable,
1355        )
1356        .unwrap();
1357
1358        // Both should produce identical output since 0.5 is clamped to 1.0
1359        assert_eq!(
1360            result_low.as_slice().as_strided_bytes(),
1361            result_one.as_slice().as_strided_bytes()
1362        );
1363    }
1364
1365    #[test]
1366    fn test_apply_gainmap_cancellation() {
1367        /// A Stop implementation that cancels immediately
1368        struct ImmediateCancel;
1369
1370        impl enough::Stop for ImmediateCancel {
1371            fn check(&self) -> std::result::Result<(), enough::StopReason> {
1372                Err(enough::StopReason::Cancelled)
1373            }
1374        }
1375
1376        // Create minimal images
1377        let sdr = crate::types::new_pixel_buffer(
1378            4,
1379            4,
1380            PixelFormat::Rgba8,
1381            ColorPrimaries::Bt709,
1382            TransferFunction::Srgb,
1383        )
1384        .unwrap();
1385        let gainmap = GainMap::new(2, 2).unwrap();
1386        let metadata = GainMapMetadata::default();
1387
1388        // Should return Stopped error due to cancellation
1389        let result = apply_gainmap(
1390            &sdr,
1391            &gainmap,
1392            &metadata,
1393            4.0,
1394            HdrOutputFormat::Srgb8,
1395            ImmediateCancel,
1396        );
1397
1398        assert!(matches!(
1399            result,
1400            Err(crate::Error::Stopped(enough::StopReason::Cancelled))
1401        ));
1402    }
1403
1404    #[test]
1405    fn shepards_lut_try_new_rejects_non_integer_ratio() {
1406        // 5x5 image with 2x2 gainmap → no exact integer scale.
1407        assert!(ShepardsLut::try_new(5, 5, 2, 2).is_none());
1408        // 0-dim gainmap is rejected.
1409        assert!(ShepardsLut::try_new(8, 8, 0, 2).is_none());
1410        // Exact 4x scale on both axes.
1411        assert!(ShepardsLut::try_new(8, 8, 2, 2).is_some());
1412        // Asymmetric integer scales are still valid.
1413        assert!(ShepardsLut::try_new(8, 12, 2, 3).is_some());
1414    }
1415
1416    #[test]
1417    fn shepards_lut_weights_at_sample_center_collapse_to_nearest() {
1418        // Sub-pixel offset (0, 0) sits exactly on the top-left sample.
1419        // C++ libultrahdr short-circuits to weights = [1, 0, 0, 0]; we mirror
1420        // that so the LUT and per-pixel paths agree at sample centers.
1421        let lut = ShepardsLut::new(4, 4);
1422        let table = lut.pick(false, false);
1423        assert_eq!(table[0], 1.0);
1424        assert_eq!(table[1], 0.0);
1425        assert_eq!(table[2], 0.0);
1426        assert_eq!(table[3], 0.0);
1427    }
1428
1429    #[test]
1430    fn shepards_weights_normalize_to_one() {
1431        // For any non-degenerate (fx, fy) the four weights must sum to 1.0
1432        // (within f32 rounding). This is what makes the result independent
1433        // of the underlying gain values.
1434        for &fx in &[0.1f32, 0.25, 0.5, 0.75, 0.9] {
1435            for &fy in &[0.1f32, 0.25, 0.5, 0.75, 0.9] {
1436                let w = shepards_weights(fx, fy);
1437                let total: f32 = w.iter().sum();
1438                assert!(
1439                    (total - 1.0).abs() < 1e-5,
1440                    "weights at ({fx}, {fy}) sum to {total}",
1441                );
1442            }
1443        }
1444    }
1445
1446    #[test]
1447    fn shepards_int_lut_matches_float_path_at_sample_centers() {
1448        // Walk a 4x4 image over a 2x2 gainmap (scale=2). At every output
1449        // pixel that lands on a gainmap sample (offsets 0 in both axes —
1450        // here that's every other pixel), both paths must agree.
1451        let mut gainmap = GainMap::new(2, 2).unwrap();
1452        gainmap.data = vec![10, 200, 50, 150];
1453        let metadata = GainMapMetadata::default();
1454        let lut = GainMapLut::new(&metadata, 1.0);
1455        let shepards = ShepardsLut::try_new(4, 4, 2, 2).unwrap();
1456
1457        let mut row_int = vec![[0.0f32; 3]; 4];
1458        let mut row_float = vec![[0.0f32; 3]; 4];
1459
1460        for y in [0u32, 2u32] {
1461            sample_row_lut_int(&gainmap, &lut, &shepards, y, &mut row_int);
1462            sample_row_lut_float(&gainmap, &lut, y, 4, 4, &mut row_float);
1463            for x in [0usize, 2usize] {
1464                // Sample-center pixels: weights collapse to nearest, both
1465                // paths must produce identical output (no f32 drift).
1466                assert_eq!(
1467                    row_int[x], row_float[x],
1468                    "mismatch at ({x}, {y}): int={:?} float={:?}",
1469                    row_int[x], row_float[x]
1470                );
1471            }
1472        }
1473    }
1474
1475    #[test]
1476    fn shepards_int_lut_matches_float_path_within_rounding() {
1477        // Same setup, full-row comparison. Off-center pixels use precomputed
1478        // vs per-pixel weights; the operations differ in associativity so
1479        // bit-equality is not guaranteed, but values must agree to 1e-6.
1480        let mut gainmap = GainMap::new(2, 2).unwrap();
1481        gainmap.data = vec![10, 200, 50, 150];
1482        let metadata = GainMapMetadata::default();
1483        let lut = GainMapLut::new(&metadata, 1.0);
1484        let shepards = ShepardsLut::try_new(8, 8, 2, 2).unwrap();
1485
1486        for y in 0..8 {
1487            let mut row_int = vec![[0.0f32; 3]; 8];
1488            let mut row_float = vec![[0.0f32; 3]; 8];
1489            sample_row_lut_int(&gainmap, &lut, &shepards, y, &mut row_int);
1490            sample_row_lut_float(&gainmap, &lut, y, 8, 8, &mut row_float);
1491            for x in 0..8 {
1492                for c in 0..3 {
1493                    let diff = (row_int[x][c] - row_float[x][c]).abs();
1494                    assert!(
1495                        diff < 1e-6,
1496                        "({x}, {y})[{c}]: int={} float={} diff={}",
1497                        row_int[x][c],
1498                        row_float[x][c],
1499                        diff,
1500                    );
1501                }
1502            }
1503        }
1504    }
1505}