Skip to main content

scirs2_transform/signal_transforms/
wpt.rs

1//! Wavelet Packet Transform (WPT) Implementation
2//!
3//! Provides wavelet packet decomposition and best basis selection.
4//! Wavelet packets extend DWT by decomposing both approximation and detail coefficients.
5
6use crate::error::{Result, TransformError};
7use crate::signal_transforms::dwt::{BoundaryMode, WaveletType, DWT};
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use std::collections::HashMap;
10
11/// Wavelet packet node
12#[derive(Debug, Clone)]
13pub struct WaveletPacketNode {
14    /// Node data (coefficients)
15    pub data: Array1<f64>,
16    /// Node path (sequence of 'a' for approximation, 'd' for detail)
17    pub path: String,
18    /// Level in the packet tree
19    pub level: usize,
20    /// Node index at this level
21    pub index: usize,
22    /// Cost/entropy of this node
23    pub cost: f64,
24}
25
26impl WaveletPacketNode {
27    /// Create a new wavelet packet node
28    pub fn new(data: Array1<f64>, path: String, level: usize, index: usize) -> Self {
29        let cost = Self::compute_cost(&data);
30        WaveletPacketNode {
31            data,
32            path,
33            level,
34            index,
35            cost,
36        }
37    }
38
39    /// Compute the cost (Shannon entropy) of the node
40    fn compute_cost(data: &Array1<f64>) -> f64 {
41        let energy: f64 = data.iter().map(|x| x * x).sum();
42        if energy < 1e-10 {
43            return 0.0;
44        }
45
46        let mut entropy = 0.0;
47        for &val in data.iter() {
48            let p = (val * val) / energy;
49            if p > 1e-10 {
50                entropy -= p * p.ln();
51            }
52        }
53
54        entropy
55    }
56
57    /// Update the cost
58    pub fn update_cost(&mut self) {
59        self.cost = Self::compute_cost(&self.data);
60    }
61}
62
63/// Best basis selection criterion
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum BestBasisCriterion {
66    /// Shannon entropy
67    Shannon,
68    /// Threshold (number of coefficients above threshold)
69    Threshold(f64),
70    /// Log energy
71    LogEnergy,
72    /// Sure (Stein's Unbiased Risk Estimate)
73    Sure,
74}
75
76/// Wavelet Packet Transform
77#[derive(Debug, Clone)]
78pub struct WPT {
79    wavelet: WaveletType,
80    max_level: usize,
81    boundary: BoundaryMode,
82    criterion: BestBasisCriterion,
83    nodes: HashMap<String, WaveletPacketNode>,
84}
85
86impl WPT {
87    /// Create a new WPT instance
88    pub fn new(wavelet: WaveletType, max_level: usize) -> Self {
89        WPT {
90            wavelet,
91            max_level,
92            boundary: BoundaryMode::Symmetric,
93            criterion: BestBasisCriterion::Shannon,
94            nodes: HashMap::new(),
95        }
96    }
97
98    /// Set the boundary mode
99    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
100        self.boundary = boundary;
101        self
102    }
103
104    /// Set the best basis criterion
105    pub fn with_criterion(mut self, criterion: BestBasisCriterion) -> Self {
106        self.criterion = criterion;
107        self
108    }
109
110    /// Perform full wavelet packet decomposition
111    pub fn decompose(&mut self, signal: &ArrayView1<f64>) -> Result<()> {
112        self.nodes.clear();
113
114        // Create root node
115        let root = WaveletPacketNode::new(signal.to_owned(), String::new(), 0, 0);
116        self.nodes.insert(String::new(), root);
117
118        // Recursively decompose
119        self.decompose_node("", 0)?;
120
121        Ok(())
122    }
123
124    /// Recursively decompose a node
125    fn decompose_node(&mut self, path: &str, level: usize) -> Result<()> {
126        if level >= self.max_level {
127            return Ok(());
128        }
129
130        // Get the current node
131        let node = self
132            .nodes
133            .get(path)
134            .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?
135            .clone();
136
137        // Create DWT instance
138        let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
139
140        // Decompose
141        let (approx, detail) = dwt.decompose(&node.data.view())?;
142
143        // Create child nodes
144        let approx_path = format!("{}a", path);
145        let detail_path = format!("{}d", path);
146
147        let index = node.index;
148        let approx_node = WaveletPacketNode::new(approx, approx_path.clone(), level + 1, index * 2);
149        let detail_node =
150            WaveletPacketNode::new(detail, detail_path.clone(), level + 1, index * 2 + 1);
151
152        self.nodes.insert(approx_path.clone(), approx_node);
153        self.nodes.insert(detail_path.clone(), detail_node);
154
155        // Recursively decompose child nodes
156        self.decompose_node(&approx_path, level + 1)?;
157        self.decompose_node(&detail_path, level + 1)?;
158
159        Ok(())
160    }
161
162    /// Select the best basis using the specified criterion
163    pub fn best_basis(&self) -> Result<Vec<WaveletPacketNode>> {
164        let mut best_nodes = Vec::new();
165        self.select_best_basis("", &mut best_nodes)?;
166        Ok(best_nodes)
167    }
168
169    /// Recursively select best basis
170    fn select_best_basis(&self, path: &str, selected: &mut Vec<WaveletPacketNode>) -> Result<f64> {
171        let node = self
172            .nodes
173            .get(path)
174            .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?;
175
176        let approx_path = format!("{}a", path);
177        let detail_path = format!("{}d", path);
178
179        // Check if we have children
180        if self.nodes.contains_key(&approx_path) && self.nodes.contains_key(&detail_path) {
181            // Compute cost of decomposition
182            let approx_cost = self.select_best_basis(&approx_path, selected)?;
183            let detail_cost = self.select_best_basis(&detail_path, selected)?;
184            let children_cost = approx_cost + detail_cost;
185
186            // Compare with keeping this node
187            if node.cost <= children_cost {
188                // Keep this node
189                selected.retain(|n| !n.path.starts_with(path) || n.path == path);
190                selected.push(node.clone());
191                Ok(node.cost)
192            } else {
193                // Use children
194                Ok(children_cost)
195            }
196        } else {
197            // Leaf node
198            selected.push(node.clone());
199            Ok(node.cost)
200        }
201    }
202
203    /// Reconstruct signal from wavelet packet coefficients
204    ///
205    /// Performs inverse WPT from a set of leaf nodes (e.g. a best-basis selection).
206    /// The algorithm works bottom-up: it places each basis node at its position in the
207    /// packet tree, then repeatedly merges pairs of sibling nodes using the inverse DWT
208    /// until the root (level 0) is reached.
209    pub fn reconstruct(&self, nodes: &[WaveletPacketNode]) -> Result<Array1<f64>> {
210        if nodes.is_empty() {
211            return Err(TransformError::InvalidInput(
212                "No nodes provided for reconstruction".to_string(),
213            ));
214        }
215
216        // Short-circuit: if the root node is among the inputs, return it directly.
217        if let Some(root) = nodes.iter().find(|n| n.path.is_empty()) {
218            return Ok(root.data.clone());
219        }
220
221        // Build a mutable map path -> data, starting from all input nodes.
222        let mut tree: HashMap<String, Array1<f64>> = nodes
223            .iter()
224            .map(|n| (n.path.clone(), n.data.clone()))
225            .collect();
226
227        // Create one DWT instance for reconstruction filters.
228        let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
229
230        // Find the maximum depth we need to collapse to.
231        let max_level = nodes.iter().map(|n| n.level).max().unwrap_or(0);
232
233        // Bottom-up merging: at each level collapse sibling pairs.
234        for _level in (1..=max_level).rev() {
235            // Collect all unique parent paths that still need merging.
236            let parents: Vec<String> = tree
237                .keys()
238                .filter_map(|p| {
239                    if p.is_empty() {
240                        return None;
241                    }
242                    // Parent is everything but the last char ('a' or 'd')
243                    let parent = &p[..p.len() - 1];
244                    // Only include if BOTH children are present and parent is absent
245                    let approx_key = format!("{}a", parent);
246                    let detail_key = format!("{}d", parent);
247                    if tree.contains_key(&approx_key)
248                        && tree.contains_key(&detail_key)
249                        && !tree.contains_key(parent)
250                    {
251                        Some(parent.to_string())
252                    } else {
253                        None
254                    }
255                })
256                .collect::<std::collections::HashSet<_>>()
257                .into_iter()
258                .collect();
259
260            for parent in parents {
261                let approx_key = format!("{}a", parent);
262                let detail_key = format!("{}d", parent);
263
264                let approx = tree.remove(&approx_key).ok_or_else(|| {
265                    TransformError::InvalidInput(format!("Missing approx node: {}", approx_key))
266                })?;
267                let detail = tree.remove(&detail_key).ok_or_else(|| {
268                    TransformError::InvalidInput(format!("Missing detail node: {}", detail_key))
269                })?;
270
271                let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
272                tree.insert(parent, reconstructed);
273            }
274        }
275
276        // The root should now be in the tree.
277        tree.remove("").ok_or_else(|| {
278            TransformError::InvalidInput(
279                "Could not fully reconstruct to root — basis nodes may be incomplete".to_string(),
280            )
281        })
282    }
283
284    /// Get all nodes at a specific level
285    pub fn get_level(&self, level: usize) -> Vec<&WaveletPacketNode> {
286        self.nodes
287            .values()
288            .filter(|node| node.level == level)
289            .collect()
290    }
291
292    /// Get a specific node by path
293    pub fn get_node(&self, path: &str) -> Option<&WaveletPacketNode> {
294        self.nodes.get(path)
295    }
296
297    /// Get all nodes
298    pub fn nodes(&self) -> &HashMap<String, WaveletPacketNode> {
299        &self.nodes
300    }
301
302    /// Compute the total cost of the best basis
303    pub fn best_basis_cost(&self) -> Result<f64> {
304        let best = self.best_basis()?;
305        Ok(best.iter().map(|node| node.cost).sum())
306    }
307}
308
309/// Denoise using wavelet packet transform
310pub fn denoise_wpt(
311    signal: &ArrayView1<f64>,
312    wavelet: WaveletType,
313    level: usize,
314    threshold: f64,
315) -> Result<Array1<f64>> {
316    // Perform WPT
317    let mut wpt = WPT::new(wavelet, level);
318    wpt.decompose(signal)?;
319
320    // Get best basis
321    let best = wpt.best_basis()?;
322
323    // Apply thresholding
324    let mut denoised_nodes = Vec::new();
325    for mut node in best {
326        // Soft thresholding
327        for val in node.data.iter_mut() {
328            if val.abs() < threshold {
329                *val = 0.0;
330            } else {
331                *val = if *val > 0.0 {
332                    *val - threshold
333                } else {
334                    *val + threshold
335                };
336            }
337        }
338        node.update_cost();
339        denoised_nodes.push(node);
340    }
341
342    // Reconstruct
343    wpt.reconstruct(&denoised_nodes)
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use approx::assert_abs_diff_eq;
350
351    #[test]
352    fn test_wpt_decompose() -> Result<()> {
353        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
354        let mut wpt = WPT::new(WaveletType::Haar, 2);
355
356        wpt.decompose(&signal.view())?;
357
358        // Should have nodes at levels 0, 1, 2
359        assert!(wpt.get_node("").is_some());
360        assert!(wpt.get_node("a").is_some());
361        assert!(wpt.get_node("d").is_some());
362        assert!(wpt.get_node("aa").is_some());
363        assert!(wpt.get_node("ad").is_some());
364        assert!(wpt.get_node("da").is_some());
365        assert!(wpt.get_node("dd").is_some());
366
367        Ok(())
368    }
369
370    #[test]
371    fn test_wpt_best_basis() -> Result<()> {
372        let signal = Array1::from_vec((0..16).map(|i| (i as f64 * 0.5).sin()).collect());
373        let mut wpt = WPT::new(WaveletType::Haar, 3);
374
375        wpt.decompose(&signal.view())?;
376        let best = wpt.best_basis()?;
377
378        assert!(!best.is_empty());
379
380        // Check that all selected nodes are unique
381        let mut paths: Vec<_> = best.iter().map(|n| n.path.clone()).collect();
382        paths.sort();
383        paths.dedup();
384        assert_eq!(paths.len(), best.len());
385
386        Ok(())
387    }
388
389    #[test]
390    fn test_wpt_levels() -> Result<()> {
391        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
392        let mut wpt = WPT::new(WaveletType::Haar, 2);
393
394        wpt.decompose(&signal.view())?;
395
396        let level0 = wpt.get_level(0);
397        let level1 = wpt.get_level(1);
398        let level2 = wpt.get_level(2);
399
400        assert_eq!(level0.len(), 1);
401        assert_eq!(level1.len(), 2);
402        assert_eq!(level2.len(), 4);
403
404        Ok(())
405    }
406
407    #[test]
408    fn test_wavelet_packet_node_cost() {
409        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
410        let node = WaveletPacketNode::new(data, "test".to_string(), 1, 0);
411
412        assert!(node.cost >= 0.0);
413    }
414
415    #[test]
416    fn test_best_basis_criterion() {
417        let wpt1 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::Shannon);
418        assert_eq!(wpt1.criterion, BestBasisCriterion::Shannon);
419
420        let wpt2 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::LogEnergy);
421        assert_eq!(wpt2.criterion, BestBasisCriterion::LogEnergy);
422    }
423
424    #[test]
425    fn test_wpt_reconstruct_from_best_basis() -> Result<()> {
426        // Reconstruct from best-basis nodes and check length is preserved
427        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
428        let original_len = signal.len();
429        let mut wpt = WPT::new(WaveletType::Haar, 2);
430        wpt.decompose(&signal.view())?;
431        let best = wpt.best_basis()?;
432        let reconstructed = wpt.reconstruct(&best)?;
433        // Reconstruction may differ in length due to boundary effects; allow ±2 samples
434        let diff = (reconstructed.len() as isize - original_len as isize).unsigned_abs();
435        assert!(
436            diff <= 2,
437            "Reconstructed length {} too different from original {}",
438            reconstructed.len(),
439            original_len
440        );
441        Ok(())
442    }
443
444    #[test]
445    fn test_wpt_reconstruct_leaf_nodes() -> Result<()> {
446        // Feed all leaf nodes at level 1 and verify reconstruction succeeds
447        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
448        let mut wpt = WPT::new(WaveletType::Haar, 1);
449        wpt.decompose(&signal.view())?;
450        let level1: Vec<WaveletPacketNode> = wpt.get_level(1).into_iter().cloned().collect();
451        // Both "a" and "d" nodes should be present, allowing reconstruction
452        assert_eq!(level1.len(), 2);
453        let reconstructed = wpt.reconstruct(&level1)?;
454        assert!(reconstructed.len() > 0);
455        Ok(())
456    }
457
458    #[test]
459    fn test_wpt_reconstruct_root_shortcut() -> Result<()> {
460        // If the root node (empty path) is in the slice, reconstruct returns it directly.
461        let data = Array1::from_vec(vec![1.0, 2.0, 3.0]);
462        let root = WaveletPacketNode::new(data.clone(), String::new(), 0, 0);
463        let wpt = WPT::new(WaveletType::Haar, 2);
464        let result = wpt.reconstruct(&[root])?;
465        assert_eq!(result.len(), data.len());
466        assert_abs_diff_eq!(result[0], data[0], epsilon = 1e-10);
467        Ok(())
468    }
469}