Skip to main content

scirs2_fft/
wavelet_packets.rs

1//! Wavelet Packet Transform (WPT)
2//!
3//! Wavelet packets generalize the discrete wavelet transform by allowing full
4//! decomposition of both approximation and detail subbands at each level.
5//! This produces a complete binary tree of subband coefficients.
6//!
7//! The best-basis algorithm (Coifman–Wickerhauser 1992) selects an optimal
8//! orthonormal basis from the packet tree by minimising an additive cost function
9//! (e.g. Shannon entropy or log-energy).
10//!
11//! # References
12//! - Coifman, R.R. & Wickerhauser, M.V. (1992). Entropy-based algorithms for best
13//!   basis selection. IEEE Trans. Inf. Theory, 38(2), 713–718.
14//! - Mallat, S. (1999). A Wavelet Tour of Signal Processing. Academic Press.
15
16use std::collections::HashMap;
17use std::f64::consts::LN_2;
18
19use crate::error::{FFTError, FFTResult};
20
21// ─────────────────────────────────────────────────────────────────────────────
22// Wavelet filter definitions
23// ─────────────────────────────────────────────────────────────────────────────
24
25/// Supported orthonormal wavelet families.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum Wavelet {
28    /// Haar wavelet (db1)
29    Haar,
30    /// Daubechies 4-tap (db2)
31    Db2,
32    /// Daubechies 6-tap (db3)
33    Db3,
34    /// Daubechies 8-tap (db4)
35    Db4,
36    /// Daubechies 10-tap (db5)
37    Db5,
38    /// Symlet 4-tap (sym2)
39    Sym2,
40    /// Symlet 8-tap (sym4)
41    Sym4,
42    /// Coiflet 6-tap (coif1)
43    Coif1,
44    /// Biorthogonal 2.2 (bior2.2) – for approximation only; analysis filters
45    Bior22,
46}
47
48/// Low-pass (scaling) and high-pass (wavelet) analysis filters for a wavelet.
49#[derive(Debug, Clone)]
50pub struct WaveletFilters {
51    /// Low-pass decomposition filter h₀
52    pub lo_d: Vec<f64>,
53    /// High-pass decomposition filter h₁
54    pub hi_d: Vec<f64>,
55    /// Low-pass reconstruction filter g₀
56    pub lo_r: Vec<f64>,
57    /// High-pass reconstruction filter g₁
58    pub hi_r: Vec<f64>,
59}
60
61impl WaveletFilters {
62    /// Return filters for the given wavelet.
63    pub fn for_wavelet(w: Wavelet) -> Self {
64        match w {
65            Wavelet::Haar => {
66                let s = 1.0_f64 / 2.0_f64.sqrt();
67                let lo = vec![s, s];
68                let hi = vec![s, -s];
69                // For orthogonal wavelets with the transpose synthesis formula,
70                // the synthesis filters equal the analysis filters (lo_r = lo_d, hi_r = hi_d).
71                let lo_r = lo.clone();
72                let hi_r = hi.clone();
73                WaveletFilters {
74                    lo_d: lo,
75                    hi_d: hi,
76                    lo_r,
77                    hi_r,
78                }
79            }
80            Wavelet::Db2 => {
81                let s3 = 3.0_f64.sqrt();
82                let norm = 4.0 * 2.0_f64.sqrt(); // 4*sqrt(2)
83                let lo = vec![
84                    (1.0 + s3) / norm,
85                    (3.0 + s3) / norm,
86                    (3.0 - s3) / norm,
87                    (1.0 - s3) / norm,
88                ];
89                let hi = qmf_hi(&lo);
90                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
91                let lo_r = lo.clone();
92                let hi_r = hi.clone();
93                WaveletFilters {
94                    lo_d: lo,
95                    hi_d: hi,
96                    lo_r,
97                    hi_r,
98                }
99            }
100            Wavelet::Db3 => {
101                // Daubechies db3 (6-tap) coefficients
102                let lo = vec![
103                    0.035226291882100656,
104                    -0.08544127388202666,
105                    -0.13501102001039084,
106                    0.4598775021193313,
107                    0.8068915093133388,
108                    0.3326705529509569,
109                ];
110                let hi = qmf_hi(&lo);
111                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
112                let lo_r = lo.clone();
113                let hi_r = hi.clone();
114                WaveletFilters {
115                    lo_d: lo,
116                    hi_d: hi,
117                    lo_r,
118                    hi_r,
119                }
120            }
121            Wavelet::Db4 => {
122                // Daubechies db4 (8-tap) coefficients
123                let lo = vec![
124                    -0.010597401784997278,
125                    0.032883011666982945,
126                    0.030841381835986965,
127                    -0.18703481171888114,
128                    -0.027_983_769_416_983_85,
129                    0.6308807679295904,
130                    0.7148465705525415,
131                    0.23037781330885523,
132                ];
133                let hi = qmf_hi(&lo);
134                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
135                let lo_r = lo.clone();
136                let hi_r = hi.clone();
137                WaveletFilters {
138                    lo_d: lo,
139                    hi_d: hi,
140                    lo_r,
141                    hi_r,
142                }
143            }
144            Wavelet::Db5 => {
145                // Daubechies db5 (10-tap) coefficients
146                let lo = vec![
147                    0.003335725285001549,
148                    -0.012580751999015526,
149                    -0.006241490213011705,
150                    0.07757149384006515,
151                    -0.03224486958502952,
152                    -0.24229488706619015,
153                    0.13842814590110342,
154                    0.7243085284377729,
155                    0.6038292697974729,
156                    0.160102397974125,
157                ];
158                let hi = qmf_hi(&lo);
159                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
160                let lo_r = lo.clone();
161                let hi_r = hi.clone();
162                WaveletFilters {
163                    lo_d: lo,
164                    hi_d: hi,
165                    lo_r,
166                    hi_r,
167                }
168            }
169            Wavelet::Sym2 => {
170                // Symlet sym2 = db2 (same energy, different phase)
171                let s3 = 3.0_f64.sqrt();
172                let lo = vec![
173                    (1.0 - s3) / 8.0_f64.sqrt(),
174                    (3.0 - s3) / 8.0_f64.sqrt(),
175                    (3.0 + s3) / 8.0_f64.sqrt(),
176                    (1.0 + s3) / 8.0_f64.sqrt(),
177                ];
178                let hi = qmf_hi(&lo);
179                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
180                let lo_r = lo.clone();
181                let hi_r = hi.clone();
182                WaveletFilters {
183                    lo_d: lo,
184                    hi_d: hi,
185                    lo_r,
186                    hi_r,
187                }
188            }
189            Wavelet::Sym4 => {
190                // Symlet sym4 (8-tap)
191                let lo = vec![
192                    -0.07576571478927333,
193                    -0.02963552764599851,
194                    0.49761866763201545,
195                    0.8037387518059161,
196                    0.29785779560527736,
197                    -0.09921954357684722,
198                    -0.012603967262037833,
199                    0.032_223_100_604_042_7,
200                ];
201                let hi = qmf_hi(&lo);
202                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
203                let lo_r = lo.clone();
204                let hi_r = hi.clone();
205                WaveletFilters {
206                    lo_d: lo,
207                    hi_d: hi,
208                    lo_r,
209                    hi_r,
210                }
211            }
212            Wavelet::Coif1 => {
213                // Coiflet coif1 (6-tap)
214                let lo = vec![
215                    -0.015655728135960927,
216                    -0.07273261951285047,
217                    0.3848648565381134,
218                    0.8525720202122554,
219                    0.3378976624578092,
220                    -0.07273261951285047,
221                ];
222                let hi = qmf_hi(&lo);
223                // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
224                let lo_r = lo.clone();
225                let hi_r = hi.clone();
226                WaveletFilters {
227                    lo_d: lo,
228                    hi_d: hi,
229                    lo_r,
230                    hi_r,
231                }
232            }
233            Wavelet::Bior22 => {
234                // Biorthogonal 2.2 analysis filters
235                let lo = vec![-0.125, 0.25, 0.75, 0.25, -0.125];
236                let hi = vec![-0.25, 0.5, -0.25];
237                let lo_r: Vec<f64> = lo.iter().rev().cloned().collect();
238                let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
239                WaveletFilters {
240                    lo_d: lo,
241                    hi_d: hi,
242                    lo_r,
243                    hi_r,
244                }
245            }
246        }
247    }
248}
249
250/// Build the high-pass QMF filter from a low-pass filter.
251///
252/// h₁[n] = (-1)^n · h₀[L-1-n]
253fn qmf_hi(lo: &[f64]) -> Vec<f64> {
254    let n = lo.len();
255    lo.iter()
256        .rev()
257        .enumerate()
258        .map(|(k, &v)| if (n - 1 - k) % 2 == 0 { v } else { -v })
259        .collect()
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Core convolution / subsampling helpers
264// ─────────────────────────────────────────────────────────────────────────────
265
266/// Convolve `signal` with `filter` using periodic (circular) boundary extension
267/// and then down-sample by 2 (keep even-indexed samples).
268///
269/// The output length is `ceil(signal.len() / 2)` = `(signal.len() + 1) / 2`.
270/// Using circular (periodic) boundary means the output length depends only on
271/// the signal length, not the filter length.
272fn conv_downsample(signal: &[f64], filter: &[f64]) -> Vec<f64> {
273    let n = signal.len();
274    let out_len = n.div_ceil(2); // ceil(n/2), independent of filter length
275    let mut out = vec![0.0_f64; out_len];
276    for k in 0..out_len {
277        let t = 2 * k;
278        let mut acc = 0.0_f64;
279        for (j, &h) in filter.iter().enumerate() {
280            // periodic (circular) boundary
281            let idx = ((t as isize - j as isize).rem_euclid(n as isize)) as usize;
282            acc += signal[idx] * h;
283        }
284        out[k] = acc;
285    }
286    out
287}
288
289/// Synthesis step: transpose of `conv_downsample` with periodic (circular) boundary.
290///
291/// This is the adjoint (transpose) of the analysis step, ensuring perfect
292/// reconstruction for orthogonal wavelets.  The formula is:
293///
294///   xhat[n] = Σₖ input[k] · filter[(2k − n) mod target_len]
295///
296/// where the modular index is only applied when it falls within the filter support
297/// `[0, filter.len())`.
298///
299/// `target_len` must equal the length of the signal that was passed to
300/// `conv_downsample` to produce `input`.
301fn synthesis_step(input: &[f64], filter: &[f64], target_len: usize) -> Vec<f64> {
302    let n_in = input.len();
303    let flen = filter.len();
304    let n_out = target_len;
305    let mut out = vec![0.0_f64; n_out];
306    for n_idx in 0..n_out {
307        let mut acc = 0.0_f64;
308        for k in 0..n_in {
309            let j = ((2 * k as isize - n_idx as isize).rem_euclid(n_out as isize)) as usize;
310            if j < flen {
311                acc += input[k] * filter[j];
312            }
313        }
314        out[n_idx] = acc;
315    }
316    out
317}
318
319// ─────────────────────────────────────────────────────────────────────────────
320// Node & tree structures
321// ─────────────────────────────────────────────────────────────────────────────
322
323/// A single node in the wavelet packet tree.
324///
325/// The node stores the subband coefficients and its position in the tree.
326/// Position is identified by `(level, index)` where `index ∈ [0, 2^level)`.
327#[derive(Debug, Clone)]
328pub struct WaveletPacketNode {
329    /// Subband coefficients at this node.
330    pub coeffs: Vec<f64>,
331    /// Decomposition level (0 = root, i.e. the original signal).
332    pub level: usize,
333    /// Node index within the level (frequency-ordered).
334    pub index: usize,
335}
336
337impl WaveletPacketNode {
338    /// Create a new node.
339    pub fn new(coeffs: Vec<f64>, level: usize, index: usize) -> Self {
340        WaveletPacketNode {
341            coeffs,
342            level,
343            index,
344        }
345    }
346
347    /// Returns `true` if this node is the root (level 0).
348    pub fn is_root(&self) -> bool {
349        self.level == 0
350    }
351
352    /// Flat key used for `HashMap` storage: `level * OFFSET + index`.
353    fn key(level: usize, index: usize) -> u64 {
354        (level as u64) << 32 | (index as u64)
355    }
356}
357
358/// A full binary tree of wavelet packet nodes.
359///
360/// Nodes are stored in a `HashMap` keyed by `(level, index)`.
361/// The tree is built by `wpd` and each interior node stores *both* the
362/// node's own coefficients and its children (low/high subbands).
363#[derive(Debug, Clone)]
364pub struct WaveletPacketTree {
365    /// All computed nodes, keyed by `WaveletPacketNode::key(level, index)`.
366    nodes: HashMap<u64, WaveletPacketNode>,
367    /// Maximum decomposition depth.
368    pub max_level: usize,
369    /// Wavelet used to build this tree.
370    pub wavelet: Wavelet,
371    /// Length of the original signal (needed for reconstruction).
372    pub signal_len: usize,
373}
374
375impl WaveletPacketTree {
376    /// Create an empty tree.
377    pub fn new(wavelet: Wavelet, max_level: usize, signal_len: usize) -> Self {
378        WaveletPacketTree {
379            nodes: HashMap::new(),
380            max_level,
381            wavelet,
382            signal_len,
383        }
384    }
385
386    /// Insert a node into the tree.
387    pub fn insert(&mut self, node: WaveletPacketNode) {
388        let key = WaveletPacketNode::key(node.level, node.index);
389        self.nodes.insert(key, node);
390    }
391
392    /// Retrieve a node by `(level, index)`.
393    pub fn get(&self, level: usize, index: usize) -> Option<&WaveletPacketNode> {
394        self.nodes.get(&WaveletPacketNode::key(level, index))
395    }
396
397    /// Iterate over all nodes at a given `level`.
398    pub fn nodes_at_level(&self, level: usize) -> impl Iterator<Item = &WaveletPacketNode> {
399        self.nodes.values().filter(move |n| n.level == level)
400    }
401
402    /// All nodes in the tree.
403    pub fn all_nodes(&self) -> impl Iterator<Item = &WaveletPacketNode> {
404        self.nodes.values()
405    }
406}
407
408// ─────────────────────────────────────────────────────────────────────────────
409// Wavelet Packet Decomposition (WPD)
410// ─────────────────────────────────────────────────────────────────────────────
411
412/// Perform a full wavelet packet decomposition up to `max_level`.
413///
414/// Every node (approximation *and* detail) at every level is recursively
415/// decomposed, producing a complete binary tree with `2^(max_level+1) - 1` nodes.
416///
417/// # Arguments
418///
419/// * `signal`    – Real-valued input signal.
420/// * `wavelet`   – Wavelet to use (determines analysis filters).
421/// * `max_level` – Maximum decomposition depth.  The root is level 0.
422///
423/// # Errors
424///
425/// Returns `FFTError::ValueError` if `signal` is empty or `max_level == 0`.
426///
427/// # Example
428///
429/// ```
430/// use scirs2_fft::wavelet_packets::{wpd, Wavelet};
431///
432/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
433/// let tree = wpd(&signal, Wavelet::Db4, 3).expect("decomposition failed");
434/// // Tree has nodes at levels 0 through 3
435/// assert!(tree.get(0, 0).is_some());
436/// assert!(tree.get(3, 7).is_some());
437/// ```
438pub fn wpd(signal: &[f64], wavelet: Wavelet, max_level: usize) -> FFTResult<WaveletPacketTree> {
439    if signal.is_empty() {
440        return Err(FFTError::ValueError("signal must be non-empty".to_string()));
441    }
442    if max_level == 0 {
443        return Err(FFTError::ValueError("max_level must be >= 1".to_string()));
444    }
445
446    let filters = WaveletFilters::for_wavelet(wavelet);
447    let signal_len = signal.len();
448    let mut tree = WaveletPacketTree::new(wavelet, max_level, signal_len);
449
450    // Root node (level 0, index 0) = original signal
451    tree.insert(WaveletPacketNode::new(signal.to_vec(), 0, 0));
452
453    // BFS decomposition
454    for level in 0..max_level {
455        let num_nodes = 1_usize << level;
456        for index in 0..num_nodes {
457            let coeffs = match tree.get(level, index) {
458                Some(n) => n.coeffs.clone(),
459                None => {
460                    return Err(FFTError::InternalError(format!(
461                        "missing node ({level}, {index})"
462                    )))
463                }
464            };
465
466            // Low-pass child (approximation) → (level+1, 2*index)
467            let lo = conv_downsample(&coeffs, &filters.lo_d);
468            tree.insert(WaveletPacketNode::new(lo, level + 1, 2 * index));
469
470            // High-pass child (detail) → (level+1, 2*index+1)
471            let hi = conv_downsample(&coeffs, &filters.hi_d);
472            tree.insert(WaveletPacketNode::new(hi, level + 1, 2 * index + 1));
473        }
474    }
475
476    Ok(tree)
477}
478
479// ─────────────────────────────────────────────────────────────────────────────
480// Cost functions
481// ─────────────────────────────────────────────────────────────────────────────
482
483/// Shannon entropy cost function.
484///
485/// E(s) = -∑ |s_i|² log₂(|s_i|²)
486///
487/// Zero coefficients are excluded from the sum (lim_{p→0} p log p = 0).
488///
489/// # Example
490///
491/// ```
492/// use scirs2_fft::wavelet_packets::shannon_entropy;
493///
494/// let coeffs = vec![0.5, -0.5, 0.5, -0.5];
495/// let e = shannon_entropy(&coeffs);
496/// assert!(e >= 0.0);
497/// ```
498pub fn shannon_entropy(coeffs: &[f64]) -> f64 {
499    coeffs
500        .iter()
501        .filter_map(|&c| {
502            let p = c * c;
503            if p > 0.0 {
504                Some(-p * p.log2())
505            } else {
506                None
507            }
508        })
509        .sum()
510}
511
512/// Log-energy entropy cost function.
513///
514/// E(s) = ∑ log(|s_i|²)   (non-zero coefficients only)
515pub fn log_energy_entropy(coeffs: &[f64]) -> f64 {
516    coeffs
517        .iter()
518        .filter_map(|&c| {
519            let p = c * c;
520            if p > 0.0 {
521                Some(p.ln() / LN_2)
522            } else {
523                None
524            }
525        })
526        .sum()
527}
528
529/// Lp-norm (p ≠ 2) cost function – measures sparsity.
530///
531/// E(s) = ∑ |s_i|^p
532pub fn lp_norm_cost(coeffs: &[f64], p: f64) -> f64 {
533    coeffs.iter().map(|&c| c.abs().powf(p)).sum()
534}
535
536// ─────────────────────────────────────────────────────────────────────────────
537// Best Basis Selection (Coifman–Wickerhauser)
538// ─────────────────────────────────────────────────────────────────────────────
539
540/// Select the best orthonormal basis from a wavelet packet tree.
541///
542/// The algorithm minimises an additive cost function `cost_fn` using a
543/// bottom-up pass: a parent node is kept when its cost is *less than or equal*
544/// to the sum of the costs of its two children.
545///
546/// # Arguments
547///
548/// * `tree`    – Packet tree produced by `wpd`.
549/// * `cost_fn` – Additive cost function; must satisfy `cost(A∪B) = cost(A) + cost(B)`.
550///
551/// # Returns
552///
553/// A `Vec<WaveletPacketNode>` that forms a partition of the time-frequency
554/// plane (i.e. a valid orthonormal basis).
555///
556/// # Errors
557///
558/// Returns `FFTError::ValueError` if the tree is empty.
559///
560/// # Example
561///
562/// ```
563/// use scirs2_fft::wavelet_packets::{wpd, best_basis, shannon_entropy, Wavelet};
564///
565/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
566/// let tree = wpd(&signal, Wavelet::Haar, 3).expect("decomp");
567/// let basis = best_basis(&tree, shannon_entropy).expect("basis");
568/// assert!(!basis.is_empty());
569/// ```
570pub fn best_basis<F>(tree: &WaveletPacketTree, cost_fn: F) -> FFTResult<Vec<WaveletPacketNode>>
571where
572    F: Fn(&[f64]) -> f64,
573{
574    if tree.max_level == 0 {
575        return Err(FFTError::ValueError("tree is empty".to_string()));
576    }
577
578    // Pre-compute costs for every node in the tree
579    let mut costs: HashMap<u64, f64> = HashMap::new();
580    for node in tree.all_nodes() {
581        let key = WaveletPacketNode::key(node.level, node.index);
582        costs.insert(key, cost_fn(&node.coeffs));
583    }
584
585    // best_flag[key] = true  →  keep this node (do NOT split)
586    let mut best_flag: HashMap<u64, bool> = HashMap::new();
587
588    // Bottom-up: iterate from max_level - 1 down to 0
589    for level in (0..tree.max_level).rev() {
590        let num_nodes = 1_usize << level;
591        for index in 0..num_nodes {
592            let parent_key = WaveletPacketNode::key(level, index);
593            let left_key = WaveletPacketNode::key(level + 1, 2 * index);
594            let right_key = WaveletPacketNode::key(level + 1, 2 * index + 1);
595
596            let parent_cost = match costs.get(&parent_key) {
597                Some(&c) => c,
598                None => continue,
599            };
600
601            // Children cost is the sum; if a child is already "split", we use
602            // the *best* cost that the subtree achieves (propagated upward).
603            let left_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index);
604            let right_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index + 1);
605            let children_cost = left_cost + right_cost;
606
607            if parent_cost <= children_cost {
608                // Parent is better (or equal) → keep parent, prune children
609                best_flag.insert(parent_key, false); // false = "not split"
610                costs.insert(parent_key, parent_cost);
611            } else {
612                // Split is better → mark parent as "split"
613                best_flag.insert(parent_key, true);
614                // Update the effective cost of this node to the children sum
615                // so grandparents can compare correctly
616                costs.insert(parent_key, children_cost);
617            }
618
619            // Ensure leaf flags exist for the children (they have no children of their own)
620            best_flag.entry(left_key).or_insert(false);
621            best_flag.entry(right_key).or_insert(false);
622        }
623    }
624
625    // Collect the basis by selecting nodes that are NOT split
626    let mut basis: Vec<WaveletPacketNode> = Vec::new();
627    collect_basis(tree, &best_flag, 0, 0, &mut basis)?;
628
629    Ok(basis)
630}
631
632/// Recursively collect basis nodes starting from `(level, index)`.
633fn collect_basis(
634    tree: &WaveletPacketTree,
635    best_flag: &HashMap<u64, bool>,
636    level: usize,
637    index: usize,
638    out: &mut Vec<WaveletPacketNode>,
639) -> FFTResult<()> {
640    let key = WaveletPacketNode::key(level, index);
641    let is_split = best_flag.get(&key).copied().unwrap_or(false);
642
643    if !is_split || level == tree.max_level {
644        // Leaf of best basis tree
645        if let Some(node) = tree.get(level, index) {
646            out.push(node.clone());
647        }
648    } else {
649        collect_basis(tree, best_flag, level + 1, 2 * index, out)?;
650        collect_basis(tree, best_flag, level + 1, 2 * index + 1, out)?;
651    }
652    Ok(())
653}
654
655/// Return the effective (post-best-basis) cost for a node.
656fn effective_cost(
657    costs: &HashMap<u64, f64>,
658    best_flag: &HashMap<u64, bool>,
659    level: usize,
660    index: usize,
661) -> f64 {
662    // If the node has already been processed (and possibly "split"), its
663    // cost in the map already reflects the best achievable cost.
664    let key = WaveletPacketNode::key(level, index);
665    costs.get(&key).copied().unwrap_or(f64::INFINITY)
666}
667
668// ─────────────────────────────────────────────────────────────────────────────
669// Reconstruction
670// ─────────────────────────────────────────────────────────────────────────────
671
672/// Reconstruct the signal from a set of wavelet packet nodes forming a basis.
673///
674/// The nodes must constitute a valid partition of the time-frequency plane
675/// (e.g. those returned by `best_basis`).  Mixed-level bases (where some nodes
676/// are at depth 2 and others at depth 3, etc.) are fully supported.
677///
678/// # Arguments
679///
680/// * `tree`       – The original packet tree (provides wavelet & signal length).
681/// * `basis_nodes` – A valid wavelet packet basis (partition of the root).
682///
683/// # Errors
684///
685/// Returns `FFTError::InternalError` if reconstruction encounters a missing node.
686///
687/// # Example
688///
689/// ```
690/// use scirs2_fft::wavelet_packets::{wpd, best_basis, wp_reconstruct, shannon_entropy, Wavelet};
691///
692/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
693/// let tree = wpd(&signal, Wavelet::Haar, 3).expect("decomp");
694/// let basis = best_basis(&tree, shannon_entropy).expect("basis");
695/// let recon = wp_reconstruct(&tree, &basis).expect("recon");
696/// assert_eq!(recon.len(), signal.len());
697/// // Perfect reconstruction (approx)
698/// for (a, b) in signal.iter().zip(recon.iter()) {
699///     assert!((a - b).abs() < 1e-10, "mismatch: {} vs {}", a, b);
700/// }
701/// ```
702pub fn wp_reconstruct(
703    tree: &WaveletPacketTree,
704    basis_nodes: &[WaveletPacketNode],
705) -> FFTResult<Vec<f64>> {
706    if basis_nodes.is_empty() {
707        return Err(FFTError::ValueError(
708            "basis_nodes must be non-empty".to_string(),
709        ));
710    }
711
712    let filters = WaveletFilters::for_wavelet(tree.wavelet);
713
714    // Map each basis node into the tree storage so we can do upward synthesis
715    let mut node_map: HashMap<u64, Vec<f64>> = HashMap::new();
716    for node in basis_nodes {
717        let key = WaveletPacketNode::key(node.level, node.index);
718        node_map.insert(key, node.coeffs.clone());
719    }
720
721    // We need to know the coefficient length at each level.
722    // We reuse the already-computed nodes in the tree.
723    // Bottom-up synthesis from max_level to level 0
724    for level in (1..=tree.max_level).rev() {
725        let num_nodes = 1_usize << level;
726        let parent_level = level - 1;
727        let num_parents = 1_usize << parent_level;
728
729        for p_idx in 0..num_parents {
730            let left_key = WaveletPacketNode::key(level, 2 * p_idx);
731            let right_key = WaveletPacketNode::key(level, 2 * p_idx + 1);
732            let parent_key = WaveletPacketNode::key(parent_level, p_idx);
733
734            // Skip if parent already exists in the basis (it was a leaf)
735            if node_map.contains_key(&parent_key) {
736                continue;
737            }
738
739            // Both children must be present to reconstruct the parent
740            let left_coeffs = match node_map.get(&left_key) {
741                Some(c) => c.clone(),
742                None => continue,
743            };
744            let right_coeffs = match node_map.get(&right_key) {
745                Some(c) => c.clone(),
746                None => continue,
747            };
748
749            // Target length: get from the tree if available, else estimate
750            let target_len = tree
751                .get(parent_level, p_idx)
752                .map(|n| n.coeffs.len())
753                .unwrap_or_else(|| {
754                    // Estimate: parent length ≈ 2 * child length
755                    left_coeffs.len() * 2
756                });
757
758            // Synthesis: lo branch + hi branch (transpose of analysis)
759            let lo_rec = synthesis_step(&left_coeffs, &filters.lo_r, target_len);
760            let hi_rec = synthesis_step(&right_coeffs, &filters.hi_r, target_len);
761            let parent_coeffs: Vec<f64> = lo_rec
762                .iter()
763                .zip(hi_rec.iter())
764                .map(|(a, b)| a + b)
765                .collect();
766
767            node_map.insert(parent_key, parent_coeffs);
768        }
769
770        // We no longer need the children at this level to save memory
771        for idx in 0..num_nodes {
772            // Only remove if both siblings have been consumed
773            let left_key = WaveletPacketNode::key(level, idx);
774            // Keep it if still needed (might be a basis leaf)
775            let _ = left_key;
776        }
777    }
778
779    // The reconstructed signal is the root (level 0, index 0)
780    let root_key = WaveletPacketNode::key(0, 0);
781    node_map.remove(&root_key).ok_or_else(|| {
782        FFTError::InternalError("reconstruction failed: root not reached".to_string())
783    })
784}
785
786// ─────────────────────────────────────────────────────────────────────────────
787// WPT Denoising
788// ─────────────────────────────────────────────────────────────────────────────
789
790/// Thresholding method for wavelet denoising.
791#[derive(Debug, Clone, Copy, PartialEq)]
792pub enum ThresholdMethod {
793    /// Hard thresholding: coefficients with |c| < τ are set to 0.
794    Hard,
795    /// Soft thresholding: shrinks coefficients toward zero by τ.
796    Soft,
797    /// Garrote (non-negative garrote): c → c - τ²/c  for |c| > τ.
798    Garrote,
799    /// Firm (semi-soft): linear transition between hard and soft.
800    Firm { t2: f64 },
801}
802
803/// Apply a scalar threshold to a coefficient vector.
804fn threshold_coeffs(coeffs: &[f64], tau: f64, method: ThresholdMethod) -> Vec<f64> {
805    coeffs
806        .iter()
807        .map(|&c| apply_threshold(c, tau, method))
808        .collect()
809}
810
811/// Apply threshold to a single coefficient.
812fn apply_threshold(c: f64, tau: f64, method: ThresholdMethod) -> f64 {
813    match method {
814        ThresholdMethod::Hard => {
815            if c.abs() >= tau {
816                c
817            } else {
818                0.0
819            }
820        }
821        ThresholdMethod::Soft => {
822            if c > tau {
823                c - tau
824            } else if c < -tau {
825                c + tau
826            } else {
827                0.0
828            }
829        }
830        ThresholdMethod::Garrote => {
831            if c.abs() <= tau {
832                0.0
833            } else {
834                c - tau * tau / c
835            }
836        }
837        ThresholdMethod::Firm { t2 } => {
838            let t1 = tau;
839            let abs_c = c.abs();
840            if abs_c <= t1 {
841                0.0
842            } else if abs_c >= t2 {
843                c
844            } else {
845                // Linear ramp
846                c.signum() * t1 * (abs_c - t1) / (t2 - t1)
847            }
848        }
849    }
850}
851
852/// Denoise a signal using the Wavelet Packet Transform.
853///
854/// The procedure is:
855/// 1. Compute the full WPT up to `max_level`.
856/// 2. Select the best basis using Shannon entropy.
857/// 3. Threshold the coefficients in the best-basis nodes.
858/// 4. Reconstruct the signal.
859///
860/// # Arguments
861///
862/// * `signal`    – Noisy input signal.
863/// * `wavelet`   – Wavelet to use.
864/// * `max_level` – Maximum decomposition depth.
865/// * `threshold` – Threshold value τ.
866/// * `method`    – Thresholding method.
867///
868/// # Errors
869///
870/// Propagates any error from `wpd`, `best_basis`, or `wp_reconstruct`.
871///
872/// # Example
873///
874/// ```
875/// use scirs2_fft::wavelet_packets::{wp_denoising, ThresholdMethod, Wavelet};
876///
877/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.2).sin()).collect();
878/// let denoised = wp_denoising(&signal, Wavelet::Db4, 3, 0.05, ThresholdMethod::Soft)
879///     .expect("denoising failed");
880/// assert_eq!(denoised.len(), signal.len());
881/// ```
882pub fn wp_denoising(
883    signal: &[f64],
884    wavelet: Wavelet,
885    max_level: usize,
886    threshold: f64,
887    method: ThresholdMethod,
888) -> FFTResult<Vec<f64>> {
889    // 1. Decompose
890    let tree = wpd(signal, wavelet, max_level)?;
891
892    // 2. Best basis
893    let basis = best_basis(&tree, shannon_entropy)?;
894
895    // 3. Threshold coefficients (do NOT threshold the root / approximation leaf)
896    let thresholded: Vec<WaveletPacketNode> = basis
897        .into_iter()
898        .map(|mut node| {
899            if node.level > 0 {
900                node.coeffs = threshold_coeffs(&node.coeffs, threshold, method);
901            }
902            node
903        })
904        .collect();
905
906    // 4. Reconstruct
907    let mut recon = wp_reconstruct(&tree, &thresholded)?;
908
909    // Trim or extend to original signal length
910    recon.truncate(signal.len());
911    while recon.len() < signal.len() {
912        recon.push(0.0);
913    }
914
915    Ok(recon)
916}
917
918// ─────────────────────────────────────────────────────────────────────────────
919// Tests
920// ─────────────────────────────────────────────────────────────────────────────
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925
926    /// Build a simple test signal.
927    fn test_signal(n: usize) -> Vec<f64> {
928        (0..n)
929            .map(|i| {
930                let t = i as f64 / n as f64;
931                (2.0 * std::f64::consts::PI * 5.0 * t).sin()
932                    + 0.5 * (2.0 * std::f64::consts::PI * 13.0 * t).sin()
933            })
934            .collect()
935    }
936
937    #[test]
938    fn test_haar_decomp_shape() {
939        let sig = test_signal(64);
940        let tree = wpd(&sig, Wavelet::Haar, 3).expect("wpd failed");
941        // All nodes at level 3 should exist
942        for idx in 0..8 {
943            assert!(tree.get(3, idx).is_some(), "missing node (3, {idx})");
944        }
945    }
946
947    #[test]
948    fn test_qmf_energy_preservation() {
949        // lo and hi filters of db2 should each have unit energy
950        let filters = WaveletFilters::for_wavelet(Wavelet::Db2);
951        let e_lo: f64 = filters.lo_d.iter().map(|&c| c * c).sum();
952        let e_hi: f64 = filters.hi_d.iter().map(|&c| c * c).sum();
953        assert!((e_lo - 1.0).abs() < 1e-10, "lo energy {e_lo}");
954        assert!((e_hi - 1.0).abs() < 1e-10, "hi energy {e_hi}");
955    }
956
957    #[test]
958    fn test_shannon_entropy_uniform() {
959        // Uniform signal (all equal nonzero): entropy should be positive
960        let coeffs = vec![0.5_f64; 8];
961        let e = shannon_entropy(&coeffs);
962        assert!(e > 0.0, "expected positive entropy, got {e}");
963    }
964
965    #[test]
966    fn test_shannon_entropy_sparse() {
967        // A single non-zero coefficient → minimum entropy (sparse)
968        let mut coeffs = vec![0.0_f64; 64];
969        coeffs[0] = 1.0;
970        let e = shannon_entropy(&coeffs);
971        assert!((e - 0.0).abs() < 1e-12, "sparse signal entropy {e}");
972    }
973
974    #[test]
975    fn test_best_basis_returns_valid_partition() {
976        let sig = test_signal(64);
977        let tree = wpd(&sig, Wavelet::Db2, 3).expect("wpd");
978        let basis = best_basis(&tree, shannon_entropy).expect("best_basis");
979
980        // Basis must be non-empty
981        assert!(!basis.is_empty(), "basis is empty");
982
983        // All nodes in basis must exist in the tree
984        for node in &basis {
985            assert!(
986                tree.get(node.level, node.index).is_some(),
987                "basis node ({}, {}) not in tree",
988                node.level,
989                node.index
990            );
991        }
992    }
993
994    #[test]
995    fn test_haar_perfect_reconstruction() {
996        let sig = test_signal(64);
997        let tree = wpd(&sig, Wavelet::Haar, 2).expect("wpd");
998        // Use all leaf nodes as basis (no simplification)
999        let basis: Vec<WaveletPacketNode> = (0..4_usize)
1000            .filter_map(|idx| tree.get(2, idx).cloned())
1001            .collect();
1002        let recon = wp_reconstruct(&tree, &basis).expect("recon");
1003        for (i, (&s, &r)) in sig.iter().zip(recon.iter()).enumerate() {
1004            assert!(
1005                (s - r).abs() < 1e-10,
1006                "mismatch at {i}: orig={s}, recon={r}"
1007            );
1008        }
1009    }
1010
1011    #[test]
1012    fn test_denoising_length_preserved() {
1013        let sig = test_signal(64);
1014        let denoised =
1015            wp_denoising(&sig, Wavelet::Db4, 3, 0.1, ThresholdMethod::Soft).expect("denoise");
1016        assert_eq!(denoised.len(), sig.len());
1017    }
1018
1019    #[test]
1020    fn test_threshold_hard() {
1021        let coeffs = vec![1.0, -0.5, 0.3, -0.1, 2.0];
1022        let out = threshold_coeffs(&coeffs, 0.4, ThresholdMethod::Hard);
1023        assert_eq!(out, vec![1.0, -0.5, 0.0, 0.0, 2.0]);
1024    }
1025
1026    #[test]
1027    fn test_threshold_soft() {
1028        let out = threshold_coeffs(&[1.0, -1.5, 0.2], 0.5, ThresholdMethod::Soft);
1029        assert!((out[0] - 0.5).abs() < 1e-12);
1030        assert!((out[1] - (-1.0)).abs() < 1e-12);
1031        assert!((out[2] - 0.0).abs() < 1e-12);
1032    }
1033
1034    #[test]
1035    fn test_wpd_error_on_empty() {
1036        let result = wpd(&[], Wavelet::Haar, 2);
1037        assert!(result.is_err());
1038    }
1039
1040    #[test]
1041    fn test_wpd_error_on_zero_level() {
1042        let result = wpd(&[1.0, 2.0, 3.0], Wavelet::Haar, 0);
1043        assert!(result.is_err());
1044    }
1045}