1use std::fmt::Debug;
7
8use super::{RawImageRectMut, Rect, internal::RawImageBuffer};
9
10pub(crate) enum BufferStorage<'a> {
16 Contiguous {
17 data: &'a mut [u8],
18 bytes_between_rows: usize,
19 },
20 #[allow(dead_code)] Fragmented { rows: Vec<&'a mut [u8]> },
22}
23
24pub struct JxlOutputBuffer<'a> {
25 storage: BufferStorage<'a>,
26 bytes_per_row: usize,
27 num_rows: usize,
28 row_offset: usize,
31}
32
33impl<'a> JxlOutputBuffer<'a> {
34 pub fn from_image_rect_mut(raw: RawImageRectMut<'a>) -> Self {
35 Self {
36 storage: BufferStorage::Contiguous {
37 data: raw.storage,
38 bytes_between_rows: raw.bytes_between_rows,
39 },
40 bytes_per_row: raw.bytes_per_row,
41 num_rows: raw.num_rows,
42 row_offset: 0,
43 }
44 }
45
46 pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self {
48 RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_per_row);
49 let expected_len = if num_rows == 0 {
50 0
51 } else {
52 (num_rows - 1) * bytes_per_row + bytes_per_row
53 };
54 assert!(buf.len() >= expected_len);
55 Self {
56 storage: BufferStorage::Contiguous {
57 data: if expected_len == 0 {
58 &mut []
59 } else {
60 &mut buf[..expected_len]
61 },
62 bytes_between_rows: bytes_per_row,
63 },
64 bytes_per_row,
65 num_rows,
66 row_offset: 0,
67 }
68 }
69
70 #[cfg(feature = "allow-unsafe")]
80 #[allow(unsafe_code)]
81 pub unsafe fn new_from_ptr(
82 buf: *mut std::mem::MaybeUninit<u8>,
83 num_rows: usize,
84 bytes_per_row: usize,
85 bytes_between_rows: usize,
86 ) -> Self {
87 RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_between_rows);
88 let total_len = if num_rows == 0 {
89 0
90 } else {
91 (num_rows - 1) * bytes_between_rows + bytes_per_row
92 };
93 let data = if total_len == 0 {
94 &mut []
95 } else {
96 unsafe { std::slice::from_raw_parts_mut(buf as *mut u8, total_len) }
99 };
100 Self {
101 storage: BufferStorage::Contiguous {
102 data,
103 bytes_between_rows,
104 },
105 bytes_per_row,
106 num_rows,
107 row_offset: 0,
108 }
109 }
110
111 pub(crate) fn reborrow(lender: &'a mut JxlOutputBuffer<'_>) -> JxlOutputBuffer<'a> {
112 JxlOutputBuffer {
113 storage: match &mut lender.storage {
114 BufferStorage::Contiguous {
115 data,
116 bytes_between_rows,
117 } => BufferStorage::Contiguous {
118 data,
119 bytes_between_rows: *bytes_between_rows,
120 },
121 BufferStorage::Fragmented { rows } => BufferStorage::Fragmented {
122 rows: rows.iter_mut().map(|r| &mut **r).collect(),
123 },
124 },
125 bytes_per_row: lender.bytes_per_row,
126 num_rows: lender.num_rows,
127 row_offset: lender.row_offset,
128 }
129 }
130
131 #[inline]
136 pub(crate) fn row_mut(&mut self, row: usize) -> &mut [u8] {
137 let local_row = row.wrapping_sub(self.row_offset);
138 assert!(
139 local_row < self.num_rows,
140 "row {row} out of range [{}, {})",
141 self.row_offset,
142 self.row_offset + self.num_rows,
143 );
144 match &mut self.storage {
145 BufferStorage::Contiguous {
146 data,
147 bytes_between_rows,
148 } => {
149 let start = local_row * *bytes_between_rows;
150 &mut data[start..start + self.bytes_per_row]
151 }
152 BufferStorage::Fragmented { rows } => &mut rows[local_row][..self.bytes_per_row],
153 }
154 }
155
156 #[inline]
157 pub fn write_bytes(&mut self, row: usize, col: usize, bytes: &[u8]) {
158 let slice = self.row_mut(row);
159 slice[col..col + bytes.len()].copy_from_slice(bytes);
160 }
161
162 pub fn byte_size(&self) -> (usize, usize) {
163 (self.bytes_per_row, self.num_rows)
164 }
165
166 #[cfg(feature = "threads")]
179 #[allow(dead_code)] pub(crate) fn split_into_row_bands(
181 &mut self,
182 split_rows: &[usize],
183 ) -> Vec<JxlOutputBuffer<'_>> {
184 let bpr = self.bytes_per_row;
185 let nrows = self.num_rows;
186 let base_offset = self.row_offset;
187
188 match &mut self.storage {
189 BufferStorage::Contiguous {
190 data,
191 bytes_between_rows,
192 } => {
193 let btr = *bytes_between_rows;
194 let mut result = Vec::with_capacity(split_rows.len() + 1);
195 let mut remaining: &mut [u8] = data;
196 let mut current_row = 0;
197
198 for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
199 assert!(
200 split_row >= current_row && split_row <= nrows,
201 "split_rows must be sorted and <= num_rows"
202 );
203 let band_rows = split_row - current_row;
204
205 if band_rows == 0 {
206 result.push(JxlOutputBuffer {
207 storage: BufferStorage::Contiguous {
208 data: &mut [],
209 bytes_between_rows: btr,
210 },
211 bytes_per_row: bpr,
212 num_rows: 0,
213 row_offset: base_offset + current_row,
214 });
215 } else {
216 let span = (band_rows - 1) * btr + bpr;
217 if split_row < nrows {
218 let total_bytes = band_rows * btr;
219 let tmp = remaining;
220 let (band_full, rest) = tmp.split_at_mut(total_bytes);
221 result.push(JxlOutputBuffer {
222 storage: BufferStorage::Contiguous {
223 data: &mut band_full[..span],
224 bytes_between_rows: btr,
225 },
226 bytes_per_row: bpr,
227 num_rows: band_rows,
228 row_offset: base_offset + current_row,
229 });
230 remaining = rest;
231 } else {
232 let tmp = remaining;
233 let (band, _) = tmp.split_at_mut(span);
234 result.push(JxlOutputBuffer {
235 storage: BufferStorage::Contiguous {
236 data: band,
237 bytes_between_rows: btr,
238 },
239 bytes_per_row: bpr,
240 num_rows: band_rows,
241 row_offset: base_offset + current_row,
242 });
243 remaining = &mut [];
244 }
245 }
246 current_row = split_row;
247 }
248 result
249 }
250 BufferStorage::Fragmented { rows } => {
251 let mut result = Vec::with_capacity(split_rows.len() + 1);
252 let mut remaining: &mut [&'_ mut [u8]] = rows;
253 let mut current_row = 0;
254
255 for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
256 assert!(
257 split_row >= current_row && split_row <= nrows,
258 "split_rows must be sorted and <= num_rows"
259 );
260 let band_rows = split_row - current_row;
261
262 let tmp = remaining;
263 let (band_slice, rest) = tmp.split_at_mut(band_rows);
264 result.push(JxlOutputBuffer {
265 storage: BufferStorage::Fragmented {
266 rows: band_slice.iter_mut().map(|r| &mut **r).collect(),
267 },
268 bytes_per_row: bpr,
269 num_rows: band_rows,
270 row_offset: base_offset + current_row,
271 });
272 remaining = rest;
273 current_row = split_row;
274 }
275 result
276 }
277 }
278 }
279
280 #[cfg(feature = "threads")]
292 #[allow(dead_code)] pub(crate) fn split_into_col_fragments(
294 &mut self,
295 split_cols: &[usize],
296 ) -> Vec<JxlOutputBuffer<'_>> {
297 let bpr = self.bytes_per_row;
298 let nrows = self.num_rows;
299 let base_offset = self.row_offset;
300 let num_frags = split_cols.len() + 1;
301
302 for (i, &col) in split_cols.iter().enumerate() {
304 assert!(col <= bpr, "split_col {col} exceeds bytes_per_row {bpr}");
305 if i > 0 {
306 assert!(col >= split_cols[i - 1], "split_cols must be sorted");
307 }
308 }
309
310 let mut fragment_rows: Vec<Vec<&mut [u8]>> =
312 (0..num_frags).map(|_| Vec::with_capacity(nrows)).collect();
313
314 match &mut self.storage {
315 BufferStorage::Contiguous {
316 data,
317 bytes_between_rows,
318 } => {
319 let btr = *bytes_between_rows;
320 let mut remaining: &mut [u8] = data;
321
322 for row_idx in 0..nrows {
323 let tmp = remaining;
326 let split_point = if row_idx < nrows - 1 { btr } else { tmp.len() };
327 let (chunk, rest) = tmp.split_at_mut(split_point);
328 remaining = rest;
329 let row_useful = &mut chunk[..bpr];
330
331 let mut col_remaining = row_useful;
333 let mut prev_col = 0;
334 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
335 let width = split_col - prev_col;
336 let (frag, rest) = col_remaining.split_at_mut(width);
337 fragment_rows[frag_idx].push(frag);
338 col_remaining = rest;
339 prev_col = split_col;
340 }
341 fragment_rows[num_frags - 1].push(col_remaining);
342 }
343 }
344 BufferStorage::Fragmented { rows } => {
345 for row in rows.iter_mut() {
346 let mut col_remaining: &mut [u8] = row;
347 let mut prev_col = 0;
348 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
349 let width = split_col - prev_col;
350 let (frag, rest) = col_remaining.split_at_mut(width);
351 fragment_rows[frag_idx].push(frag);
352 col_remaining = rest;
353 prev_col = split_col;
354 }
355 fragment_rows[num_frags - 1].push(col_remaining);
356 }
357 }
358 }
359
360 let mut result = Vec::with_capacity(num_frags);
362 let mut prev_col = 0;
363 for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
364 let col_end = if frag_idx < split_cols.len() {
365 split_cols[frag_idx]
366 } else {
367 bpr
368 };
369 let width = col_end - prev_col;
370 result.push(JxlOutputBuffer {
371 storage: BufferStorage::Fragmented { rows },
372 bytes_per_row: width,
373 num_rows: nrows,
374 row_offset: base_offset,
375 });
376 prev_col = col_end;
377 }
378 result
379 }
380
381 #[cfg(feature = "threads")]
393 pub(crate) fn split_into_tile_grid(
394 &mut self,
395 split_rows: &[usize],
396 split_cols_per_band: &[&[usize]],
397 ) -> Vec<Vec<JxlOutputBuffer<'_>>> {
398 let bpr = self.bytes_per_row;
399 let nrows = self.num_rows;
400 let base_offset = self.row_offset;
401 let num_bands = split_rows.len() + 1;
402
403 assert_eq!(
404 split_cols_per_band.len(),
405 num_bands,
406 "need one set of split_cols per band"
407 );
408
409 let BufferStorage::Contiguous {
410 data,
411 bytes_between_rows,
412 } = &mut self.storage
413 else {
414 panic!("split_into_tile_grid requires Contiguous storage")
415 };
416 let btr = *bytes_between_rows;
417 let mut remaining: &mut [u8] = data;
418
419 let mut result: Vec<Vec<JxlOutputBuffer<'_>>> = Vec::with_capacity(num_bands);
420 let mut current_row = 0;
421
422 for band_idx in 0..num_bands {
423 let band_end = if band_idx < split_rows.len() {
424 split_rows[band_idx]
425 } else {
426 nrows
427 };
428 assert!(
429 band_end >= current_row && band_end <= nrows,
430 "split_rows must be sorted and <= num_rows"
431 );
432 let band_rows = band_end - current_row;
433 let split_cols = split_cols_per_band[band_idx];
434 let num_frags = split_cols.len() + 1;
435
436 let mut fragment_rows: Vec<Vec<&mut [u8]>> = (0..num_frags)
438 .map(|_| Vec::with_capacity(band_rows))
439 .collect();
440
441 for row_offset_in_band in 0..band_rows {
442 let is_last_overall_row = current_row + row_offset_in_band == nrows - 1;
443 let tmp = remaining;
444 let split_point = if is_last_overall_row { tmp.len() } else { btr };
445 let (chunk, rest) = tmp.split_at_mut(split_point);
446 remaining = rest;
447 let row_useful = &mut chunk[..bpr];
448
449 let mut col_remaining = row_useful;
451 let mut prev_col = 0;
452 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
453 let width = split_col - prev_col;
454 let (frag, rest) = col_remaining.split_at_mut(width);
455 fragment_rows[frag_idx].push(frag);
456 col_remaining = rest;
457 prev_col = split_col;
458 }
459 fragment_rows[num_frags - 1].push(col_remaining);
460 }
461
462 let mut band_frags = Vec::with_capacity(num_frags);
464 let mut prev_col = 0;
465 for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
466 let col_end = if frag_idx < split_cols.len() {
467 split_cols[frag_idx]
468 } else {
469 bpr
470 };
471 let width = col_end - prev_col;
472 band_frags.push(JxlOutputBuffer {
473 storage: BufferStorage::Fragmented { rows },
474 bytes_per_row: width,
475 num_rows: band_rows,
476 row_offset: base_offset + current_row,
477 });
478 prev_col = col_end;
479 }
480 result.push(band_frags);
481 current_row = band_end;
482 }
483 result
484 }
485
486 pub fn rect(&mut self, rect: Rect) -> JxlOutputBuffer<'_> {
487 if rect.size.0 == 0 || rect.size.1 == 0 {
488 return JxlOutputBuffer {
489 storage: BufferStorage::Contiguous {
490 data: &mut [],
491 bytes_between_rows: 0,
492 },
493 bytes_per_row: 0,
494 num_rows: 0,
495 row_offset: 0,
496 };
497 }
498 assert!(
499 rect.origin.1 >= self.row_offset,
500 "rect origin row {} < row_offset {}",
501 rect.origin.1,
502 self.row_offset,
503 );
504 let local_y = rect.origin.1 - self.row_offset;
505 assert!(local_y + rect.size.1 <= self.num_rows);
506 assert!(rect.origin.0 + rect.size.0 <= self.bytes_per_row);
507
508 match &mut self.storage {
509 BufferStorage::Contiguous {
510 data,
511 bytes_between_rows,
512 } => {
513 let btr = *bytes_between_rows;
514 let new_start = local_y * btr + rect.origin.0;
515 let data_span = (rect.size.1 - 1) * btr + rect.size.0;
516 assert!(new_start + data_span <= data.len());
517 JxlOutputBuffer {
518 storage: BufferStorage::Contiguous {
519 data: &mut data[new_start..new_start + data_span],
520 bytes_between_rows: btr,
521 },
522 bytes_per_row: rect.size.0,
523 num_rows: rect.size.1,
524 row_offset: 0,
526 }
527 }
528 BufferStorage::Fragmented { rows } => {
529 let col_start = rect.origin.0;
530 let col_end = col_start + rect.size.0;
531 let sub_rows: Vec<&mut [u8]> = rows[local_y..local_y + rect.size.1]
532 .iter_mut()
533 .map(|row| &mut row[col_start..col_end])
534 .collect();
535 JxlOutputBuffer {
536 storage: BufferStorage::Fragmented { rows: sub_rows },
537 bytes_per_row: rect.size.0,
538 num_rows: rect.size.1,
539 row_offset: 0,
541 }
542 }
543 }
544 }
545}
546
547impl Debug for JxlOutputBuffer<'_> {
548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 write!(
550 f,
551 "JxlOutputBuffer {}x{}",
552 self.bytes_per_row, self.num_rows
553 )
554 }
555}