1use crate::error::{IffError, Result};
8use crate::prime::QuantizationTable;
9use crate::compression::{compress_rle, decompress_rle};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SubBand {
15 LL,
16 LH,
17 HL,
18 HH,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct WaveletDecomposition {
24 pub width: u32,
26 pub height: u32,
27 pub levels: usize,
29 pub channels: u8,
31 pub data: Vec<u8>,
34}
35
36impl WaveletDecomposition {
37 pub fn new(width: u32, height: u32, levels: usize, channels: u8) -> Self {
39 WaveletDecomposition {
40 width,
41 height,
42 levels,
43 channels,
44 data: Vec::new(),
45 }
46 }
47
48 pub fn from_dense(
50 width: u32,
51 height: u32,
52 levels: usize,
53 channel_data: &[Vec<i32>],
54 ) -> Result<Self> {
55 let mut all_data = Vec::new();
56 for channel in channel_data {
57 all_data.extend_from_slice(channel);
58 }
59
60 let compressed = compress_rle(&all_data)?;
61
62 Ok(WaveletDecomposition {
63 width,
64 height,
65 levels,
66 channels: channel_data.len() as u8,
67 data: compressed,
68 })
69 }
70
71 pub fn to_dense(&self) -> Result<Vec<Vec<i32>>> {
73 let total_pixels = (self.width * self.height) as usize;
74 let expected_len = total_pixels * self.channels as usize;
75
76 let all_data = decompress_rle(&self.data, Some(expected_len))?;
77
78 let mut channels = Vec::with_capacity(self.channels as usize);
79 for i in 0..self.channels as usize {
80 let start = i * total_pixels;
81 let end = start + total_pixels;
82 if end > all_data.len() {
83 return Err(IffError::Other("Insufficient data for channels".to_string()));
84 }
85 channels.push(all_data[start..end].to_vec());
86 }
87
88 Ok(channels)
89 }
90}
91
92pub struct Cdf53Transform {
94 levels: usize,
96}
97
98impl Cdf53Transform {
99 pub fn new(levels: usize) -> Self {
101 Cdf53Transform { levels }
102 }
103
104 pub fn forward(&self, image: &[i32], width: usize, height: usize) -> Result<Vec<i32>> {
106 if image.len() != width * height {
107 return Err(IffError::Other(
108 "Image dimensions don't match data length".to_string(),
109 ));
110 }
111
112 let mut buffer = image.to_vec();
114
115 let mut current_width = width;
117 let mut current_height = height;
118
119 for _ in 0..self.levels {
120 if current_width < 2 || current_height < 2 {
121 break; }
123
124 for y in 0..current_height {
126 self.forward_1d_row(&mut buffer, y, current_width, width);
127 }
128
129 for x in 0..current_width {
131 self.forward_1d_col(&mut buffer, x, current_height, width);
132 }
133
134 current_width /= 2;
136 current_height /= 2;
137 }
138
139 Ok(buffer)
140 }
141
142 pub fn inverse(&self, coefficients: &[i32], width: usize, height: usize) -> Result<Vec<i32>> {
144 if coefficients.len() != width * height {
145 return Err(IffError::Other(
146 "Coefficient dimensions don't match data length".to_string(),
147 ));
148 }
149
150 let mut buffer = coefficients.to_vec();
151
152 let mut dims = Vec::new();
163 let mut w = width;
164 let mut h = height;
165 for _ in 0..self.levels {
166 dims.push((w, h));
167 w /= 2;
168 h /= 2;
169 }
170
171 for (w, h) in dims.iter().rev() {
172 let current_width = *w;
173 let current_height = *h;
174
175 if current_width < 2 || current_height < 2 {
176 continue;
177 }
178
179 for x in 0..current_width {
181 self.inverse_1d_col(&mut buffer, x, current_height, width);
182 }
183
184 for y in 0..current_height {
186 self.inverse_1d_row(&mut buffer, y, current_width, width);
187 }
188 }
189
190 Ok(buffer)
191 }
192
193 pub fn quantize(&self, buffer: &mut [i32], width: usize, height: usize, table: &QuantizationTable) {
195 let mut current_width = width;
200 let mut current_height = height;
201
202 for level in 0..self.levels {
203 let half_w = current_width / 2;
204 let half_h = current_height / 2;
205
206 if half_w == 0 || half_h == 0 { break; }
207
208 for y in 0..half_h {
216 for x in half_w..current_width {
217 self.quantize_pixel(buffer, x, y, width, table);
218 }
219 }
220
221 for y in half_h..current_height {
223 for x in 0..half_w {
224 self.quantize_pixel(buffer, x, y, width, table);
225 }
226 }
227
228 for y in half_h..current_height {
230 for x in half_w..current_width {
231 self.quantize_pixel(buffer, x, y, width, table);
232 }
233 }
234
235 if level == self.levels - 1 {
237 for y in 0..half_h {
238 for x in 0..half_w {
239 self.quantize_pixel(buffer, x, y, width, table);
240 }
241 }
242 }
243
244 current_width = half_w;
245 current_height = half_h;
246 }
247 }
248
249 fn quantize_pixel(&self, buffer: &mut [i32], x: usize, y: usize, width: usize, table: &QuantizationTable) {
250 let idx = y * width + x;
251 let val = buffer[idx];
252 let step = table.get_step(x, y);
253
254 let quantized = val / step as i32;
255
256 if quantized.abs() < 2 { buffer[idx] = 0;
259 } else {
260 buffer[idx] = quantized * step as i32;
261 }
262 }
263
264 fn forward_1d_row(&self, data: &mut [i32], y: usize, width: usize, stride: usize) {
266 if width < 2 {
267 return;
268 }
269
270 let offset = y * stride;
271 let mut temp = vec![0i32; width];
272
273 for i in 0..(width / 2) {
275 let left = data[offset + 2 * i];
276 let right = if 2 * i + 2 < width {
277 data[offset + 2 * i + 2]
278 } else {
279 data[offset + 2 * i] };
281 temp[width / 2 + i] = data[offset + 2 * i + 1] - ((left + right) / 2);
282 }
283
284 for i in 0..(width / 2) {
286 let d_left = if i > 0 {
287 temp[width / 2 + i - 1]
288 } else {
289 temp[width / 2] };
291 let d_right = if i < width / 2 {
292 temp[width / 2 + i]
293 } else {
294 temp[width / 2 + i - 1] };
296 temp[i] = data[offset + 2 * i] + ((d_left + d_right + 2) / 4);
297 }
298
299 for i in 0..width {
301 data[offset + i] = temp[i];
302 }
303 }
304
305 fn forward_1d_col(&self, data: &mut [i32], x: usize, height: usize, stride: usize) {
307 if height < 2 {
308 return;
309 }
310
311 let mut temp = vec![0i32; height];
312
313 for i in 0..(height / 2) {
315 let top = data[2 * i * stride + x];
316 let bottom = if 2 * i + 2 < height {
317 data[(2 * i + 2) * stride + x]
318 } else {
319 data[2 * i * stride + x] };
321 temp[height / 2 + i] = data[(2 * i + 1) * stride + x] - ((top + bottom) / 2);
322 }
323
324 for i in 0..(height / 2) {
326 let d_top = if i > 0 {
327 temp[height / 2 + i - 1]
328 } else {
329 temp[height / 2] };
331 let d_bottom = if i < height / 2 {
332 temp[height / 2 + i]
333 } else {
334 temp[height / 2 + i - 1] };
336 temp[i] = data[2 * i * stride + x] + ((d_top + d_bottom + 2) / 4);
337 }
338
339 for i in 0..height {
341 data[i * stride + x] = temp[i];
342 }
343 }
344
345 fn inverse_1d_row(&self, data: &mut [i32], y: usize, width: usize, stride: usize) {
347 if width < 2 {
348 return;
349 }
350
351 let offset = y * stride;
352 let mut temp = vec![0i32; width];
353
354 for i in 0..width {
356 temp[i] = data[offset + i];
357 }
358
359 for i in 0..(width / 2) {
361 let d_left = if i > 0 {
362 temp[width / 2 + i - 1]
363 } else {
364 temp[width / 2]
365 };
366 let d_right = if i < width / 2 {
367 temp[width / 2 + i]
368 } else {
369 temp[width / 2 + i - 1]
370 };
371 data[offset + 2 * i] = temp[i] - ((d_left + d_right + 2) / 4);
372 }
373
374 for i in 0..(width / 2) {
376 let left = data[offset + 2 * i];
377 let right = if 2 * i + 2 < width {
378 data[offset + 2 * i + 2]
379 } else {
380 data[offset + 2 * i]
381 };
382 data[offset + 2 * i + 1] = temp[width / 2 + i] + ((left + right) / 2);
383 }
384 }
385
386 fn inverse_1d_col(&self, data: &mut [i32], x: usize, height: usize, stride: usize) {
388 if height < 2 {
389 return;
390 }
391
392 let mut temp = vec![0i32; height];
393
394 for i in 0..height {
396 temp[i] = data[i * stride + x];
397 }
398
399 for i in 0..(height / 2) {
401 let d_top = if i > 0 {
402 temp[height / 2 + i - 1]
403 } else {
404 temp[height / 2]
405 };
406 let d_bottom = if i < height / 2 {
407 temp[height / 2 + i]
408 } else {
409 temp[height / 2 + i - 1]
410 };
411 data[2 * i * stride + x] = temp[i] - ((d_top + d_bottom + 2) / 4);
412 }
413
414 for i in 0..(height / 2) {
416 let top = data[2 * i * stride + x];
417 let bottom = if 2 * i + 2 < height {
418 data[(2 * i + 2) * stride + x]
419 } else {
420 data[2 * i * stride + x]
421 };
422 data[(2 * i + 1) * stride + x] = temp[height / 2 + i] + ((top + bottom) / 2);
423 }
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_perfect_reconstruction() {
433 let width = 8;
435 let height = 8;
436 let image: Vec<i32> = (0..(width * height))
437 .map(|i| ((i % 256) as i32 - 128))
438 .collect();
439
440 let transform = Cdf53Transform::new(2);
442 let coeffs = transform.forward(&image, width, height).unwrap();
443
444 let reconstructed = transform.inverse(&coeffs, width, height).unwrap();
446
447 for (orig, recon) in image.iter().zip(reconstructed.iter()) {
449 assert_eq!(orig, recon, "Perfect reconstruction failed");
450 }
451 }
452}