scirs2_transform/signal_transforms/
wpt.rs1use crate::error::{Result, TransformError};
7use crate::signal_transforms::dwt::{BoundaryMode, WaveletType, DWT};
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct WaveletPacketNode {
14 pub data: Array1<f64>,
16 pub path: String,
18 pub level: usize,
20 pub index: usize,
22 pub cost: f64,
24}
25
26impl WaveletPacketNode {
27 pub fn new(data: Array1<f64>, path: String, level: usize, index: usize) -> Self {
29 let cost = Self::compute_cost(&data);
30 WaveletPacketNode {
31 data,
32 path,
33 level,
34 index,
35 cost,
36 }
37 }
38
39 fn compute_cost(data: &Array1<f64>) -> f64 {
41 let energy: f64 = data.iter().map(|x| x * x).sum();
42 if energy < 1e-10 {
43 return 0.0;
44 }
45
46 let mut entropy = 0.0;
47 for &val in data.iter() {
48 let p = (val * val) / energy;
49 if p > 1e-10 {
50 entropy -= p * p.ln();
51 }
52 }
53
54 entropy
55 }
56
57 pub fn update_cost(&mut self) {
59 self.cost = Self::compute_cost(&self.data);
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum BestBasisCriterion {
66 Shannon,
68 Threshold(f64),
70 LogEnergy,
72 Sure,
74}
75
76#[derive(Debug, Clone)]
78pub struct WPT {
79 wavelet: WaveletType,
80 max_level: usize,
81 boundary: BoundaryMode,
82 criterion: BestBasisCriterion,
83 nodes: HashMap<String, WaveletPacketNode>,
84}
85
86impl WPT {
87 pub fn new(wavelet: WaveletType, max_level: usize) -> Self {
89 WPT {
90 wavelet,
91 max_level,
92 boundary: BoundaryMode::Symmetric,
93 criterion: BestBasisCriterion::Shannon,
94 nodes: HashMap::new(),
95 }
96 }
97
98 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
100 self.boundary = boundary;
101 self
102 }
103
104 pub fn with_criterion(mut self, criterion: BestBasisCriterion) -> Self {
106 self.criterion = criterion;
107 self
108 }
109
110 pub fn decompose(&mut self, signal: &ArrayView1<f64>) -> Result<()> {
112 self.nodes.clear();
113
114 let root = WaveletPacketNode::new(signal.to_owned(), String::new(), 0, 0);
116 self.nodes.insert(String::new(), root);
117
118 self.decompose_node("", 0)?;
120
121 Ok(())
122 }
123
124 fn decompose_node(&mut self, path: &str, level: usize) -> Result<()> {
126 if level >= self.max_level {
127 return Ok(());
128 }
129
130 let node = self
132 .nodes
133 .get(path)
134 .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?
135 .clone();
136
137 let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
139
140 let (approx, detail) = dwt.decompose(&node.data.view())?;
142
143 let approx_path = format!("{}a", path);
145 let detail_path = format!("{}d", path);
146
147 let index = node.index;
148 let approx_node = WaveletPacketNode::new(approx, approx_path.clone(), level + 1, index * 2);
149 let detail_node =
150 WaveletPacketNode::new(detail, detail_path.clone(), level + 1, index * 2 + 1);
151
152 self.nodes.insert(approx_path.clone(), approx_node);
153 self.nodes.insert(detail_path.clone(), detail_node);
154
155 self.decompose_node(&approx_path, level + 1)?;
157 self.decompose_node(&detail_path, level + 1)?;
158
159 Ok(())
160 }
161
162 pub fn best_basis(&self) -> Result<Vec<WaveletPacketNode>> {
164 let mut best_nodes = Vec::new();
165 self.select_best_basis("", &mut best_nodes)?;
166 Ok(best_nodes)
167 }
168
169 fn select_best_basis(&self, path: &str, selected: &mut Vec<WaveletPacketNode>) -> Result<f64> {
171 let node = self
172 .nodes
173 .get(path)
174 .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?;
175
176 let approx_path = format!("{}a", path);
177 let detail_path = format!("{}d", path);
178
179 if self.nodes.contains_key(&approx_path) && self.nodes.contains_key(&detail_path) {
181 let approx_cost = self.select_best_basis(&approx_path, selected)?;
183 let detail_cost = self.select_best_basis(&detail_path, selected)?;
184 let children_cost = approx_cost + detail_cost;
185
186 if node.cost <= children_cost {
188 selected.retain(|n| !n.path.starts_with(path) || n.path == path);
190 selected.push(node.clone());
191 Ok(node.cost)
192 } else {
193 Ok(children_cost)
195 }
196 } else {
197 selected.push(node.clone());
199 Ok(node.cost)
200 }
201 }
202
203 pub fn reconstruct(&self, nodes: &[WaveletPacketNode]) -> Result<Array1<f64>> {
205 if nodes.is_empty() {
206 return Err(TransformError::InvalidInput(
207 "No nodes provided for reconstruction".to_string(),
208 ));
209 }
210
211 if let Some(root) = nodes.iter().find(|n| n.path.is_empty()) {
213 return Ok(root.data.clone());
214 }
215
216 Err(TransformError::NotImplemented(
218 "Reconstruction from arbitrary basis not yet implemented".to_string(),
219 ))
220 }
221
222 pub fn get_level(&self, level: usize) -> Vec<&WaveletPacketNode> {
224 self.nodes
225 .values()
226 .filter(|node| node.level == level)
227 .collect()
228 }
229
230 pub fn get_node(&self, path: &str) -> Option<&WaveletPacketNode> {
232 self.nodes.get(path)
233 }
234
235 pub fn nodes(&self) -> &HashMap<String, WaveletPacketNode> {
237 &self.nodes
238 }
239
240 pub fn best_basis_cost(&self) -> Result<f64> {
242 let best = self.best_basis()?;
243 Ok(best.iter().map(|node| node.cost).sum())
244 }
245}
246
247pub fn denoise_wpt(
249 signal: &ArrayView1<f64>,
250 wavelet: WaveletType,
251 level: usize,
252 threshold: f64,
253) -> Result<Array1<f64>> {
254 let mut wpt = WPT::new(wavelet, level);
256 wpt.decompose(signal)?;
257
258 let best = wpt.best_basis()?;
260
261 let mut denoised_nodes = Vec::new();
263 for mut node in best {
264 for val in node.data.iter_mut() {
266 if val.abs() < threshold {
267 *val = 0.0;
268 } else {
269 *val = if *val > 0.0 {
270 *val - threshold
271 } else {
272 *val + threshold
273 };
274 }
275 }
276 node.update_cost();
277 denoised_nodes.push(node);
278 }
279
280 wpt.reconstruct(&denoised_nodes)
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use approx::assert_abs_diff_eq;
288
289 #[test]
290 fn test_wpt_decompose() -> Result<()> {
291 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
292 let mut wpt = WPT::new(WaveletType::Haar, 2);
293
294 wpt.decompose(&signal.view())?;
295
296 assert!(wpt.get_node("").is_some());
298 assert!(wpt.get_node("a").is_some());
299 assert!(wpt.get_node("d").is_some());
300 assert!(wpt.get_node("aa").is_some());
301 assert!(wpt.get_node("ad").is_some());
302 assert!(wpt.get_node("da").is_some());
303 assert!(wpt.get_node("dd").is_some());
304
305 Ok(())
306 }
307
308 #[test]
309 fn test_wpt_best_basis() -> Result<()> {
310 let signal = Array1::from_vec((0..16).map(|i| (i as f64 * 0.5).sin()).collect());
311 let mut wpt = WPT::new(WaveletType::Haar, 3);
312
313 wpt.decompose(&signal.view())?;
314 let best = wpt.best_basis()?;
315
316 assert!(!best.is_empty());
317
318 let mut paths: Vec<_> = best.iter().map(|n| n.path.clone()).collect();
320 paths.sort();
321 paths.dedup();
322 assert_eq!(paths.len(), best.len());
323
324 Ok(())
325 }
326
327 #[test]
328 fn test_wpt_levels() -> Result<()> {
329 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
330 let mut wpt = WPT::new(WaveletType::Haar, 2);
331
332 wpt.decompose(&signal.view())?;
333
334 let level0 = wpt.get_level(0);
335 let level1 = wpt.get_level(1);
336 let level2 = wpt.get_level(2);
337
338 assert_eq!(level0.len(), 1);
339 assert_eq!(level1.len(), 2);
340 assert_eq!(level2.len(), 4);
341
342 Ok(())
343 }
344
345 #[test]
346 fn test_wavelet_packet_node_cost() {
347 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
348 let node = WaveletPacketNode::new(data, "test".to_string(), 1, 0);
349
350 assert!(node.cost >= 0.0);
351 }
352
353 #[test]
354 fn test_best_basis_criterion() {
355 let wpt1 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::Shannon);
356 assert_eq!(wpt1.criterion, BestBasisCriterion::Shannon);
357
358 let wpt2 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::LogEnergy);
359 assert_eq!(wpt2.criterion, BestBasisCriterion::LogEnergy);
360 }
361}