Skip to main content

yscv_video/
hevc_syntax.rs

1//! HEVC CU/PU/TU syntax parsing using CABAC (ITU-T H.265, sections 7.3.8-7.3.12).
2//!
3//! This module reads coded data from the bitstream via the CABAC engine
4//! ([`CabacDecoder`]) and produces decoded coding unit results including
5//! prediction mode, intra modes, and transform coefficients.
6
7use super::hevc_cabac::{CabacDecoder, ContextModel};
8use super::hevc_decoder::{HevcPps, HevcPredMode, HevcSliceType, HevcSps};
9use super::hevc_inter::{HevcMvField, parse_inter_prediction};
10
11// ---------------------------------------------------------------------------
12// Context index offsets (ITU-T H.265, Table 9-4)
13// ---------------------------------------------------------------------------
14
15/// `split_cu_flag` — 3 contexts (indexed by depth + neighbour availability).
16pub const CTX_SPLIT_CU_FLAG: usize = 0;
17/// `cu_skip_flag` — 3 contexts.
18pub const CTX_CU_SKIP_FLAG: usize = 3;
19/// `pred_mode_flag` — 1 context.
20pub const CTX_PRED_MODE_FLAG: usize = 6;
21/// `part_mode` — 4 contexts.
22pub const CTX_PART_MODE: usize = 7;
23/// `prev_intra_luma_pred_flag` — 1 context.
24pub const CTX_PREV_INTRA_LUMA_PRED_FLAG: usize = 11;
25/// `intra_chroma_pred_mode` — 1 context.
26pub const CTX_INTRA_CHROMA_PRED_MODE: usize = 12;
27/// `cbf_luma` — 2 contexts (indexed by transform depth).
28pub const CTX_CBF_LUMA: usize = 13;
29/// `cbf_cb` / `cbf_cr` — 4 contexts (shared).
30pub const CTX_CBF_CB: usize = 15;
31/// `last_sig_coeff_x_prefix` — 18 contexts.
32pub const CTX_LAST_SIG_COEFF_X_PREFIX: usize = 19;
33/// `last_sig_coeff_y_prefix` — 18 contexts.
34pub const CTX_LAST_SIG_COEFF_Y_PREFIX: usize = 37;
35/// `coded_sub_block_flag` — 4 contexts.
36pub const CTX_CODED_SUB_BLOCK_FLAG: usize = 55;
37/// `sig_coeff_flag` — 44 contexts.
38pub const CTX_SIG_COEFF_FLAG: usize = 59;
39/// `coeff_abs_level_greater1_flag` — 24 contexts.
40pub const CTX_COEFF_ABS_LEVEL_GREATER1: usize = 103;
41/// `coeff_abs_level_greater2_flag` — 6 contexts.
42pub const CTX_COEFF_ABS_LEVEL_GREATER2: usize = 127;
43/// Total number of CABAC context models for syntax parsing.
44pub const NUM_CABAC_CONTEXTS: usize = 133;
45
46// ---------------------------------------------------------------------------
47// Default initialisation values (representative subset from Table 9-4)
48// ---------------------------------------------------------------------------
49
50/// Default init values for I-slice contexts (Table 9-4, initType = 0).
51/// One value per context in the order defined by the CTX_* constants above.
52#[rustfmt::skip]
53const INIT_VALUES_I_SLICE: [u8; NUM_CABAC_CONTEXTS] = [
54    // split_cu_flag (3)
55    139, 141, 157,
56    // cu_skip_flag (3)
57    197, 185, 201,
58    // pred_mode_flag (1)
59    149,
60    // part_mode (4)
61    184, 154, 139, 154,
62    // prev_intra_luma_pred_flag (1)
63    184,
64    // intra_chroma_pred_mode (1)
65    152,
66    // cbf_luma (2)
67    111, 141,
68    // cbf_cb (4)
69    94, 138, 182, 154,
70    // last_sig_coeff_x_prefix (18)
71    110, 110, 124, 125, 140, 153, 125, 127, 140,
72    109, 111, 143, 127, 111, 79, 108, 123, 63,
73    // last_sig_coeff_y_prefix (18)
74    110, 110, 124, 125, 140, 153, 125, 127, 140,
75    109, 111, 143, 127, 111, 79, 108, 123, 63,
76    // coded_sub_block_flag (4)
77    91, 171, 134, 141,
78    // sig_coeff_flag (44)
79    111, 111, 125, 110, 110,  94, 124, 108, 124, 107,
80    125, 141, 179, 153, 125, 107, 125, 141, 179, 153,
81    125, 107, 125, 141, 179, 153, 125, 107, 125, 141,
82    179, 153, 125, 141, 140, 139, 182, 182, 152, 136,
83    152, 136, 153, 136,
84    // coeff_abs_level_greater1 (24)
85    140, 92, 137, 138, 140, 152, 138, 139, 153,  74,
86    149,  92, 139, 107, 122, 152, 140, 179, 166, 182,
87    140, 227, 122, 197,
88    // coeff_abs_level_greater2 (6)
89    138, 153, 136, 167, 152, 152,
90];
91
92// ---------------------------------------------------------------------------
93// Scan orders for coefficient coding (ITU-T H.265, 6.5.3)
94// ---------------------------------------------------------------------------
95
96/// Diagonal up-right scan for 4x4 sub-blocks.
97#[rustfmt::skip]
98const SCAN_ORDER_4X4_DIAG: [u8; 16] = [
99     0,  4,  1,  8,  5,  2, 12,  9,
100     6,  3, 13, 10,  7, 14, 11, 15,
101];
102
103/// Scan order for 4x4 sub-block positions within an 8x8 TU (2x2 sub-blocks).
104#[rustfmt::skip]
105const SCAN_ORDER_2X2_DIAG: [u8; 4] = [0, 2, 1, 3];
106
107/// Scan order for 4x4 sub-block positions within a 16x16 TU (4x4 sub-blocks).
108#[rustfmt::skip]
109const SCAN_ORDER_4X4_SUBBLOCK_DIAG: [u8; 16] = [
110     0,  4,  1,  8,  5,  2, 12,  9,
111     6,  3, 13, 10,  7, 14, 11, 15,
112];
113
114// ---------------------------------------------------------------------------
115// Slice-level CABAC state
116// ---------------------------------------------------------------------------
117
118/// CABAC state for decoding a single slice, holding all context models and the
119/// arithmetic decoder tied to the slice payload data.
120pub struct HevcSliceCabacState<'a> {
121    /// Adaptive probability contexts for all syntax elements.
122    pub contexts: Vec<ContextModel>,
123    /// The arithmetic decoder reading from the slice payload.
124    pub cabac: CabacDecoder<'a>,
125}
126
127impl<'a> HevcSliceCabacState<'a> {
128    /// Create a new slice CABAC state from slice payload bytes and QP.
129    ///
130    /// Initialises all context models according to ITU-T H.265 Table 9-4
131    /// using the given `slice_qp` and the I-slice initialisation table.
132    pub fn new(slice_data: &'a [u8], slice_qp: i32) -> Self {
133        let mut contexts = Vec::with_capacity(NUM_CABAC_CONTEXTS);
134        for &iv in &INIT_VALUES_I_SLICE {
135            let mut ctx = ContextModel::new(iv);
136            ctx.init(slice_qp, iv);
137            contexts.push(ctx);
138        }
139        let cabac = CabacDecoder::new(slice_data);
140        HevcSliceCabacState { contexts, cabac }
141    }
142
143    /// Re-initialise all context models for a given QP (e.g. at WPP row start).
144    pub fn reinit_contexts(&mut self, slice_qp: i32) {
145        for (ctx, &iv) in self.contexts.iter_mut().zip(INIT_VALUES_I_SLICE.iter()) {
146            ctx.init(slice_qp, iv);
147        }
148    }
149}
150
151// ---------------------------------------------------------------------------
152// CU-level data
153// ---------------------------------------------------------------------------
154
155/// Decoded data produced by [`parse_coding_unit`].
156#[derive(Debug, Clone)]
157pub struct CodingUnitData {
158    /// Prediction mode (Intra, Inter, Skip).
159    pub pred_mode: HevcPredMode,
160    /// Luma intra prediction mode index (0..=34).
161    pub intra_mode_luma: u8,
162    /// Chroma intra prediction mode index.
163    pub intra_mode_chroma: u8,
164    /// Whether the luma CBF is set (nonzero residual).
165    pub cbf_luma: bool,
166    /// Whether the Cb CBF is set.
167    pub cbf_cb: bool,
168    /// Whether the Cr CBF is set.
169    pub cbf_cr: bool,
170    /// Transform coefficients (luma) in scan order, length = block_size^2.
171    pub residual_luma: Vec<i16>,
172}
173
174// ---------------------------------------------------------------------------
175// Coding tree traversal (split_cu_flag)
176// ---------------------------------------------------------------------------
177
178/// Read `split_cu_flag` from the bitstream.
179///
180/// Context index is derived from the current depth plus the availability of
181/// left/above neighbours (simplified: `ctx_idx = depth.min(2)`).
182pub fn parse_split_cu_flag(
183    state: &mut HevcSliceCabacState<'_>,
184    depth: u8,
185    _left_available: bool,
186    _above_available: bool,
187) -> bool {
188    // Context selection: depth contributes to the index (spec 9.3.4.2.2).
189    // Simplified: left+above availability each add 1 in the real spec, but
190    // here we approximate with just depth clamped to 0..2.
191    let ctx_idx = CTX_SPLIT_CU_FLAG + (depth as usize).min(2);
192    state.cabac.decode_decision(&mut state.contexts[ctx_idx])
193}
194
195// ---------------------------------------------------------------------------
196// CU-level syntax parsing
197// ---------------------------------------------------------------------------
198
199/// Parse a coding unit from the bitstream (ITU-T H.265, 7.3.8.5).
200///
201/// Returns the prediction mode, intra luma/chroma modes, CBF flags, and
202/// residual transform coefficients for the luma plane.
203pub fn parse_coding_unit(
204    state: &mut HevcSliceCabacState<'_>,
205    _x: usize,
206    _y: usize,
207    log2_cu_size: u32,
208    sps: &HevcSps,
209    pps: &HevcPps,
210    slice_type: HevcSliceType,
211) -> CodingUnitData {
212    let cu_size = 1u32 << log2_cu_size;
213    let num_samples = (cu_size * cu_size) as usize;
214
215    // -- cu_skip_flag (P/B slices only) ------------------------------------
216    let skip_flag = if slice_type != HevcSliceType::I {
217        let ctx_idx = CTX_CU_SKIP_FLAG; // simplified: always ctx 0
218        state.cabac.decode_decision(&mut state.contexts[ctx_idx])
219    } else {
220        false
221    };
222
223    if skip_flag {
224        return CodingUnitData {
225            pred_mode: HevcPredMode::Skip,
226            intra_mode_luma: 0,
227            intra_mode_chroma: 0,
228            cbf_luma: false,
229            cbf_cb: false,
230            cbf_cr: false,
231            residual_luma: vec![0; num_samples],
232        };
233    }
234
235    // -- pred_mode_flag (non-I slices) -------------------------------------
236    let pred_mode = if slice_type == HevcSliceType::I {
237        HevcPredMode::Intra
238    } else {
239        let ctx_idx = CTX_PRED_MODE_FLAG;
240        if state.cabac.decode_decision(&mut state.contexts[ctx_idx]) {
241            HevcPredMode::Intra
242        } else {
243            HevcPredMode::Inter
244        }
245    };
246
247    // -- part_mode ----------------------------------------------------------
248    // For intra CUs the only valid mode is PART_2Nx2N (spec 7.4.9.5).
249    // For inter we decode but currently only handle 2Nx2N.
250    if pred_mode == HevcPredMode::Inter {
251        let ctx_idx = CTX_PART_MODE;
252        let _part_2nx2n = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
253        // If not 2Nx2N, additional bins would follow; skip for now.
254    }
255
256    // -- Intra mode signalling ---------------------------------------------
257    let (intra_mode_luma, intra_mode_chroma) = if pred_mode == HevcPredMode::Intra {
258        let luma = parse_intra_mode_luma(state);
259        let chroma = parse_intra_chroma_pred_mode(state);
260        (luma, chroma)
261    } else {
262        (0u8, 0u8)
263    };
264
265    // -- Transform tree (simplified: single TU = CU) -----------------------
266    let log2_min_tu = sps.log2_min_transform_size as u32;
267    let log2_tu = log2_cu_size.max(log2_min_tu);
268
269    // CBF flags
270    let cbf_cb = parse_cbf_chroma(state, 0);
271    let cbf_cr = parse_cbf_chroma(state, 0);
272    let cbf_luma = parse_cbf_luma(state, 0);
273
274    // Residual coefficients (luma)
275    let residual_luma = if cbf_luma {
276        parse_transform_unit(state, log2_tu, true, pps.sign_data_hiding_enabled)
277    } else {
278        vec![0; num_samples]
279    };
280
281    CodingUnitData {
282        pred_mode,
283        intra_mode_luma,
284        intra_mode_chroma,
285        cbf_luma,
286        cbf_cb,
287        cbf_cr,
288        residual_luma,
289    }
290}
291
292// ---------------------------------------------------------------------------
293// Intra mode signalling (ITU-T H.265, 7.3.8.5 + 8.4.2)
294// ---------------------------------------------------------------------------
295
296/// Parse luma intra prediction mode.
297///
298/// Reads `prev_intra_luma_pred_flag`; if set, reads `mpm_idx` (TR-coded,
299/// 0..2); otherwise reads `rem_intra_luma_pred_mode` (5 bypass bins).
300///
301/// The Most Probable Mode (MPM) list is constructed from DC, Planar, and
302/// Angular-26 as a simplified default (real spec uses neighbour modes).
303fn parse_intra_mode_luma(state: &mut HevcSliceCabacState<'_>) -> u8 {
304    let ctx_idx = CTX_PREV_INTRA_LUMA_PRED_FLAG;
305    let prev_flag = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
306
307    if prev_flag {
308        // mpm_idx: truncated unary, max 2, bypass coded
309        let mpm_idx = parse_mpm_idx(state);
310        // Simplified MPM list: {Planar(0), DC(1), Angular-26(26)}
311        let mpm_list = build_default_mpm_list();
312        mpm_list[mpm_idx as usize]
313    } else {
314        // rem_intra_luma_pred_mode: 5 bypass bins (0..31)
315        let rem = state.cabac.decode_fl(5) as u8;
316        // Map rem to actual mode, skipping MPM entries (simplified)
317        let mpm_list = build_default_mpm_list();
318        remap_rem_mode(rem, &mpm_list)
319    }
320}
321
322/// Decode `mpm_idx` — truncated unary bypass code, max value 2.
323fn parse_mpm_idx(state: &mut HevcSliceCabacState<'_>) -> u8 {
324    // mpm_idx is bypass-coded as truncated unary with cMax=2
325    if !state.cabac.decode_bypass() {
326        0
327    } else if !state.cabac.decode_bypass() {
328        1
329    } else {
330        2
331    }
332}
333
334/// Build the default MPM list when neighbour modes are unavailable.
335///
336/// Per ITU-T H.265 8.4.2, when both neighbours are unavailable the MPM list
337/// is {Planar, DC, Angular-26}. The list is always sorted in ascending order.
338fn build_default_mpm_list() -> [u8; 3] {
339    let mut mpm = [0u8, 1, 26]; // Planar, DC, Angular-26
340    // Sort ascending (already sorted in this case)
341    mpm.sort_unstable();
342    mpm
343}
344
345/// Build the MPM list from left and above neighbour intra modes.
346///
347/// Follows ITU-T H.265 section 8.4.2 for constructing the three most
348/// probable modes.
349pub fn build_mpm_list(left_mode: u8, above_mode: u8) -> [u8; 3] {
350    let mut mpm = [0u8; 3];
351    if left_mode == above_mode {
352        if left_mode < 2 {
353            // Both are Planar or DC
354            mpm[0] = 0; // Planar
355            mpm[1] = 1; // DC
356            mpm[2] = 26; // Angular-26 (vertical)
357        } else {
358            mpm[0] = left_mode;
359            mpm[1] = 2 + ((left_mode + 29) % 32);
360            mpm[2] = 2 + ((left_mode - 2 + 1) % 32);
361        }
362    } else {
363        mpm[0] = left_mode;
364        mpm[1] = above_mode;
365        if left_mode != 0 && above_mode != 0 {
366            mpm[2] = 0; // Planar
367        } else if left_mode != 1 && above_mode != 1 {
368            mpm[2] = 1; // DC
369        } else {
370            mpm[2] = 26; // Angular-26
371        }
372    }
373    mpm
374}
375
376/// Map `rem_intra_luma_pred_mode` to the actual mode index, skipping MPMs.
377///
378/// The `rem` value (0..31) indexes into the 32 non-MPM modes.  We walk
379/// through modes 0..34, skip the three MPM entries, and select the `rem`-th
380/// remaining mode.
381fn remap_rem_mode(rem: u8, mpm_list: &[u8; 3]) -> u8 {
382    let mut sorted_mpm = *mpm_list;
383    sorted_mpm.sort_unstable();
384
385    let mut mode = rem;
386    for &m in &sorted_mpm {
387        if mode >= m {
388            mode += 1;
389        }
390    }
391    mode.min(34)
392}
393
394/// Parse `intra_chroma_pred_mode` (ITU-T H.265, 7.3.8.5).
395///
396/// One context-coded bin selects between mode 4 (derived from luma) and an
397/// explicit 2-bit bypass-coded index (0..3 mapping to planar/angular/DC/angular).
398fn parse_intra_chroma_pred_mode(state: &mut HevcSliceCabacState<'_>) -> u8 {
399    let ctx_idx = CTX_INTRA_CHROMA_PRED_MODE;
400    let derived = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
401
402    if !derived {
403        // Mode 4: "derived from luma" (DM mode)
404        4
405    } else {
406        // 2 bypass bins encoding index 0..3
407        state.cabac.decode_fl(2) as u8
408    }
409}
410
411// ---------------------------------------------------------------------------
412// CBF (Coded Block Flag) parsing
413// ---------------------------------------------------------------------------
414
415/// Parse `cbf_luma` (ITU-T H.265, 7.3.8.11).
416///
417/// Context index depends on the transform depth within the CU.
418fn parse_cbf_luma(state: &mut HevcSliceCabacState<'_>, trafo_depth: u32) -> bool {
419    let ctx_idx = CTX_CBF_LUMA + (trafo_depth.min(1) as usize);
420    state.cabac.decode_decision(&mut state.contexts[ctx_idx])
421}
422
423/// Parse `cbf_cb` or `cbf_cr` (ITU-T H.265, 7.3.8.11).
424///
425/// Contexts are shared between Cb and Cr; index depends on transform depth.
426fn parse_cbf_chroma(state: &mut HevcSliceCabacState<'_>, trafo_depth: u32) -> bool {
427    let ctx_idx = CTX_CBF_CB + (trafo_depth.min(3) as usize);
428    state.cabac.decode_decision(&mut state.contexts[ctx_idx])
429}
430
431// ---------------------------------------------------------------------------
432// TU-level residual parsing (ITU-T H.265, 7.3.8.11 + 7.3.8.12)
433// ---------------------------------------------------------------------------
434
435/// Parse a transform unit's residual coefficients.
436///
437/// Implements the coefficient coding syntax from ITU-T H.265 section 7.3.8.12
438/// including last-significant-coefficient position, sub-block significance
439/// flags, per-coefficient significance, greater-than-1/2 flags, and bypass-
440/// coded sign/remaining-level.
441///
442/// Returns the dequantised coefficient array in raster order (row-major),
443/// with length `(1 << log2_tu_size)^2`.
444pub fn parse_transform_unit(
445    state: &mut HevcSliceCabacState<'_>,
446    log2_tu_size: u32,
447    is_luma: bool,
448    sign_data_hiding_enabled: bool,
449) -> Vec<i16> {
450    let tu_size = 1u32 << log2_tu_size;
451    let num_coeffs = (tu_size * tu_size) as usize;
452    let mut coeffs = vec![0i16; num_coeffs];
453
454    // -- Last significant coefficient position ----------------------------
455    let (last_x, last_y) = parse_last_sig_coeff_pos(state, log2_tu_size, is_luma);
456
457    if last_x >= tu_size || last_y >= tu_size {
458        // Out of bounds — treat as all-zero TU.
459        return coeffs;
460    }
461
462    // -- Sub-block and coefficient scanning --------------------------------
463    let log2_sub = 2u32; // 4x4 sub-blocks
464    let sub_size = 1u32 << log2_sub;
465    let num_sub_x = tu_size >> log2_sub;
466    let num_sub_total = (num_sub_x * num_sub_x) as usize;
467
468    // Determine which sub-block contains the last significant coeff
469    let last_sub_x = last_x >> log2_sub;
470    let last_sub_y = last_y >> log2_sub;
471    let last_sub_scan = sub_pos_to_scan_idx(last_sub_x, last_sub_y, num_sub_x);
472
473    // Sub-block coded flags (all blocks up to and including last are
474    // potentially coded; the last sub-block is implicitly coded).
475    let mut sub_coded = vec![false; num_sub_total];
476    if last_sub_scan < num_sub_total {
477        sub_coded[last_sub_scan] = true; // last sub-block always coded
478    }
479
480    // Process sub-blocks in reverse scan order
481    let first_sub = if last_sub_scan < num_sub_total {
482        last_sub_scan
483    } else {
484        0
485    };
486
487    for sub_scan in (0..=first_sub).rev() {
488        // Read coded_sub_block_flag for non-last, non-DC sub-blocks
489        if sub_scan < first_sub && sub_scan > 0 {
490            let ctx_idx = CTX_CODED_SUB_BLOCK_FLAG + if is_luma { 0 } else { 2 };
491            sub_coded[sub_scan] = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
492        } else if sub_scan == 0 && first_sub > 0 {
493            // DC sub-block: infer coded if any higher sub-block is coded
494            sub_coded[0] = true;
495        } else if sub_scan == first_sub {
496            sub_coded[sub_scan] = true;
497        }
498
499        if !sub_coded[sub_scan] {
500            continue;
501        }
502
503        // Get sub-block position in raster order
504        let (sub_x, sub_y) = scan_idx_to_sub_pos(sub_scan, num_sub_x);
505        let base_x = sub_x * sub_size;
506        let base_y = sub_y * sub_size;
507
508        // Determine the last coeff scan index within this sub-block
509        let last_scan_in_sub = if sub_scan == first_sub {
510            // The sub-block that contains the last significant coeff
511            let local_x = last_x - base_x;
512            let local_y = last_y - base_y;
513            local_pos_to_scan_idx(local_x, local_y)
514        } else {
515            15 // full 4x4 sub-block
516        };
517
518        // Parse significance flags, levels, and signs for this sub-block
519        parse_subblock_coeffs(
520            state,
521            &mut coeffs,
522            tu_size,
523            base_x,
524            base_y,
525            last_scan_in_sub,
526            is_luma,
527            sub_scan == first_sub,
528            sign_data_hiding_enabled,
529        );
530    }
531
532    coeffs
533}
534
535// ---------------------------------------------------------------------------
536// Last significant coefficient position
537// ---------------------------------------------------------------------------
538
539/// Parse `last_sig_coeff_x_prefix/suffix` and `last_sig_coeff_y_prefix/suffix`.
540///
541/// Returns `(last_x, last_y)` — the position of the last significant
542/// coefficient in the TU (in scan-to-raster mapped coordinates).
543fn parse_last_sig_coeff_pos(
544    state: &mut HevcSliceCabacState<'_>,
545    log2_tu_size: u32,
546    is_luma: bool,
547) -> (u32, u32) {
548    let last_x = parse_last_sig_coeff_prefix_suffix(state, log2_tu_size, is_luma, true);
549    let last_y = parse_last_sig_coeff_prefix_suffix(state, log2_tu_size, is_luma, false);
550    (last_x, last_y)
551}
552
553/// Parse one component (X or Y) of the last significant coefficient position.
554///
555/// The prefix is truncated-unary coded with context models; the suffix
556/// (if prefix >= 2) is bypass-coded fixed-length.
557fn parse_last_sig_coeff_prefix_suffix(
558    state: &mut HevcSliceCabacState<'_>,
559    log2_tu_size: u32,
560    is_luma: bool,
561    is_x: bool,
562) -> u32 {
563    let tu_size = 1u32 << log2_tu_size;
564    // Maximum prefix value = 2 * (log2_tu_size - 1)  (capped at TU size)
565    let max_prefix = if log2_tu_size > 1 {
566        2 * (log2_tu_size - 1)
567    } else {
568        0
569    };
570    // Context offset depends on component (x/y) and luma/chroma
571    let ctx_base = if is_x {
572        CTX_LAST_SIG_COEFF_X_PREFIX
573    } else {
574        CTX_LAST_SIG_COEFF_Y_PREFIX
575    };
576    let ctx_offset_c = if is_luma { 0usize } else { 9 };
577
578    // Decode prefix as truncated unary
579    let mut prefix = 0u32;
580    while prefix < max_prefix {
581        // Context index: 3 * (log2_tu_size - 2) + (prefix >> 1), capped to 8
582        let ctx_inc = if log2_tu_size >= 2 {
583            let group = (prefix >> 1) as usize;
584            let base = 3 * ((log2_tu_size as usize).saturating_sub(2));
585            (base + group).min(8)
586        } else {
587            0
588        };
589        let ctx_idx = ctx_base + ctx_offset_c + ctx_inc;
590        let ctx_idx = ctx_idx.min(state.contexts.len() - 1);
591        if state.cabac.decode_decision(&mut state.contexts[ctx_idx]) {
592            prefix += 1;
593        } else {
594            break;
595        }
596    }
597
598    // Decode suffix (if prefix >= 2)
599    if prefix < 2 {
600        return prefix;
601    }
602    let suffix_len = (prefix >> 1) - 1;
603    if suffix_len == 0 {
604        return prefix;
605    }
606    let suffix = state.cabac.decode_fl(suffix_len);
607    let value = (1u32 << suffix_len) + suffix + prefix - 2;
608    value.min(tu_size - 1)
609}
610
611// ---------------------------------------------------------------------------
612// Sub-block coefficient parsing
613// ---------------------------------------------------------------------------
614
615/// Parse significance flags, levels, and signs for one 4x4 sub-block.
616#[allow(clippy::too_many_arguments)]
617fn parse_subblock_coeffs(
618    state: &mut HevcSliceCabacState<'_>,
619    coeffs: &mut [i16],
620    tu_size: u32,
621    base_x: u32,
622    base_y: u32,
623    last_scan_pos: u32,
624    is_luma: bool,
625    is_last_subblock: bool,
626    sign_data_hiding_enabled: bool,
627) {
628    // Step 1: significance flags
629    let mut sig = [false; 16];
630    let mut num_sig = 0u32;
631
632    for scan_idx in (0..=last_scan_pos.min(15)).rev() {
633        if is_last_subblock && scan_idx == last_scan_pos {
634            // The last significant position is implicitly significant
635            sig[scan_idx as usize] = true;
636            num_sig += 1;
637            continue;
638        }
639        // Read sig_coeff_flag
640        let ctx_inc = sig_coeff_ctx_inc(scan_idx, is_luma);
641        let ctx_idx = (CTX_SIG_COEFF_FLAG + ctx_inc).min(state.contexts.len() - 1);
642        sig[scan_idx as usize] = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
643        if sig[scan_idx as usize] {
644            num_sig += 1;
645        }
646    }
647
648    if num_sig == 0 {
649        return;
650    }
651
652    // Step 2: greater1 and greater2 flags (up to 8 coefficients per sub-block)
653    let mut greater1 = [false; 16];
654    let mut greater2 = [false; 16];
655    let mut coeff_count = 0u32;
656    let max_greater1 = 8u32;
657
658    // Context set selection for greater1 (simplified)
659    let ctx_set = if is_luma { 0usize } else { 12 };
660
661    for scan_idx in (0..=last_scan_pos.min(15)).rev() {
662        if !sig[scan_idx as usize] {
663            continue;
664        }
665        if coeff_count < max_greater1 {
666            let ctx_inc = ctx_set + (coeff_count as usize).min(3);
667            let ctx_idx = (CTX_COEFF_ABS_LEVEL_GREATER1 + ctx_inc).min(state.contexts.len() - 1);
668            greater1[scan_idx as usize] = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
669        }
670        coeff_count += 1;
671    }
672
673    // greater2 flag: only for the first greater1 coefficient
674    let mut first_greater1_scan = None;
675    for scan_idx in (0..=last_scan_pos.min(15)).rev() {
676        if sig[scan_idx as usize] && greater1[scan_idx as usize] {
677            first_greater1_scan = Some(scan_idx);
678            break;
679        }
680    }
681    if let Some(scan_idx) = first_greater1_scan {
682        let ctx_inc = if is_luma { 0usize } else { 3 };
683        let ctx_idx = (CTX_COEFF_ABS_LEVEL_GREATER2 + ctx_inc).min(state.contexts.len() - 1);
684        greater2[scan_idx as usize] = state.cabac.decode_decision(&mut state.contexts[ctx_idx]);
685    }
686
687    // Step 3: signs (bypass coded)
688    let mut signs = [false; 16];
689    let mut num_hidden = 0u32;
690    for scan_idx in (0..=last_scan_pos.min(15)).rev() {
691        if sig[scan_idx as usize] {
692            num_hidden += 1;
693            // Sign data hiding: last sign may be inferred
694            let hide = sign_data_hiding_enabled
695                && num_sig > 1
696                && scan_idx == 0
697                && last_scan_pos - scan_idx > 3;
698            if !hide {
699                signs[scan_idx as usize] = state.cabac.decode_bypass();
700            }
701        }
702    }
703
704    // Step 4: remaining level (bypass-coded Exp-Golomb-Rice)
705    let mut abs_levels = [0i32; 16];
706    let mut rice_param = 0u32;
707
708    for scan_idx in (0..=last_scan_pos.min(15)).rev() {
709        if !sig[scan_idx as usize] {
710            continue;
711        }
712        let mut base_level = 1i32;
713        if greater1[scan_idx as usize] {
714            base_level += 1;
715        }
716        if greater2[scan_idx as usize] {
717            base_level += 1;
718        }
719
720        // coeff_abs_level_remaining is coded if the level exceeds the base
721        let needs_remaining = greater2[scan_idx as usize]
722            || (greater1[scan_idx as usize] && first_greater1_scan != Some(scan_idx));
723        let remaining =
724            if needs_remaining || (greater1[scan_idx as usize] && greater2[scan_idx as usize]) {
725                decode_coeff_abs_level_remaining(state, rice_param)
726            } else if !greater1[scan_idx as usize] {
727                0
728            } else {
729                0
730            };
731
732        let abs_val = base_level + remaining as i32;
733        abs_levels[scan_idx as usize] = abs_val;
734
735        // Update Rice parameter
736        if abs_val > (3i32 << rice_param) {
737            rice_param = (rice_param + 1).min(4);
738        }
739    }
740
741    // Ignore hidden sign count warning
742    let _ = num_hidden;
743
744    // Step 5: write coefficients to the output buffer
745    for scan_idx in 0..=last_scan_pos.min(15) {
746        if !sig[scan_idx as usize] {
747            continue;
748        }
749        let (lx, ly) = scan_to_local_pos(scan_idx);
750        let px = base_x + lx;
751        let py = base_y + ly;
752        if px < tu_size && py < tu_size {
753            let idx = (py * tu_size + px) as usize;
754            let val = abs_levels[scan_idx as usize] as i16;
755            coeffs[idx] = if signs[scan_idx as usize] { -val } else { val };
756        }
757    }
758}
759
760/// Decode `coeff_abs_level_remaining` using Exp-Golomb-Rice bypass coding
761/// (ITU-T H.265, 9.3.3.11).
762fn decode_coeff_abs_level_remaining(state: &mut HevcSliceCabacState<'_>, rice_param: u32) -> u32 {
763    // Count prefix ones (up to a max to avoid infinite loops)
764    let mut prefix = 0u32;
765    let max_prefix = 28u32; // safety limit
766    while prefix < max_prefix && state.cabac.decode_bypass() {
767        prefix += 1;
768    }
769
770    if prefix < 3 {
771        // Standard Rice coding: suffix has rice_param bits
772        let suffix = if rice_param > 0 {
773            state.cabac.decode_fl(rice_param)
774        } else {
775            0
776        };
777        (prefix << rice_param) + suffix
778    } else {
779        // Exp-Golomb extension: suffix has (prefix - 3 + rice_param) bits
780        let suffix_len = prefix - 3 + rice_param;
781        let suffix = state.cabac.decode_fl(suffix_len);
782        ((1u32 << suffix_len) - 1 + (3u32 << rice_param)).wrapping_add(suffix)
783    }
784}
785
786// ---------------------------------------------------------------------------
787// Scan-order helpers
788// ---------------------------------------------------------------------------
789
790/// Significance context increment for a given scan position within a 4x4
791/// sub-block. Simplified derivation from ITU-T H.265, Table 9-39.
792fn sig_coeff_ctx_inc(scan_idx: u32, is_luma: bool) -> usize {
793    let base = if is_luma { 0usize } else { 27 };
794    let inc = (scan_idx as usize).min(15);
795    // Map scan position to a context offset (simplified grouping)
796    let group = match inc {
797        0 => 0,
798        1..=4 => 1,
799        5..=8 => 2,
800        9..=12 => 3,
801        _ => 4,
802    };
803    base + group
804}
805
806/// Convert a sub-block scan index to raster (x, y) position within the
807/// sub-block grid.
808fn scan_idx_to_sub_pos(scan_idx: usize, num_sub_x: u32) -> (u32, u32) {
809    let num_sub = (num_sub_x * num_sub_x) as usize;
810    if num_sub <= 4 {
811        // 2x2 grid
812        let idx = if scan_idx < 4 {
813            SCAN_ORDER_2X2_DIAG[scan_idx] as u32
814        } else {
815            scan_idx as u32
816        };
817        (idx % num_sub_x, idx / num_sub_x)
818    } else if num_sub <= 16 {
819        // 4x4 grid
820        let idx = if scan_idx < 16 {
821            SCAN_ORDER_4X4_SUBBLOCK_DIAG[scan_idx] as u32
822        } else {
823            scan_idx as u32
824        };
825        (idx % num_sub_x, idx / num_sub_x)
826    } else {
827        // Larger: use raster fallback
828        let idx = scan_idx as u32;
829        (idx % num_sub_x, idx / num_sub_x)
830    }
831}
832
833/// Convert a raster sub-block position to a scan index (reverse of above).
834fn sub_pos_to_scan_idx(sub_x: u32, sub_y: u32, num_sub_x: u32) -> usize {
835    let raster = sub_y * num_sub_x + sub_x;
836    let num_sub = (num_sub_x * num_sub_x) as usize;
837    if num_sub <= 4 {
838        for (i, &s) in SCAN_ORDER_2X2_DIAG.iter().enumerate() {
839            if s as u32 == raster {
840                return i;
841            }
842        }
843        raster as usize
844    } else if num_sub <= 16 {
845        for (i, &s) in SCAN_ORDER_4X4_SUBBLOCK_DIAG.iter().enumerate() {
846            if s as u32 == raster {
847                return i;
848            }
849        }
850        raster as usize
851    } else {
852        raster as usize
853    }
854}
855
856/// Convert a scan index within a 4x4 sub-block to local (x, y) position.
857fn scan_to_local_pos(scan_idx: u32) -> (u32, u32) {
858    let idx = if (scan_idx as usize) < 16 {
859        SCAN_ORDER_4X4_DIAG[scan_idx as usize] as u32
860    } else {
861        scan_idx
862    };
863    (idx % 4, idx / 4)
864}
865
866/// Convert a local (x, y) within a 4x4 sub-block to a scan index.
867fn local_pos_to_scan_idx(lx: u32, ly: u32) -> u32 {
868    let raster = ly * 4 + lx;
869    for (i, &s) in SCAN_ORDER_4X4_DIAG.iter().enumerate() {
870        if s as u32 == raster {
871            return i as u32;
872        }
873    }
874    raster
875}
876
877// ---------------------------------------------------------------------------
878// Coding tree integration
879// ---------------------------------------------------------------------------
880
881/// Recursively decode a coding tree using CABAC, producing decoded CU leaves.
882///
883/// This replaces the stub `decode_coding_tree` in `hevc_decoder.rs` with
884/// actual CABAC-driven split decisions and CU parsing.
885pub fn decode_coding_tree_cabac(
886    state: &mut HevcSliceCabacState<'_>,
887    x: usize,
888    y: usize,
889    log2_cu_size: u8,
890    depth: u8,
891    max_depth: u8,
892    sps: &HevcSps,
893    pps: &HevcPps,
894    slice_type: HevcSliceType,
895    pic_width: usize,
896    pic_height: usize,
897    recon_luma: &mut Vec<i16>,
898    results: &mut Vec<super::hevc_decoder::DecodedCu>,
899    dpb: &super::hevc_inter::HevcDpb,
900    mv_field: &mut Vec<HevcMvField>,
901) {
902    let cu_size = 1usize << log2_cu_size;
903
904    // Out of picture bounds — skip
905    if x >= pic_width || y >= pic_height {
906        return;
907    }
908
909    // Decide whether to split
910    let can_split = depth < max_depth && cu_size > (1usize << sps.log2_min_cb_size);
911    let must_split = cu_size > 64; // CTU is at most 64x64
912
913    let should_split = if must_split {
914        true
915    } else if can_split {
916        let left_avail = x > 0;
917        let above_avail = y > 0;
918        parse_split_cu_flag(state, depth, left_avail, above_avail)
919    } else {
920        false
921    };
922
923    if should_split {
924        let half = log2_cu_size - 1;
925        let half_size = 1usize << half;
926        let nd = depth + 1;
927        decode_coding_tree_cabac(
928            state, x, y, half, nd, max_depth, sps, pps, slice_type, pic_width, pic_height,
929            recon_luma, results, dpb, mv_field,
930        );
931        decode_coding_tree_cabac(
932            state,
933            x + half_size,
934            y,
935            half,
936            nd,
937            max_depth,
938            sps,
939            pps,
940            slice_type,
941            pic_width,
942            pic_height,
943            recon_luma,
944            results,
945            dpb,
946            mv_field,
947        );
948        decode_coding_tree_cabac(
949            state,
950            x,
951            y + half_size,
952            half,
953            nd,
954            max_depth,
955            sps,
956            pps,
957            slice_type,
958            pic_width,
959            pic_height,
960            recon_luma,
961            results,
962            dpb,
963            mv_field,
964        );
965        decode_coding_tree_cabac(
966            state,
967            x + half_size,
968            y + half_size,
969            half,
970            nd,
971            max_depth,
972            sps,
973            pps,
974            slice_type,
975            pic_width,
976            pic_height,
977            recon_luma,
978            results,
979            dpb,
980            mv_field,
981        );
982    } else {
983        // Leaf CU — parse prediction/residual via CABAC
984        let cu_data = parse_coding_unit(state, x, y, log2_cu_size as u32, sps, pps, slice_type);
985
986        let actual_w = cu_size.min(pic_width.saturating_sub(x));
987        let actual_h = cu_size.min(pic_height.saturating_sub(y));
988
989        // Intra prediction (from reconstructed neighbours in recon_luma)
990        let mut pred = vec![0i16; cu_size * cu_size];
991        if cu_data.pred_mode == HevcPredMode::Intra {
992            // Build top and left reference samples
993            let top = build_top_ref(recon_luma, x, y, cu_size, pic_width);
994            let left = build_left_ref(recon_luma, x, y, cu_size, pic_width, pic_height);
995
996            match cu_data.intra_mode_luma {
997                0 => {
998                    let top_right = if x + cu_size < pic_width && y > 0 {
999                        recon_luma[(y - 1) * pic_width + x + cu_size]
1000                    } else {
1001                        *top.last().unwrap_or(&128)
1002                    };
1003                    let bottom_left = if y + cu_size < pic_height && x > 0 {
1004                        recon_luma[(y + cu_size) * pic_width + x - 1]
1005                    } else {
1006                        *left.last().unwrap_or(&128)
1007                    };
1008                    super::hevc_decoder::intra_predict_planar(
1009                        &top,
1010                        &left,
1011                        top_right,
1012                        bottom_left,
1013                        cu_size,
1014                        &mut pred,
1015                    );
1016                }
1017                1 => {
1018                    super::hevc_decoder::intra_predict_dc(&top, &left, cu_size, &mut pred);
1019                }
1020                m @ 2..=34 => {
1021                    super::hevc_decoder::intra_predict_angular(&top, &left, m, cu_size, &mut pred);
1022                }
1023                _ => {
1024                    // Fallback DC
1025                    super::hevc_decoder::intra_predict_dc(&top, &left, cu_size, &mut pred);
1026                }
1027            }
1028        } else {
1029            // Inter/Skip: parse inter prediction data and motion compensate
1030            // from DPB reference frames.
1031            let min_pu = 4usize;
1032            let pic_w_pu = pic_width.div_ceil(min_pu);
1033            let inter_mv =
1034                parse_inter_prediction(state, sps, slice_type, mv_field, pic_w_pu, x, y, cu_size);
1035
1036            // Store MV in the picture-wide MV field for future merge candidates
1037            let pu_x = x / min_pu;
1038            let pu_y = y / min_pu;
1039            let pu_w = cu_size / min_pu;
1040            for py in 0..pu_w {
1041                for px in 0..pu_w {
1042                    let idx = (pu_y + py) * pic_w_pu + (pu_x + px);
1043                    if idx < mv_field.len() {
1044                        mv_field[idx] = inter_mv;
1045                    }
1046                }
1047            }
1048
1049            // Motion compensate from DPB reference
1050            let ref_poc = inter_mv.ref_idx[0] as i32; // L0 reference POC
1051            if let Some(ref_pic) = dpb.get_by_poc(ref_poc) {
1052                super::hevc_inter::hevc_mc_luma(
1053                    ref_pic,
1054                    x as i32,
1055                    y as i32,
1056                    inter_mv.mv[0],
1057                    cu_size,
1058                    cu_size,
1059                    &mut pred,
1060                );
1061            } else {
1062                // No reference available — fall back to mid-grey
1063                for v in pred.iter_mut() {
1064                    *v = 128;
1065                }
1066            }
1067        }
1068
1069        // Add residual to prediction
1070        let mut recon = vec![0i16; cu_size * cu_size];
1071        for i in 0..cu_size * cu_size {
1072            let r = if i < cu_data.residual_luma.len() {
1073                cu_data.residual_luma[i] as i32
1074            } else {
1075                0
1076            };
1077            recon[i] = (pred[i] as i32 + r).clamp(0, 255) as i16;
1078        }
1079
1080        // Write reconstructed samples back to the picture buffer
1081        for row in 0..actual_h {
1082            for col in 0..actual_w {
1083                let py = y + row;
1084                let px = x + col;
1085                if py < pic_height && px < pic_width {
1086                    recon_luma[py * pic_width + px] = recon[row * cu_size + col];
1087                }
1088            }
1089        }
1090
1091        results.push(super::hevc_decoder::DecodedCu {
1092            x,
1093            y,
1094            size: cu_size,
1095            pred_mode: cu_data.pred_mode,
1096            recon_luma: recon,
1097        });
1098    }
1099}
1100
1101// ---------------------------------------------------------------------------
1102// Reference sample helpers
1103// ---------------------------------------------------------------------------
1104
1105/// Build the top reference row for intra prediction.
1106fn build_top_ref(
1107    recon: &[i16],
1108    x: usize,
1109    y: usize,
1110    block_size: usize,
1111    pic_width: usize,
1112) -> Vec<i16> {
1113    let mut top = vec![128i16; block_size];
1114    if y > 0 {
1115        for i in 0..block_size {
1116            let px = x + i;
1117            if px < pic_width {
1118                top[i] = recon[(y - 1) * pic_width + px];
1119            }
1120        }
1121    }
1122    top
1123}
1124
1125/// Build the left reference column for intra prediction.
1126fn build_left_ref(
1127    recon: &[i16],
1128    x: usize,
1129    y: usize,
1130    block_size: usize,
1131    pic_width: usize,
1132    pic_height: usize,
1133) -> Vec<i16> {
1134    let mut left = vec![128i16; block_size];
1135    if x > 0 {
1136        for i in 0..block_size {
1137            let py = y + i;
1138            if py < pic_height {
1139                left[i] = recon[py * pic_width + x - 1];
1140            }
1141        }
1142    }
1143    left
1144}
1145
1146// ---------------------------------------------------------------------------
1147// Tests
1148// ---------------------------------------------------------------------------
1149
1150#[cfg(test)]
1151mod tests {
1152    use super::*;
1153
1154    /// Create a CABAC state from raw bytes with default QP 26.
1155    fn make_state(data: &[u8]) -> HevcSliceCabacState<'_> {
1156        HevcSliceCabacState::new(data, 26)
1157    }
1158
1159    /// Build a default SPS for testing.
1160    fn test_sps() -> HevcSps {
1161        HevcSps {
1162            sps_id: 0,
1163            vps_id: 0,
1164            max_sub_layers: 1,
1165            chroma_format_idc: 1,
1166            pic_width: 64,
1167            pic_height: 64,
1168            bit_depth_luma: 8,
1169            bit_depth_chroma: 8,
1170            log2_max_pic_order_cnt: 4,
1171            log2_min_cb_size: 3,
1172            log2_diff_max_min_cb_size: 3,
1173            log2_min_transform_size: 2,
1174            log2_diff_max_min_transform_size: 3,
1175            max_transform_hierarchy_depth_inter: 1,
1176            max_transform_hierarchy_depth_intra: 1,
1177            sample_adaptive_offset_enabled: false,
1178            pcm_enabled: false,
1179            num_short_term_ref_pic_sets: 0,
1180            long_term_ref_pics_present: false,
1181            sps_temporal_mvp_enabled: false,
1182            strong_intra_smoothing_enabled: false,
1183        }
1184    }
1185
1186    /// Build a default PPS for testing.
1187    fn test_pps() -> HevcPps {
1188        HevcPps {
1189            pps_id: 0,
1190            sps_id: 0,
1191            dependent_slice_segments_enabled: false,
1192            output_flag_present: false,
1193            num_extra_slice_header_bits: 0,
1194            sign_data_hiding_enabled: false,
1195            cabac_init_present: false,
1196            num_ref_idx_l0_default: 1,
1197            num_ref_idx_l1_default: 1,
1198            init_qp: 26,
1199            constrained_intra_pred: false,
1200            transform_skip_enabled: false,
1201            cu_qp_delta_enabled: false,
1202            cb_qp_offset: 0,
1203            cr_qp_offset: 0,
1204            deblocking_filter_override_enabled: false,
1205            deblocking_filter_disabled: true,
1206            loop_filter_across_slices_enabled: false,
1207            tiles_enabled: false,
1208            entropy_coding_sync_enabled: false,
1209        }
1210    }
1211
1212    // -- Context initialisation tests ----------------------------------------
1213
1214    #[test]
1215    fn cabac_state_context_count() {
1216        let data = [0u8; 16];
1217        let state = make_state(&data);
1218        assert_eq!(state.contexts.len(), NUM_CABAC_CONTEXTS);
1219    }
1220
1221    #[test]
1222    fn cabac_state_reinit_preserves_count() {
1223        let data = [0u8; 16];
1224        let mut state = make_state(&data);
1225        state.reinit_contexts(30);
1226        assert_eq!(state.contexts.len(), NUM_CABAC_CONTEXTS);
1227    }
1228
1229    // -- split_cu_flag tests -------------------------------------------------
1230
1231    #[test]
1232    fn split_cu_flag_deterministic() {
1233        let data = [0x00u8; 32];
1234        let mut state = make_state(&data);
1235        // Decode several split flags at different depths — should not panic
1236        // and should produce deterministic results.
1237        let r0 = parse_split_cu_flag(&mut state, 0, false, false);
1238        let r1 = parse_split_cu_flag(&mut state, 1, true, false);
1239        let r2 = parse_split_cu_flag(&mut state, 2, true, true);
1240        // Results are deterministic given the same input
1241        let _ = (r0, r1, r2);
1242    }
1243
1244    #[test]
1245    fn split_cu_flag_depth_clamp() {
1246        // Very high depth should still select a valid context (clamped to 2)
1247        let data = [0xFFu8; 32];
1248        let mut state = make_state(&data);
1249        let _ = parse_split_cu_flag(&mut state, 10, true, true);
1250    }
1251
1252    // -- Intra mode signalling tests -----------------------------------------
1253
1254    #[test]
1255    fn mpm_list_default_construction() {
1256        let mpm = build_default_mpm_list();
1257        assert_eq!(mpm.len(), 3);
1258        assert!(mpm.contains(&0)); // Planar
1259        assert!(mpm.contains(&1)); // DC
1260        assert!(mpm.contains(&26)); // Angular-26
1261    }
1262
1263    #[test]
1264    fn mpm_list_from_neighbours_equal_dc() {
1265        let mpm = build_mpm_list(1, 1);
1266        assert_eq!(mpm[0], 0); // Planar
1267        assert_eq!(mpm[1], 1); // DC
1268        assert_eq!(mpm[2], 26); // Angular-26
1269    }
1270
1271    #[test]
1272    fn mpm_list_from_neighbours_equal_angular() {
1273        let mpm = build_mpm_list(10, 10);
1274        assert_eq!(mpm[0], 10);
1275        // mpm[1] = 2 + ((10 + 29) % 32) = 2 + 7 = 9
1276        assert_eq!(mpm[1], 9);
1277        // mpm[2] = 2 + ((10 - 2 + 1) % 32) = 2 + 9 = 11
1278        assert_eq!(mpm[2], 11);
1279    }
1280
1281    #[test]
1282    fn mpm_list_from_neighbours_different() {
1283        let mpm = build_mpm_list(5, 10);
1284        assert_eq!(mpm[0], 5);
1285        assert_eq!(mpm[1], 10);
1286        assert_eq!(mpm[2], 0); // Planar (neither is 0)
1287    }
1288
1289    #[test]
1290    fn remap_rem_mode_skips_mpms() {
1291        let mpm = [0u8, 1, 26];
1292        // rem=0 should give mode 2 (skipping 0 and 1)
1293        let mode = remap_rem_mode(0, &mpm);
1294        assert_eq!(mode, 2);
1295        // rem=23 should skip modes 0 and 1: 23 -> 24 -> 25
1296        // (26 is in MPM but 25 < 26, so no further skip)
1297        let mode = remap_rem_mode(23, &mpm);
1298        assert_eq!(mode, 25);
1299        // rem=24 should skip modes 0, 1, and 26: 24 -> 25 -> 26 -> 27
1300        let mode = remap_rem_mode(24, &mpm);
1301        assert_eq!(mode, 27);
1302    }
1303
1304    #[test]
1305    fn remap_rem_mode_clamped() {
1306        let mpm = [0u8, 1, 2];
1307        // rem=31 => walks past 0,1,2 so mode = 34, clamped
1308        let mode = remap_rem_mode(31, &mpm);
1309        assert_eq!(mode, 34);
1310    }
1311
1312    // -- Residual coefficient parsing tests ----------------------------------
1313
1314    #[test]
1315    fn parse_tu_all_zero_stream() {
1316        // An all-zero stream should produce near-zero coefficients
1317        // (the last_sig_coeff position will be small or zero).
1318        let data = [0x00u8; 64];
1319        let mut state = make_state(&data);
1320        let coeffs = parse_transform_unit(&mut state, 2, true, false);
1321        assert_eq!(coeffs.len(), 16); // 4x4
1322    }
1323
1324    #[test]
1325    fn parse_tu_all_ones_stream() {
1326        // An all-ones stream exercises the "prefix keeps incrementing" path.
1327        let data = [0xFFu8; 128];
1328        let mut state = make_state(&data);
1329        let coeffs = parse_transform_unit(&mut state, 2, true, false);
1330        assert_eq!(coeffs.len(), 16);
1331    }
1332
1333    #[test]
1334    fn parse_tu_8x8_size() {
1335        let data = [0x55u8; 128];
1336        let mut state = make_state(&data);
1337        let coeffs = parse_transform_unit(&mut state, 3, true, false);
1338        assert_eq!(coeffs.len(), 64); // 8x8
1339    }
1340
1341    #[test]
1342    fn parse_tu_chroma() {
1343        let data = [0xAAu8; 64];
1344        let mut state = make_state(&data);
1345        let coeffs = parse_transform_unit(&mut state, 2, false, false);
1346        assert_eq!(coeffs.len(), 16);
1347    }
1348
1349    // -- Full CU parsing tests -----------------------------------------------
1350
1351    #[test]
1352    fn parse_cu_intra_i_slice() {
1353        let data = [0x00u8; 128];
1354        let mut state = make_state(&data);
1355        let sps = test_sps();
1356        let pps = test_pps();
1357        let cu = parse_coding_unit(&mut state, 0, 0, 3, &sps, &pps, HevcSliceType::I);
1358        assert_eq!(cu.pred_mode, HevcPredMode::Intra);
1359        assert!(cu.intra_mode_luma <= 34);
1360    }
1361
1362    #[test]
1363    fn parse_cu_p_slice_may_skip() {
1364        // In a P slice, the first bin decoded is cu_skip_flag.
1365        let data = [0xFFu8; 128];
1366        let mut state = make_state(&data);
1367        let sps = test_sps();
1368        let pps = test_pps();
1369        let cu = parse_coding_unit(&mut state, 0, 0, 3, &sps, &pps, HevcSliceType::P);
1370        // Should produce a valid result (skip, intra, or inter)
1371        assert!(matches!(
1372            cu.pred_mode,
1373            HevcPredMode::Intra | HevcPredMode::Inter | HevcPredMode::Skip
1374        ));
1375    }
1376
1377    // -- Coding tree integration tests ---------------------------------------
1378
1379    #[test]
1380    fn coding_tree_cabac_produces_cus() {
1381        let data = [0x00u8; 512];
1382        let mut state = make_state(&data);
1383        let sps = test_sps();
1384        let pps = test_pps();
1385        let mut recon = vec![128i16; 64 * 64];
1386        let mut results = Vec::new();
1387
1388        let dpb = crate::hevc_inter::HevcDpb::new(16);
1389        let mut mv_field = vec![crate::hevc_inter::HevcMvField::unavailable(); 16 * 16];
1390        decode_coding_tree_cabac(
1391            &mut state,
1392            0,
1393            0,
1394            6, // 64x64 CTU
1395            0,
1396            3, // max_depth
1397            &sps,
1398            &pps,
1399            HevcSliceType::I,
1400            64,
1401            64,
1402            &mut recon,
1403            &mut results,
1404            &dpb,
1405            &mut mv_field,
1406        );
1407        // Should produce at least one CU
1408        assert!(!results.is_empty());
1409        // All CUs should be within picture bounds
1410        for cu in &results {
1411            assert!(cu.x < 64);
1412            assert!(cu.y < 64);
1413        }
1414    }
1415
1416    #[test]
1417    fn coding_tree_cabac_boundary() {
1418        // 48x48 picture with 64x64 CTU — boundary clipping
1419        let data = [0x55u8; 512];
1420        let mut state = make_state(&data);
1421        let sps = test_sps();
1422        let pps = test_pps();
1423        let mut recon = vec![128i16; 48 * 48];
1424        let mut results = Vec::new();
1425        let dpb = crate::hevc_inter::HevcDpb::new(16);
1426        let mut mv_field = vec![crate::hevc_inter::HevcMvField::unavailable(); 12 * 12];
1427
1428        decode_coding_tree_cabac(
1429            &mut state,
1430            0,
1431            0,
1432            6,
1433            0,
1434            3,
1435            &sps,
1436            &pps,
1437            HevcSliceType::I,
1438            48,
1439            48,
1440            &mut recon,
1441            &mut results,
1442            &dpb,
1443            &mut mv_field,
1444        );
1445        assert!(!results.is_empty());
1446        for cu in &results {
1447            assert!(cu.x < 48);
1448            assert!(cu.y < 48);
1449        }
1450    }
1451
1452    // -- Scan order tests ----------------------------------------------------
1453
1454    #[test]
1455    fn scan_4x4_roundtrip() {
1456        // Every position 0..15 should appear exactly once in the diagonal scan.
1457        let mut seen = [false; 16];
1458        for &s in &SCAN_ORDER_4X4_DIAG {
1459            assert!(!seen[s as usize], "duplicate in scan order");
1460            seen[s as usize] = true;
1461        }
1462        assert!(seen.iter().all(|&v| v));
1463    }
1464
1465    #[test]
1466    fn scan_to_local_roundtrip() {
1467        for scan_idx in 0..16u32 {
1468            let (lx, ly) = scan_to_local_pos(scan_idx);
1469            let back = local_pos_to_scan_idx(lx, ly);
1470            assert_eq!(back, scan_idx, "roundtrip failed for scan_idx={scan_idx}");
1471        }
1472    }
1473
1474    #[test]
1475    fn sub_pos_scan_roundtrip_2x2() {
1476        for scan_idx in 0..4usize {
1477            let (sx, sy) = scan_idx_to_sub_pos(scan_idx, 2);
1478            let back = sub_pos_to_scan_idx(sx, sy, 2);
1479            assert_eq!(back, scan_idx, "2x2 roundtrip failed at {scan_idx}");
1480        }
1481    }
1482
1483    #[test]
1484    fn sub_pos_scan_roundtrip_4x4() {
1485        for scan_idx in 0..16usize {
1486            let (sx, sy) = scan_idx_to_sub_pos(scan_idx, 4);
1487            let back = sub_pos_to_scan_idx(sx, sy, 4);
1488            assert_eq!(back, scan_idx, "4x4 roundtrip failed at {scan_idx}");
1489        }
1490    }
1491}