1use crate::h264_motion::{MotionVector, motion_compensate_16x16, parse_mvd, predict_mv};
5use crate::{BitstreamReader, VideoError};
6
7#[derive(Debug, Clone, Copy, Default)]
13pub struct BiMotionVector {
14 pub forward: MotionVector,
16 pub backward: MotionVector,
18 pub mode: BPredMode,
20}
21
22#[derive(Debug, Clone, Copy, Default, PartialEq)]
24pub enum BPredMode {
25 #[default]
27 Forward,
28 Backward,
30 BiPred,
32 Direct,
34}
35
36#[allow(clippy::too_many_arguments)]
43pub fn motion_compensate_bipred(
44 ref_fwd: &[u8],
45 ref_bwd: &[u8],
46 width: usize,
47 height: usize,
48 channels: usize,
49 bi_mv: &BiMotionVector,
50 mb_x: usize,
51 mb_y: usize,
52 output: &mut [u8],
53 out_width: usize,
54) {
55 match bi_mv.mode {
56 BPredMode::Forward => {
57 motion_compensate_16x16(
58 ref_fwd,
59 width,
60 height,
61 channels,
62 bi_mv.forward,
63 mb_x,
64 mb_y,
65 output,
66 out_width,
67 );
68 }
69 BPredMode::Backward => {
70 motion_compensate_16x16(
71 ref_bwd,
72 width,
73 height,
74 channels,
75 bi_mv.backward,
76 mb_x,
77 mb_y,
78 output,
79 out_width,
80 );
81 }
82 BPredMode::BiPred | BPredMode::Direct => {
83 let mut fwd_block = vec![0u8; 16 * 16 * channels];
85 let mut bwd_block = vec![0u8; 16 * 16 * channels];
86
87 motion_compensate_16x16(
88 ref_fwd,
89 width,
90 height,
91 channels,
92 bi_mv.forward,
93 mb_x,
94 mb_y,
95 &mut fwd_block,
96 16,
97 );
98 motion_compensate_16x16(
99 ref_bwd,
100 width,
101 height,
102 channels,
103 bi_mv.backward,
104 mb_x,
105 mb_y,
106 &mut bwd_block,
107 16,
108 );
109
110 for row in 0..16 {
112 for col in 0..16 {
113 let dst_y = mb_y * 16 + row;
114 let dst_x = mb_x * 16 + col;
115 for c in 0..channels {
116 let f = fwd_block[(row * 16 + col) * channels + c] as u16;
117 let b = bwd_block[(row * 16 + col) * channels + c] as u16;
118 let dst_idx = (dst_y * out_width + dst_x) * channels + c;
119 if dst_idx < output.len() {
120 output[dst_idx] = (f + b).div_ceil(2) as u8;
121 }
122 }
123 }
124 }
125 }
126 }
127}
128
129#[allow(clippy::too_many_arguments)]
139pub fn decode_b_macroblock(
140 reader: &mut BitstreamReader,
141 ref_fwd: &[u8],
142 ref_bwd: &[u8],
143 width: usize,
144 height: usize,
145 mb_x: usize,
146 mb_y: usize,
147 neighbor_mvs_fwd: &[MotionVector],
148 neighbor_mvs_bwd: &[MotionVector],
149 output: &mut [u8],
150 out_width: usize,
151) -> Result<BiMotionVector, VideoError> {
152 let mb_type = reader.read_ue()?;
154 let mode = match mb_type {
155 0 => BPredMode::Direct,
156 1 => BPredMode::Forward,
157 2 => BPredMode::Backward,
158 _ => BPredMode::BiPred,
159 };
160
161 let forward = if mode == BPredMode::Forward || mode == BPredMode::BiPred {
163 let (mvd_x, mvd_y) = parse_mvd(reader)?;
164 let predicted = predict_mv(
165 neighbor_mvs_fwd.first().copied().unwrap_or_default(),
166 neighbor_mvs_fwd.get(1).copied().unwrap_or_default(),
167 neighbor_mvs_fwd.get(2).copied().unwrap_or_default(),
168 );
169 MotionVector {
170 dx: predicted.dx + mvd_x,
171 dy: predicted.dy + mvd_y,
172 ref_idx: 0,
173 }
174 } else {
175 MotionVector::default()
176 };
177
178 let backward = if mode == BPredMode::Backward || mode == BPredMode::BiPred {
179 let (mvd_x, mvd_y) = parse_mvd(reader)?;
180 let predicted = predict_mv(
181 neighbor_mvs_bwd.first().copied().unwrap_or_default(),
182 neighbor_mvs_bwd.get(1).copied().unwrap_or_default(),
183 neighbor_mvs_bwd.get(2).copied().unwrap_or_default(),
184 );
185 MotionVector {
186 dx: predicted.dx + mvd_x,
187 dy: predicted.dy + mvd_y,
188 ref_idx: 0,
189 }
190 } else {
191 MotionVector::default()
192 };
193
194 let bi_mv = BiMotionVector {
195 forward,
196 backward,
197 mode,
198 };
199
200 let channels = 3;
202 motion_compensate_bipred(
203 ref_fwd, ref_bwd, width, height, channels, &bi_mv, mb_x, mb_y, output, out_width,
204 );
205
206 Ok(bi_mv)
207}
208
209#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn bipred_averages_references() {
219 let width = 32;
221 let height = 32;
222 let channels = 1;
223 let ref_fwd = vec![100u8; width * height * channels];
224 let ref_bwd = vec![200u8; width * height * channels];
225
226 let bi_mv = BiMotionVector {
227 forward: MotionVector::default(),
228 backward: MotionVector::default(),
229 mode: BPredMode::BiPred,
230 };
231
232 let mut output = vec![0u8; width * height * channels];
233 motion_compensate_bipred(
234 &ref_fwd,
235 &ref_bwd,
236 width,
237 height,
238 channels,
239 &bi_mv,
240 0,
241 0,
242 &mut output,
243 width,
244 );
245
246 for row in 0..16 {
248 for col in 0..16 {
249 assert_eq!(
250 output[row * width + col],
251 150,
252 "bipred average mismatch at ({row}, {col})"
253 );
254 }
255 }
256 }
257
258 #[test]
259 fn bpred_forward_only() {
260 let width = 32;
261 let height = 32;
262 let channels = 1;
263 let ref_fwd = vec![42u8; width * height * channels];
264 let ref_bwd = vec![200u8; width * height * channels];
265
266 let bi_mv = BiMotionVector {
267 forward: MotionVector::default(),
268 backward: MotionVector::default(),
269 mode: BPredMode::Forward,
270 };
271
272 let mut output = vec![0u8; width * height * channels];
273 motion_compensate_bipred(
274 &ref_fwd,
275 &ref_bwd,
276 width,
277 height,
278 channels,
279 &bi_mv,
280 0,
281 0,
282 &mut output,
283 width,
284 );
285
286 for row in 0..16 {
288 for col in 0..16 {
289 assert_eq!(
290 output[row * width + col],
291 42,
292 "forward-only mismatch at ({row}, {col})"
293 );
294 }
295 }
296 }
297
298 #[test]
299 fn bpred_backward_only() {
300 let width = 32;
301 let height = 32;
302 let channels = 1;
303 let ref_fwd = vec![42u8; width * height * channels];
304 let ref_bwd = vec![77u8; width * height * channels];
305
306 let bi_mv = BiMotionVector {
307 forward: MotionVector::default(),
308 backward: MotionVector::default(),
309 mode: BPredMode::Backward,
310 };
311
312 let mut output = vec![0u8; width * height * channels];
313 motion_compensate_bipred(
314 &ref_fwd,
315 &ref_bwd,
316 width,
317 height,
318 channels,
319 &bi_mv,
320 0,
321 0,
322 &mut output,
323 width,
324 );
325
326 for row in 0..16 {
328 for col in 0..16 {
329 assert_eq!(
330 output[row * width + col],
331 77,
332 "backward-only mismatch at ({row}, {col})"
333 );
334 }
335 }
336 }
337}