Skip to main content

scirs2_transform/signal_transforms/
wpt.rs

1//! Wavelet Packet Transform (WPT) Implementation
2//!
3//! Provides wavelet packet decomposition and best basis selection.
4//! Wavelet packets extend DWT by decomposing both approximation and detail coefficients.
5
6use crate::error::{Result, TransformError};
7use crate::signal_transforms::dwt::{BoundaryMode, WaveletType, DWT};
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use std::collections::HashMap;
10
11/// Wavelet packet node
12#[derive(Debug, Clone)]
13pub struct WaveletPacketNode {
14    /// Node data (coefficients)
15    pub data: Array1<f64>,
16    /// Node path (sequence of 'a' for approximation, 'd' for detail)
17    pub path: String,
18    /// Level in the packet tree
19    pub level: usize,
20    /// Node index at this level
21    pub index: usize,
22    /// Cost/entropy of this node
23    pub cost: f64,
24}
25
26impl WaveletPacketNode {
27    /// Create a new wavelet packet node
28    pub fn new(data: Array1<f64>, path: String, level: usize, index: usize) -> Self {
29        let cost = Self::compute_cost(&data);
30        WaveletPacketNode {
31            data,
32            path,
33            level,
34            index,
35            cost,
36        }
37    }
38
39    /// Compute the cost (Shannon entropy) of the node
40    fn compute_cost(data: &Array1<f64>) -> f64 {
41        let energy: f64 = data.iter().map(|x| x * x).sum();
42        if energy < 1e-10 {
43            return 0.0;
44        }
45
46        let mut entropy = 0.0;
47        for &val in data.iter() {
48            let p = (val * val) / energy;
49            if p > 1e-10 {
50                entropy -= p * p.ln();
51            }
52        }
53
54        entropy
55    }
56
57    /// Update the cost
58    pub fn update_cost(&mut self) {
59        self.cost = Self::compute_cost(&self.data);
60    }
61}
62
63/// Best basis selection criterion
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum BestBasisCriterion {
66    /// Shannon entropy
67    Shannon,
68    /// Threshold (number of coefficients above threshold)
69    Threshold(f64),
70    /// Log energy
71    LogEnergy,
72    /// Sure (Stein's Unbiased Risk Estimate)
73    Sure,
74}
75
76/// Wavelet Packet Transform
77#[derive(Debug, Clone)]
78pub struct WPT {
79    wavelet: WaveletType,
80    max_level: usize,
81    boundary: BoundaryMode,
82    criterion: BestBasisCriterion,
83    nodes: HashMap<String, WaveletPacketNode>,
84}
85
86impl WPT {
87    /// Create a new WPT instance
88    pub fn new(wavelet: WaveletType, max_level: usize) -> Self {
89        WPT {
90            wavelet,
91            max_level,
92            boundary: BoundaryMode::Symmetric,
93            criterion: BestBasisCriterion::Shannon,
94            nodes: HashMap::new(),
95        }
96    }
97
98    /// Set the boundary mode
99    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
100        self.boundary = boundary;
101        self
102    }
103
104    /// Set the best basis criterion
105    pub fn with_criterion(mut self, criterion: BestBasisCriterion) -> Self {
106        self.criterion = criterion;
107        self
108    }
109
110    /// Perform full wavelet packet decomposition
111    pub fn decompose(&mut self, signal: &ArrayView1<f64>) -> Result<()> {
112        self.nodes.clear();
113
114        // Create root node
115        let root = WaveletPacketNode::new(signal.to_owned(), String::new(), 0, 0);
116        self.nodes.insert(String::new(), root);
117
118        // Recursively decompose
119        self.decompose_node("", 0)?;
120
121        Ok(())
122    }
123
124    /// Recursively decompose a node
125    fn decompose_node(&mut self, path: &str, level: usize) -> Result<()> {
126        if level >= self.max_level {
127            return Ok(());
128        }
129
130        // Get the current node
131        let node = self
132            .nodes
133            .get(path)
134            .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?
135            .clone();
136
137        // Create DWT instance
138        let dwt = DWT::new(self.wavelet)?.with_boundary(self.boundary);
139
140        // Decompose
141        let (approx, detail) = dwt.decompose(&node.data.view())?;
142
143        // Create child nodes
144        let approx_path = format!("{}a", path);
145        let detail_path = format!("{}d", path);
146
147        let index = node.index;
148        let approx_node = WaveletPacketNode::new(approx, approx_path.clone(), level + 1, index * 2);
149        let detail_node =
150            WaveletPacketNode::new(detail, detail_path.clone(), level + 1, index * 2 + 1);
151
152        self.nodes.insert(approx_path.clone(), approx_node);
153        self.nodes.insert(detail_path.clone(), detail_node);
154
155        // Recursively decompose child nodes
156        self.decompose_node(&approx_path, level + 1)?;
157        self.decompose_node(&detail_path, level + 1)?;
158
159        Ok(())
160    }
161
162    /// Select the best basis using the specified criterion
163    pub fn best_basis(&self) -> Result<Vec<WaveletPacketNode>> {
164        let mut best_nodes = Vec::new();
165        self.select_best_basis("", &mut best_nodes)?;
166        Ok(best_nodes)
167    }
168
169    /// Recursively select best basis
170    fn select_best_basis(&self, path: &str, selected: &mut Vec<WaveletPacketNode>) -> Result<f64> {
171        let node = self
172            .nodes
173            .get(path)
174            .ok_or_else(|| TransformError::InvalidInput(format!("Node not found: {}", path)))?;
175
176        let approx_path = format!("{}a", path);
177        let detail_path = format!("{}d", path);
178
179        // Check if we have children
180        if self.nodes.contains_key(&approx_path) && self.nodes.contains_key(&detail_path) {
181            // Compute cost of decomposition
182            let approx_cost = self.select_best_basis(&approx_path, selected)?;
183            let detail_cost = self.select_best_basis(&detail_path, selected)?;
184            let children_cost = approx_cost + detail_cost;
185
186            // Compare with keeping this node
187            if node.cost <= children_cost {
188                // Keep this node
189                selected.retain(|n| !n.path.starts_with(path) || n.path == path);
190                selected.push(node.clone());
191                Ok(node.cost)
192            } else {
193                // Use children
194                Ok(children_cost)
195            }
196        } else {
197            // Leaf node
198            selected.push(node.clone());
199            Ok(node.cost)
200        }
201    }
202
203    /// Reconstruct signal from wavelet packet coefficients
204    pub fn reconstruct(&self, nodes: &[WaveletPacketNode]) -> Result<Array1<f64>> {
205        if nodes.is_empty() {
206            return Err(TransformError::InvalidInput(
207                "No nodes provided for reconstruction".to_string(),
208            ));
209        }
210
211        // Find the root or reconstruct from best basis
212        if let Some(root) = nodes.iter().find(|n| n.path.is_empty()) {
213            return Ok(root.data.clone());
214        }
215
216        // For now, return error - full reconstruction requires inverse WPT
217        Err(TransformError::NotImplemented(
218            "Reconstruction from arbitrary basis not yet implemented".to_string(),
219        ))
220    }
221
222    /// Get all nodes at a specific level
223    pub fn get_level(&self, level: usize) -> Vec<&WaveletPacketNode> {
224        self.nodes
225            .values()
226            .filter(|node| node.level == level)
227            .collect()
228    }
229
230    /// Get a specific node by path
231    pub fn get_node(&self, path: &str) -> Option<&WaveletPacketNode> {
232        self.nodes.get(path)
233    }
234
235    /// Get all nodes
236    pub fn nodes(&self) -> &HashMap<String, WaveletPacketNode> {
237        &self.nodes
238    }
239
240    /// Compute the total cost of the best basis
241    pub fn best_basis_cost(&self) -> Result<f64> {
242        let best = self.best_basis()?;
243        Ok(best.iter().map(|node| node.cost).sum())
244    }
245}
246
247/// Denoise using wavelet packet transform
248pub fn denoise_wpt(
249    signal: &ArrayView1<f64>,
250    wavelet: WaveletType,
251    level: usize,
252    threshold: f64,
253) -> Result<Array1<f64>> {
254    // Perform WPT
255    let mut wpt = WPT::new(wavelet, level);
256    wpt.decompose(signal)?;
257
258    // Get best basis
259    let best = wpt.best_basis()?;
260
261    // Apply thresholding
262    let mut denoised_nodes = Vec::new();
263    for mut node in best {
264        // Soft thresholding
265        for val in node.data.iter_mut() {
266            if val.abs() < threshold {
267                *val = 0.0;
268            } else {
269                *val = if *val > 0.0 {
270                    *val - threshold
271                } else {
272                    *val + threshold
273                };
274            }
275        }
276        node.update_cost();
277        denoised_nodes.push(node);
278    }
279
280    // Reconstruct
281    wpt.reconstruct(&denoised_nodes)
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use approx::assert_abs_diff_eq;
288
289    #[test]
290    fn test_wpt_decompose() -> Result<()> {
291        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
292        let mut wpt = WPT::new(WaveletType::Haar, 2);
293
294        wpt.decompose(&signal.view())?;
295
296        // Should have nodes at levels 0, 1, 2
297        assert!(wpt.get_node("").is_some());
298        assert!(wpt.get_node("a").is_some());
299        assert!(wpt.get_node("d").is_some());
300        assert!(wpt.get_node("aa").is_some());
301        assert!(wpt.get_node("ad").is_some());
302        assert!(wpt.get_node("da").is_some());
303        assert!(wpt.get_node("dd").is_some());
304
305        Ok(())
306    }
307
308    #[test]
309    fn test_wpt_best_basis() -> Result<()> {
310        let signal = Array1::from_vec((0..16).map(|i| (i as f64 * 0.5).sin()).collect());
311        let mut wpt = WPT::new(WaveletType::Haar, 3);
312
313        wpt.decompose(&signal.view())?;
314        let best = wpt.best_basis()?;
315
316        assert!(!best.is_empty());
317
318        // Check that all selected nodes are unique
319        let mut paths: Vec<_> = best.iter().map(|n| n.path.clone()).collect();
320        paths.sort();
321        paths.dedup();
322        assert_eq!(paths.len(), best.len());
323
324        Ok(())
325    }
326
327    #[test]
328    fn test_wpt_levels() -> Result<()> {
329        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
330        let mut wpt = WPT::new(WaveletType::Haar, 2);
331
332        wpt.decompose(&signal.view())?;
333
334        let level0 = wpt.get_level(0);
335        let level1 = wpt.get_level(1);
336        let level2 = wpt.get_level(2);
337
338        assert_eq!(level0.len(), 1);
339        assert_eq!(level1.len(), 2);
340        assert_eq!(level2.len(), 4);
341
342        Ok(())
343    }
344
345    #[test]
346    fn test_wavelet_packet_node_cost() {
347        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
348        let node = WaveletPacketNode::new(data, "test".to_string(), 1, 0);
349
350        assert!(node.cost >= 0.0);
351    }
352
353    #[test]
354    fn test_best_basis_criterion() {
355        let wpt1 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::Shannon);
356        assert_eq!(wpt1.criterion, BestBasisCriterion::Shannon);
357
358        let wpt2 = WPT::new(WaveletType::Haar, 3).with_criterion(BestBasisCriterion::LogEnergy);
359        assert_eq!(wpt2.criterion, BestBasisCriterion::LogEnergy);
360    }
361}