rustdct/algorithm/
type4_convert_to_type3.rs

1use std::sync::Arc;
2
3use rustfft::num_complex::Complex;
4use rustfft::Length;
5
6use crate::common::dct_error_inplace;
7use crate::{twiddles, Dct4, DctNum, Dst4, RequiredScratch, TransformType2And3, TransformType4};
8
9/// DCT4 and DST4 implementation that converts the problem into two DCT3 of half size.
10///
11/// If the inner DCT3 is O(nlogn), then so is this. This algorithm can only be used if the problem size is even.
12///
13/// ~~~
14/// // Computes a DCT Type 4 of size 1234
15/// use std::sync::Arc;
16/// use rustdct::Dct4;
17/// use rustdct::algorithm::Type4ConvertToType3Even;
18/// use rustdct::DctPlanner;
19///
20/// let len = 1234;
21/// let mut planner = DctPlanner::new();
22/// let inner_dct3 = planner.plan_dct3(len / 2);
23///
24/// let dct = Type4ConvertToType3Even::new(inner_dct3);
25///
26/// let mut buffer = vec![0f32; len];
27/// dct.process_dct4(&mut buffer);
28/// ~~~
29pub struct Type4ConvertToType3Even<T> {
30    inner_dct: Arc<dyn TransformType2And3<T>>,
31    twiddles: Box<[Complex<T>]>,
32    scratch_len: usize,
33}
34
35impl<T: DctNum> Type4ConvertToType3Even<T> {
36    /// Creates a new DCT4 context that will process signals of length `inner_dct.len() * 2`.
37    pub fn new(inner_dct: Arc<dyn TransformType2And3<T>>) -> Self {
38        let inner_len = inner_dct.len();
39        let len = inner_len * 2;
40
41        let twiddles: Vec<Complex<T>> = (0..inner_len)
42            .map(|i| twiddles::single_twiddle(2 * i + 1, len * 8).conj())
43            .collect();
44
45        let inner_scratch = inner_dct.get_scratch_len();
46        let scratch_len = if inner_scratch <= len {
47            len
48        } else {
49            len + inner_scratch
50        };
51
52        Self {
53            inner_dct: inner_dct,
54            twiddles: twiddles.into_boxed_slice(),
55            scratch_len,
56        }
57    }
58}
59impl<T: DctNum> Dct4<T> for Type4ConvertToType3Even<T> {
60    fn process_dct4_with_scratch(&self, buffer: &mut [T], scratch: &mut [T]) {
61        let scratch = validate_buffers!(buffer, scratch, self.len(), self.get_scratch_len());
62
63        let (self_scratch, extra_scratch) = scratch.split_at_mut(self.len());
64
65        let len = self.len();
66        let inner_len = len / 2;
67
68        //pre-process the input by splitting into into two arrays, one for the inner DCT3, and the other for the DST3
69        let (mut output_left, mut output_right) = self_scratch.split_at_mut(inner_len);
70
71        output_left[0] = buffer[0] * T::two();
72        for k in 1..inner_len {
73            output_left[k] = buffer[2 * k - 1] + buffer[2 * k];
74            output_right[k - 1] = buffer[2 * k - 1] - buffer[2 * k];
75        }
76        output_right[inner_len - 1] = buffer[len - 1] * T::two();
77
78        //run the two inner DCTs on our separated arrays
79        let inner_scratch = if extra_scratch.len() > 0 {
80            extra_scratch
81        } else {
82            &mut buffer[..]
83        };
84
85        self.inner_dct
86            .process_dct3_with_scratch(&mut output_left, inner_scratch);
87        self.inner_dct
88            .process_dst3_with_scratch(&mut output_right, inner_scratch);
89
90        //post-process the data by combining it back into a single array
91        for k in 0..inner_len {
92            let twiddle = self.twiddles[k];
93            let cos_value = output_left[k];
94            let sin_value = output_right[k];
95
96            buffer[k] = cos_value * twiddle.re + sin_value * twiddle.im;
97            buffer[len - 1 - k] = cos_value * twiddle.im - sin_value * twiddle.re;
98        }
99    }
100}
101impl<T: DctNum> Dst4<T> for Type4ConvertToType3Even<T> {
102    fn process_dst4_with_scratch(&self, buffer: &mut [T], scratch: &mut [T]) {
103        let scratch = validate_buffers!(buffer, scratch, self.len(), self.get_scratch_len());
104
105        let (self_scratch, extra_scratch) = scratch.split_at_mut(self.len());
106
107        let len = self.len();
108        let inner_len = len / 2;
109
110        //pre-process the input by splitting into into two arrays, one for the inner DCT3, and the other for the DST3
111        let (mut output_left, mut output_right) = self_scratch.split_at_mut(inner_len);
112
113        output_right[0] = buffer[0] * T::two();
114        for k in 1..inner_len {
115            output_left[k - 1] = buffer[2 * k - 1] + buffer[2 * k];
116            output_right[k] = buffer[2 * k] - buffer[2 * k - 1];
117        }
118        output_left[inner_len - 1] = buffer[len - 1] * T::two();
119
120        //run the two inner DCTs on our separated arrays
121        let inner_scratch = if extra_scratch.len() > 0 {
122            extra_scratch
123        } else {
124            &mut buffer[..]
125        };
126
127        self.inner_dct
128            .process_dst3_with_scratch(&mut output_left, inner_scratch);
129        self.inner_dct
130            .process_dct3_with_scratch(&mut output_right, inner_scratch);
131
132        //post-process the data by combining it back into a single array
133        for k in 0..inner_len {
134            let twiddle = self.twiddles[k];
135            let cos_value = output_left[k];
136            let sin_value = output_right[k];
137
138            buffer[k] = cos_value * twiddle.re + sin_value * twiddle.im;
139            buffer[len - 1 - k] = sin_value * twiddle.re - cos_value * twiddle.im;
140        }
141    }
142}
143impl<T> RequiredScratch for Type4ConvertToType3Even<T> {
144    fn get_scratch_len(&self) -> usize {
145        self.scratch_len
146    }
147}
148impl<T: DctNum> TransformType4<T> for Type4ConvertToType3Even<T> {}
149impl<T> Length for Type4ConvertToType3Even<T> {
150    fn len(&self) -> usize {
151        self.twiddles.len() * 2
152    }
153}
154
155#[cfg(test)]
156mod test {
157    use super::*;
158    use crate::algorithm::{Type2And3Naive, Type4Naive};
159    use crate::test_utils::{compare_float_vectors, random_signal};
160
161    #[test]
162    fn unittest_dct4_via_type3() {
163        for inner_size in 1..20 {
164            let size = inner_size * 2;
165
166            let mut expected_buffer = random_signal(size);
167            let mut actual_buffer = expected_buffer.clone();
168
169            let naive_dct4 = Type4Naive::new(size);
170            naive_dct4.process_dct4(&mut expected_buffer);
171
172            let inner_dct3 = Arc::new(Type2And3Naive::new(inner_size));
173            let dct = Type4ConvertToType3Even::new(inner_dct3);
174            dct.process_dct4(&mut actual_buffer);
175
176            println!("");
177            println!("expected: {:?}", expected_buffer);
178            println!("actual:   {:?}", actual_buffer);
179
180            assert!(
181                compare_float_vectors(&expected_buffer, &actual_buffer),
182                "len = {}",
183                size
184            );
185        }
186    }
187
188    #[test]
189    fn unittest_dst4_via_type3() {
190        for inner_size in 1..20 {
191            let size = inner_size * 2;
192
193            let mut expected_buffer = random_signal(size);
194            let mut actual_buffer = expected_buffer.clone();
195
196            let naive_dst4 = Type4Naive::new(size);
197            naive_dst4.process_dst4(&mut expected_buffer);
198
199            let inner_dst3 = Arc::new(Type2And3Naive::new(inner_size));
200            let dst = Type4ConvertToType3Even::new(inner_dst3);
201            dst.process_dst4(&mut actual_buffer);
202
203            println!("");
204            println!("expected: {:?}", expected_buffer);
205            println!("actual:   {:?}", actual_buffer);
206
207            assert!(
208                compare_float_vectors(&expected_buffer, &actual_buffer),
209                "len = {}",
210                size
211            );
212        }
213    }
214}