Skip to main content

yscv_video/
hevc_decoder.rs

1//! # H.265/HEVC Video Decoder
2//!
3//! Pure Rust implementation of the H.265/HEVC Main profile decoder.
4//!
5//! ## Architecture
6//!
7//! Implemented across 5 files (~6600 lines total):
8//!
9//! | File | Responsibility |
10//! |------|---------------|
11//! | `hevc_decoder.rs` | VPS/SPS/PPS parsing, CTU quad-tree, intra prediction (DC, planar, angular with fractional interpolation), inverse transforms (DST-4x4, DCT 4/8/16/32), dequantisation, top-level `HevcDecoder` with CABAC path |
12//! | `hevc_cabac.rs` | Full CABAC arithmetic engine, context models, state transitions, binarization (TR, FL, unary, EGk) |
13//! | `hevc_syntax.rs` | CU/PU/TU syntax parsing via CABAC: split_cu_flag, pred_mode, intra modes (MPM list), transform coeff parsing, residual decoding, reference sample construction |
14//! | `hevc_inter.rs` | DPB, motion compensation, merge candidates, AMVP, MVD parsing, bi-prediction framework |
15//! | `hevc_filter.rs` | Deblocking filter (boundary strength, tc/beta tables, luma/chroma edge filtering), SAO, chroma reconstruction, YCbCr-to-RGB conversion |
16//!
17//! ## Supported features
18//! - I-slices (intra prediction: DC, planar, all 33 angular modes)
19//! - CABAC entropy coding (full arithmetic engine with context adaptation)
20//! - CTU quad-tree partitioning (up to 64x64 CTU size)
21//! - Inverse transforms (DST-4x4, DCT 4x4/8x8/16x16/32x32)
22//! - Deblocking filter with boundary strength calculation
23//! - Sample Adaptive Offset (SAO) filtering framework
24//! - VPS/SPS/PPS parameter set parsing
25//! - YCbCr 4:2:0 to RGB8 conversion
26//!
27//! ## Not supported / limitations
28//! - Inter prediction defaults to mid-grey (no real reference frame DPB)
29//! - SAO parameters not parsed from bitstream (passed as None)
30//! - P/B slice support is framework-only (I-slice fully functional)
31//! - No weighted prediction
32//! - No Main 10 / Main 12 bit depth profiles
33//! - No 4:2:2 or 4:4:4 chroma formats
34//! - No WPP (Wavefront Parallel Processing)
35//! - No dependent slice segments
36//! - No tiles
37//!
38//! ## Error handling
39//! Malformed bitstreams return `VideoError` instead of panicking.
40//! However, this decoder has not been fuzz-tested and may not handle
41//! all adversarial inputs gracefully. For production video pipelines
42//! with untrusted input, consider FFI to libavcodec.
43//!
44//! ## End-to-end pipeline
45//! NAL -> CABAC -> CU parse -> intra/inter pred -> residual
46//! -> reconstruct -> deblock -> SAO -> chroma -> RGB output.
47
48use super::h264_decoder::BitstreamReader;
49use crate::VideoError;
50
51// ---------------------------------------------------------------------------
52// Video Parameter Set (VPS)
53// ---------------------------------------------------------------------------
54
55/// HEVC Video Parameter Set.
56#[derive(Debug, Clone)]
57pub struct HevcVps {
58    pub vps_id: u8,
59    pub max_layers: u8,
60    pub max_sub_layers: u8,
61    pub temporal_id_nesting: bool,
62}
63
64// ---------------------------------------------------------------------------
65// Sequence Parameter Set (SPS)
66// ---------------------------------------------------------------------------
67
68/// HEVC Sequence Parameter Set.
69#[derive(Debug, Clone)]
70pub struct HevcSps {
71    pub sps_id: u8,
72    pub vps_id: u8,
73    pub max_sub_layers: u8,
74    pub chroma_format_idc: u8, // 0=mono, 1=4:2:0, 2=4:2:2, 3=4:4:4
75    pub pic_width: u32,
76    pub pic_height: u32,
77    pub bit_depth_luma: u8,
78    pub bit_depth_chroma: u8,
79    pub log2_max_pic_order_cnt: u8,
80    pub log2_min_cb_size: u8,
81    pub log2_diff_max_min_cb_size: u8,
82    pub log2_min_transform_size: u8,
83    pub log2_diff_max_min_transform_size: u8,
84    pub max_transform_hierarchy_depth_inter: u8,
85    pub max_transform_hierarchy_depth_intra: u8,
86    pub sample_adaptive_offset_enabled: bool,
87    pub pcm_enabled: bool,
88    pub num_short_term_ref_pic_sets: u8,
89    pub long_term_ref_pics_present: bool,
90    pub sps_temporal_mvp_enabled: bool,
91    pub strong_intra_smoothing_enabled: bool,
92}
93
94// ---------------------------------------------------------------------------
95// Picture Parameter Set (PPS)
96// ---------------------------------------------------------------------------
97
98/// HEVC Picture Parameter Set.
99#[derive(Debug, Clone)]
100pub struct HevcPps {
101    pub pps_id: u8,
102    pub sps_id: u8,
103    pub dependent_slice_segments_enabled: bool,
104    pub output_flag_present: bool,
105    pub num_extra_slice_header_bits: u8,
106    pub sign_data_hiding_enabled: bool,
107    pub cabac_init_present: bool,
108    pub num_ref_idx_l0_default: u8,
109    pub num_ref_idx_l1_default: u8,
110    pub init_qp: i8,
111    pub constrained_intra_pred: bool,
112    pub transform_skip_enabled: bool,
113    pub cu_qp_delta_enabled: bool,
114    pub cb_qp_offset: i8,
115    pub cr_qp_offset: i8,
116    pub deblocking_filter_override_enabled: bool,
117    pub deblocking_filter_disabled: bool,
118    pub loop_filter_across_slices_enabled: bool,
119    pub tiles_enabled: bool,
120    pub entropy_coding_sync_enabled: bool,
121}
122
123// ---------------------------------------------------------------------------
124// Slice header & slice types
125// ---------------------------------------------------------------------------
126
127/// HEVC Slice Header (simplified).
128#[derive(Debug, Clone)]
129pub struct HevcSliceHeader {
130    pub first_slice_in_pic: bool,
131    pub slice_type: HevcSliceType,
132    pub pps_id: u8,
133    pub slice_qp_delta: i8,
134}
135
136/// HEVC slice types.
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum HevcSliceType {
139    B = 0,
140    P = 1,
141    I = 2,
142}
143
144// ---------------------------------------------------------------------------
145// Coding Tree Unit (CTU)
146// ---------------------------------------------------------------------------
147
148/// Coding Tree Unit — the basic processing unit in HEVC (replaces H.264 macroblock).
149#[derive(Debug, Clone)]
150pub struct CodingTreeUnit {
151    pub x: usize,
152    pub y: usize,
153    pub size: usize, // CTU size (typically 64)
154    pub qp: i8,
155}
156
157// ---------------------------------------------------------------------------
158// VPS parsing
159// ---------------------------------------------------------------------------
160
161/// Parse HEVC VPS from NAL unit payload (after the 2-byte NAL header).
162pub fn parse_hevc_vps(data: &[u8]) -> Result<HevcVps, VideoError> {
163    let mut reader = BitstreamReader::new(data);
164    let vps_id = reader.read_bits(4)? as u8;
165    reader.read_bits(2)?; // vps_base_layer_internal_flag + vps_base_layer_available_flag
166    let max_layers = reader.read_bits(6)? as u8 + 1;
167    let max_sub_layers = reader.read_bits(3)? as u8 + 1;
168    let temporal_id_nesting = reader.read_bit()? != 0;
169    Ok(HevcVps {
170        vps_id,
171        max_layers,
172        max_sub_layers,
173        temporal_id_nesting,
174    })
175}
176
177// ---------------------------------------------------------------------------
178// Profile-tier-level skipping
179// ---------------------------------------------------------------------------
180
181/// Skip profile_tier_level syntax element.
182fn skip_profile_tier_level(
183    reader: &mut BitstreamReader,
184    max_sub_layers: u8,
185) -> Result<(), VideoError> {
186    // general_profile_space(2) + general_tier_flag(1) + general_profile_idc(5) = 8 bits
187    reader.read_bits(8)?;
188    // general_profile_compatibility_flags (32 bits)
189    reader.read_bits(16)?;
190    reader.read_bits(16)?;
191    // general_constraint_indicator_flags (48 bits)
192    reader.read_bits(16)?;
193    reader.read_bits(16)?;
194    reader.read_bits(16)?;
195    // general_level_idc (8 bits)
196    reader.read_bits(8)?;
197    // sub_layer flags (if max_sub_layers > 1)
198    for _ in 1..max_sub_layers {
199        reader.read_bits(2)?; // sub_layer_profile_present + sub_layer_level_present
200    }
201    if max_sub_layers > 1 {
202        for _ in max_sub_layers..8 {
203            reader.read_bits(2)?; // reserved zero 2 bits
204        }
205    }
206    Ok(())
207}
208
209// ---------------------------------------------------------------------------
210// SPS parsing
211// ---------------------------------------------------------------------------
212
213/// Parse HEVC SPS from NAL unit payload (after the 2-byte NAL header).
214pub fn parse_hevc_sps(data: &[u8]) -> Result<HevcSps, VideoError> {
215    let mut reader = BitstreamReader::new(data);
216    let vps_id = reader.read_bits(4)? as u8;
217    let max_sub_layers = reader.read_bits(3)? as u8 + 1;
218    let _temporal_id_nesting = reader.read_bit()?;
219
220    // Skip profile_tier_level (simplified — fixed-length approximation)
221    skip_profile_tier_level(&mut reader, max_sub_layers)?;
222
223    let sps_id = reader.read_ue()? as u8;
224    let chroma_format_idc = reader.read_ue()? as u8;
225    if chroma_format_idc == 3 {
226        reader.read_bit()?; // separate_colour_plane_flag
227    }
228    let pic_width = reader.read_ue()?;
229    let pic_height = reader.read_ue()?;
230
231    let conformance_window = reader.read_bit()? != 0;
232    if conformance_window {
233        reader.read_ue()?; // conf_win_left_offset
234        reader.read_ue()?; // conf_win_right_offset
235        reader.read_ue()?; // conf_win_top_offset
236        reader.read_ue()?; // conf_win_bottom_offset
237    }
238
239    let bit_depth_luma = reader.read_ue()? as u8 + 8;
240    let bit_depth_chroma = reader.read_ue()? as u8 + 8;
241    let log2_max_pic_order_cnt = reader.read_ue()? as u8 + 4;
242
243    // sub_layer_ordering_info_present_flag
244    let sub_layer_ordering_info_present = reader.read_bit()? != 0;
245    let start = if sub_layer_ordering_info_present {
246        0
247    } else {
248        max_sub_layers - 1
249    };
250    for _ in start..max_sub_layers {
251        reader.read_ue()?; // max_dec_pic_buffering_minus1
252        reader.read_ue()?; // max_num_reorder_pics
253        reader.read_ue()?; // max_latency_increase_plus1
254    }
255
256    let log2_min_cb_size = reader.read_ue()? as u8 + 3;
257    let log2_diff_max_min_cb_size = reader.read_ue()? as u8;
258    let log2_min_transform_size = reader.read_ue()? as u8 + 2;
259    let log2_diff_max_min_transform_size = reader.read_ue()? as u8;
260    let max_transform_hierarchy_depth_inter = reader.read_ue()? as u8;
261    let max_transform_hierarchy_depth_intra = reader.read_ue()? as u8;
262
263    // scaling_list_enabled_flag
264    let scaling_list_enabled = reader.read_bit()? != 0;
265    if scaling_list_enabled {
266        let scaling_list_data_present = reader.read_bit()? != 0;
267        if scaling_list_data_present {
268            skip_scaling_list_data(&mut reader)?;
269        }
270    }
271
272    // amp_enabled_flag, sample_adaptive_offset_enabled_flag
273    let _amp_enabled = reader.read_bit()?;
274    let sample_adaptive_offset_enabled = reader.read_bit()? != 0;
275
276    // pcm_enabled_flag
277    let pcm_enabled = reader.read_bit()? != 0;
278    if pcm_enabled {
279        // pcm_sample_bit_depth_luma_minus1 (4) + pcm_sample_bit_depth_chroma_minus1 (4)
280        reader.read_bits(4)?;
281        reader.read_bits(4)?;
282        reader.read_ue()?; // log2_min_pcm_luma_coding_block_size_minus3
283        reader.read_ue()?; // log2_diff_max_min_pcm_luma_coding_block_size
284        reader.read_bit()?; // pcm_loop_filter_disabled_flag
285    }
286
287    let num_short_term_ref_pic_sets = reader.read_ue()? as u8;
288    // Skip actual short-term ref pic set parsing (complex; fill defaults below)
289
290    // For remaining flags that require parsing the ref pic sets first,
291    // use conservative defaults.
292    Ok(HevcSps {
293        sps_id,
294        vps_id,
295        max_sub_layers,
296        chroma_format_idc,
297        pic_width,
298        pic_height,
299        bit_depth_luma,
300        bit_depth_chroma,
301        log2_max_pic_order_cnt,
302        log2_min_cb_size,
303        log2_diff_max_min_cb_size,
304        log2_min_transform_size,
305        log2_diff_max_min_transform_size,
306        max_transform_hierarchy_depth_inter,
307        max_transform_hierarchy_depth_intra,
308        sample_adaptive_offset_enabled,
309        pcm_enabled,
310        num_short_term_ref_pic_sets,
311        long_term_ref_pics_present: false,
312        sps_temporal_mvp_enabled: false,
313        strong_intra_smoothing_enabled: false,
314    })
315}
316
317// ---------------------------------------------------------------------------
318// PPS parsing
319// ---------------------------------------------------------------------------
320
321/// Parse HEVC PPS from NAL unit payload (after the 2-byte NAL header).
322pub fn parse_hevc_pps(data: &[u8]) -> Result<HevcPps, VideoError> {
323    let mut reader = BitstreamReader::new(data);
324
325    let pps_id = reader.read_ue()? as u8;
326    let sps_id = reader.read_ue()? as u8;
327    let dependent_slice_segments_enabled = reader.read_bit()? != 0;
328    let output_flag_present = reader.read_bit()? != 0;
329    let num_extra_slice_header_bits = reader.read_bits(3)? as u8;
330    let sign_data_hiding_enabled = reader.read_bit()? != 0;
331    let cabac_init_present = reader.read_bit()? != 0;
332    let num_ref_idx_l0_default = reader.read_ue()? as u8 + 1;
333    let num_ref_idx_l1_default = reader.read_ue()? as u8 + 1;
334    let init_qp_minus26 = reader.read_se()?;
335    let init_qp = (26 + init_qp_minus26) as i8;
336    let constrained_intra_pred = reader.read_bit()? != 0;
337    let transform_skip_enabled = reader.read_bit()? != 0;
338    let cu_qp_delta_enabled = reader.read_bit()? != 0;
339    if cu_qp_delta_enabled {
340        reader.read_ue()?; // diff_cu_qp_delta_depth
341    }
342    let cb_qp_offset = reader.read_se()? as i8;
343    let cr_qp_offset = reader.read_se()? as i8;
344    let _slice_chroma_qp_offsets_present = reader.read_bit()?;
345    let _weighted_pred = reader.read_bit()?;
346    let _weighted_bipred = reader.read_bit()?;
347    let _transquant_bypass_enabled = reader.read_bit()?;
348    let tiles_enabled = reader.read_bit()? != 0;
349    let entropy_coding_sync_enabled = reader.read_bit()? != 0;
350
351    if tiles_enabled {
352        let num_tile_columns = reader.read_ue()? + 1;
353        let num_tile_rows = reader.read_ue()? + 1;
354        let uniform_spacing = reader.read_bit()? != 0;
355        if !uniform_spacing {
356            for _ in 0..num_tile_columns - 1 {
357                reader.read_ue()?;
358            }
359            for _ in 0..num_tile_rows - 1 {
360                reader.read_ue()?;
361            }
362        }
363        if tiles_enabled || entropy_coding_sync_enabled {
364            reader.read_bit()?; // loop_filter_across_tiles_enabled_flag
365        }
366    }
367
368    let loop_filter_across_slices_enabled = reader.read_bit()? != 0;
369    let deblocking_filter_control_present = reader.read_bit()? != 0;
370    let mut deblocking_filter_override_enabled = false;
371    let mut deblocking_filter_disabled = false;
372    if deblocking_filter_control_present {
373        deblocking_filter_override_enabled = reader.read_bit()? != 0;
374        deblocking_filter_disabled = reader.read_bit()? != 0;
375        if !deblocking_filter_disabled {
376            reader.read_se()?; // pps_beta_offset_div2
377            reader.read_se()?; // pps_tc_offset_div2
378        }
379    }
380
381    Ok(HevcPps {
382        pps_id,
383        sps_id,
384        dependent_slice_segments_enabled,
385        output_flag_present,
386        num_extra_slice_header_bits,
387        sign_data_hiding_enabled,
388        cabac_init_present,
389        num_ref_idx_l0_default,
390        num_ref_idx_l1_default,
391        init_qp,
392        constrained_intra_pred,
393        transform_skip_enabled,
394        cu_qp_delta_enabled,
395        cb_qp_offset,
396        cr_qp_offset,
397        deblocking_filter_override_enabled,
398        deblocking_filter_disabled,
399        loop_filter_across_slices_enabled,
400        tiles_enabled,
401        entropy_coding_sync_enabled,
402    })
403}
404
405// ---------------------------------------------------------------------------
406// Scaling list data parsing (§7.3.4)
407// ---------------------------------------------------------------------------
408
409/// Parse and discard scaling_list_data() per HEVC spec §7.3.4.
410/// Advances the bitstream position correctly without storing values.
411fn skip_scaling_list_data(reader: &mut BitstreamReader) -> Result<(), VideoError> {
412    for size_id in 0..4u8 {
413        let matrix_count: u8 = if size_id == 3 { 2 } else { 6 };
414        let matrix_step: u8 = if size_id == 3 { 3 } else { 1 };
415        for matrix_idx in 0..matrix_count {
416            let _matrix_id = matrix_idx * matrix_step;
417            let pred_mode_flag = reader.read_bit()?;
418            if pred_mode_flag == 0 {
419                // scaling_list_pred_matrix_id_delta
420                reader.read_ue()?;
421            } else {
422                let coef_num = std::cmp::min(64, 1u32 << (4 + (u32::from(size_id) << 1)));
423                if size_id > 1 {
424                    // scaling_list_dc_coef_minus8
425                    reader.read_se()?;
426                }
427                for _ in 0..coef_num {
428                    // scaling_list_delta_coef
429                    reader.read_se()?;
430                }
431            }
432        }
433    }
434    Ok(())
435}
436
437// ---------------------------------------------------------------------------
438// Helpers
439// ---------------------------------------------------------------------------
440
441/// Extract frame dimensions from HEVC SPS.
442pub fn hevc_frame_dimensions(sps: &HevcSps) -> (u32, u32) {
443    (sps.pic_width, sps.pic_height)
444}
445
446// ---------------------------------------------------------------------------
447// Intra prediction modes
448// ---------------------------------------------------------------------------
449
450/// HEVC intra prediction mode index.
451#[derive(Debug, Clone, Copy, PartialEq, Eq)]
452#[repr(u8)]
453pub enum HevcIntraMode {
454    Planar = 0,
455    Dc = 1,
456    Angular2 = 2,
457    Angular3 = 3,
458    Angular4 = 4,
459    Angular5 = 5,
460    Angular6 = 6,
461    Angular7 = 7,
462    Angular8 = 8,
463    Angular9 = 9,
464    Angular10 = 10,
465    Angular11 = 11,
466    Angular12 = 12,
467    Angular13 = 13,
468    Angular14 = 14,
469    Angular15 = 15,
470    Angular16 = 16,
471    Angular17 = 17,
472    Angular18 = 18,
473    Angular19 = 19,
474    Angular20 = 20,
475    Angular21 = 21,
476    Angular22 = 22,
477    Angular23 = 23,
478    Angular24 = 24,
479    Angular25 = 25,
480    Angular26 = 26,
481    Angular27 = 27,
482    Angular28 = 28,
483    Angular29 = 29,
484    Angular30 = 30,
485    Angular31 = 31,
486    Angular32 = 32,
487    Angular33 = 33,
488    Angular34 = 34,
489}
490
491impl HevcIntraMode {
492    /// Convert from a raw mode index (0..=34).
493    pub fn from_index(idx: u8) -> Option<Self> {
494        match idx {
495            0 => Some(Self::Planar),
496            1 => Some(Self::Dc),
497            2 => Some(Self::Angular2),
498            3 => Some(Self::Angular3),
499            4 => Some(Self::Angular4),
500            5 => Some(Self::Angular5),
501            6 => Some(Self::Angular6),
502            7 => Some(Self::Angular7),
503            8 => Some(Self::Angular8),
504            9 => Some(Self::Angular9),
505            10 => Some(Self::Angular10),
506            11 => Some(Self::Angular11),
507            12 => Some(Self::Angular12),
508            13 => Some(Self::Angular13),
509            14 => Some(Self::Angular14),
510            15 => Some(Self::Angular15),
511            16 => Some(Self::Angular16),
512            17 => Some(Self::Angular17),
513            18 => Some(Self::Angular18),
514            19 => Some(Self::Angular19),
515            20 => Some(Self::Angular20),
516            21 => Some(Self::Angular21),
517            22 => Some(Self::Angular22),
518            23 => Some(Self::Angular23),
519            24 => Some(Self::Angular24),
520            25 => Some(Self::Angular25),
521            26 => Some(Self::Angular26),
522            27 => Some(Self::Angular27),
523            28 => Some(Self::Angular28),
524            29 => Some(Self::Angular29),
525            30 => Some(Self::Angular30),
526            31 => Some(Self::Angular31),
527            32 => Some(Self::Angular32),
528            33 => Some(Self::Angular33),
529            34 => Some(Self::Angular34),
530            _ => None,
531        }
532    }
533}
534
535/// DC intra prediction: fills block with average of top and left neighbours.
536pub fn intra_predict_dc(top: &[i16], left: &[i16], block_size: usize, out: &mut [i16]) {
537    debug_assert!(top.len() >= block_size);
538    debug_assert!(left.len() >= block_size);
539    debug_assert!(out.len() >= block_size * block_size);
540    let sum: i32 = top[..block_size].iter().map(|&v| v as i32).sum::<i32>()
541        + left[..block_size].iter().map(|&v| v as i32).sum::<i32>();
542    let dc = ((sum + block_size as i32) / (2 * block_size as i32)) as i16;
543    for v in out[..block_size * block_size].iter_mut() {
544        *v = dc;
545    }
546}
547
548/// Planar intra prediction (HEVC mode 0).
549pub fn intra_predict_planar(
550    top: &[i16],
551    left: &[i16],
552    top_right: i16,
553    bottom_left: i16,
554    block_size: usize,
555    out: &mut [i16],
556) {
557    debug_assert!(top.len() >= block_size);
558    debug_assert!(left.len() >= block_size);
559    debug_assert!(out.len() >= block_size * block_size);
560    let n = block_size as i32;
561    let log2n = (block_size as u32).trailing_zeros();
562    for y in 0..block_size {
563        for x in 0..block_size {
564            let h = (n - 1 - x as i32) * left[y] as i32 + (x as i32 + 1) * top_right as i32;
565            let v = (n - 1 - y as i32) * top[x] as i32 + (y as i32 + 1) * bottom_left as i32;
566            out[y * block_size + x] = ((h + v + n) >> (log2n + 1)) as i16;
567        }
568    }
569}
570
571/// Simple angular intra prediction placeholder (modes 2..=34).
572/// Uses horizontal or vertical extrapolation depending on mode direction.
573/// HEVC angular intra prediction with fractional-sample interpolation.
574///
575/// ITU-T H.265, section 8.4.4.2.6. Modes 2-34 project through reference
576/// samples at angles specified by the `INTRA_PRED_ANGLE` table.
577/// Fractional positions use 32-phase linear interpolation.
578pub fn intra_predict_angular(
579    top: &[i16],
580    left: &[i16],
581    mode: u8,
582    block_size: usize,
583    out: &mut [i16],
584) {
585    debug_assert!((2..=34).contains(&mode));
586    debug_assert!(top.len() >= block_size);
587    debug_assert!(left.len() >= block_size);
588    debug_assert!(out.len() >= block_size * block_size);
589
590    // ITU-T H.265, Table 8-4: intraPredAngle for modes 2..34
591    #[rustfmt::skip]
592    const INTRA_PRED_ANGLE: [i32; 33] = [
593        // modes 2..34 (index 0 = mode 2)
594        32, 26, 21, 17, 13, 9, 5, 2, 0, -2, -5, -9, -13, -17, -21, -26,
595        -32, -26, -21, -17, -13, -9, -5, -2, 0, 2, 5, 9, 13, 17, 21, 26, 32,
596    ];
597
598    let angle = INTRA_PRED_ANGLE[(mode - 2) as usize];
599    let is_vertical = mode >= 18; // modes 18-34 are vertical-dominant
600
601    // Build extended reference array from top or left samples
602    let n = block_size;
603    let mut ref_samples = vec![128i16; 2 * n + 1];
604
605    if is_vertical {
606        // Main reference = top row, side reference = left column
607        ref_samples[0] = left[0]; // corner
608        ref_samples[1..(n.min(top.len()) + 1)].copy_from_slice(&top[..n.min(top.len())]);
609        // Extend with projected left samples for negative angles
610        if angle < 0 {
611            let inv_angle = ((256 * 32) as f32 / (-angle) as f32).round() as i32;
612            let num_ext = (n as i32 * angle) >> 5;
613            for k in num_ext..0 {
614                let ref_idx = ((-k * inv_angle + 128) >> 8) as usize;
615                let dst = k as isize;
616                if dst >= -(n as isize) && ref_idx < left.len() {
617                    ref_samples[(dst + n as isize) as usize] = left[ref_idx];
618                }
619            }
620        }
621        // Project each output sample through the angle
622        for y in 0..n {
623            let delta = (y as i32 + 1) * angle;
624            let idx_offset = delta >> 5;
625            let frac = (delta & 31) as i16;
626            for x in 0..n {
627                let ref_idx = (x as i32 + idx_offset + 1) as usize;
628                if frac == 0 {
629                    out[y * n + x] = ref_samples.get(ref_idx).copied().unwrap_or(128);
630                } else {
631                    // 32-phase linear interpolation
632                    let a = ref_samples.get(ref_idx).copied().unwrap_or(128) as i32;
633                    let b = ref_samples
634                        .get(ref_idx.wrapping_add(1))
635                        .copied()
636                        .unwrap_or(128) as i32;
637                    out[y * n + x] = ((32 - frac as i32) * a + frac as i32 * b + 16) as i16 >> 5;
638                }
639            }
640        }
641    } else {
642        // Horizontal-dominant (modes 2-17): main reference = left column
643        ref_samples[0] = top[0]; // corner
644        ref_samples[1..(n.min(left.len()) + 1)].copy_from_slice(&left[..n.min(left.len())]);
645        if angle < 0 {
646            let inv_angle = ((256 * 32) as f32 / (-angle) as f32).round() as i32;
647            let num_ext = (n as i32 * angle) >> 5;
648            for k in num_ext..0 {
649                let ref_idx = ((-k * inv_angle + 128) >> 8) as usize;
650                let dst = k as isize;
651                if dst >= -(n as isize) && ref_idx < top.len() {
652                    ref_samples[(dst + n as isize) as usize] = top[ref_idx];
653                }
654            }
655        }
656        // Project — transposed relative to vertical
657        for x in 0..n {
658            let delta = (x as i32 + 1) * angle;
659            let idx_offset = delta >> 5;
660            let frac = (delta & 31) as i16;
661            for y in 0..n {
662                let ref_idx = (y as i32 + idx_offset + 1) as usize;
663                if frac == 0 {
664                    out[y * n + x] = ref_samples.get(ref_idx).copied().unwrap_or(128);
665                } else {
666                    let a = ref_samples.get(ref_idx).copied().unwrap_or(128) as i32;
667                    let b = ref_samples
668                        .get(ref_idx.wrapping_add(1))
669                        .copied()
670                        .unwrap_or(128) as i32;
671                    out[y * n + x] = ((32 - frac as i32) * a + frac as i32 * b + 16) as i16 >> 5;
672                }
673            }
674        }
675    }
676}
677
678// ---------------------------------------------------------------------------
679// Transform / dequantisation
680// ---------------------------------------------------------------------------
681
682/// HEVC 4×4 DST-VII core matrix (for intra 4×4 luma TUs).
683const DST4_MATRIX: [[i32; 4]; 4] = [
684    [29, 55, 74, 84],
685    [74, 74, 0, -74],
686    [84, -29, -74, 55],
687    [55, -84, 74, -29],
688];
689
690/// HEVC 4×4 DCT-II core matrix.
691const DCT4_MATRIX: [[i32; 4]; 4] = [
692    [64, 64, 64, 64],
693    [83, 36, -36, -83],
694    [64, -64, -64, 64],
695    [36, -83, 83, -36],
696];
697
698/// HEVC 8×8 DCT-II core matrix.
699const DCT8_MATRIX: [[i32; 8]; 8] = [
700    [64, 64, 64, 64, 64, 64, 64, 64],
701    [89, 75, 50, 18, -18, -50, -75, -89],
702    [83, 36, -36, -83, -83, -36, 36, 83],
703    [75, -18, -89, -50, 50, 89, 18, -75],
704    [64, -64, -64, 64, 64, -64, -64, 64],
705    [50, -89, 18, 75, -75, -18, 89, -50],
706    [36, -83, 83, -36, -36, 83, -83, 36],
707    [18, -50, 75, -89, 89, -75, 50, -18],
708];
709
710/// Inverse 4×4 DST (HEVC, for intra 4×4 luma).
711pub fn hevc_inverse_dst_4x4(coeffs: &[i32; 16], out: &mut [i32; 16]) {
712    // 1-D inverse DST on rows
713    let mut tmp = [0i32; 16];
714    for i in 0..4 {
715        for j in 0..4 {
716            let mut sum = 0i32;
717            for k in 0..4 {
718                sum += DST4_MATRIX[k][j] * coeffs[i * 4 + k];
719            }
720            tmp[i * 4 + j] = (sum + 64) >> 7;
721        }
722    }
723    // 1-D inverse DST on columns
724    for j in 0..4 {
725        for i in 0..4 {
726            let mut sum = 0i32;
727            for k in 0..4 {
728                sum += DST4_MATRIX[k][i] * tmp[k * 4 + j];
729            }
730            out[i * 4 + j] = (sum + 2048) >> 12;
731        }
732    }
733}
734
735/// Inverse 4×4 DCT-II (HEVC).
736pub fn hevc_inverse_dct_4x4(coeffs: &[i32; 16], out: &mut [i32; 16]) {
737    let mut tmp = [0i32; 16];
738    for i in 0..4 {
739        for j in 0..4 {
740            let mut sum = 0i32;
741            for k in 0..4 {
742                sum += DCT4_MATRIX[k][j] * coeffs[i * 4 + k];
743            }
744            tmp[i * 4 + j] = (sum + 64) >> 7;
745        }
746    }
747    for j in 0..4 {
748        for i in 0..4 {
749            let mut sum = 0i32;
750            for k in 0..4 {
751                sum += DCT4_MATRIX[k][i] * tmp[k * 4 + j];
752            }
753            out[i * 4 + j] = (sum + 2048) >> 12;
754        }
755    }
756}
757
758/// Inverse 8×8 DCT-II (HEVC).
759pub fn hevc_inverse_dct_8x8(coeffs: &[i32; 64], out: &mut [i32; 64]) {
760    let mut tmp = [0i32; 64];
761    for i in 0..8 {
762        for j in 0..8 {
763            let mut sum = 0i32;
764            for k in 0..8 {
765                sum += DCT8_MATRIX[k][j] * coeffs[i * 8 + k];
766            }
767            tmp[i * 8 + j] = (sum + 64) >> 7;
768        }
769    }
770    for j in 0..8 {
771        for i in 0..8 {
772            let mut sum = 0i32;
773            for k in 0..8 {
774                sum += DCT8_MATRIX[k][i] * tmp[k * 8 + j];
775            }
776            out[i * 8 + j] = (sum + 2048) >> 12;
777        }
778    }
779}
780
781/// Generic inverse DCT for 16×16 blocks (partial butterfly, simplified).
782pub fn hevc_inverse_dct_16x16(coeffs: &[i32; 256], out: &mut [i32; 256]) {
783    // Direct matrix multiply using HEVC 16-point DCT-II core.
784    static HEVC_DCT16: [[i32; 16]; 16] = [
785        [
786            64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
787        ],
788        [
789            90, 87, 80, 70, 57, 43, 25, 9, -9, -25, -43, -57, -70, -80, -87, -90,
790        ],
791        [
792            89, 75, 50, 18, -18, -50, -75, -89, -89, -75, -50, -18, 18, 50, 75, 89,
793        ],
794        [
795            87, 57, 9, -43, -80, -90, -70, -25, 25, 70, 90, 80, 43, -9, -57, -87,
796        ],
797        [
798            83, 36, -36, -83, -83, -36, 36, 83, 83, 36, -36, -83, -83, -36, 36, 83,
799        ],
800        [
801            80, 9, -70, -87, -25, 57, 90, 43, -43, -90, -57, 25, 87, 70, -9, -80,
802        ],
803        [
804            75, -18, -89, -50, 50, 89, 18, -75, -75, 18, 89, 50, -50, -89, -18, 75,
805        ],
806        [
807            70, -43, -87, 9, 90, 25, -80, -57, 57, 80, -25, -90, -9, 87, 43, -70,
808        ],
809        [
810            64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64,
811        ],
812        [
813            57, -80, -25, 90, -9, -87, 43, 70, -70, -43, 87, 9, -90, 25, 80, -57,
814        ],
815        [
816            50, -89, 18, 75, -75, -18, 89, -50, -50, 89, -18, -75, 75, 18, -89, 50,
817        ],
818        [
819            43, -90, 57, 25, -87, 70, 9, -80, 80, -9, -70, 87, -25, -57, 90, -43,
820        ],
821        [
822            36, -83, 83, -36, -36, 83, -83, 36, 36, -83, 83, -36, -36, 83, -83, 36,
823        ],
824        [
825            25, -70, 90, -80, 43, 9, -57, 87, -87, 57, -9, -43, 80, -90, 70, -25,
826        ],
827        [
828            18, -50, 75, -89, 89, -75, 50, -18, -18, 50, -75, 89, -89, 75, -50, 18,
829        ],
830        [
831            9, -25, 43, -57, 70, -80, 87, -90, 90, -87, 80, -70, 57, -43, 25, -9,
832        ],
833    ];
834    let mut tmp = [0i32; 256];
835    for i in 0..16 {
836        for j in 0..16 {
837            let mut sum = 0i32;
838            for k in 0..16 {
839                sum += HEVC_DCT16[k][j] * coeffs[i * 16 + k];
840            }
841            tmp[i * 16 + j] = (sum + 64) >> 7;
842        }
843    }
844    for j in 0..16 {
845        for i in 0..16 {
846            let mut sum = 0i32;
847            for k in 0..16 {
848                sum += HEVC_DCT16[k][i] * tmp[k * 16 + j];
849            }
850            out[i * 16 + j] = (sum + 2048) >> 12;
851        }
852    }
853}
854
855/// Generic inverse DCT for 32×32 blocks (direct matrix multiply, simplified).
856pub fn hevc_inverse_dct_32x32(coeffs: &[i32; 1024], out: &mut [i32; 1024]) {
857    // HEVC 32-point DCT-II core matrix.
858    static HEVC_DCT32: [[i32; 32]; 32] = hevc_dct32_matrix();
859    let mut tmp = [0i32; 1024];
860    for i in 0..32 {
861        for j in 0..32 {
862            let mut sum = 0i64;
863            for k in 0..32 {
864                sum += HEVC_DCT32[k][j] as i64 * coeffs[i * 32 + k] as i64;
865            }
866            tmp[i * 32 + j] = ((sum + 64) >> 7) as i32;
867        }
868    }
869    for j in 0..32 {
870        for i in 0..32 {
871            let mut sum = 0i64;
872            for k in 0..32 {
873                sum += HEVC_DCT32[k][i] as i64 * tmp[k * 32 + j] as i64;
874            }
875            out[i * 32 + j] = ((sum + 2048) >> 12) as i32;
876        }
877    }
878}
879
880/// Build the HEVC 32-point DCT-II transform matrix at compile time.
881const fn hevc_dct32_matrix() -> [[i32; 32]; 32] {
882    // Even rows (0,2,4,...,30) come from 16-point matrix expanded to 32 columns.
883    // Odd rows (1,3,5,...,31) are the 32-point odd basis from HEVC spec Table 8-7.
884    let even16: [[i32; 16]; 16] = [
885        [
886            64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
887        ],
888        [
889            90, 87, 80, 70, 57, 43, 25, 9, -9, -25, -43, -57, -70, -80, -87, -90,
890        ],
891        [
892            89, 75, 50, 18, -18, -50, -75, -89, -89, -75, -50, -18, 18, 50, 75, 89,
893        ],
894        [
895            87, 57, 9, -43, -80, -90, -70, -25, 25, 70, 90, 80, 43, -9, -57, -87,
896        ],
897        [
898            83, 36, -36, -83, -83, -36, 36, 83, 83, 36, -36, -83, -83, -36, 36, 83,
899        ],
900        [
901            80, 9, -70, -87, -25, 57, 90, 43, -43, -90, -57, 25, 87, 70, -9, -80,
902        ],
903        [
904            75, -18, -89, -50, 50, 89, 18, -75, -75, 18, 89, 50, -50, -89, -18, 75,
905        ],
906        [
907            70, -43, -87, 9, 90, 25, -80, -57, 57, 80, -25, -90, -9, 87, 43, -70,
908        ],
909        [
910            64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64,
911        ],
912        [
913            57, -80, -25, 90, -9, -87, 43, 70, -70, -43, 87, 9, -90, 25, 80, -57,
914        ],
915        [
916            50, -89, 18, 75, -75, -18, 89, -50, -50, 89, -18, -75, 75, 18, -89, 50,
917        ],
918        [
919            43, -90, 57, 25, -87, 70, 9, -80, 80, -9, -70, 87, -25, -57, 90, -43,
920        ],
921        [
922            36, -83, 83, -36, -36, 83, -83, 36, 36, -83, 83, -36, -36, 83, -83, 36,
923        ],
924        [
925            25, -70, 90, -80, 43, 9, -57, 87, -87, 57, -9, -43, 80, -90, 70, -25,
926        ],
927        [
928            18, -50, 75, -89, 89, -75, 50, -18, -18, 50, -75, 89, -89, 75, -50, 18,
929        ],
930        [
931            9, -25, 43, -57, 70, -80, 87, -90, 90, -87, 80, -70, 57, -43, 25, -9,
932        ],
933    ];
934    let odd_rows: [[i32; 32]; 16] = [
935        [
936            90, 90, 88, 85, 82, 78, 73, 67, 61, 54, 46, 38, 31, 22, 13, 4, -4, -13, -22, -31, -38,
937            -46, -54, -61, -67, -73, -78, -82, -85, -88, -90, -90,
938        ],
939        [
940            90, 82, 67, 46, 22, -4, -31, -54, -73, -85, -90, -88, -78, -61, -38, -13, 13, 38, 61,
941            78, 88, 90, 85, 73, 54, 31, 4, -22, -46, -67, -82, -90,
942        ],
943        [
944            88, 67, 31, -13, -54, -82, -90, -78, -46, -4, 38, 73, 90, 85, 61, 22, -22, -61, -85,
945            -90, -73, -38, 4, 46, 78, 90, 82, 54, 13, -31, -67, -88,
946        ],
947        [
948            85, 46, -13, -67, -90, -73, -22, 38, 82, 88, 54, -4, -61, -90, -78, -31, 31, 78, 90,
949            61, 4, -54, -88, -82, -38, 22, 73, 90, 67, 13, -46, -85,
950        ],
951        [
952            82, 22, -54, -90, -61, 13, 78, 85, 31, -46, -90, -67, 4, 73, 88, 38, -38, -88, -73, -4,
953            67, 90, 46, -31, -85, -78, -13, 61, 90, 54, -22, -82,
954        ],
955        [
956            78, -4, -82, -73, 13, 85, 67, -22, -88, -61, 31, 90, 54, -38, -90, -46, 46, 90, 38,
957            -54, -90, -31, 61, 88, 22, -67, -85, -13, 73, 82, 4, -78,
958        ],
959        [
960            73, -31, -90, -22, 78, 67, -38, -90, -13, 82, 61, -46, -88, -4, 85, 54, -54, -85, 4,
961            88, 46, -61, -82, 13, 90, 38, -67, -78, 22, 90, 31, -73,
962        ],
963        [
964            67, -54, -78, 38, 85, -22, -90, 4, 90, 13, -88, -31, 82, 46, -73, -61, 61, 73, -46,
965            -82, 31, 88, -13, -90, -4, 90, 22, -85, -38, 78, 54, -67,
966        ],
967        [
968            61, -73, -46, 82, 31, -88, -13, 90, -4, -90, 22, 85, -38, -78, 54, 67, -67, -54, 78,
969            38, -85, -22, 90, 4, -90, 13, 88, -31, -82, 46, 73, -61,
970        ],
971        [
972            54, -85, -4, 88, -46, -61, 82, 13, -90, 38, 67, -78, -22, 90, -31, -73, 73, 31, -90,
973            22, 78, -67, -38, 90, -13, -82, 61, 46, -88, 4, 85, -54,
974        ],
975        [
976            46, -90, 38, 54, -90, 31, 61, -88, 22, 67, -85, 13, 73, -82, 4, 78, -78, -4, 82, -73,
977            -13, 85, -67, -22, 88, -61, -31, 90, -54, -38, 90, -46,
978        ],
979        [
980            38, -88, 73, -4, -67, 90, -46, -31, 85, -78, 13, 61, -90, 54, 22, -82, 82, -22, -54,
981            90, -61, -13, 78, -85, 31, 46, -90, 67, 4, -73, 88, -38,
982        ],
983        [
984            31, -78, 90, -61, 4, 54, -88, 82, -38, -22, 73, -90, 67, -13, -46, 85, -85, 46, 13,
985            -67, 90, -73, 22, 38, -82, 88, -54, -4, 61, -90, 78, -31,
986        ],
987        [
988            22, -61, 85, -90, 73, -38, -4, 46, -78, 90, -82, 54, -13, -31, 67, -88, 88, -67, 31,
989            13, -54, 82, -90, 78, -46, 4, 38, -73, 90, -85, 61, -22,
990        ],
991        [
992            13, -38, 61, -78, 88, -90, 85, -73, 54, -31, 4, 22, -46, 67, -82, 90, -90, 82, -67, 46,
993            -22, -4, 31, -54, 73, -85, 90, -88, 78, -61, 38, -13,
994        ],
995        [
996            4, -13, 22, -31, 38, -46, 54, -61, 67, -73, 78, -82, 85, -88, 90, -90, 90, -90, 88,
997            -85, 82, -78, 73, -67, 61, -54, 46, -38, 31, -22, 13, -4,
998        ],
999    ];
1000    // Expand the 16-point even basis into 32-column rows
1001    let even_rows_full: [[i32; 32]; 16] = expand_even_rows(&even16);
1002    // Assemble: even rows at indices 0,2,4,...,30; odd rows at 1,3,5,...,31
1003    let mut m = [[0i32; 32]; 32];
1004    let mut row = 0;
1005    while row < 16 {
1006        m[row * 2] = even_rows_full[row];
1007        m[row * 2 + 1] = odd_rows[row];
1008        row += 1;
1009    }
1010    m
1011}
1012
1013/// Expand 16-point even basis into 32 columns (DCT decomposition).
1014const fn expand_even_rows(even16: &[[i32; 16]; 16]) -> [[i32; 32]; 16] {
1015    let mut out = [[0i32; 32]; 16];
1016    let mut r = 0;
1017    while r < 16 {
1018        let mut c = 0;
1019        while c < 32 {
1020            // For even rows of the 32-pt DCT, the values at column n equal
1021            // the 16-pt DCT values at column n for column < 16 and mirrored for >= 16.
1022            // T_even[k][n] = T16[k][n] for n=0..15 and sign-symmetric for n=16..31
1023            // Actually: T_32_even[k][n] = T_16[k][n/2] when n is even;
1024            // But the correct relationship is: the even rows of a 2N-pt DCT
1025            // are the N-pt DCT applied to (x[n]+x[2N-1-n]).
1026            // For the transform matrix this means T_2N[2k][n] has symmetry.
1027            //
1028            // The correct values for even rows can be read from Table 8-7 directly.
1029            // For simplicity in this const fn, we compute them from the 16-pt matrix:
1030            // T_32[2k][n] = T_16[k][n] for n=0..15
1031            // T_32[2k][31-n] = T_16[k][n] * sign where sign alternates by k
1032            if c < 16 {
1033                out[r][c] = even16[r][c];
1034            } else {
1035                let mirror = 31 - c;
1036                // For even k: symmetric; for odd k: antisymmetric
1037                if r % 2 == 0 {
1038                    out[r][c] = even16[r][mirror];
1039                } else {
1040                    out[r][c] = -even16[r][mirror];
1041                }
1042            }
1043            c += 1;
1044        }
1045        r += 1;
1046    }
1047    out
1048}
1049
1050/// HEVC dequantisation for a transform block.
1051/// Applies `level * scale >> shift` per coefficient.
1052pub fn hevc_dequant(coeffs: &mut [i32], qp: i32, bit_depth: u8, log2_transform_size: u8) {
1053    // HEVC dequant: coeff * (level_scale[qp%6] << (qp/6)) >> shift
1054    const LEVEL_SCALE: [i32; 6] = [40, 45, 51, 57, 64, 72];
1055    let qp = qp.max(0) as u32;
1056    let scale = LEVEL_SCALE[(qp % 6) as usize];
1057    let shift_base = qp / 6;
1058    let bd_offset = (bit_depth as u32).saturating_sub(8);
1059    // transform_shift = max_log2_dynamic_range - bit_depth - log2_transform_size
1060    // For 8-bit: max_log2_dynamic_range = 15
1061    let max_log2 = 15 + bd_offset;
1062    let transform_shift = max_log2 as i32 - bit_depth as i32 - log2_transform_size as i32;
1063    let total_shift = shift_base as i32 + transform_shift;
1064    if total_shift >= 0 {
1065        let offset = if total_shift > 0 {
1066            1 << (total_shift - 1)
1067        } else {
1068            0
1069        };
1070        for c in coeffs.iter_mut() {
1071            *c = (*c * scale + offset) >> total_shift;
1072        }
1073    } else {
1074        let left_shift = (-total_shift) as u32;
1075        for c in coeffs.iter_mut() {
1076            *c = *c * scale * (1 << left_shift);
1077        }
1078    }
1079}
1080
1081// ---------------------------------------------------------------------------
1082// Coding Tree Unit decode framework
1083// ---------------------------------------------------------------------------
1084
1085/// Prediction mode for a coding unit.
1086#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1087pub enum HevcPredMode {
1088    Intra,
1089    Inter,
1090    Skip,
1091}
1092
1093/// Result of decoding one coding unit leaf.
1094#[derive(Debug, Clone)]
1095pub struct DecodedCu {
1096    pub x: usize,
1097    pub y: usize,
1098    pub size: usize,
1099    pub pred_mode: HevcPredMode,
1100    /// Reconstructed luma samples (row-major, `size × size`).
1101    pub recon_luma: Vec<i16>,
1102}
1103
1104/// Recursively decode a coding tree (quad-tree split).
1105///
1106/// `depth` starts at 0 for the CTU root. `max_depth` is derived from SPS
1107/// (log2_diff_max_min_cb_size). When `depth == max_depth` or the split flag
1108/// is 0, the node is a leaf CU.
1109///
1110/// This is a framework function; actual CABAC parsing is not implemented.
1111/// Instead, it treats each leaf CU as intra-DC predicted with zero residual.
1112pub fn decode_coding_tree(
1113    x: usize,
1114    y: usize,
1115    log2_cu_size: u8,
1116    depth: u8,
1117    max_depth: u8,
1118    qp: i8,
1119    pic_width: usize,
1120    pic_height: usize,
1121    results: &mut Vec<DecodedCu>,
1122) {
1123    let cu_size = 1usize << log2_cu_size;
1124    let _ = qp; // will be used once CABAC residual decoding is added
1125
1126    // If outside picture bounds, skip
1127    if x >= pic_width || y >= pic_height {
1128        return;
1129    }
1130
1131    // Decide whether to split (framework: split until min CU size)
1132    let should_split = depth < max_depth && cu_size > 8;
1133
1134    if should_split {
1135        let half = log2_cu_size - 1;
1136        let half_size = 1usize << half;
1137        let next_depth = depth + 1;
1138        decode_coding_tree(
1139            x, y, half, next_depth, max_depth, qp, pic_width, pic_height, results,
1140        );
1141        decode_coding_tree(
1142            x + half_size,
1143            y,
1144            half,
1145            next_depth,
1146            max_depth,
1147            qp,
1148            pic_width,
1149            pic_height,
1150            results,
1151        );
1152        decode_coding_tree(
1153            x,
1154            y + half_size,
1155            half,
1156            next_depth,
1157            max_depth,
1158            qp,
1159            pic_width,
1160            pic_height,
1161            results,
1162        );
1163        decode_coding_tree(
1164            x + half_size,
1165            y + half_size,
1166            half,
1167            next_depth,
1168            max_depth,
1169            qp,
1170            pic_width,
1171            pic_height,
1172            results,
1173        );
1174    } else {
1175        // Leaf CU — apply intra DC prediction with zero residual
1176        let actual_w = cu_size.min(pic_width.saturating_sub(x));
1177        let actual_h = cu_size.min(pic_height.saturating_sub(y));
1178        let dc_val = 1i16 << 7; // 128 for 8-bit
1179        let recon = vec![dc_val; actual_w * actual_h];
1180
1181        results.push(DecodedCu {
1182            x,
1183            y,
1184            size: cu_size,
1185            pred_mode: HevcPredMode::Intra,
1186            recon_luma: recon,
1187        });
1188    }
1189}
1190
1191// ---------------------------------------------------------------------------
1192// HEVC Decoder
1193// ---------------------------------------------------------------------------
1194
1195/// Top-level HEVC decoder state.
1196pub struct HevcDecoder {
1197    vps: Option<HevcVps>,
1198    sps: Option<HevcSps>,
1199    pps: Option<HevcPps>,
1200    /// Decoded Picture Buffer for inter prediction reference frames.
1201    dpb: super::hevc_inter::HevcDpb,
1202    /// Picture Order Count counter.
1203    poc: i32,
1204}
1205
1206impl HevcDecoder {
1207    /// Create a new decoder with no parameter sets.
1208    pub fn new() -> Self {
1209        Self {
1210            vps: None,
1211            sps: None,
1212            pps: None,
1213            dpb: super::hevc_inter::HevcDpb::new(16), // max 16 reference frames
1214            poc: 0,
1215        }
1216    }
1217
1218    /// Current SPS, if any.
1219    pub fn sps(&self) -> Option<&HevcSps> {
1220        self.sps.as_ref()
1221    }
1222
1223    /// Current PPS, if any.
1224    pub fn pps(&self) -> Option<&HevcPps> {
1225        self.pps.as_ref()
1226    }
1227
1228    /// Decode a single NAL unit (payload after start code, including the 2-byte header).
1229    ///
1230    /// Returns `Some(DecodedFrame)` when a complete picture is produced (IDR / CRA),
1231    /// or `None` for parameter-set and other non-VCL NALs.
1232    pub fn decode_nal(
1233        &mut self,
1234        nal_data: &[u8],
1235    ) -> Result<Option<crate::DecodedFrame>, VideoError> {
1236        use crate::HevcNalUnitType;
1237
1238        if nal_data.len() < 2 {
1239            return Err(VideoError::Codec("NAL unit too short".into()));
1240        }
1241        let nal_type = HevcNalUnitType::from_header(nal_data);
1242        let payload = &nal_data[2..]; // skip 2-byte NAL header
1243
1244        match nal_type {
1245            HevcNalUnitType::VpsNut => {
1246                self.vps = Some(parse_hevc_vps(payload)?);
1247                Ok(None)
1248            }
1249            HevcNalUnitType::SpsNut => {
1250                self.sps = Some(parse_hevc_sps(payload)?);
1251                Ok(None)
1252            }
1253            HevcNalUnitType::PpsNut => {
1254                self.pps = Some(parse_hevc_pps(payload)?);
1255                Ok(None)
1256            }
1257            HevcNalUnitType::IdrWRadl | HevcNalUnitType::IdrNLp | HevcNalUnitType::CraNut => {
1258                self.decode_picture(payload)
1259            }
1260            _ => Ok(None), // non-VCL or unsupported VCL
1261        }
1262    }
1263
1264    /// Decode a picture from a slice NAL payload.
1265    ///
1266    /// When the payload has enough data, uses CABAC-driven coding tree
1267    /// decoding via [`super::hevc_syntax::decode_coding_tree_cabac`].
1268    /// Falls back to the stub (DC fill) path when the payload is too short
1269    /// to bootstrap the CABAC decoder or when no SPS/PPS is available.
1270    fn decode_picture(
1271        &mut self,
1272        payload: &[u8],
1273    ) -> Result<Option<crate::DecodedFrame>, VideoError> {
1274        let sps = self
1275            .sps
1276            .as_ref()
1277            .ok_or_else(|| VideoError::Codec("Slice received before SPS".into()))?;
1278        let pps = self
1279            .pps
1280            .as_ref()
1281            .ok_or_else(|| VideoError::Codec("Slice received before PPS".into()))?;
1282
1283        let w = sps.pic_width as usize;
1284        let h = sps.pic_height as usize;
1285        let ctu_size_log2 = sps.log2_min_cb_size + sps.log2_diff_max_min_cb_size;
1286        let ctu_size = 1usize << ctu_size_log2;
1287        let max_depth = sps.log2_diff_max_min_cb_size;
1288
1289        // Use CABAC path when we have a meaningful payload, otherwise fall
1290        // back to the deterministic stub (useful for unit tests that build
1291        // an HevcDecoder with synthetic parameter sets but no real slice
1292        // data).
1293        let use_cabac = payload.len() >= 2;
1294
1295        let mut cus = Vec::new();
1296        let mut sao_list: Vec<super::hevc_filter::SaoParams> = Vec::new();
1297
1298        if use_cabac {
1299            let slice_qp = pps.init_qp as i32;
1300            let mut cabac_state = super::hevc_syntax::HevcSliceCabacState::new(payload, slice_qp);
1301            let mut recon_luma = vec![128i16; w * h];
1302            let min_pu = 4usize;
1303            let pic_w_pu = w.div_ceil(min_pu);
1304            let pic_h_pu = h.div_ceil(min_pu);
1305            let mut mv_field =
1306                vec![super::hevc_inter::HevcMvField::unavailable(); pic_w_pu * pic_h_pu];
1307
1308            let mut ctu_y = 0;
1309            while ctu_y < h {
1310                let mut ctu_x = 0;
1311                while ctu_x < w {
1312                    // Parse SAO parameters per CTU when enabled in SPS
1313                    if sps.sample_adaptive_offset_enabled {
1314                        let left_avail = ctu_x > 0;
1315                        let above_avail = ctu_y > 0;
1316                        let sao = super::hevc_filter::parse_sao_params(
1317                            &mut cabac_state.cabac,
1318                            left_avail,
1319                            above_avail,
1320                        );
1321                        sao_list.push(sao);
1322                    }
1323
1324                    super::hevc_syntax::decode_coding_tree_cabac(
1325                        &mut cabac_state,
1326                        ctu_x,
1327                        ctu_y,
1328                        ctu_size_log2,
1329                        0,
1330                        max_depth,
1331                        sps,
1332                        pps,
1333                        HevcSliceType::I, // default to I-slice
1334                        w,
1335                        h,
1336                        &mut recon_luma,
1337                        &mut cus,
1338                        &self.dpb,
1339                        &mut mv_field,
1340                    );
1341                    ctu_x += ctu_size;
1342                }
1343                ctu_y += ctu_size;
1344            }
1345        } else {
1346            // Stub path: split down to leaves and fill with DC 128
1347            let mut ctu_y = 0;
1348            while ctu_y < h {
1349                let mut ctu_x = 0;
1350                while ctu_x < w {
1351                    decode_coding_tree(
1352                        ctu_x,
1353                        ctu_y,
1354                        ctu_size_log2,
1355                        0,
1356                        max_depth,
1357                        pps.init_qp,
1358                        w,
1359                        h,
1360                        &mut cus,
1361                    );
1362                    ctu_x += ctu_size;
1363                }
1364                ctu_y += ctu_size;
1365            }
1366        }
1367
1368        // Assemble luma plane from decoded CUs
1369        let mut y_plane = vec![128u8; w * h];
1370        let mut cu_info: Vec<(usize, usize, usize, HevcPredMode)> = Vec::with_capacity(cus.len());
1371        for cu in &cus {
1372            let cu_w = cu.size.min(w.saturating_sub(cu.x));
1373            let cu_h = cu.recon_luma.len() / cu.size.max(1);
1374            let cu_h = cu_h.min(h.saturating_sub(cu.y));
1375            for row in 0..cu_h {
1376                for col in 0..cu_w {
1377                    let py = cu.y + row;
1378                    let px = cu.x + col;
1379                    if py < h && px < w {
1380                        y_plane[py * w + px] =
1381                            cu.recon_luma[row * cu.size + col].clamp(0, 255) as u8;
1382                    }
1383                }
1384            }
1385            cu_info.push((cu.x, cu.y, cu.size, cu.pred_mode));
1386        }
1387
1388        // Finalize: chroma fill, deblocking, SAO, YCbCr-to-RGB conversion
1389        let slice_qp = pps.init_qp.unsigned_abs();
1390        let sao_ref = if sao_list.is_empty() {
1391            None
1392        } else {
1393            Some(sao_list.as_slice())
1394        };
1395        let rgb = super::hevc_filter::finalize_hevc_frame(
1396            &mut y_plane,
1397            w,
1398            h,
1399            &cu_info,
1400            slice_qp,
1401            sao_ref,
1402        );
1403
1404        // Store reconstructed luma in DPB for future inter prediction
1405        self.dpb.add(super::hevc_inter::HevcReferencePicture {
1406            poc: self.poc,
1407            luma: y_plane.to_vec(),
1408            width: w,
1409            height: h,
1410            is_long_term: false,
1411        });
1412        self.poc += 1;
1413
1414        Ok(Some(crate::DecodedFrame {
1415            width: w,
1416            height: h,
1417            rgb8_data: rgb,
1418            timestamp_us: 0,
1419            keyframe: true,
1420        }))
1421    }
1422}
1423
1424// ---------------------------------------------------------------------------
1425// Tests
1426// ---------------------------------------------------------------------------
1427
1428#[cfg(test)]
1429mod tests {
1430    use super::*;
1431    use crate::BitstreamReader;
1432
1433    // -- Test helpers (same convention as h264_decoder.rs) -------------------
1434
1435    fn push_bits(bits: &mut Vec<u8>, value: u32, count: u8) {
1436        for i in (0..count).rev() {
1437            bits.push(((value >> i) & 1) as u8);
1438        }
1439    }
1440
1441    fn push_exp_golomb(bits: &mut Vec<u8>, value: u32) {
1442        if value == 0 {
1443            bits.push(1);
1444            return;
1445        }
1446        let code = value + 1;
1447        let bit_len = 32 - code.leading_zeros();
1448        let leading_zeros = bit_len - 1;
1449        for _ in 0..leading_zeros {
1450            bits.push(0);
1451        }
1452        for i in (0..bit_len).rev() {
1453            bits.push(((code >> i) & 1) as u8);
1454        }
1455    }
1456
1457    fn push_signed_exp_golomb(bits: &mut Vec<u8>, value: i32) {
1458        let code = if value <= 0 {
1459            (-value * 2) as u32
1460        } else {
1461            (value * 2 - 1) as u32
1462        };
1463        push_exp_golomb(bits, code);
1464    }
1465
1466    fn bits_to_bytes(bits: &[u8]) -> Vec<u8> {
1467        let mut bytes = Vec::new();
1468        for chunk in bits.chunks(8) {
1469            let mut byte = 0u8;
1470            for (i, &bit) in chunk.iter().enumerate() {
1471                byte |= bit << (7 - i);
1472            }
1473            bytes.push(byte);
1474        }
1475        bytes
1476    }
1477
1478    #[test]
1479    fn hevc_vps_parse() {
1480        let mut bits = Vec::new();
1481        // vps_id = 2 (4 bits)
1482        push_bits(&mut bits, 2, 4);
1483        // reserved 2 bits (vps_base_layer flags)
1484        push_bits(&mut bits, 0, 2);
1485        // max_layers_minus1 = 0 (6 bits) => max_layers = 1
1486        push_bits(&mut bits, 0, 6);
1487        // max_sub_layers_minus1 = 2 (3 bits) => max_sub_layers = 3
1488        push_bits(&mut bits, 2, 3);
1489        // temporal_id_nesting = 1 (1 bit)
1490        push_bits(&mut bits, 1, 1);
1491
1492        let bytes = bits_to_bytes(&bits);
1493        let vps = parse_hevc_vps(&bytes).unwrap();
1494        assert_eq!(vps.vps_id, 2);
1495        assert_eq!(vps.max_layers, 1);
1496        assert_eq!(vps.max_sub_layers, 3);
1497        assert!(vps.temporal_id_nesting);
1498    }
1499
1500    #[test]
1501    fn hevc_sps_dimensions() {
1502        // Build a minimal SPS bitstream for 1920x1080, 1 sub-layer, chroma 4:2:0.
1503        let mut bits = Vec::new();
1504
1505        // vps_id = 0 (4 bits)
1506        push_bits(&mut bits, 0, 4);
1507        // max_sub_layers_minus1 = 0 (3 bits) => max_sub_layers = 1
1508        push_bits(&mut bits, 0, 3);
1509        // temporal_id_nesting_flag = 1
1510        push_bits(&mut bits, 1, 1);
1511
1512        // profile_tier_level for max_sub_layers=1:
1513        // general_profile_space(2) + general_tier_flag(1) + general_profile_idc(5) = 8 bits
1514        push_bits(&mut bits, 0, 8);
1515        // general_profile_compatibility_flags (32 bits)
1516        push_bits(&mut bits, 0, 16);
1517        push_bits(&mut bits, 0, 16);
1518        // general_constraint_indicator_flags (48 bits)
1519        push_bits(&mut bits, 0, 16);
1520        push_bits(&mut bits, 0, 16);
1521        push_bits(&mut bits, 0, 16);
1522        // general_level_idc (8 bits)
1523        push_bits(&mut bits, 0, 8);
1524        // no sub_layer flags when max_sub_layers == 1
1525
1526        // sps_id = ue(0)
1527        push_exp_golomb(&mut bits, 0);
1528        // chroma_format_idc = ue(1) => 4:2:0
1529        push_exp_golomb(&mut bits, 1);
1530        // pic_width = ue(1920)
1531        push_exp_golomb(&mut bits, 1920);
1532        // pic_height = ue(1080)
1533        push_exp_golomb(&mut bits, 1080);
1534        // conformance_window_flag = 0
1535        push_bits(&mut bits, 0, 1);
1536        // bit_depth_luma_minus8 = ue(0) => 8
1537        push_exp_golomb(&mut bits, 0);
1538        // bit_depth_chroma_minus8 = ue(0) => 8
1539        push_exp_golomb(&mut bits, 0);
1540        // log2_max_pic_order_cnt_lsb_minus4 = ue(0) => 4
1541        push_exp_golomb(&mut bits, 0);
1542        // sub_layer_ordering_info_present_flag = 1
1543        push_bits(&mut bits, 1, 1);
1544        // For 1 sub-layer: max_dec_pic_buffering_minus1, max_num_reorder_pics, max_latency_increase_plus1
1545        push_exp_golomb(&mut bits, 0);
1546        push_exp_golomb(&mut bits, 0);
1547        push_exp_golomb(&mut bits, 0);
1548        // log2_min_luma_coding_block_size_minus3 = ue(0) => 3
1549        push_exp_golomb(&mut bits, 0);
1550        // log2_diff_max_min_luma_coding_block_size = ue(3)
1551        push_exp_golomb(&mut bits, 3);
1552        // log2_min_luma_transform_block_size_minus2 = ue(0) => 2
1553        push_exp_golomb(&mut bits, 0);
1554        // log2_diff_max_min_luma_transform_block_size = ue(3)
1555        push_exp_golomb(&mut bits, 3);
1556        // max_transform_hierarchy_depth_inter = ue(0)
1557        push_exp_golomb(&mut bits, 0);
1558        // max_transform_hierarchy_depth_intra = ue(0)
1559        push_exp_golomb(&mut bits, 0);
1560        // scaling_list_enabled_flag = 0
1561        push_bits(&mut bits, 0, 1);
1562        // amp_enabled_flag = 0
1563        push_bits(&mut bits, 0, 1);
1564        // sample_adaptive_offset_enabled_flag = 0
1565        push_bits(&mut bits, 0, 1);
1566        // pcm_enabled_flag = 0
1567        push_bits(&mut bits, 0, 1);
1568        // num_short_term_ref_pic_sets = ue(0)
1569        push_exp_golomb(&mut bits, 0);
1570
1571        // Pad to byte boundary
1572        while bits.len() % 8 != 0 {
1573            bits.push(0);
1574        }
1575
1576        let bytes = bits_to_bytes(&bits);
1577        let sps = parse_hevc_sps(&bytes).unwrap();
1578        assert_eq!(sps.pic_width, 1920);
1579        assert_eq!(sps.pic_height, 1080);
1580        assert_eq!(sps.chroma_format_idc, 1);
1581        assert_eq!(sps.bit_depth_luma, 8);
1582        assert_eq!(sps.bit_depth_chroma, 8);
1583        assert_eq!(sps.vps_id, 0);
1584        assert_eq!(sps.sps_id, 0);
1585    }
1586
1587    #[test]
1588    fn hevc_slice_type_enum() {
1589        assert_eq!(HevcSliceType::B as u8, 0);
1590        assert_eq!(HevcSliceType::P as u8, 1);
1591        assert_eq!(HevcSliceType::I as u8, 2);
1592    }
1593
1594    #[test]
1595    fn hevc_frame_dimensions_from_sps() {
1596        let sps = HevcSps {
1597            sps_id: 0,
1598            vps_id: 0,
1599            max_sub_layers: 1,
1600            chroma_format_idc: 1,
1601            pic_width: 3840,
1602            pic_height: 2160,
1603            bit_depth_luma: 10,
1604            bit_depth_chroma: 10,
1605            log2_max_pic_order_cnt: 8,
1606            log2_min_cb_size: 3,
1607            log2_diff_max_min_cb_size: 3,
1608            log2_min_transform_size: 2,
1609            log2_diff_max_min_transform_size: 3,
1610            max_transform_hierarchy_depth_inter: 1,
1611            max_transform_hierarchy_depth_intra: 1,
1612            sample_adaptive_offset_enabled: true,
1613            pcm_enabled: false,
1614            num_short_term_ref_pic_sets: 0,
1615            long_term_ref_pics_present: false,
1616            sps_temporal_mvp_enabled: true,
1617            strong_intra_smoothing_enabled: true,
1618        };
1619        let (w, h) = hevc_frame_dimensions(&sps);
1620        assert_eq!(w, 3840);
1621        assert_eq!(h, 2160);
1622    }
1623
1624    // -- Scaling list parsing tests -----------------------------------------
1625
1626    #[test]
1627    fn hevc_sps_with_scaling_list_data() {
1628        // Build SPS with scaling_list_enabled=1 and scaling_list_data_present=1,
1629        // then provide scaling list data where every matrix uses pred_mode_flag=0
1630        // with delta=0 (copy from default).
1631        let mut bits = Vec::new();
1632
1633        // VPS/sub-layer/profile-tier-level header (same as hevc_sps_dimensions)
1634        push_bits(&mut bits, 0, 4); // vps_id
1635        push_bits(&mut bits, 0, 3); // max_sub_layers_minus1
1636        push_bits(&mut bits, 1, 1); // temporal_id_nesting
1637        // profile_tier_level
1638        push_bits(&mut bits, 0, 8);
1639        push_bits(&mut bits, 0, 16);
1640        push_bits(&mut bits, 0, 16);
1641        push_bits(&mut bits, 0, 16);
1642        push_bits(&mut bits, 0, 16);
1643        push_bits(&mut bits, 0, 16);
1644        push_bits(&mut bits, 0, 8);
1645
1646        push_exp_golomb(&mut bits, 0); // sps_id
1647        push_exp_golomb(&mut bits, 1); // chroma_format_idc
1648        push_exp_golomb(&mut bits, 64); // pic_width
1649        push_exp_golomb(&mut bits, 64); // pic_height
1650        push_bits(&mut bits, 0, 1); // no conformance window
1651        push_exp_golomb(&mut bits, 0); // bit_depth_luma_minus8
1652        push_exp_golomb(&mut bits, 0); // bit_depth_chroma_minus8
1653        push_exp_golomb(&mut bits, 0); // log2_max_poc_minus4
1654        push_bits(&mut bits, 1, 1); // sub_layer_ordering_info_present
1655        push_exp_golomb(&mut bits, 0);
1656        push_exp_golomb(&mut bits, 0);
1657        push_exp_golomb(&mut bits, 0);
1658        push_exp_golomb(&mut bits, 0); // log2_min_cb_size_minus3
1659        push_exp_golomb(&mut bits, 3); // log2_diff_max_min_cb_size
1660        push_exp_golomb(&mut bits, 0); // log2_min_transform_size_minus2
1661        push_exp_golomb(&mut bits, 3); // log2_diff_max_min_transform_size
1662        push_exp_golomb(&mut bits, 0); // max_transform_hierarchy_depth_inter
1663        push_exp_golomb(&mut bits, 0); // max_transform_hierarchy_depth_intra
1664
1665        // scaling_list_enabled_flag = 1
1666        push_bits(&mut bits, 1, 1);
1667        // scaling_list_data_present_flag = 1
1668        push_bits(&mut bits, 1, 1);
1669
1670        // scaling_list_data(): for each sizeId 0..3, each matrixId:
1671        // pred_mode_flag = 0, pred_matrix_id_delta = ue(0)
1672        // sizeId 0: 6 matrices
1673        for _ in 0..6 {
1674            push_bits(&mut bits, 0, 1); // pred_mode_flag = 0
1675            push_exp_golomb(&mut bits, 0); // delta = 0
1676        }
1677        // sizeId 1: 6 matrices
1678        for _ in 0..6 {
1679            push_bits(&mut bits, 0, 1);
1680            push_exp_golomb(&mut bits, 0);
1681        }
1682        // sizeId 2: 6 matrices
1683        for _ in 0..6 {
1684            push_bits(&mut bits, 0, 1);
1685            push_exp_golomb(&mut bits, 0);
1686        }
1687        // sizeId 3: 2 matrices
1688        for _ in 0..2 {
1689            push_bits(&mut bits, 0, 1);
1690            push_exp_golomb(&mut bits, 0);
1691        }
1692
1693        // Continue with remaining SPS fields
1694        push_bits(&mut bits, 0, 1); // amp_enabled
1695        push_bits(&mut bits, 0, 1); // sample_adaptive_offset_enabled
1696        push_bits(&mut bits, 0, 1); // pcm_enabled
1697        push_exp_golomb(&mut bits, 0); // num_short_term_ref_pic_sets
1698
1699        while bits.len() % 8 != 0 {
1700            bits.push(0);
1701        }
1702
1703        let bytes = bits_to_bytes(&bits);
1704        let sps = parse_hevc_sps(&bytes).unwrap();
1705        assert_eq!(sps.pic_width, 64);
1706        assert_eq!(sps.pic_height, 64);
1707    }
1708
1709    #[test]
1710    fn hevc_scaling_list_with_explicit_coeffs() {
1711        // Test scaling list parsing where pred_mode_flag=1 (explicit coefficients).
1712        // Build a standalone scaling_list_data bitstream and parse it.
1713        let mut bits = Vec::new();
1714
1715        // sizeId 0: 6 matrices, coefNum = min(64, 1<<(4+0)) = 16
1716        for _ in 0..6 {
1717            push_bits(&mut bits, 1, 1); // pred_mode_flag = 1
1718            // sizeId 0 <= 1, so no dc_coef
1719            for _ in 0..16 {
1720                push_signed_exp_golomb(&mut bits, 0); // delta_coef = 0
1721            }
1722        }
1723        // sizeId 1: 6 matrices, coefNum = min(64, 1<<(4+2)) = 64
1724        for _ in 0..6 {
1725            push_bits(&mut bits, 1, 1);
1726            // sizeId 1 <= 1, so no dc_coef
1727            for _ in 0..64 {
1728                push_signed_exp_golomb(&mut bits, 0);
1729            }
1730        }
1731        // sizeId 2: 6 matrices, coefNum = min(64, 1<<(4+4)) = 64
1732        for _ in 0..6 {
1733            push_bits(&mut bits, 1, 1);
1734            push_signed_exp_golomb(&mut bits, 0); // dc_coef_minus8 (sizeId > 1)
1735            for _ in 0..64 {
1736                push_signed_exp_golomb(&mut bits, 0);
1737            }
1738        }
1739        // sizeId 3: 2 matrices, coefNum = min(64, 1<<(4+6)) = 64
1740        for _ in 0..2 {
1741            push_bits(&mut bits, 1, 1);
1742            push_signed_exp_golomb(&mut bits, 0); // dc_coef_minus8
1743            for _ in 0..64 {
1744                push_signed_exp_golomb(&mut bits, 0);
1745            }
1746        }
1747
1748        while bits.len() % 8 != 0 {
1749            bits.push(0);
1750        }
1751
1752        let bytes = bits_to_bytes(&bits);
1753        let mut reader = BitstreamReader::new(&bytes);
1754        let result = skip_scaling_list_data(&mut reader);
1755        assert!(result.is_ok());
1756    }
1757
1758    // -- Intra prediction tests ---------------------------------------------
1759
1760    #[test]
1761    fn intra_dc_prediction_4x4() {
1762        let top = [100i16, 100, 100, 100];
1763        let left = [200i16, 200, 200, 200];
1764        let mut out = [0i16; 16];
1765        intra_predict_dc(&top, &left, 4, &mut out);
1766        // DC = (4*100 + 4*200 + 4) / 8 = 1204/8 = 150
1767        for v in &out {
1768            assert_eq!(*v, 150);
1769        }
1770    }
1771
1772    #[test]
1773    fn intra_dc_prediction_8x8() {
1774        let top = [128i16; 8];
1775        let left = [128i16; 8];
1776        let mut out = [0i16; 64];
1777        intra_predict_dc(&top, &left, 8, &mut out);
1778        for v in &out {
1779            assert_eq!(*v, 128);
1780        }
1781    }
1782
1783    #[test]
1784    fn intra_planar_prediction_4x4() {
1785        let top = [100i16, 100, 100, 100];
1786        let left = [100i16, 100, 100, 100];
1787        let mut out = [0i16; 16];
1788        intra_predict_planar(&top, &left, 100, 100, 4, &mut out);
1789        // With uniform neighbours, all outputs should be 100
1790        for v in &out {
1791            assert_eq!(*v, 100);
1792        }
1793    }
1794
1795    #[test]
1796    fn intra_angular_horizontal() {
1797        let top = [50i16; 4];
1798        let left = [200i16, 201, 202, 203];
1799        let mut out = [0i16; 16];
1800        intra_predict_angular(&top, &left, 10, 4, &mut out);
1801        // Mode 10 is horizontal-like, each row should equal left[y]
1802        for y in 0..4 {
1803            for x in 0..4 {
1804                assert_eq!(out[y * 4 + x], left[y]);
1805            }
1806        }
1807    }
1808
1809    #[test]
1810    fn intra_angular_vertical() {
1811        let top = [200i16, 201, 202, 203];
1812        let left = [50i16; 4];
1813        let mut out = [0i16; 16];
1814        intra_predict_angular(&top, &left, 26, 4, &mut out);
1815        // Mode 26 is vertical-like, each column should equal top[x]
1816        for y in 0..4 {
1817            for x in 0..4 {
1818                assert_eq!(out[y * 4 + x], top[x]);
1819            }
1820        }
1821    }
1822
1823    // -- Transform roundtrip tests ------------------------------------------
1824
1825    #[test]
1826    fn hevc_dst_4x4_dc_roundtrip() {
1827        // DC coefficient — DST is asymmetric, so outputs are not uniform.
1828        // Use a large enough value to survive the >>12 shift in the column pass.
1829        let mut coeffs = [0i32; 16];
1830        coeffs[0] = 16384;
1831        let mut out = [0i32; 16];
1832        hevc_inverse_dst_4x4(&coeffs, &mut out);
1833        // At least some outputs should be non-zero with a large DC input.
1834        let any_nonzero = out.iter().any(|&v| v != 0);
1835        assert!(any_nonzero, "DST output should have non-zero values");
1836    }
1837
1838    #[test]
1839    fn hevc_dct_4x4_dc_roundtrip() {
1840        let mut coeffs = [0i32; 16];
1841        coeffs[0] = 256;
1842        let mut out = [0i32; 16];
1843        hevc_inverse_dct_4x4(&coeffs, &mut out);
1844        // DC coefficient should produce uniform output
1845        let first = out[0];
1846        assert!(first != 0);
1847        // All values should be equal for DC-only input
1848        for v in &out {
1849            assert_eq!(*v, first);
1850        }
1851    }
1852
1853    #[test]
1854    fn hevc_dct_4x4_zero_input() {
1855        let coeffs = [0i32; 16];
1856        let mut out = [99i32; 16];
1857        hevc_inverse_dct_4x4(&coeffs, &mut out);
1858        for v in &out {
1859            assert_eq!(*v, 0);
1860        }
1861    }
1862
1863    #[test]
1864    fn hevc_dct_8x8_dc_coefficient() {
1865        let mut coeffs = [0i32; 64];
1866        coeffs[0] = 128;
1867        let mut out = [0i32; 64];
1868        hevc_inverse_dct_8x8(&coeffs, &mut out);
1869        let first = out[0];
1870        assert!(first != 0);
1871        // All should be equal for DC
1872        for v in &out {
1873            assert_eq!(*v, first);
1874        }
1875    }
1876
1877    // -- Coding tree / HevcDecoder tests ------------------------------------
1878
1879    #[test]
1880    fn decode_coding_tree_splits_correctly() {
1881        let mut results = Vec::new();
1882        // 64x64 CTU, max_depth=3, should split down to 8x8
1883        decode_coding_tree(0, 0, 6, 0, 3, 26, 64, 64, &mut results);
1884        // 64 -> 32 -> 16 -> 8: 4^3 = 64 leaf CUs
1885        assert_eq!(results.len(), 64);
1886        for cu in &results {
1887            assert_eq!(cu.size, 8);
1888            assert_eq!(cu.pred_mode, HevcPredMode::Intra);
1889        }
1890    }
1891
1892    #[test]
1893    fn decode_coding_tree_boundary_clipping() {
1894        let mut results = Vec::new();
1895        // 48x48 picture, 64x64 CTU at origin
1896        decode_coding_tree(0, 0, 6, 0, 3, 26, 48, 48, &mut results);
1897        // Should produce CUs, some may be at boundary but none outside
1898        assert!(!results.is_empty());
1899        for cu in &results {
1900            assert!(cu.x < 48);
1901            assert!(cu.y < 48);
1902        }
1903    }
1904
1905    #[test]
1906    fn hevc_decoder_new_and_nal_routing() {
1907        let mut decoder = HevcDecoder::new();
1908        assert!(decoder.sps().is_none());
1909        assert!(decoder.pps().is_none());
1910
1911        // Too-short NAL should error
1912        let result = decoder.decode_nal(&[0x00]);
1913        assert!(result.is_err());
1914
1915        // Unknown NAL type should return Ok(None)
1916        // NAL header: nal_unit_type in bits [6:1] of first byte
1917        // Type 35 (AUD) = 0b100011 -> first byte = 0b0_100011_0 = 0x46
1918        let aud_nal = [0x46, 0x01];
1919        let result = decoder.decode_nal(&aud_nal);
1920        assert!(result.unwrap().is_none());
1921    }
1922
1923    #[test]
1924    fn hevc_decoder_idr_without_sps_errors() {
1925        let mut decoder = HevcDecoder::new();
1926        // IDR_W_RADL type = 19 -> first byte bits [6:1] = 19 = 0b010011
1927        // first byte = 0b0_010011_0 = 0x26
1928        let idr_nal = [0x26, 0x01, 0x00];
1929        let result = decoder.decode_nal(&idr_nal);
1930        assert!(result.is_err());
1931    }
1932
1933    #[test]
1934    fn hevc_intra_mode_from_index() {
1935        assert_eq!(HevcIntraMode::from_index(0), Some(HevcIntraMode::Planar));
1936        assert_eq!(HevcIntraMode::from_index(1), Some(HevcIntraMode::Dc));
1937        assert_eq!(HevcIntraMode::from_index(2), Some(HevcIntraMode::Angular2));
1938        assert_eq!(
1939            HevcIntraMode::from_index(34),
1940            Some(HevcIntraMode::Angular34)
1941        );
1942        assert_eq!(HevcIntraMode::from_index(35), None);
1943    }
1944
1945    #[test]
1946    fn hevc_dequant_basic() {
1947        let mut coeffs = [100i32; 16];
1948        hevc_dequant(&mut coeffs, 26, 8, 2);
1949        // After dequant, values should have changed
1950        assert!(coeffs[0] != 100);
1951    }
1952}