rustdct/algorithm/
type4_convert_to_type3.rs1use std::sync::Arc;
2
3use rustfft::num_complex::Complex;
4use rustfft::Length;
5
6use crate::common::dct_error_inplace;
7use crate::{twiddles, Dct4, DctNum, Dst4, RequiredScratch, TransformType2And3, TransformType4};
8
9pub struct Type4ConvertToType3Even<T> {
30 inner_dct: Arc<dyn TransformType2And3<T>>,
31 twiddles: Box<[Complex<T>]>,
32 scratch_len: usize,
33}
34
35impl<T: DctNum> Type4ConvertToType3Even<T> {
36 pub fn new(inner_dct: Arc<dyn TransformType2And3<T>>) -> Self {
38 let inner_len = inner_dct.len();
39 let len = inner_len * 2;
40
41 let twiddles: Vec<Complex<T>> = (0..inner_len)
42 .map(|i| twiddles::single_twiddle(2 * i + 1, len * 8).conj())
43 .collect();
44
45 let inner_scratch = inner_dct.get_scratch_len();
46 let scratch_len = if inner_scratch <= len {
47 len
48 } else {
49 len + inner_scratch
50 };
51
52 Self {
53 inner_dct: inner_dct,
54 twiddles: twiddles.into_boxed_slice(),
55 scratch_len,
56 }
57 }
58}
59impl<T: DctNum> Dct4<T> for Type4ConvertToType3Even<T> {
60 fn process_dct4_with_scratch(&self, buffer: &mut [T], scratch: &mut [T]) {
61 let scratch = validate_buffers!(buffer, scratch, self.len(), self.get_scratch_len());
62
63 let (self_scratch, extra_scratch) = scratch.split_at_mut(self.len());
64
65 let len = self.len();
66 let inner_len = len / 2;
67
68 let (mut output_left, mut output_right) = self_scratch.split_at_mut(inner_len);
70
71 output_left[0] = buffer[0] * T::two();
72 for k in 1..inner_len {
73 output_left[k] = buffer[2 * k - 1] + buffer[2 * k];
74 output_right[k - 1] = buffer[2 * k - 1] - buffer[2 * k];
75 }
76 output_right[inner_len - 1] = buffer[len - 1] * T::two();
77
78 let inner_scratch = if extra_scratch.len() > 0 {
80 extra_scratch
81 } else {
82 &mut buffer[..]
83 };
84
85 self.inner_dct
86 .process_dct3_with_scratch(&mut output_left, inner_scratch);
87 self.inner_dct
88 .process_dst3_with_scratch(&mut output_right, inner_scratch);
89
90 for k in 0..inner_len {
92 let twiddle = self.twiddles[k];
93 let cos_value = output_left[k];
94 let sin_value = output_right[k];
95
96 buffer[k] = cos_value * twiddle.re + sin_value * twiddle.im;
97 buffer[len - 1 - k] = cos_value * twiddle.im - sin_value * twiddle.re;
98 }
99 }
100}
101impl<T: DctNum> Dst4<T> for Type4ConvertToType3Even<T> {
102 fn process_dst4_with_scratch(&self, buffer: &mut [T], scratch: &mut [T]) {
103 let scratch = validate_buffers!(buffer, scratch, self.len(), self.get_scratch_len());
104
105 let (self_scratch, extra_scratch) = scratch.split_at_mut(self.len());
106
107 let len = self.len();
108 let inner_len = len / 2;
109
110 let (mut output_left, mut output_right) = self_scratch.split_at_mut(inner_len);
112
113 output_right[0] = buffer[0] * T::two();
114 for k in 1..inner_len {
115 output_left[k - 1] = buffer[2 * k - 1] + buffer[2 * k];
116 output_right[k] = buffer[2 * k] - buffer[2 * k - 1];
117 }
118 output_left[inner_len - 1] = buffer[len - 1] * T::two();
119
120 let inner_scratch = if extra_scratch.len() > 0 {
122 extra_scratch
123 } else {
124 &mut buffer[..]
125 };
126
127 self.inner_dct
128 .process_dst3_with_scratch(&mut output_left, inner_scratch);
129 self.inner_dct
130 .process_dct3_with_scratch(&mut output_right, inner_scratch);
131
132 for k in 0..inner_len {
134 let twiddle = self.twiddles[k];
135 let cos_value = output_left[k];
136 let sin_value = output_right[k];
137
138 buffer[k] = cos_value * twiddle.re + sin_value * twiddle.im;
139 buffer[len - 1 - k] = sin_value * twiddle.re - cos_value * twiddle.im;
140 }
141 }
142}
143impl<T> RequiredScratch for Type4ConvertToType3Even<T> {
144 fn get_scratch_len(&self) -> usize {
145 self.scratch_len
146 }
147}
148impl<T: DctNum> TransformType4<T> for Type4ConvertToType3Even<T> {}
149impl<T> Length for Type4ConvertToType3Even<T> {
150 fn len(&self) -> usize {
151 self.twiddles.len() * 2
152 }
153}
154
155#[cfg(test)]
156mod test {
157 use super::*;
158 use crate::algorithm::{Type2And3Naive, Type4Naive};
159 use crate::test_utils::{compare_float_vectors, random_signal};
160
161 #[test]
162 fn unittest_dct4_via_type3() {
163 for inner_size in 1..20 {
164 let size = inner_size * 2;
165
166 let mut expected_buffer = random_signal(size);
167 let mut actual_buffer = expected_buffer.clone();
168
169 let naive_dct4 = Type4Naive::new(size);
170 naive_dct4.process_dct4(&mut expected_buffer);
171
172 let inner_dct3 = Arc::new(Type2And3Naive::new(inner_size));
173 let dct = Type4ConvertToType3Even::new(inner_dct3);
174 dct.process_dct4(&mut actual_buffer);
175
176 println!("");
177 println!("expected: {:?}", expected_buffer);
178 println!("actual: {:?}", actual_buffer);
179
180 assert!(
181 compare_float_vectors(&expected_buffer, &actual_buffer),
182 "len = {}",
183 size
184 );
185 }
186 }
187
188 #[test]
189 fn unittest_dst4_via_type3() {
190 for inner_size in 1..20 {
191 let size = inner_size * 2;
192
193 let mut expected_buffer = random_signal(size);
194 let mut actual_buffer = expected_buffer.clone();
195
196 let naive_dst4 = Type4Naive::new(size);
197 naive_dst4.process_dst4(&mut expected_buffer);
198
199 let inner_dst3 = Arc::new(Type2And3Naive::new(inner_size));
200 let dst = Type4ConvertToType3Even::new(inner_dst3);
201 dst.process_dst4(&mut actual_buffer);
202
203 println!("");
204 println!("expected: {:?}", expected_buffer);
205 println!("actual: {:?}", actual_buffer);
206
207 assert!(
208 compare_float_vectors(&expected_buffer, &actual_buffer),
209 "len = {}",
210 size
211 );
212 }
213 }
214}