Skip to main content

yscv_video/
h264_bslice.rs

1//! H.264 B-slice decoding: bi-directional motion vectors, prediction modes,
2//! and bi-predictive motion compensation.
3
4use crate::h264_motion::{MotionVector, motion_compensate_16x16, parse_mvd, predict_mv};
5use crate::{BitstreamReader, VideoError};
6
7// ---------------------------------------------------------------------------
8// Bi-directional motion vector types
9// ---------------------------------------------------------------------------
10
11/// Bi-directional motion vectors for B-slice macroblocks.
12#[derive(Debug, Clone, Copy, Default)]
13pub struct BiMotionVector {
14    /// Motion vector from past (forward) reference.
15    pub forward: MotionVector,
16    /// Motion vector from future (backward) reference.
17    pub backward: MotionVector,
18    /// Prediction mode for this macroblock.
19    pub mode: BPredMode,
20}
21
22/// B-slice prediction mode.
23#[derive(Debug, Clone, Copy, Default, PartialEq)]
24pub enum BPredMode {
25    /// Use only forward (past) reference.
26    #[default]
27    Forward,
28    /// Use only backward (future) reference.
29    Backward,
30    /// Average of forward and backward predictions.
31    BiPred,
32    /// Derive MVs from co-located MB in future reference.
33    Direct,
34}
35
36// ---------------------------------------------------------------------------
37// Bi-directional motion compensation
38// ---------------------------------------------------------------------------
39
40/// Bi-directional motion compensation: dispatches to the correct mode and
41/// averages forward and backward predictions when needed.
42#[allow(clippy::too_many_arguments)]
43pub fn motion_compensate_bipred(
44    ref_fwd: &[u8],
45    ref_bwd: &[u8],
46    width: usize,
47    height: usize,
48    channels: usize,
49    bi_mv: &BiMotionVector,
50    mb_x: usize,
51    mb_y: usize,
52    output: &mut [u8],
53    out_width: usize,
54) {
55    match bi_mv.mode {
56        BPredMode::Forward => {
57            motion_compensate_16x16(
58                ref_fwd,
59                width,
60                height,
61                channels,
62                bi_mv.forward,
63                mb_x,
64                mb_y,
65                output,
66                out_width,
67            );
68        }
69        BPredMode::Backward => {
70            motion_compensate_16x16(
71                ref_bwd,
72                width,
73                height,
74                channels,
75                bi_mv.backward,
76                mb_x,
77                mb_y,
78                output,
79                out_width,
80            );
81        }
82        BPredMode::BiPred | BPredMode::Direct => {
83            // Allocate temporary buffers for each prediction direction.
84            let mut fwd_block = vec![0u8; 16 * 16 * channels];
85            let mut bwd_block = vec![0u8; 16 * 16 * channels];
86
87            motion_compensate_16x16(
88                ref_fwd,
89                width,
90                height,
91                channels,
92                bi_mv.forward,
93                mb_x,
94                mb_y,
95                &mut fwd_block,
96                16,
97            );
98            motion_compensate_16x16(
99                ref_bwd,
100                width,
101                height,
102                channels,
103                bi_mv.backward,
104                mb_x,
105                mb_y,
106                &mut bwd_block,
107                16,
108            );
109
110            // Average the two predictions with rounding.
111            for row in 0..16 {
112                for col in 0..16 {
113                    let dst_y = mb_y * 16 + row;
114                    let dst_x = mb_x * 16 + col;
115                    for c in 0..channels {
116                        let f = fwd_block[(row * 16 + col) * channels + c] as u16;
117                        let b = bwd_block[(row * 16 + col) * channels + c] as u16;
118                        let dst_idx = (dst_y * out_width + dst_x) * channels + c;
119                        if dst_idx < output.len() {
120                            output[dst_idx] = (f + b).div_ceil(2) as u8;
121                        }
122                    }
123                }
124            }
125        }
126    }
127}
128
129// ---------------------------------------------------------------------------
130// B-slice macroblock decoder
131// ---------------------------------------------------------------------------
132
133/// Decode a B-slice macroblock: parse mb_type and motion vectors, then apply
134/// motion compensation.
135///
136/// Returns the decoded `BiMotionVector` so the caller can store it for
137/// neighboring-block prediction of subsequent macroblocks.
138#[allow(clippy::too_many_arguments)]
139pub fn decode_b_macroblock(
140    reader: &mut BitstreamReader,
141    ref_fwd: &[u8],
142    ref_bwd: &[u8],
143    width: usize,
144    height: usize,
145    mb_x: usize,
146    mb_y: usize,
147    neighbor_mvs_fwd: &[MotionVector],
148    neighbor_mvs_bwd: &[MotionVector],
149    output: &mut [u8],
150    out_width: usize,
151) -> Result<BiMotionVector, VideoError> {
152    // 1. Parse mb_type to determine prediction mode.
153    let mb_type = reader.read_ue()?;
154    let mode = match mb_type {
155        0 => BPredMode::Direct,
156        1 => BPredMode::Forward,
157        2 => BPredMode::Backward,
158        _ => BPredMode::BiPred,
159    };
160
161    // 2. Parse motion vector differences and predict final MVs.
162    let forward = if mode == BPredMode::Forward || mode == BPredMode::BiPred {
163        let (mvd_x, mvd_y) = parse_mvd(reader)?;
164        let predicted = predict_mv(
165            neighbor_mvs_fwd.first().copied().unwrap_or_default(),
166            neighbor_mvs_fwd.get(1).copied().unwrap_or_default(),
167            neighbor_mvs_fwd.get(2).copied().unwrap_or_default(),
168        );
169        MotionVector {
170            dx: predicted.dx + mvd_x,
171            dy: predicted.dy + mvd_y,
172            ref_idx: 0,
173        }
174    } else {
175        MotionVector::default()
176    };
177
178    let backward = if mode == BPredMode::Backward || mode == BPredMode::BiPred {
179        let (mvd_x, mvd_y) = parse_mvd(reader)?;
180        let predicted = predict_mv(
181            neighbor_mvs_bwd.first().copied().unwrap_or_default(),
182            neighbor_mvs_bwd.get(1).copied().unwrap_or_default(),
183            neighbor_mvs_bwd.get(2).copied().unwrap_or_default(),
184        );
185        MotionVector {
186            dx: predicted.dx + mvd_x,
187            dy: predicted.dy + mvd_y,
188            ref_idx: 0,
189        }
190    } else {
191        MotionVector::default()
192    };
193
194    let bi_mv = BiMotionVector {
195        forward,
196        backward,
197        mode,
198    };
199
200    // 3. Apply motion compensation.
201    let channels = 3;
202    motion_compensate_bipred(
203        ref_fwd, ref_bwd, width, height, channels, &bi_mv, mb_x, mb_y, output, out_width,
204    );
205
206    Ok(bi_mv)
207}
208
209// ---------------------------------------------------------------------------
210// Tests
211// ---------------------------------------------------------------------------
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn bipred_averages_references() {
219        // Two 32x32 single-channel reference frames: one filled with 100, one with 200.
220        let width = 32;
221        let height = 32;
222        let channels = 1;
223        let ref_fwd = vec![100u8; width * height * channels];
224        let ref_bwd = vec![200u8; width * height * channels];
225
226        let bi_mv = BiMotionVector {
227            forward: MotionVector::default(),
228            backward: MotionVector::default(),
229            mode: BPredMode::BiPred,
230        };
231
232        let mut output = vec![0u8; width * height * channels];
233        motion_compensate_bipred(
234            &ref_fwd,
235            &ref_bwd,
236            width,
237            height,
238            channels,
239            &bi_mv,
240            0,
241            0,
242            &mut output,
243            width,
244        );
245
246        // Average of 100 and 200 with rounding = (100 + 200 + 1) / 2 = 150.
247        for row in 0..16 {
248            for col in 0..16 {
249                assert_eq!(
250                    output[row * width + col],
251                    150,
252                    "bipred average mismatch at ({row}, {col})"
253                );
254            }
255        }
256    }
257
258    #[test]
259    fn bpred_forward_only() {
260        let width = 32;
261        let height = 32;
262        let channels = 1;
263        let ref_fwd = vec![42u8; width * height * channels];
264        let ref_bwd = vec![200u8; width * height * channels];
265
266        let bi_mv = BiMotionVector {
267            forward: MotionVector::default(),
268            backward: MotionVector::default(),
269            mode: BPredMode::Forward,
270        };
271
272        let mut output = vec![0u8; width * height * channels];
273        motion_compensate_bipred(
274            &ref_fwd,
275            &ref_bwd,
276            width,
277            height,
278            channels,
279            &bi_mv,
280            0,
281            0,
282            &mut output,
283            width,
284        );
285
286        // Forward-only should use ref_fwd exclusively.
287        for row in 0..16 {
288            for col in 0..16 {
289                assert_eq!(
290                    output[row * width + col],
291                    42,
292                    "forward-only mismatch at ({row}, {col})"
293                );
294            }
295        }
296    }
297
298    #[test]
299    fn bpred_backward_only() {
300        let width = 32;
301        let height = 32;
302        let channels = 1;
303        let ref_fwd = vec![42u8; width * height * channels];
304        let ref_bwd = vec![77u8; width * height * channels];
305
306        let bi_mv = BiMotionVector {
307            forward: MotionVector::default(),
308            backward: MotionVector::default(),
309            mode: BPredMode::Backward,
310        };
311
312        let mut output = vec![0u8; width * height * channels];
313        motion_compensate_bipred(
314            &ref_fwd,
315            &ref_bwd,
316            width,
317            height,
318            channels,
319            &bi_mv,
320            0,
321            0,
322            &mut output,
323            width,
324        );
325
326        // Backward-only should use ref_bwd exclusively.
327        for row in 0..16 {
328            for col in 0..16 {
329                assert_eq!(
330                    output[row * width + col],
331                    77,
332                    "backward-only mismatch at ({row}, {col})"
333                );
334            }
335        }
336    }
337}