scirs2_transform/signal_transforms/
wpt.rs1use crate::error::{Result, TransformError};
7use crate::signal_transforms::dwt::{BoundaryMode, WaveletType, DWT};
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct WaveletPacketNode {
14 pub data: Array1<f64>,
16 pub path: String,
18 pub level: usize,
20 pub index: usize,
22 pub cost: f64,
24}
25
26impl WaveletPacketNode {
27 pub fn new(data: Array1<f64>, path: String, level: usize, index: usize) -> Self {
29 let cost = Self::compute_cost(&data);
30 WaveletPacketNode {
31 data,
32 path,
33 level,
34 index,
35 cost,
36 }
37 }
38
39 fn compute_cost(data: &Array1<f64>) -> f64 {
41 let energy: f64 = data.iter().map(|x| x * x).sum();
42 if energy < 1e-10 {
43 return 0.0;
44 }
45
46 let mut entropy = 0.0;
47 for &val in data.iter() {
48 let p = (val * val) / energy;
49 if p > 1e-10 {
50 entropy -= p * p.ln();
51 }
52 }
53
54 entropy
55 }
56
57 pub fn update_cost(&mut self) {
59 self.cost = Self::compute_cost(&self.data);
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum BestBasisCriterion {
66 Shannon,
68 Threshold(f64),
70 LogEnergy,
72 Sure,
74}
75
76#[derive(Debug, Clone)]
78pub struct WPT {
79 wavelet: WaveletType,
80 max_level: usize,
81 boundary: BoundaryMode,
82 criterion: BestBasisCriterion,
83 nodes: HashMap<String, WaveletPacketNode>,
84}
85
86impl WPT {
87 pub fn new(wavelet: WaveletType, max_level: usize) -> Self {
89 WPT {
90 wavelet,
91 max_level,
92 boundary: BoundaryMode::Symmetric,
93 criterion: BestBasisCriterion::Shannon,
94 nodes: HashMap::new(),
95 }
96 }
97
98 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
100 self.boundary = boundary;
101 self
102 }
103
104 pub fn with_criterion(mut self, criterion: BestBasisCriterion) -> Self {
106 self.criterion = criterion;
107 self
108 }
109
110 pub fn decompose(&mut self, signal: &ArrayView1<f64>) -> Result<()> {
112 self.nodes.clear();
113
114 let root = WaveletPacketNode::new(signal.to_owned(), String::new(), 0, 0);
116 self.nodes.insert(String::new(), root);
117
118 self.decompose_node("", 0)?;
120
121 Ok(())
122 }
123
124 fn decompose_node(&mut self, path: &str, level: usize) -> Result<()> {
126 if level >= self.max_level {
127 return Ok(());
128 }
129
130 let node = self
132 .nodes
133 .get(path)
134 .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?
135 .clone();
136
137 let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
139
140 let (approx, detail) = dwt.decompose(&node.data.view())?;
142
143 let approx_path = format!("{}a", path);
145 let detail_path = format!("{}d", path);
146
147 let index = node.index;
148 let approx_node = WaveletPacketNode::new(approx, approx_path.clone(), level + 1, index * 2);
149 let detail_node =
150 WaveletPacketNode::new(detail, detail_path.clone(), level + 1, index * 2 + 1);
151
152 self.nodes.insert(approx_path.clone(), approx_node);
153 self.nodes.insert(detail_path.clone(), detail_node);
154
155 self.decompose_node(&approx_path, level + 1)?;
157 self.decompose_node(&detail_path, level + 1)?;
158
159 Ok(())
160 }
161
162 pub fn best_basis(&self) -> Result<Vec<WaveletPacketNode>> {
164 let mut best_nodes = Vec::new();
165 self.select_best_basis("", &mut best_nodes)?;
166 Ok(best_nodes)
167 }
168
169 fn select_best_basis(&self, path: &str, selected: &mut Vec<WaveletPacketNode>) -> Result<f64> {
171 let node = self
172 .nodes
173 .get(path)
174 .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?;
175
176 let approx_path = format!("{}a", path);
177 let detail_path = format!("{}d", path);
178
179 if self.nodes.contains_key(&approx_path) && self.nodes.contains_key(&detail_path) {
181 let approx_cost = self.select_best_basis(&approx_path, selected)?;
183 let detail_cost = self.select_best_basis(&detail_path, selected)?;
184 let children_cost = approx_cost + detail_cost;
185
186 if node.cost <= children_cost {
188 selected.retain(|n| !n.path.starts_with(path) || n.path == path);
190 selected.push(node.clone());
191 Ok(node.cost)
192 } else {
193 Ok(children_cost)
195 }
196 } else {
197 selected.push(node.clone());
199 Ok(node.cost)
200 }
201 }
202
203 pub fn reconstruct(&self, nodes: &[WaveletPacketNode]) -> Result<Array1<f64>> {
210 if nodes.is_empty() {
211 return Err(TransformError::InvalidInput(
212 "No nodes provided for reconstruction".to_string(),
213 ));
214 }
215
216 if let Some(root) = nodes.iter().find(|n| n.path.is_empty()) {
218 return Ok(root.data.clone());
219 }
220
221 let mut tree: HashMap<String, Array1<f64>> = nodes
223 .iter()
224 .map(|n| (n.path.clone(), n.data.clone()))
225 .collect();
226
227 let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
229
230 let max_level = nodes.iter().map(|n| n.level).max().unwrap_or(0);
232
233 for _level in (1..=max_level).rev() {
235 let parents: Vec<String> = tree
237 .keys()
238 .filter_map(|p| {
239 if p.is_empty() {
240 return None;
241 }
242 let parent = &p[..p.len() - 1];
244 let approx_key = format!("{}a", parent);
246 let detail_key = format!("{}d", parent);
247 if tree.contains_key(&approx_key)
248 && tree.contains_key(&detail_key)
249 && !tree.contains_key(parent)
250 {
251 Some(parent.to_string())
252 } else {
253 None
254 }
255 })
256 .collect::<std::collections::HashSet<_>>()
257 .into_iter()
258 .collect();
259
260 for parent in parents {
261 let approx_key = format!("{}a", parent);
262 let detail_key = format!("{}d", parent);
263
264 let approx = tree.remove(&approx_key).ok_or_else(|| {
265 TransformError::InvalidInput(format!("Missing approx node: {}", approx_key))
266 })?;
267 let detail = tree.remove(&detail_key).ok_or_else(|| {
268 TransformError::InvalidInput(format!("Missing detail node: {}", detail_key))
269 })?;
270
271 let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
272 tree.insert(parent, reconstructed);
273 }
274 }
275
276 tree.remove("").ok_or_else(|| {
278 TransformError::InvalidInput(
279 "Could not fully reconstruct to root — basis nodes may be incomplete".to_string(),
280 )
281 })
282 }
283
284 pub fn get_level(&self, level: usize) -> Vec<&WaveletPacketNode> {
286 self.nodes
287 .values()
288 .filter(|node| node.level == level)
289 .collect()
290 }
291
292 pub fn get_node(&self, path: &str) -> Option<&WaveletPacketNode> {
294 self.nodes.get(path)
295 }
296
297 pub fn nodes(&self) -> &HashMap<String, WaveletPacketNode> {
299 &self.nodes
300 }
301
302 pub fn best_basis_cost(&self) -> Result<f64> {
304 let best = self.best_basis()?;
305 Ok(best.iter().map(|node| node.cost).sum())
306 }
307}
308
309pub fn denoise_wpt(
311 signal: &ArrayView1<f64>,
312 wavelet: WaveletType,
313 level: usize,
314 threshold: f64,
315) -> Result<Array1<f64>> {
316 let mut wpt = WPT::new(wavelet, level);
318 wpt.decompose(signal)?;
319
320 let best = wpt.best_basis()?;
322
323 let mut denoised_nodes = Vec::new();
325 for mut node in best {
326 for val in node.data.iter_mut() {
328 if val.abs() < threshold {
329 *val = 0.0;
330 } else {
331 *val = if *val > 0.0 {
332 *val - threshold
333 } else {
334 *val + threshold
335 };
336 }
337 }
338 node.update_cost();
339 denoised_nodes.push(node);
340 }
341
342 wpt.reconstruct(&denoised_nodes)
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use approx::assert_abs_diff_eq;
350
351 #[test]
352 fn test_wpt_decompose() -> Result<()> {
353 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
354 let mut wpt = WPT::new(WaveletType::Haar, 2);
355
356 wpt.decompose(&signal.view())?;
357
358 assert!(wpt.get_node("").is_some());
360 assert!(wpt.get_node("a").is_some());
361 assert!(wpt.get_node("d").is_some());
362 assert!(wpt.get_node("aa").is_some());
363 assert!(wpt.get_node("ad").is_some());
364 assert!(wpt.get_node("da").is_some());
365 assert!(wpt.get_node("dd").is_some());
366
367 Ok(())
368 }
369
370 #[test]
371 fn test_wpt_best_basis() -> Result<()> {
372 let signal = Array1::from_vec((0..16).map(|i| (i as f64 * 0.5).sin()).collect());
373 let mut wpt = WPT::new(WaveletType::Haar, 3);
374
375 wpt.decompose(&signal.view())?;
376 let best = wpt.best_basis()?;
377
378 assert!(!best.is_empty());
379
380 let mut paths: Vec<_> = best.iter().map(|n| n.path.clone()).collect();
382 paths.sort();
383 paths.dedup();
384 assert_eq!(paths.len(), best.len());
385
386 Ok(())
387 }
388
389 #[test]
390 fn test_wpt_levels() -> Result<()> {
391 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
392 let mut wpt = WPT::new(WaveletType::Haar, 2);
393
394 wpt.decompose(&signal.view())?;
395
396 let level0 = wpt.get_level(0);
397 let level1 = wpt.get_level(1);
398 let level2 = wpt.get_level(2);
399
400 assert_eq!(level0.len(), 1);
401 assert_eq!(level1.len(), 2);
402 assert_eq!(level2.len(), 4);
403
404 Ok(())
405 }
406
407 #[test]
408 fn test_wavelet_packet_node_cost() {
409 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
410 let node = WaveletPacketNode::new(data, "test".to_string(), 1, 0);
411
412 assert!(node.cost >= 0.0);
413 }
414
415 #[test]
416 fn test_best_basis_criterion() {
417 let wpt1 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::Shannon);
418 assert_eq!(wpt1.criterion, BestBasisCriterion::Shannon);
419
420 let wpt2 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::LogEnergy);
421 assert_eq!(wpt2.criterion, BestBasisCriterion::LogEnergy);
422 }
423
424 #[test]
425 fn test_wpt_reconstruct_from_best_basis() -> Result<()> {
426 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
428 let original_len = signal.len();
429 let mut wpt = WPT::new(WaveletType::Haar, 2);
430 wpt.decompose(&signal.view())?;
431 let best = wpt.best_basis()?;
432 let reconstructed = wpt.reconstruct(&best)?;
433 let diff = (reconstructed.len() as isize - original_len as isize).unsigned_abs();
435 assert!(
436 diff <= 2,
437 "Reconstructed length {} too different from original {}",
438 reconstructed.len(),
439 original_len
440 );
441 Ok(())
442 }
443
444 #[test]
445 fn test_wpt_reconstruct_leaf_nodes() -> Result<()> {
446 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
448 let mut wpt = WPT::new(WaveletType::Haar, 1);
449 wpt.decompose(&signal.view())?;
450 let level1: Vec<WaveletPacketNode> = wpt.get_level(1).into_iter().cloned().collect();
451 assert_eq!(level1.len(), 2);
453 let reconstructed = wpt.reconstruct(&level1)?;
454 assert!(reconstructed.len() > 0);
455 Ok(())
456 }
457
458 #[test]
459 fn test_wpt_reconstruct_root_shortcut() -> Result<()> {
460 let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
462 let root = WaveletPacketNode::new(data.clone(), String::new(), 0, 0);
463 let wpt = WPT::new(WaveletType::Haar, 2);
464 let result = wpt.reconstruct(&[root])?;
465 assert_eq!(result.len(), data.len());
466 assert_abs_diff_eq!(result[0], data[0], epsilon = 1e-10);
467 Ok(())
468 }
469}