Skip to main content

vector_ta/indicators/
bandpass.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::bandpass_wrapper::CudaBandpass;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::cuda_available;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::moving_averages::DeviceArrayF32;
23use crate::indicators::highpass::{highpass, HighPassError, HighPassInput, HighPassParams};
24use crate::utilities::data_loader::{source_type, Candles};
25#[cfg(all(feature = "python", feature = "cuda"))]
26use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
27use crate::utilities::enums::Kernel;
28use crate::utilities::helpers::{
29    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
30    make_uninit_matrix,
31};
32#[cfg(feature = "python")]
33use crate::utilities::kernel_validation::validate_kernel;
34use aligned_vec::{AVec, CACHELINE_ALIGN};
35#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
36use core::arch::x86_64::*;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::context::Context;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use cust::memory::DeviceBuffer;
41#[cfg(not(target_arch = "wasm32"))]
42use rayon::prelude::*;
43use std::convert::AsRef;
44use std::f64::consts::PI;
45#[cfg(all(feature = "python", feature = "cuda"))]
46use std::sync::Arc;
47use thiserror::Error;
48
49impl<'a> AsRef<[f64]> for BandPassInput<'a> {
50    #[inline(always)]
51    fn as_ref(&self) -> &[f64] {
52        match &self.data {
53            BandPassData::Slice(slice) => slice,
54            BandPassData::Candles { candles, source } => source_type(candles, source),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum BandPassData<'a> {
61    Candles {
62        candles: &'a Candles,
63        source: &'a str,
64    },
65    Slice(&'a [f64]),
66}
67
68#[derive(Debug, Clone, Copy)]
69#[cfg_attr(
70    all(target_arch = "wasm32", feature = "wasm"),
71    derive(Serialize, Deserialize)
72)]
73pub struct BandPassParams {
74    pub period: Option<usize>,
75    pub bandwidth: Option<f64>,
76}
77
78impl Default for BandPassParams {
79    fn default() -> Self {
80        Self {
81            period: Some(20),
82            bandwidth: Some(0.3),
83        }
84    }
85}
86
87#[derive(Debug, Clone)]
88pub struct BandPassInput<'a> {
89    pub data: BandPassData<'a>,
90    pub params: BandPassParams,
91}
92
93impl<'a> BandPassInput<'a> {
94    #[inline]
95    pub fn from_candles(c: &'a Candles, s: &'a str, p: BandPassParams) -> Self {
96        Self {
97            data: BandPassData::Candles {
98                candles: c,
99                source: s,
100            },
101            params: p,
102        }
103    }
104    #[inline]
105    pub fn from_slice(sl: &'a [f64], p: BandPassParams) -> Self {
106        Self {
107            data: BandPassData::Slice(sl),
108            params: p,
109        }
110    }
111    #[inline]
112    pub fn with_default_candles(c: &'a Candles) -> Self {
113        Self::from_candles(c, "close", BandPassParams::default())
114    }
115    #[inline]
116    pub fn get_period(&self) -> usize {
117        self.params.period.unwrap_or(20)
118    }
119    #[inline]
120    pub fn get_bandwidth(&self) -> f64 {
121        self.params.bandwidth.unwrap_or(0.3)
122    }
123}
124
125#[derive(Copy, Clone, Debug)]
126pub struct BandPassBuilder {
127    period: Option<usize>,
128    bandwidth: Option<f64>,
129    kernel: Kernel,
130}
131
132impl Default for BandPassBuilder {
133    fn default() -> Self {
134        Self {
135            period: None,
136            bandwidth: None,
137            kernel: Kernel::Auto,
138        }
139    }
140}
141
142impl BandPassBuilder {
143    #[inline(always)]
144    pub fn new() -> Self {
145        Self::default()
146    }
147    #[inline(always)]
148    pub fn period(mut self, n: usize) -> Self {
149        self.period = Some(n);
150        self
151    }
152    #[inline(always)]
153    pub fn bandwidth(mut self, b: f64) -> Self {
154        self.bandwidth = Some(b);
155        self
156    }
157    #[inline(always)]
158    pub fn kernel(mut self, k: Kernel) -> Self {
159        self.kernel = k;
160        self
161    }
162    #[inline(always)]
163    pub fn apply(self, c: &Candles) -> Result<BandPassOutput, BandPassError> {
164        let p = BandPassParams {
165            period: self.period,
166            bandwidth: self.bandwidth,
167        };
168        let i = BandPassInput::from_candles(c, "close", p);
169        bandpass_with_kernel(&i, self.kernel)
170    }
171    #[inline(always)]
172    pub fn apply_slice(self, d: &[f64]) -> Result<BandPassOutput, BandPassError> {
173        let p = BandPassParams {
174            period: self.period,
175            bandwidth: self.bandwidth,
176        };
177        let i = BandPassInput::from_slice(d, p);
178        bandpass_with_kernel(&i, self.kernel)
179    }
180    #[inline(always)]
181    pub fn into_stream(self) -> Result<BandPassStream, BandPassError> {
182        let p = BandPassParams {
183            period: self.period,
184            bandwidth: self.bandwidth,
185        };
186        BandPassStream::try_new(p)
187    }
188}
189
190#[derive(Debug, Error)]
191pub enum BandPassError {
192    #[error("bandpass: Input data slice is empty.")]
193    EmptyInputData,
194    #[error("bandpass: All values are NaN.")]
195    AllValuesNaN,
196    #[error("bandpass: Invalid period: period={period}, data_len={data_len}")]
197    InvalidPeriod { period: usize, data_len: usize },
198    #[error("bandpass: Not enough valid data: needed={needed}, valid={valid}")]
199    NotEnoughValidData { needed: usize, valid: usize },
200    #[error("bandpass: Invalid bandwidth={bandwidth}")]
201    InvalidBandwidth { bandwidth: f64 },
202    #[error("bandpass: hp_period too small ({hp_period})")]
203    HpPeriodTooSmall { hp_period: usize },
204    #[error("bandpass: trigger_period too small ({trigger_period})")]
205    TriggerPeriodTooSmall { trigger_period: usize },
206    #[error("bandpass: Output length mismatch: expected={expected}, got={got}")]
207    OutputLengthMismatch { expected: usize, got: usize },
208    #[error("bandpass: Invalid range: start={start}, end={end}, step={step}")]
209    InvalidRange {
210        start: usize,
211        end: usize,
212        step: usize,
213    },
214    #[error("bandpass: Invalid kernel for batch: {0:?}")]
215    InvalidKernelForBatch(Kernel),
216    #[error(transparent)]
217    HighPassError(#[from] HighPassError),
218}
219
220#[derive(Debug, Clone)]
221pub struct BandPassOutput {
222    pub bp: Vec<f64>,
223    pub bp_normalized: Vec<f64>,
224    pub signal: Vec<f64>,
225    pub trigger: Vec<f64>,
226}
227
228#[inline]
229pub fn bandpass(input: &BandPassInput) -> Result<BandPassOutput, BandPassError> {
230    bandpass_with_kernel(input, Kernel::Auto)
231}
232
233#[inline(always)]
234fn bandpass_prepare<'a>(
235    input: &'a BandPassInput,
236    kernel: Kernel,
237) -> Result<(&'a [f64], usize, usize, f64, usize, usize, Kernel), BandPassError> {
238    let data: &[f64] = input.as_ref();
239    let len = data.len();
240    let period = input.get_period();
241    let bandwidth = input.get_bandwidth();
242
243    if len == 0 {
244        return Err(BandPassError::EmptyInputData);
245    }
246    let first_valid = data
247        .iter()
248        .position(|x| x.is_finite())
249        .ok_or(BandPassError::AllValuesNaN)?;
250    if period == 0 || period > len {
251        return Err(BandPassError::InvalidPeriod {
252            period,
253            data_len: len,
254        });
255    }
256    if len - first_valid < period {
257        return Err(BandPassError::NotEnoughValidData {
258            needed: period,
259            valid: len - first_valid,
260        });
261    }
262    if !(0.0..=1.0).contains(&bandwidth) || !bandwidth.is_finite() {
263        return Err(BandPassError::InvalidBandwidth { bandwidth });
264    }
265
266    let hp_period = (4.0 * period as f64 / bandwidth).round() as usize;
267    if hp_period < 2 {
268        return Err(BandPassError::HpPeriodTooSmall { hp_period });
269    }
270    let trigger_period = ((period as f64 / bandwidth) / 1.5).round() as usize;
271    if trigger_period < 2 {
272        return Err(BandPassError::TriggerPeriodTooSmall { trigger_period });
273    }
274
275    let chosen = match kernel {
276        Kernel::Auto => detect_best_kernel(),
277        k => k,
278    };
279    Ok((
280        data,
281        len,
282        period,
283        bandwidth,
284        hp_period,
285        trigger_period,
286        chosen,
287    ))
288}
289
290pub fn bandpass_with_kernel(
291    input: &BandPassInput,
292    kernel: Kernel,
293) -> Result<BandPassOutput, BandPassError> {
294    let (data, len, period, bandwidth, hp_period, trigger_period, chosen) =
295        bandpass_prepare(input, kernel)?;
296
297    let mut hp_params = HighPassParams::default();
298    hp_params.period = Some(hp_period);
299    let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
300
301    let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(len);
302    let warmup_bp = first_valid_hp.saturating_add(2).max(2).min(len);
303
304    let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
305    let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
306    let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
307
308    let mut bp = alloc_with_nan_prefix(len, warmup_bp);
309
310    let bp_start = warmup_bp.saturating_sub(2);
311    unsafe {
312        match chosen {
313            Kernel::Scalar | Kernel::ScalarBatch => {
314                bandpass_scalar(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
315            }
316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317            Kernel::Avx2 | Kernel::Avx2Batch => {
318                bandpass_avx2(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
319            }
320            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
321            Kernel::Avx512 | Kernel::Avx512Batch => {
322                bandpass_avx512(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
323            }
324            _ => unreachable!(),
325        }
326    }
327
328    for v in &mut bp[..warmup_bp] {
329        *v = f64::NAN;
330    }
331
332    let mut bp_normalized = alloc_with_nan_prefix(len, warmup_bp);
333
334    let k = 0.991;
335    let mut peak_prev = 0.0f64;
336    for i in warmup_bp..len {
337        peak_prev *= k;
338        let abs_bp = bp[i].abs();
339        if abs_bp > peak_prev {
340            peak_prev = abs_bp;
341        }
342        bp_normalized[i] = if peak_prev != 0.0 {
343            bp[i] / peak_prev
344        } else {
345            0.0
346        };
347    }
348
349    let mut trigger = alloc_with_nan_prefix(len, warmup_bp);
350    if warmup_bp < len {
351        let mut trigger_params = HighPassParams::default();
352        trigger_params.period = Some(trigger_period);
353        let trig_inp = HighPassInput::from_slice(&bp_normalized[warmup_bp..], trigger_params);
354        crate::indicators::moving_averages::highpass::highpass_into_slice(
355            &mut trigger[warmup_bp..],
356            &trig_inp,
357            chosen,
358        )?;
359    }
360
361    let first_trig = trigger.iter().position(|x| !x.is_nan()).unwrap_or(len);
362    let warm_sig = warmup_bp.max(first_trig);
363    let mut signal = alloc_with_nan_prefix(len, warm_sig);
364    for i in warm_sig..len {
365        let bn = bp_normalized[i];
366        let tr = trigger[i];
367        signal[i] = if bn < tr {
368            1.0
369        } else if bn > tr {
370            -1.0
371        } else {
372            0.0
373        };
374    }
375
376    Ok(BandPassOutput {
377        bp,
378        bp_normalized,
379        signal,
380        trigger,
381    })
382}
383
384#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
385#[inline]
386pub fn bandpass_into(
387    input: &BandPassInput,
388    bp_out: &mut [f64],
389    bp_normalized_out: &mut [f64],
390    signal_out: &mut [f64],
391    trigger_out: &mut [f64],
392) -> Result<(), BandPassError> {
393    bandpass_into_slice(
394        bp_out,
395        bp_normalized_out,
396        signal_out,
397        trigger_out,
398        input,
399        Kernel::Auto,
400    )
401}
402
403#[inline]
404pub fn bandpass_into_slice(
405    bp_dst: &mut [f64],
406    bpn_dst: &mut [f64],
407    sig_dst: &mut [f64],
408    trig_dst: &mut [f64],
409    input: &BandPassInput,
410    kernel: Kernel,
411) -> Result<(), BandPassError> {
412    let (data, len, period, bandwidth, hp_period, trigger_period, chosen) =
413        bandpass_prepare(input, kernel)?;
414    if bp_dst.len() != len || bpn_dst.len() != len || sig_dst.len() != len || trig_dst.len() != len
415    {
416        return Err(BandPassError::OutputLengthMismatch {
417            expected: len,
418            got: *[bp_dst.len(), bpn_dst.len(), sig_dst.len(), trig_dst.len()]
419                .iter()
420                .min()
421                .unwrap_or(&0),
422        });
423    }
424
425    let mut hp_params = HighPassParams::default();
426    hp_params.period = Some(hp_period);
427    let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
428
429    let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(len);
430    let warm_bp = first_valid_hp.saturating_add(2).max(2).min(len);
431
432    let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
433    let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
434    let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
435
436    let bp_start = warm_bp.saturating_sub(2);
437    unsafe {
438        match chosen {
439            Kernel::Scalar | Kernel::ScalarBatch => bandpass_scalar(
440                &hp[bp_start..],
441                period,
442                alpha,
443                beta,
444                &mut bp_dst[bp_start..],
445            ),
446            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
447            Kernel::Avx2 | Kernel::Avx2Batch => bandpass_avx2(
448                &hp[bp_start..],
449                period,
450                alpha,
451                beta,
452                &mut bp_dst[bp_start..],
453            ),
454            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
455            Kernel::Avx512 | Kernel::Avx512Batch => bandpass_avx512(
456                &hp[bp_start..],
457                period,
458                alpha,
459                beta,
460                &mut bp_dst[bp_start..],
461            ),
462            _ => unreachable!(),
463        }
464    }
465
466    for v in &mut bp_dst[..warm_bp] {
467        *v = f64::NAN;
468    }
469
470    for v in &mut bpn_dst[..warm_bp] {
471        *v = f64::NAN;
472    }
473    let k = 0.991;
474    let mut peak = 0.0f64;
475    for i in warm_bp..len {
476        peak *= k;
477        let v = bp_dst[i];
478        let av = v.abs();
479        if av > peak {
480            peak = av;
481        }
482        bpn_dst[i] = if peak != 0.0 { v / peak } else { 0.0 };
483    }
484
485    for v in trig_dst.iter_mut() {
486        *v = f64::NAN;
487    }
488    if warm_bp < len {
489        let mut trigger_params = HighPassParams::default();
490        trigger_params.period = Some(trigger_period);
491        let trig_inp = HighPassInput::from_slice(&bpn_dst[warm_bp..], trigger_params);
492
493        crate::indicators::moving_averages::highpass::highpass_into_slice(
494            &mut trig_dst[warm_bp..],
495            &trig_inp,
496            chosen,
497        )?;
498    }
499
500    let first_trig = trig_dst.iter().position(|x| !x.is_nan()).unwrap_or(len);
501    let warm_sig = warm_bp.max(first_trig);
502    for v in &mut sig_dst[..warm_sig] {
503        *v = f64::NAN;
504    }
505    for i in warm_sig..len {
506        let bn = bpn_dst[i];
507        let tr = trig_dst[i];
508        sig_dst[i] = if bn < tr {
509            1.0
510        } else if bn > tr {
511            -1.0
512        } else {
513            0.0
514        };
515    }
516
517    Ok(())
518}
519
520#[inline(always)]
521pub fn bandpass_scalar(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
522    let len = hp.len();
523    if len == 0 {
524        return;
525    }
526
527    out[0] = hp[0];
528    if len == 1 {
529        return;
530    }
531    out[1] = hp[1];
532    if len == 2 {
533        return;
534    }
535
536    let a = 0.5 * (1.0 - alpha);
537    let c = beta * (1.0 + alpha);
538    let d = -alpha;
539
540    let mut y_im2 = out[0];
541    let mut y_im1 = out[1];
542
543    let mut i = 2usize;
544    while i + 3 < len {
545        let delta0 = hp[i] - hp[i - 2];
546        let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
547        out[i] = y0;
548
549        let delta1 = hp[i + 1] - hp[i - 1];
550        let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
551        out[i + 1] = y1;
552
553        let delta2 = hp[i + 2] - hp[i];
554        let y2 = d.mul_add(y0, c.mul_add(y1, a * delta2));
555        out[i + 2] = y2;
556
557        let delta3 = hp[i + 3] - hp[i + 1];
558        let y3 = d.mul_add(y1, c.mul_add(y2, a * delta3));
559        out[i + 3] = y3;
560
561        y_im2 = y2;
562        y_im1 = y3;
563        i += 4;
564    }
565
566    while i + 1 < len {
567        let delta0 = hp[i] - hp[i - 2];
568        let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
569        out[i] = y0;
570
571        let delta1 = hp[i + 1] - hp[i - 1];
572        let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
573        out[i + 1] = y1;
574
575        y_im2 = y0;
576        y_im1 = y1;
577        i += 2;
578    }
579
580    if i < len {
581        let delta = hp[i] - hp[i - 2];
582        out[i] = d.mul_add(y_im2, c.mul_add(y_im1, a * delta));
583    }
584}
585
586#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
587#[inline(always)]
588pub fn bandpass_avx2(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
589    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
590}
591
592#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
593#[inline(always)]
594pub fn bandpass_avx512(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
595    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
596}
597
598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
599#[inline(always)]
600pub fn bandpass_avx512_short(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
601    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
602}
603
604#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
605#[inline(always)]
606pub fn bandpass_avx512_long(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
607    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
608}
609
610#[derive(Debug, Clone)]
611pub struct BandPassStream {
612    period: usize,
613
614    c0: f64,
615    c1: f64,
616    c2: f64,
617
618    hp_stream: crate::indicators::highpass::HighPassStream,
619
620    hp_z1: f64,
621    hp_z2: f64,
622    y_z1: f64,
623    y_z2: f64,
624
625    hp_valid: u8,
626}
627
628impl BandPassStream {
629    #[inline]
630    pub fn try_new(params: BandPassParams) -> Result<Self, BandPassError> {
631        let period = params.period.unwrap_or(20);
632        if period < 2 {
633            return Err(BandPassError::InvalidPeriod {
634                period,
635                data_len: 0,
636            });
637        }
638        let bandwidth = params.bandwidth.unwrap_or(0.3);
639        if !(0.0..=1.0).contains(&bandwidth) || !bandwidth.is_finite() {
640            return Err(BandPassError::InvalidBandwidth { bandwidth });
641        }
642
643        let hp_period = (4.0 * period as f64 / bandwidth).round() as usize;
644        if hp_period < 2 {
645            return Err(BandPassError::HpPeriodTooSmall { hp_period });
646        }
647
648        let mut hp_params = HighPassParams::default();
649        hp_params.period = Some(hp_period);
650        let hp_stream = crate::indicators::highpass::HighPassStream::try_new(hp_params)?;
651
652        use std::f64::consts::PI;
653        let beta = (2.0 * PI / period as f64).cos();
654        let gamma = (2.0 * PI * bandwidth / period as f64).cos();
655
656        #[inline(always)]
657        fn alpha_from_gamma(gamma: f64) -> f64 {
658            let g2 = gamma * gamma;
659            let s = (1.0 - g2).sqrt();
660            (1.0 - s) / gamma
661        }
662        let alpha = alpha_from_gamma(gamma);
663
664        let c0 = 0.5 * (1.0 - alpha);
665        let c1 = beta * (1.0 + alpha);
666        let c2 = -alpha;
667
668        Ok(Self {
669            period,
670            c0,
671            c1,
672            c2,
673            hp_stream,
674            hp_z1: 0.0,
675            hp_z2: 0.0,
676            y_z1: 0.0,
677            y_z2: 0.0,
678            hp_valid: 0,
679        })
680    }
681
682    #[inline(always)]
683    pub fn update(&mut self, value: f64) -> f64 {
684        let hp = self.hp_stream.update(value);
685
686        if !hp.is_finite() {
687            return f64::NAN;
688        }
689
690        if self.hp_valid < 2 {
691            let y = hp;
692
693            self.hp_z2 = self.hp_z1;
694            self.hp_z1 = hp;
695
696            self.y_z2 = self.y_z1;
697            self.y_z1 = y;
698
699            self.hp_valid += 1;
700            return f64::NAN;
701        }
702
703        let delta = hp - self.hp_z2;
704        let y = self
705            .c2
706            .mul_add(self.y_z2, self.c1.mul_add(self.y_z1, self.c0 * delta));
707
708        self.hp_z2 = self.hp_z1;
709        self.hp_z1 = hp;
710
711        self.y_z2 = self.y_z1;
712        self.y_z1 = y;
713
714        y
715    }
716}
717
718#[derive(Clone, Debug)]
719pub struct BandPassBatchRange {
720    pub period: (usize, usize, usize),
721    pub bandwidth: (f64, f64, f64),
722}
723
724impl Default for BandPassBatchRange {
725    fn default() -> Self {
726        Self {
727            period: (20, 20, 0),
728            bandwidth: (0.3, 0.549, 0.001),
729        }
730    }
731}
732
733#[derive(Clone, Debug, Default)]
734pub struct BandPassBatchBuilder {
735    range: BandPassBatchRange,
736    kernel: Kernel,
737}
738
739impl BandPassBatchBuilder {
740    pub fn new() -> Self {
741        Self::default()
742    }
743    pub fn kernel(mut self, k: Kernel) -> Self {
744        self.kernel = k;
745        self
746    }
747    #[inline]
748    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
749        self.range.period = (start, end, step);
750        self
751    }
752    #[inline]
753    pub fn period_static(mut self, p: usize) -> Self {
754        self.range.period = (p, p, 0);
755        self
756    }
757    #[inline]
758    pub fn bandwidth_range(mut self, start: f64, end: f64, step: f64) -> Self {
759        self.range.bandwidth = (start, end, step);
760        self
761    }
762    #[inline]
763    pub fn bandwidth_static(mut self, b: f64) -> Self {
764        self.range.bandwidth = (b, b, 0.0);
765        self
766    }
767    pub fn apply_slice(self, data: &[f64]) -> Result<BandPassBatchOutput, BandPassError> {
768        bandpass_batch_with_kernel(data, &self.range, self.kernel)
769    }
770    pub fn with_default_slice(
771        data: &[f64],
772        k: Kernel,
773    ) -> Result<BandPassBatchOutput, BandPassError> {
774        BandPassBatchBuilder::new().kernel(k).apply_slice(data)
775    }
776    pub fn apply_candles(
777        self,
778        c: &Candles,
779        src: &str,
780    ) -> Result<BandPassBatchOutput, BandPassError> {
781        let slice = source_type(c, src);
782        self.apply_slice(slice)
783    }
784    pub fn with_default_candles(c: &Candles) -> Result<BandPassBatchOutput, BandPassError> {
785        BandPassBatchBuilder::new()
786            .kernel(Kernel::Auto)
787            .apply_candles(c, "close")
788    }
789}
790
791#[derive(Clone, Debug)]
792pub struct BandPassBatchOutput {
793    pub bp: Vec<f64>,
794    pub bp_normalized: Vec<f64>,
795    pub signal: Vec<f64>,
796    pub trigger: Vec<f64>,
797    pub combos: Vec<BandPassParams>,
798    pub rows: usize,
799    pub cols: usize,
800}
801
802impl BandPassBatchOutput {
803    #[inline]
804    pub fn row_for_params(&self, p: &BandPassParams) -> Option<usize> {
805        self.combos.iter().position(|c| {
806            c.period.unwrap_or(20) == p.period.unwrap_or(20)
807                && (c.bandwidth.unwrap_or(0.3) - p.bandwidth.unwrap_or(0.3)).abs() < 1e-12
808        })
809    }
810    #[inline]
811    pub fn row_slices(&self, row: usize) -> Option<(&[f64], &[f64], &[f64], &[f64])> {
812        if row >= self.rows {
813            return None;
814        }
815        let (r, c) = (row, self.cols);
816        Some((
817            &self.bp[r * self.cols..r * self.cols + c],
818            &self.bp_normalized[r * self.cols..r * self.cols + c],
819            &self.signal[r * self.cols..r * self.cols + c],
820            &self.trigger[r * self.cols..r * self.cols + c],
821        ))
822    }
823}
824
825#[inline(always)]
826fn expand_grid(r: &BandPassBatchRange) -> Vec<BandPassParams> {
827    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, BandPassError> {
828        if step == 0 || start == end {
829            return Ok(vec![start]);
830        }
831        let mut vals = Vec::new();
832        if start < end {
833            let mut v = start;
834            while v <= end {
835                vals.push(v);
836                match v.checked_add(step) {
837                    Some(next) => {
838                        if next == v {
839                            break;
840                        }
841                        v = next;
842                    }
843                    None => break,
844                }
845            }
846        } else {
847            let mut v = start;
848            loop {
849                vals.push(v);
850                if v == end {
851                    break;
852                }
853                let next = v.saturating_sub(step);
854                if next == v {
855                    break;
856                }
857                v = next;
858                if v < end {
859                    break;
860                }
861            }
862        }
863        if vals.is_empty() {
864            return Err(BandPassError::InvalidRange { start, end, step });
865        }
866        Ok(vals)
867    }
868    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, BandPassError> {
869        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
870            return Ok(vec![start]);
871        }
872        let mut vals = Vec::new();
873        if start <= end {
874            let mut x = start;
875            loop {
876                vals.push(x);
877                if x >= end {
878                    break;
879                }
880                let next = x + step;
881                if !next.is_finite() || next == x {
882                    break;
883                }
884                x = next;
885                if x > end + 1e-12 {
886                    break;
887                }
888            }
889        } else {
890            let mut x = start;
891            loop {
892                vals.push(x);
893                if x <= end {
894                    break;
895                }
896                let next = x - step.abs();
897                if !next.is_finite() || next == x {
898                    break;
899                }
900                x = next;
901                if x < end - 1e-12 {
902                    break;
903                }
904            }
905        }
906        if vals.is_empty() {
907            return Err(BandPassError::InvalidRange {
908                start: start as usize,
909                end: end as usize,
910                step: step.abs() as usize,
911            });
912        }
913        Ok(vals)
914    }
915    let periods = match axis_usize(r.period) {
916        Ok(v) => v,
917        Err(_) => return Vec::new(),
918    };
919    let bandwidths = match axis_f64(r.bandwidth) {
920        Ok(v) => v,
921        Err(_) => return Vec::new(),
922    };
923    let mut out = Vec::with_capacity(periods.len() * bandwidths.len());
924    for &p in &periods {
925        for &b in &bandwidths {
926            out.push(BandPassParams {
927                period: Some(p),
928                bandwidth: Some(b),
929            });
930        }
931    }
932    out
933}
934
935#[cfg(all(feature = "python", feature = "cuda"))]
936#[pyclass(
937    module = "ta_indicators.cuda",
938    name = "BandPassDeviceArrayF32",
939    unsendable
940)]
941pub struct BandPassDeviceArrayF32Py {
942    pub(crate) inner: DeviceArrayF32,
943    pub(crate) _ctx: Arc<Context>,
944    pub(crate) device_id: u32,
945    pub(crate) stream: usize,
946}
947
948#[cfg(all(feature = "python", feature = "cuda"))]
949#[pymethods]
950impl BandPassDeviceArrayF32Py {
951    #[getter]
952    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
953        let inner = &self.inner;
954        let d = PyDict::new(py);
955        d.set_item("shape", (inner.rows, inner.cols))?;
956        d.set_item("typestr", "<f4")?;
957        d.set_item(
958            "strides",
959            (
960                inner.cols * std::mem::size_of::<f32>(),
961                std::mem::size_of::<f32>(),
962            ),
963        )?;
964        d.set_item("data", (inner.device_ptr() as usize, false))?;
965
966        d.set_item("version", 3)?;
967        Ok(d)
968    }
969
970    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
971        Ok((2, self.device_id as i32))
972    }
973
974    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
975    fn __dlpack__<'py>(
976        &mut self,
977        py: Python<'py>,
978        stream: Option<pyo3::PyObject>,
979        max_version: Option<pyo3::PyObject>,
980        dl_device: Option<pyo3::PyObject>,
981        copy: Option<pyo3::PyObject>,
982    ) -> PyResult<PyObject> {
983        let (kdl, alloc_dev) = self.__dlpack_device__()?;
984        if let Some(dev_obj) = dl_device.as_ref() {
985            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
986                if dev_ty != kdl || dev_id != alloc_dev {
987                    let wants_copy = copy
988                        .as_ref()
989                        .and_then(|c| c.extract::<bool>(py).ok())
990                        .unwrap_or(false);
991                    if wants_copy {
992                        return Err(PyValueError::new_err(
993                            "device copy not implemented for __dlpack__",
994                        ));
995                    } else {
996                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
997                    }
998                }
999            }
1000        }
1001        let _ = stream;
1002
1003        let dummy =
1004            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1005        let inner = std::mem::replace(
1006            &mut self.inner,
1007            DeviceArrayF32 {
1008                buf: dummy,
1009                rows: 0,
1010                cols: 0,
1011            },
1012        );
1013
1014        let rows = inner.rows;
1015        let cols = inner.cols;
1016        let buf = inner.buf;
1017
1018        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1019
1020        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1021    }
1022}
1023
1024pub fn bandpass_batch_with_kernel(
1025    data: &[f64],
1026    sweep: &BandPassBatchRange,
1027    k: Kernel,
1028) -> Result<BandPassBatchOutput, BandPassError> {
1029    let kernel = match k {
1030        Kernel::Auto => detect_best_batch_kernel(),
1031        other if other.is_batch() => other,
1032        other => return Err(BandPassError::InvalidKernelForBatch(other)),
1033    };
1034    let simd = match kernel {
1035        Kernel::Avx512Batch => Kernel::Avx512,
1036        Kernel::Avx2Batch => Kernel::Avx2,
1037        Kernel::ScalarBatch => Kernel::Scalar,
1038        _ => unreachable!(),
1039    };
1040    bandpass_batch_par_slice(data, sweep, simd)
1041}
1042
1043pub fn bandpass_batch_slice(
1044    data: &[f64],
1045    sweep: &BandPassBatchRange,
1046    kern: Kernel,
1047) -> Result<BandPassBatchOutput, BandPassError> {
1048    bandpass_batch_inner(data, sweep, kern, false)
1049}
1050
1051pub fn bandpass_batch_par_slice(
1052    data: &[f64],
1053    sweep: &BandPassBatchRange,
1054    kern: Kernel,
1055) -> Result<BandPassBatchOutput, BandPassError> {
1056    bandpass_batch_inner(data, sweep, kern, true)
1057}
1058
1059fn bandpass_batch_inner(
1060    data: &[f64],
1061    sweep: &BandPassBatchRange,
1062    kern: Kernel,
1063    parallel: bool,
1064) -> Result<BandPassBatchOutput, BandPassError> {
1065    let combos = expand_grid(sweep);
1066    if combos.is_empty() {
1067        return Err(BandPassError::InvalidRange {
1068            start: sweep.period.0,
1069            end: sweep.period.1,
1070            step: sweep.period.2,
1071        });
1072    }
1073
1074    let cols = data.len();
1075    if cols == 0 {
1076        return Err(BandPassError::EmptyInputData);
1077    }
1078    let rows = combos.len();
1079    rows.checked_mul(cols).ok_or(BandPassError::InvalidRange {
1080        start: sweep.period.0,
1081        end: sweep.period.1,
1082        step: sweep.period.2,
1083    })?;
1084
1085    let mut bp_mu = make_uninit_matrix(rows, cols);
1086    let mut bpn_mu = make_uninit_matrix(rows, cols);
1087    let mut sig_mu = make_uninit_matrix(rows, cols);
1088    let mut trg_mu = make_uninit_matrix(rows, cols);
1089
1090    let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1091    let warms_bp: Vec<usize> = combos
1092        .iter()
1093        .map(|p| {
1094            let period = p.period.unwrap();
1095            let bandwidth = p.bandwidth.unwrap();
1096            let hp_p = (4.0 * period as f64 / bandwidth).round() as usize;
1097            let warm_hp = first + hp_p - 1;
1098            warm_hp.max(2)
1099        })
1100        .collect();
1101    let warms_trg: Vec<usize> = combos
1102        .iter()
1103        .map(|p| {
1104            let period = p.period.unwrap();
1105            let bandwidth = p.bandwidth.unwrap();
1106            let hp_p = (4.0 * period as f64 / bandwidth).round() as usize;
1107            let trig_p = ((period as f64 / bandwidth) / 1.5).round() as usize;
1108            let warm_hp = first + hp_p - 1;
1109            let warm_bp = warm_hp.max(2);
1110            warm_bp + trig_p - 1
1111        })
1112        .collect();
1113
1114    init_matrix_prefixes(&mut bp_mu, cols, &warms_bp);
1115    init_matrix_prefixes(&mut bpn_mu, cols, &warms_bp);
1116    init_matrix_prefixes(&mut trg_mu, cols, &warms_trg);
1117    init_matrix_prefixes(&mut sig_mu, cols, &warms_trg);
1118
1119    let mut bp_guard = core::mem::ManuallyDrop::new(bp_mu);
1120    let mut bpn_guard = core::mem::ManuallyDrop::new(bpn_mu);
1121    let mut sig_guard = core::mem::ManuallyDrop::new(sig_mu);
1122    let mut trg_guard = core::mem::ManuallyDrop::new(trg_mu);
1123
1124    let bp_out: &mut [f64] = unsafe {
1125        core::slice::from_raw_parts_mut(bp_guard.as_mut_ptr() as *mut f64, bp_guard.len())
1126    };
1127    let bpn_out: &mut [f64] = unsafe {
1128        core::slice::from_raw_parts_mut(bpn_guard.as_mut_ptr() as *mut f64, bpn_guard.len())
1129    };
1130    let sig_out: &mut [f64] = unsafe {
1131        core::slice::from_raw_parts_mut(sig_guard.as_mut_ptr() as *mut f64, sig_guard.len())
1132    };
1133    let trg_out: &mut [f64] = unsafe {
1134        core::slice::from_raw_parts_mut(trg_guard.as_mut_ptr() as *mut f64, trg_guard.len())
1135    };
1136
1137    bandpass_batch_inner_into(
1138        data, &combos, kern, parallel, bp_out, bpn_out, sig_out, trg_out,
1139    )?;
1140
1141    let bp = unsafe {
1142        Vec::from_raw_parts(
1143            bp_guard.as_mut_ptr() as *mut f64,
1144            bp_guard.len(),
1145            bp_guard.capacity(),
1146        )
1147    };
1148    let bpn = unsafe {
1149        Vec::from_raw_parts(
1150            bpn_guard.as_mut_ptr() as *mut f64,
1151            bpn_guard.len(),
1152            bpn_guard.capacity(),
1153        )
1154    };
1155    let sig = unsafe {
1156        Vec::from_raw_parts(
1157            sig_guard.as_mut_ptr() as *mut f64,
1158            sig_guard.len(),
1159            sig_guard.capacity(),
1160        )
1161    };
1162    let trg = unsafe {
1163        Vec::from_raw_parts(
1164            trg_guard.as_mut_ptr() as *mut f64,
1165            trg_guard.len(),
1166            trg_guard.capacity(),
1167        )
1168    };
1169
1170    Ok(BandPassBatchOutput {
1171        bp,
1172        bp_normalized: bpn,
1173        signal: sig,
1174        trigger: trg,
1175        combos,
1176        rows,
1177        cols,
1178    })
1179}
1180
1181#[inline(always)]
1182fn bandpass_batch_inner_into(
1183    data: &[f64],
1184    combos: &[BandPassParams],
1185    kern: Kernel,
1186    parallel: bool,
1187    bp_out: &mut [f64],
1188    bpn_out: &mut [f64],
1189    sig_out: &mut [f64],
1190    trg_out: &mut [f64],
1191) -> Result<(), BandPassError> {
1192    let rows = combos.len();
1193    let cols = data.len();
1194    let chosen = match kern {
1195        Kernel::Auto => detect_best_batch_kernel(),
1196        k => k,
1197    };
1198    let simd = match chosen {
1199        Kernel::ScalarBatch => Kernel::Scalar,
1200        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1201        Kernel::Avx2Batch => Kernel::Avx2,
1202        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1203        Kernel::Avx512Batch => Kernel::Avx512,
1204        _ => Kernel::Scalar,
1205    };
1206
1207    struct RowMeta {
1208        period: usize,
1209        bandwidth: f64,
1210        hp_p: usize,
1211        trig_p: usize,
1212        ksel: Kernel,
1213    }
1214
1215    let mut metas = Vec::with_capacity(rows);
1216    for &p in combos.iter() {
1217        let period = p.period.unwrap();
1218        let bandwidth = p.bandwidth.unwrap();
1219        let (_d, _len, _per, _bw, hp_p, trig_p, ksel) =
1220            bandpass_prepare(&BandPassInput::from_slice(data, p), simd)?;
1221        metas.push(RowMeta {
1222            period,
1223            bandwidth,
1224            hp_p,
1225            trig_p,
1226            ksel,
1227        });
1228    }
1229
1230    use std::collections::HashMap;
1231    let mut hp_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1232    for meta in metas.iter() {
1233        if !hp_cache.contains_key(&meta.hp_p) {
1234            let mut hp_params = HighPassParams::default();
1235            hp_params.period = Some(meta.hp_p);
1236            let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
1237            hp_cache.insert(meta.hp_p, hp);
1238        }
1239    }
1240
1241    if parallel {
1242        #[cfg(not(target_arch = "wasm32"))]
1243        {
1244            use rayon::prelude::*;
1245            use std::sync::atomic::{AtomicPtr, Ordering};
1246
1247            let bp_ptr = AtomicPtr::new(bp_out.as_mut_ptr());
1248            let bpn_ptr = AtomicPtr::new(bpn_out.as_mut_ptr());
1249            let sig_ptr = AtomicPtr::new(sig_out.as_mut_ptr());
1250            let trg_ptr = AtomicPtr::new(trg_out.as_mut_ptr());
1251
1252            (0..rows)
1253                .into_par_iter()
1254                .try_for_each(|row| -> Result<(), BandPassError> {
1255                    let m = &metas[row];
1256                    let period = m.period;
1257                    let bandwidth = m.bandwidth;
1258
1259                    let bp_r = unsafe {
1260                        std::slice::from_raw_parts_mut(
1261                            bp_ptr.load(Ordering::Relaxed).add(row * cols),
1262                            cols,
1263                        )
1264                    };
1265                    let bpn_r = unsafe {
1266                        std::slice::from_raw_parts_mut(
1267                            bpn_ptr.load(Ordering::Relaxed).add(row * cols),
1268                            cols,
1269                        )
1270                    };
1271                    let sig_r = unsafe {
1272                        std::slice::from_raw_parts_mut(
1273                            sig_ptr.load(Ordering::Relaxed).add(row * cols),
1274                            cols,
1275                        )
1276                    };
1277                    let trg_r = unsafe {
1278                        std::slice::from_raw_parts_mut(
1279                            trg_ptr.load(Ordering::Relaxed).add(row * cols),
1280                            cols,
1281                        )
1282                    };
1283                    let trig_p = m.trig_p;
1284                    let ksel = m.ksel;
1285                    let hp = hp_cache.get(&m.hp_p).expect("hp cache missing");
1286
1287                    let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1288                    let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1289                    let bp_start = warm_bp.saturating_sub(2);
1290
1291                    let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1292                    let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1293                    let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1294
1295                    unsafe {
1296                        match ksel {
1297                            Kernel::Scalar => bandpass_scalar(
1298                                &hp[bp_start..],
1299                                period,
1300                                alpha,
1301                                beta,
1302                                &mut bp_r[bp_start..],
1303                            ),
1304                            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1305                            Kernel::Avx2 => bandpass_avx2(
1306                                &hp[bp_start..],
1307                                period,
1308                                alpha,
1309                                beta,
1310                                &mut bp_r[bp_start..],
1311                            ),
1312                            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1313                            Kernel::Avx512 => bandpass_avx512(
1314                                &hp[bp_start..],
1315                                period,
1316                                alpha,
1317                                beta,
1318                                &mut bp_r[bp_start..],
1319                            ),
1320                            _ => unreachable!(),
1321                        }
1322                    }
1323
1324                    for v in &mut bp_r[..warm_bp] {
1325                        *v = f64::NAN;
1326                    }
1327
1328                    let k = 0.991;
1329                    let mut peak = 0.0f64;
1330                    for i in warm_bp..cols {
1331                        peak *= k;
1332                        let v = bp_r[i];
1333                        let av = v.abs();
1334                        if av > peak {
1335                            peak = av;
1336                        }
1337                        bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1338                    }
1339
1340                    for v in trg_r.iter_mut() {
1341                        *v = f64::NAN;
1342                    }
1343                    if warm_bp < cols {
1344                        let mut trig_params = HighPassParams::default();
1345                        trig_params.period = Some(trig_p);
1346                        let trig_inp = HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params);
1347                        crate::indicators::moving_averages::highpass::highpass_into_slice(
1348                            &mut trg_r[warm_bp..],
1349                            &trig_inp,
1350                            ksel,
1351                        )?;
1352                    }
1353
1354                    let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1355                    let warm_sig = warm_bp.max(first_tr);
1356                    for v in &mut bpn_r[..warm_bp] {
1357                        if !v.is_nan() {
1358                            *v = f64::NAN;
1359                        }
1360                    }
1361                    for v in &mut sig_r[..warm_sig] {
1362                        *v = f64::NAN;
1363                    }
1364                    for i in warm_sig..cols {
1365                        let bn = bpn_r[i];
1366                        let tr = trg_r[i];
1367                        sig_r[i] = if bn < tr {
1368                            1.0
1369                        } else if bn > tr {
1370                            -1.0
1371                        } else {
1372                            0.0
1373                        };
1374                    }
1375
1376                    Ok(())
1377                })?;
1378        }
1379        #[cfg(target_arch = "wasm32")]
1380        {
1381            for row in 0..rows {
1382                let m = &metas[row];
1383                let period = m.period;
1384                let bandwidth = m.bandwidth;
1385
1386                let bp_r = &mut bp_out[row * cols..(row + 1) * cols];
1387                let bpn_r = &mut bpn_out[row * cols..(row + 1) * cols];
1388                let sig_r = &mut sig_out[row * cols..(row + 1) * cols];
1389                let trg_r = &mut trg_out[row * cols..(row + 1) * cols];
1390                let trig_p = m.trig_p;
1391                let ksel = m.ksel;
1392                let hp = hp_cache.get(&m.hp_p).expect("hp cache missing");
1393
1394                let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1395                let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1396                let bp_start = warm_bp.saturating_sub(2);
1397
1398                let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1399                let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1400                let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1401
1402                unsafe {
1403                    match ksel {
1404                        Kernel::Scalar => bandpass_scalar(
1405                            &hp[bp_start..],
1406                            period,
1407                            alpha,
1408                            beta,
1409                            &mut bp_r[bp_start..],
1410                        ),
1411                        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1412                        Kernel::Avx2 => bandpass_avx2(
1413                            &hp[bp_start..],
1414                            period,
1415                            alpha,
1416                            beta,
1417                            &mut bp_r[bp_start..],
1418                        ),
1419                        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1420                        Kernel::Avx512 => bandpass_avx512(
1421                            &hp[bp_start..],
1422                            period,
1423                            alpha,
1424                            beta,
1425                            &mut bp_r[bp_start..],
1426                        ),
1427                        _ => unreachable!(),
1428                    }
1429                }
1430
1431                for v in &mut bp_r[..warm_bp] {
1432                    *v = f64::NAN;
1433                }
1434
1435                let k = 0.991;
1436                let mut peak = 0.0f64;
1437                for i in warm_bp..cols {
1438                    peak *= k;
1439                    let v = bp_r[i];
1440                    let av = v.abs();
1441                    if av > peak {
1442                        peak = av;
1443                    }
1444                    bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1445                }
1446
1447                for v in trg_r.iter_mut() {
1448                    *v = f64::NAN;
1449                }
1450                if warm_bp < cols {
1451                    let mut trig_params = HighPassParams::default();
1452                    trig_params.period = Some(trig_p);
1453                    let trig_inp = HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params);
1454                    crate::indicators::moving_averages::highpass::highpass_into_slice(
1455                        &mut trg_r[warm_bp..],
1456                        &trig_inp,
1457                        ksel,
1458                    )?;
1459                }
1460
1461                let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1462                let warm_sig = warm_bp.max(first_tr);
1463                for v in &mut bpn_r[..warm_bp] {
1464                    if !v.is_nan() {
1465                        *v = f64::NAN;
1466                    }
1467                }
1468                for v in &mut sig_r[..warm_sig] {
1469                    *v = f64::NAN;
1470                }
1471                for i in warm_sig..cols {
1472                    let bn = bpn_r[i];
1473                    let tr = trg_r[i];
1474                    sig_r[i] = if bn < tr {
1475                        1.0
1476                    } else if bn > tr {
1477                        -1.0
1478                    } else {
1479                        0.0
1480                    };
1481                }
1482            }
1483        }
1484    } else {
1485        for row in 0..rows {
1486            let p = combos[row];
1487            let period = p.period.unwrap();
1488            let bandwidth = p.bandwidth.unwrap();
1489
1490            let bp_r = &mut bp_out[row * cols..(row + 1) * cols];
1491            let bpn_r = &mut bpn_out[row * cols..(row + 1) * cols];
1492            let sig_r = &mut sig_out[row * cols..(row + 1) * cols];
1493            let trg_r = &mut trg_out[row * cols..(row + 1) * cols];
1494
1495            let (_d, _len, _per, _bw, hp_p, trig_p, ksel) =
1496                bandpass_prepare(&BandPassInput::from_slice(data, p), simd)?;
1497
1498            let mut hp_params = HighPassParams::default();
1499            hp_params.period = Some(hp_p);
1500            let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
1501
1502            let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1503            let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1504            let bp_start = warm_bp.saturating_sub(2);
1505
1506            let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1507            let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1508            let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1509
1510            unsafe {
1511                match ksel {
1512                    Kernel::Scalar => {
1513                        bandpass_scalar(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1514                    }
1515                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1516                    Kernel::Avx2 => {
1517                        bandpass_avx2(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1518                    }
1519                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1520                    Kernel::Avx512 => {
1521                        bandpass_avx512(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1522                    }
1523                    _ => unreachable!(),
1524                }
1525            }
1526
1527            for v in &mut bp_r[..warm_bp] {
1528                *v = f64::NAN;
1529            }
1530
1531            let k = 0.991;
1532            let mut peak = 0.0f64;
1533            for i in warm_bp..cols {
1534                peak *= k;
1535                let v = bp_r[i];
1536                let av = v.abs();
1537                if av > peak {
1538                    peak = av;
1539                }
1540                bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1541            }
1542
1543            for v in trg_r.iter_mut() {
1544                *v = f64::NAN;
1545            }
1546            if warm_bp < cols {
1547                let mut trig_params = HighPassParams::default();
1548                trig_params.period = Some(trig_p);
1549                let trig_vec =
1550                    highpass(&HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params))?.values;
1551                trg_r[warm_bp..].copy_from_slice(&trig_vec);
1552            }
1553
1554            let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1555            let warm_sig = warm_bp.max(first_tr);
1556            for v in &mut bpn_r[..warm_bp] {
1557                if !v.is_nan() {
1558                    *v = f64::NAN;
1559                }
1560            }
1561            for v in &mut sig_r[..warm_sig] {
1562                *v = f64::NAN;
1563            }
1564            for i in warm_sig..cols {
1565                let bn = bpn_r[i];
1566                let tr = trg_r[i];
1567                sig_r[i] = if bn < tr {
1568                    1.0
1569                } else if bn > tr {
1570                    -1.0
1571                } else {
1572                    0.0
1573                };
1574            }
1575        }
1576    }
1577
1578    Ok(())
1579}
1580
1581#[inline(always)]
1582pub fn bandpass_row_scalar(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1583    bandpass_scalar(hp, period, alpha, beta, out)
1584}
1585
1586#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1587#[inline(always)]
1588pub fn bandpass_row_avx2(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1589    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1590}
1591
1592#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1593#[inline(always)]
1594pub fn bandpass_row_avx512(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1595    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1596}
1597
1598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1599#[inline(always)]
1600pub fn bandpass_row_avx512_short(
1601    hp: &[f64],
1602    period: usize,
1603    alpha: f64,
1604    beta: f64,
1605    out: &mut [f64],
1606) {
1607    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1608}
1609
1610#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1611#[inline(always)]
1612pub fn bandpass_row_avx512_long(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1613    unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1614}
1615
1616#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1617#[inline(always)]
1618unsafe fn bandpass_scalar_unchecked(hp: &[f64], alpha: f64, beta: f64, out: &mut [f64]) {
1619    debug_assert!(out.len() >= hp.len());
1620    let len = hp.len();
1621    if len == 0 {
1622        return;
1623    }
1624
1625    let out_ptr = out.as_mut_ptr();
1626    let hp_ptr = hp.as_ptr();
1627
1628    *out_ptr.add(0) = *hp_ptr.add(0);
1629    if len == 1 {
1630        return;
1631    }
1632    *out_ptr.add(1) = *hp_ptr.add(1);
1633    if len == 2 {
1634        return;
1635    }
1636
1637    let a = 0.5 * (1.0 - alpha);
1638    let c = beta * (1.0 + alpha);
1639    let d = -alpha;
1640
1641    let mut y_im2 = *out_ptr.add(0);
1642    let mut y_im1 = *out_ptr.add(1);
1643
1644    let mut i = 2usize;
1645    while i + 3 < len {
1646        let delta0 = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1647        let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
1648        *out_ptr.add(i) = y0;
1649
1650        let delta1 = *hp_ptr.add(i + 1) - *hp_ptr.add(i - 1);
1651        let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
1652        *out_ptr.add(i + 1) = y1;
1653
1654        let delta2 = *hp_ptr.add(i + 2) - *hp_ptr.add(i);
1655        let y2 = d.mul_add(y0, c.mul_add(y1, a * delta2));
1656        *out_ptr.add(i + 2) = y2;
1657
1658        let delta3 = *hp_ptr.add(i + 3) - *hp_ptr.add(i + 1);
1659        let y3 = d.mul_add(y1, c.mul_add(y2, a * delta3));
1660        *out_ptr.add(i + 3) = y3;
1661
1662        y_im2 = y2;
1663        y_im1 = y3;
1664        i += 4;
1665    }
1666
1667    while i + 1 < len {
1668        let delta0 = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1669        let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
1670        *out_ptr.add(i) = y0;
1671
1672        let delta1 = *hp_ptr.add(i + 1) - *hp_ptr.add(i - 1);
1673        let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
1674        *out_ptr.add(i + 1) = y1;
1675
1676        y_im2 = y0;
1677        y_im1 = y1;
1678        i += 2;
1679    }
1680
1681    if i < len {
1682        let delta = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1683        let y = d.mul_add(y_im2, c.mul_add(y_im1, a * delta));
1684        *out_ptr.add(i) = y;
1685    }
1686}
1687
1688#[inline(always)]
1689fn expand_grid_for_bandpass(r: &BandPassBatchRange) -> Vec<BandPassParams> {
1690    expand_grid(r)
1691}
1692
1693#[cfg(test)]
1694mod tests {
1695    use super::*;
1696    use crate::skip_if_unsupported;
1697    use crate::utilities::data_loader::read_candles_from_csv;
1698
1699    #[test]
1700    fn test_bandpass_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1701        let len = 512usize;
1702        let mut data = Vec::with_capacity(len);
1703        for i in 0..len {
1704            let t = i as f64 * 0.07;
1705
1706            data.push((t).sin() * 0.8 + (0.33 * t).sin() * 0.2 + 0.001 * i as f64);
1707        }
1708
1709        let input = BandPassInput::from_slice(&data, BandPassParams::default());
1710
1711        let base = bandpass(&input)?;
1712
1713        let mut bp = vec![0.0; len];
1714        let mut bpn = vec![0.0; len];
1715        let mut sig = vec![0.0; len];
1716        let mut trg = vec![0.0; len];
1717
1718        #[allow(unused_mut)]
1719        let mut _ok = bandpass_into(&input, &mut bp, &mut bpn, &mut sig, &mut trg)?;
1720
1721        assert_eq!(bp.len(), base.bp.len());
1722        assert_eq!(bpn.len(), base.bp_normalized.len());
1723        assert_eq!(sig.len(), base.signal.len());
1724        assert_eq!(trg.len(), base.trigger.len());
1725
1726        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1727            (a.is_nan() && b.is_nan()) || (a == b)
1728        }
1729
1730        for i in 0..len {
1731            assert!(
1732                eq_or_both_nan(bp[i], base.bp[i]),
1733                "bp mismatch at {}: {} vs {}",
1734                i,
1735                bp[i],
1736                base.bp[i]
1737            );
1738            assert!(
1739                eq_or_both_nan(bpn[i], base.bp_normalized[i]),
1740                "bpn mismatch at {}: {} vs {}",
1741                i,
1742                bpn[i],
1743                base.bp_normalized[i]
1744            );
1745            assert!(
1746                eq_or_both_nan(sig[i], base.signal[i]),
1747                "sig mismatch at {}: {} vs {}",
1748                i,
1749                sig[i],
1750                base.signal[i]
1751            );
1752            assert!(
1753                eq_or_both_nan(trg[i], base.trigger[i]),
1754                "trg mismatch at {}: {} vs {}",
1755                i,
1756                trg[i],
1757                base.trigger[i]
1758            );
1759        }
1760
1761        Ok(())
1762    }
1763
1764    fn check_bandpass_partial_params(
1765        test_name: &str,
1766        kernel: Kernel,
1767    ) -> Result<(), Box<dyn std::error::Error>> {
1768        skip_if_unsupported!(kernel, test_name);
1769        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1770        let candles = read_candles_from_csv(file_path)?;
1771        let default_params = BandPassParams::default();
1772        let input = BandPassInput::from_candles(&candles, "close", default_params);
1773        let output = bandpass_with_kernel(&input, kernel)?;
1774        assert_eq!(output.bp.len(), candles.close.len());
1775        Ok(())
1776    }
1777    fn check_bandpass_accuracy(
1778        test_name: &str,
1779        kernel: Kernel,
1780    ) -> Result<(), Box<dyn std::error::Error>> {
1781        skip_if_unsupported!(kernel, test_name);
1782        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1783        let candles = read_candles_from_csv(file_path)?;
1784        let input = BandPassInput::with_default_candles(&candles);
1785        let result = bandpass_with_kernel(&input, kernel)?;
1786        let expected_bp_last_five = [
1787            -236.23678021132827,
1788            -247.4846395608195,
1789            -242.3788746078502,
1790            -212.89589193350128,
1791            -179.97293838509464,
1792        ];
1793        let expected_bp_normalized_last_five = [
1794            -0.4399672555578846,
1795            -0.4651011734720517,
1796            -0.4596426251402882,
1797            -0.40739824942488945,
1798            -0.3475245023284841,
1799        ];
1800        let expected_signal_last_five = [-1.0, 1.0, 1.0, 1.0, 1.0];
1801        let expected_trigger_last_five = [
1802            -0.4746908356434579,
1803            -0.4353877348116954,
1804            -0.3727126131420441,
1805            -0.2746336628365846,
1806            -0.18240018384226137,
1807        ];
1808        let start = result.bp.len().saturating_sub(5);
1809        assert!(result.bp.len() >= 5);
1810        assert!(result.bp_normalized.len() >= 5);
1811        assert!(result.signal.len() >= 5);
1812        assert!(result.trigger.len() >= 5);
1813
1814        let first_bp = result
1815            .bp
1816            .iter()
1817            .position(|x| !x.is_nan())
1818            .unwrap_or(result.bp.len());
1819        let first_bpn = result
1820            .bp_normalized
1821            .iter()
1822            .position(|x| !x.is_nan())
1823            .unwrap_or(result.bp_normalized.len());
1824        let first_sig = result
1825            .signal
1826            .iter()
1827            .position(|x| !x.is_nan())
1828            .unwrap_or(result.signal.len());
1829        let first_trig = result
1830            .trigger
1831            .iter()
1832            .position(|x| !x.is_nan())
1833            .unwrap_or(result.trigger.len());
1834
1835        if first_sig >= start {
1836            panic!(
1837                "Signal values are all NaN in the last 5 indices. Debug info:\n\
1838				first_bp: {}, first_bpn: {}, first_sig: {}, first_trig: {}, start: {}, len: {}\n\
1839				bp[start]: {:?}, bpn[start]: {:?}, trig[start]: {:?}",
1840                first_bp,
1841                first_bpn,
1842                first_sig,
1843                first_trig,
1844                start,
1845                result.bp.len(),
1846                result.bp[start],
1847                result.bp_normalized[start],
1848                result.trigger[start]
1849            );
1850        }
1851        for (i, &value) in result.bp[start..].iter().enumerate() {
1852            assert!(
1853                (value - expected_bp_last_five[i]).abs() < 1e-1,
1854                "BP value mismatch at index {}: expected {}, got {}",
1855                i,
1856                expected_bp_last_five[i],
1857                value
1858            );
1859        }
1860        for (i, &value) in result.bp_normalized[start..].iter().enumerate() {
1861            assert!(
1862                (value - expected_bp_normalized_last_five[i]).abs() < 1e-1,
1863                "BP Normalized value mismatch at index {}: expected {}, got {}",
1864                i,
1865                expected_bp_normalized_last_five[i],
1866                value
1867            );
1868        }
1869        for (i, &value) in result.signal[start..].iter().enumerate() {
1870            if value.is_nan() {
1871                continue;
1872            }
1873            assert!(
1874                (value - expected_signal_last_five[i]).abs() < 1e-1,
1875                "Signal value mismatch at index {}: expected {}, got {}",
1876                i,
1877                expected_signal_last_five[i],
1878                value
1879            );
1880        }
1881        for (i, &value) in result.trigger[start..].iter().enumerate() {
1882            assert!(
1883                (value - expected_trigger_last_five[i]).abs() < 1e-1,
1884                "Trigger value mismatch at index {}: expected {}, got {}",
1885                i,
1886                expected_trigger_last_five[i],
1887                value
1888            );
1889        }
1890
1891        for val in &result.bp {
1892            if !val.is_nan() {
1893                assert!(val.is_finite(), "bp value not finite: {}", val);
1894            }
1895        }
1896        for val in &result.bp_normalized {
1897            if !val.is_nan() {
1898                assert!(val.is_finite(), "bp_normalized value not finite: {}", val);
1899            }
1900        }
1901        for val in &result.signal {
1902            if !val.is_nan() {
1903                assert!(val.is_finite(), "signal value not finite: {}", val);
1904            }
1905        }
1906        for val in &result.trigger {
1907            if !val.is_nan() {
1908                assert!(val.is_finite(), "trigger value not finite: {}", val);
1909            }
1910        }
1911        Ok(())
1912    }
1913    fn check_bandpass_default_candles(
1914        test_name: &str,
1915        kernel: Kernel,
1916    ) -> Result<(), Box<dyn std::error::Error>> {
1917        skip_if_unsupported!(kernel, test_name);
1918        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1919        let candles = read_candles_from_csv(file_path)?;
1920        let input = BandPassInput::with_default_candles(&candles);
1921        match input.data {
1922            BandPassData::Candles { source, .. } => assert_eq!(source, "close"),
1923            _ => panic!("Expected BandPassData::Candles"),
1924        }
1925        let output = bandpass_with_kernel(&input, kernel)?;
1926        assert_eq!(output.bp.len(), candles.close.len());
1927        Ok(())
1928    }
1929    fn check_bandpass_zero_period(
1930        test_name: &str,
1931        kernel: Kernel,
1932    ) -> Result<(), Box<dyn std::error::Error>> {
1933        skip_if_unsupported!(kernel, test_name);
1934        let input_data = [10.0, 20.0, 30.0];
1935        let params = BandPassParams {
1936            period: Some(0),
1937            bandwidth: Some(0.3),
1938        };
1939        let input = BandPassInput::from_slice(&input_data, params);
1940        let res = bandpass_with_kernel(&input, kernel);
1941        assert!(res.is_err());
1942        Ok(())
1943    }
1944    fn check_bandpass_period_exceeds_length(
1945        test_name: &str,
1946        kernel: Kernel,
1947    ) -> Result<(), Box<dyn std::error::Error>> {
1948        skip_if_unsupported!(kernel, test_name);
1949        let data_small = [10.0, 20.0, 30.0];
1950        let params = BandPassParams {
1951            period: Some(10),
1952            bandwidth: Some(0.3),
1953        };
1954        let input = BandPassInput::from_slice(&data_small, params);
1955        let res = bandpass_with_kernel(&input, kernel);
1956        assert!(res.is_err());
1957        Ok(())
1958    }
1959    fn check_bandpass_very_small_dataset(
1960        test_name: &str,
1961        kernel: Kernel,
1962    ) -> Result<(), Box<dyn std::error::Error>> {
1963        skip_if_unsupported!(kernel, test_name);
1964        let single_point = [42.0];
1965        let params = BandPassParams {
1966            period: Some(20),
1967            bandwidth: Some(0.3),
1968        };
1969        let input = BandPassInput::from_slice(&single_point, params);
1970        let res = bandpass_with_kernel(&input, kernel);
1971        assert!(res.is_err());
1972        Ok(())
1973    }
1974    fn check_bandpass_reinput(
1975        test_name: &str,
1976        kernel: Kernel,
1977    ) -> Result<(), Box<dyn std::error::Error>> {
1978        skip_if_unsupported!(kernel, test_name);
1979        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1980        let candles = read_candles_from_csv(file_path)?;
1981        let first_params = BandPassParams {
1982            period: Some(20),
1983            bandwidth: Some(0.3),
1984        };
1985        let first_input = BandPassInput::from_candles(&candles, "close", first_params);
1986        let first_result = bandpass_with_kernel(&first_input, kernel)?;
1987        let second_params = BandPassParams {
1988            period: Some(30),
1989            bandwidth: Some(0.5),
1990        };
1991        let second_input = BandPassInput::from_slice(&first_result.bp, second_params);
1992        let second_result = bandpass_with_kernel(&second_input, kernel)?;
1993        assert_eq!(second_result.bp.len(), first_result.bp.len());
1994        Ok(())
1995    }
1996    fn check_bandpass_nan_handling(
1997        test_name: &str,
1998        kernel: Kernel,
1999    ) -> Result<(), Box<dyn std::error::Error>> {
2000        skip_if_unsupported!(kernel, test_name);
2001        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2002        let candles = read_candles_from_csv(file_path)?;
2003        let input = BandPassInput::from_candles(
2004            &candles,
2005            "close",
2006            BandPassParams {
2007                period: Some(20),
2008                bandwidth: Some(0.3),
2009            },
2010        );
2011        let res = bandpass_with_kernel(&input, kernel)?;
2012        assert_eq!(res.bp.len(), candles.close.len());
2013
2014        let first_bp = res
2015            .bp
2016            .iter()
2017            .position(|x| !x.is_nan())
2018            .unwrap_or(res.bp.len());
2019        let first_bpn = res
2020            .bp_normalized
2021            .iter()
2022            .position(|x| !x.is_nan())
2023            .unwrap_or(res.bp_normalized.len());
2024        let first_sig = res
2025            .signal
2026            .iter()
2027            .position(|x| !x.is_nan())
2028            .unwrap_or(res.signal.len());
2029        let first_trig = res
2030            .trigger
2031            .iter()
2032            .position(|x| !x.is_nan())
2033            .unwrap_or(res.trigger.len());
2034
2035        let warmup_end = first_bp.max(first_bpn).max(first_sig).max(first_trig);
2036        if warmup_end < res.bp.len() {
2037            for i in warmup_end..res.bp.len() {
2038                assert!(!res.bp[i].is_nan(), "bp[{}] is NaN after warmup", i);
2039                assert!(
2040                    !res.bp_normalized[i].is_nan(),
2041                    "bp_normalized[{}] is NaN after warmup",
2042                    i
2043                );
2044                assert!(!res.signal[i].is_nan(), "signal[{}] is NaN after warmup", i);
2045                assert!(
2046                    !res.trigger[i].is_nan(),
2047                    "trigger[{}] is NaN after warmup",
2048                    i
2049                );
2050            }
2051        }
2052        Ok(())
2053    }
2054
2055    macro_rules! generate_all_bandpass_tests {
2056        ($($test_fn:ident),*) => {
2057            paste::paste! {
2058                $(
2059                    #[test]
2060                    fn [<$test_fn _scalar_f64>]() {
2061                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2062                    }
2063                )*
2064                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2065                $(
2066                    #[test]
2067                    fn [<$test_fn _avx2_f64>]() {
2068                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2069                    }
2070                    #[test]
2071                    fn [<$test_fn _avx512_f64>]() {
2072                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2073                    }
2074                )*
2075            }
2076        }
2077    }
2078
2079    #[cfg(debug_assertions)]
2080    fn check_bandpass_no_poison(
2081        test_name: &str,
2082        kernel: Kernel,
2083    ) -> Result<(), Box<dyn std::error::Error>> {
2084        skip_if_unsupported!(kernel, test_name);
2085
2086        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2087        let candles = read_candles_from_csv(file_path)?;
2088
2089        let test_params = vec![
2090            BandPassParams {
2091                period: Some(10),
2092                bandwidth: Some(0.2),
2093            },
2094            BandPassParams {
2095                period: Some(20),
2096                bandwidth: Some(0.3),
2097            },
2098            BandPassParams {
2099                period: Some(30),
2100                bandwidth: Some(0.4),
2101            },
2102            BandPassParams {
2103                period: Some(50),
2104                bandwidth: Some(0.5),
2105            },
2106            BandPassParams {
2107                period: Some(5),
2108                bandwidth: Some(0.1),
2109            },
2110            BandPassParams {
2111                period: Some(100),
2112                bandwidth: Some(0.8),
2113            },
2114        ];
2115
2116        for params in test_params {
2117            let input = BandPassInput::from_candles(&candles, "close", params.clone());
2118            let output = bandpass_with_kernel(&input, kernel)?;
2119
2120            for (i, &val) in output.bp.iter().enumerate() {
2121                if val.is_nan() {
2122                    continue;
2123                }
2124
2125                let bits = val.to_bits();
2126
2127                if bits == 0x11111111_11111111 {
2128                    panic!(
2129                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2130                    test_name, val, bits, i, params
2131                );
2132                }
2133
2134                if bits == 0x22222222_22222222 {
2135                    panic!(
2136                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2137                    test_name, val, bits, i, params
2138                );
2139                }
2140
2141                if bits == 0x33333333_33333333 {
2142                    panic!(
2143						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2144						test_name, val, bits, i, params
2145					);
2146                }
2147            }
2148
2149            for (i, &val) in output.bp_normalized.iter().enumerate() {
2150                if val.is_nan() {
2151                    continue;
2152                }
2153
2154                let bits = val.to_bits();
2155
2156                if bits == 0x11111111_11111111 {
2157                    panic!(
2158                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2159                    test_name, val, bits, i, params
2160                );
2161                }
2162
2163                if bits == 0x22222222_22222222 {
2164                    panic!(
2165                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2166                    test_name, val, bits, i, params
2167                );
2168                }
2169
2170                if bits == 0x33333333_33333333 {
2171                    panic!(
2172                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2173                    test_name, val, bits, i, params
2174                );
2175                }
2176            }
2177
2178            for (i, &val) in output.signal.iter().enumerate() {
2179                if val.is_nan() {
2180                    continue;
2181                }
2182
2183                let bits = val.to_bits();
2184
2185                if bits == 0x11111111_11111111 {
2186                    panic!(
2187                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2188                    test_name, val, bits, i, params
2189                );
2190                }
2191
2192                if bits == 0x22222222_22222222 {
2193                    panic!(
2194                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2195                    test_name, val, bits, i, params
2196                );
2197                }
2198
2199                if bits == 0x33333333_33333333 {
2200                    panic!(
2201                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2202                    test_name, val, bits, i, params
2203                );
2204                }
2205            }
2206
2207            for (i, &val) in output.trigger.iter().enumerate() {
2208                if val.is_nan() {
2209                    continue;
2210                }
2211
2212                let bits = val.to_bits();
2213
2214                if bits == 0x11111111_11111111 {
2215                    panic!(
2216                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2217                    test_name, val, bits, i, params
2218                );
2219                }
2220
2221                if bits == 0x22222222_22222222 {
2222                    panic!(
2223                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2224                    test_name, val, bits, i, params
2225                );
2226                }
2227
2228                if bits == 0x33333333_33333333 {
2229                    panic!(
2230                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2231                    test_name, val, bits, i, params
2232                );
2233                }
2234            }
2235        }
2236
2237        Ok(())
2238    }
2239
2240    #[cfg(not(debug_assertions))]
2241    fn check_bandpass_no_poison(
2242        _test_name: &str,
2243        _kernel: Kernel,
2244    ) -> Result<(), Box<dyn std::error::Error>> {
2245        Ok(())
2246    }
2247
2248    fn check_bandpass_streaming(
2249        test_name: &str,
2250        kernel: Kernel,
2251    ) -> Result<(), Box<dyn std::error::Error>> {
2252        skip_if_unsupported!(kernel, test_name);
2253
2254        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2255        let candles = read_candles_from_csv(file_path)?;
2256
2257        let period = 20;
2258        let bandwidth = 0.3;
2259
2260        let input = BandPassInput::from_candles(
2261            &candles,
2262            "close",
2263            BandPassParams {
2264                period: Some(period),
2265                bandwidth: Some(bandwidth),
2266            },
2267        );
2268        let batch_output = bandpass_with_kernel(&input, kernel)?;
2269
2270        let mut stream = BandPassStream::try_new(BandPassParams {
2271            period: Some(period),
2272            bandwidth: Some(bandwidth),
2273        })?;
2274
2275        let mut stream_values = Vec::with_capacity(candles.close.len());
2276        for &price in &candles.close {
2277            let bp_val = stream.update(price);
2278            stream_values.push(bp_val);
2279        }
2280
2281        assert_eq!(batch_output.bp.len(), stream_values.len());
2282        for (i, (&b, &s)) in batch_output.bp.iter().zip(stream_values.iter()).enumerate() {
2283            if b.is_nan() {
2284                continue;
2285            }
2286            let diff = (b - s).abs();
2287            let tol = 1e-10 * b.abs().max(1.0);
2288            assert!(
2289                diff < tol,
2290                "[{}] Streaming vs batch mismatch at index {}: batch={}, stream={}, diff={}",
2291                test_name,
2292                i,
2293                b,
2294                s,
2295                diff
2296            );
2297        }
2298
2299        Ok(())
2300    }
2301
2302    #[cfg(feature = "proptest")]
2303    #[allow(clippy::float_cmp)]
2304    fn check_bandpass_property(
2305        test_name: &str,
2306        kernel: Kernel,
2307    ) -> Result<(), Box<dyn std::error::Error>> {
2308        use proptest::prelude::*;
2309        skip_if_unsupported!(kernel, test_name);
2310
2311        let strat = (2usize..=50).prop_flat_map(|period| {
2312            (
2313                prop::collection::vec(
2314                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2315                    period..400,
2316                ),
2317                Just(period),
2318                (0.1f64..=0.9f64),
2319            )
2320        });
2321
2322        proptest::test_runner::TestRunner::default().run(&strat, |(data, period, bandwidth)| {
2323            let params = BandPassParams {
2324                period: Some(period),
2325                bandwidth: Some(bandwidth),
2326            };
2327            let input = BandPassInput::from_slice(&data, params);
2328
2329            let result = bandpass_with_kernel(&input, kernel)?;
2330
2331            let ref_result = bandpass_with_kernel(&input, Kernel::Scalar)?;
2332
2333            prop_assert_eq!(result.bp.len(), data.len());
2334            prop_assert_eq!(result.bp_normalized.len(), data.len());
2335            prop_assert_eq!(result.signal.len(), data.len());
2336            prop_assert_eq!(result.trigger.len(), data.len());
2337
2338            let hp_period = ((4.0 * period as f64) / bandwidth).round() as usize;
2339            let expected_warmup = hp_period.saturating_sub(1).max(2);
2340
2341            if data.len() >= expected_warmup {
2342                for i in 0..(expected_warmup.saturating_sub(1)).min(data.len()) {
2343                    prop_assert!(
2344                        result.bp[i].is_nan(),
2345                        "bp[{}] should be NaN during warmup but got {}",
2346                        i,
2347                        result.bp[i]
2348                    );
2349                }
2350            }
2351
2352            for i in 0..data.len() {
2353                if !result.bp[i].is_nan() {
2354                    prop_assert!(
2355                        result.bp[i].is_finite(),
2356                        "bp[{}] not finite: {}",
2357                        i,
2358                        result.bp[i]
2359                    );
2360                }
2361                if !result.bp_normalized[i].is_nan() {
2362                    prop_assert!(
2363                        result.bp_normalized[i].is_finite(),
2364                        "bp_normalized[{}] not finite: {}",
2365                        i,
2366                        result.bp_normalized[i]
2367                    );
2368                }
2369                if !result.signal[i].is_nan() {
2370                    prop_assert!(
2371                        result.signal[i].is_finite(),
2372                        "signal[{}] not finite: {}",
2373                        i,
2374                        result.signal[i]
2375                    );
2376                }
2377                if !result.trigger[i].is_nan() {
2378                    prop_assert!(
2379                        result.trigger[i].is_finite(),
2380                        "trigger[{}] not finite: {}",
2381                        i,
2382                        result.trigger[i]
2383                    );
2384                }
2385            }
2386
2387            for i in 0..data.len() {
2388                let sig = result.signal[i];
2389                if !sig.is_nan() {
2390                    prop_assert!(
2391                        sig == -1.0 || sig == 0.0 || sig == 1.0,
2392                        "signal[{}] = {} not in {{-1, 0, 1}}",
2393                        i,
2394                        sig
2395                    );
2396                }
2397            }
2398
2399            for i in 0..data.len() {
2400                let norm = result.bp_normalized[i];
2401                if !norm.is_nan() {
2402                    prop_assert!(
2403                        norm >= -1.0 - 1e-9 && norm <= 1.0 + 1e-9,
2404                        "bp_normalized[{}] = {} not in [-1, 1]",
2405                        i,
2406                        norm
2407                    );
2408                }
2409            }
2410
2411            if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9)
2412                && data.len() > expected_warmup + 20
2413            {
2414                let check_start = (expected_warmup + 10).min(data.len() - 1);
2415                for i in check_start..data.len() {
2416                    if !result.bp[i].is_nan() {
2417                        let tolerance = 1e-6 * data[0].abs().max(1.0);
2418                        prop_assert!(
2419                            result.bp[i].abs() < tolerance,
2420                            "bp[{}] = {} not near 0 for constant data (tolerance={})",
2421                            i,
2422                            result.bp[i],
2423                            tolerance
2424                        );
2425                    }
2426                }
2427            }
2428
2429            for i in 0..data.len() {
2430                let bn = result.bp_normalized[i];
2431                let tr = result.trigger[i];
2432                let sig = result.signal[i];
2433
2434                if !bn.is_nan() && !tr.is_nan() && !sig.is_nan() {
2435                    if (bn - tr).abs() < 1e-12 {
2436                        prop_assert_eq!(
2437                            sig,
2438                            0.0,
2439                            "signal[{}] should be 0 when bp_normalized≈trigger",
2440                            i
2441                        );
2442                    } else if bn < tr {
2443                        prop_assert_eq!(
2444                            sig,
2445                            1.0,
2446                            "signal[{}] should be 1 when bp_normalized<trigger ({} < {})",
2447                            i,
2448                            bn,
2449                            tr
2450                        );
2451                    } else {
2452                        prop_assert_eq!(
2453                            sig,
2454                            -1.0,
2455                            "signal[{}] should be -1 when bp_normalized>trigger ({} > {})",
2456                            i,
2457                            bn,
2458                            tr
2459                        );
2460                    }
2461                }
2462            }
2463
2464            if data.len() > 50 {
2465                let last_quarter_start = (data.len() * 3) / 4;
2466                let mut max_abs_bp = 0.0f64;
2467                let mut max_abs_input = 0.0f64;
2468
2469                for i in last_quarter_start..data.len() {
2470                    if !result.bp[i].is_nan() {
2471                        max_abs_bp = max_abs_bp.max(result.bp[i].abs());
2472                    }
2473                    max_abs_input = max_abs_input.max(data[i].abs());
2474                }
2475
2476                if max_abs_input > 0.0 {
2477                    let amplification = max_abs_bp / max_abs_input;
2478                    prop_assert!(
2479							amplification < 100.0,
2480							"Recursive filter may be unstable: amplification={} at period={}, bandwidth={}",
2481							amplification, period, bandwidth
2482						);
2483                }
2484            }
2485
2486            if period == 2 {
2487                let non_nan_count = result.bp.iter().filter(|&&x| !x.is_nan()).count();
2488                prop_assert!(
2489                    non_nan_count >= data.len().saturating_sub(expected_warmup),
2490                    "period=2 should produce mostly valid output after warmup"
2491                );
2492            }
2493
2494            for i in 0..data.len() {
2495                let bp = result.bp[i];
2496                let bp_ref = ref_result.bp[i];
2497                let bp_norm = result.bp_normalized[i];
2498                let bp_norm_ref = ref_result.bp_normalized[i];
2499                let sig = result.signal[i];
2500                let sig_ref = ref_result.signal[i];
2501                let trig = result.trigger[i];
2502                let trig_ref = ref_result.trigger[i];
2503
2504                if !bp.is_finite() || !bp_ref.is_finite() {
2505                    prop_assert_eq!(
2506                        bp.to_bits(),
2507                        bp_ref.to_bits(),
2508                        "bp finite/NaN mismatch at {}",
2509                        i
2510                    );
2511                } else {
2512                    let ulp_diff = bp.to_bits().abs_diff(bp_ref.to_bits());
2513                    prop_assert!(
2514                        (bp - bp_ref).abs() <= 1e-9 || ulp_diff <= 4,
2515                        "bp mismatch at {}: {} vs {} (ULP={})",
2516                        i,
2517                        bp,
2518                        bp_ref,
2519                        ulp_diff
2520                    );
2521                }
2522
2523                if !bp_norm.is_finite() || !bp_norm_ref.is_finite() {
2524                    prop_assert_eq!(
2525                        bp_norm.to_bits(),
2526                        bp_norm_ref.to_bits(),
2527                        "bp_normalized finite/NaN mismatch at {}",
2528                        i
2529                    );
2530                } else {
2531                    let ulp_diff = bp_norm.to_bits().abs_diff(bp_norm_ref.to_bits());
2532                    prop_assert!(
2533                        (bp_norm - bp_norm_ref).abs() <= 1e-9 || ulp_diff <= 4,
2534                        "bp_normalized mismatch at {}: {} vs {} (ULP={})",
2535                        i,
2536                        bp_norm,
2537                        bp_norm_ref,
2538                        ulp_diff
2539                    );
2540                }
2541
2542                if !sig.is_nan() && !sig_ref.is_nan() {
2543                    prop_assert_eq!(
2544                        sig,
2545                        sig_ref,
2546                        "signal mismatch at {}: {} vs {}",
2547                        i,
2548                        sig,
2549                        sig_ref
2550                    );
2551                } else {
2552                    prop_assert_eq!(
2553                        sig.is_nan(),
2554                        sig_ref.is_nan(),
2555                        "signal NaN mismatch at {}",
2556                        i
2557                    );
2558                }
2559
2560                if !trig.is_finite() || !trig_ref.is_finite() {
2561                    prop_assert_eq!(
2562                        trig.to_bits(),
2563                        trig_ref.to_bits(),
2564                        "trigger finite/NaN mismatch at {}",
2565                        i
2566                    );
2567                } else {
2568                    let ulp_diff = trig.to_bits().abs_diff(trig_ref.to_bits());
2569                    prop_assert!(
2570                        (trig - trig_ref).abs() <= 1e-9 || ulp_diff <= 4,
2571                        "trigger mismatch at {}: {} vs {} (ULP={})",
2572                        i,
2573                        trig,
2574                        trig_ref,
2575                        ulp_diff
2576                    );
2577                }
2578            }
2579
2580            Ok(())
2581        })?;
2582
2583        Ok(())
2584    }
2585
2586    #[cfg(not(feature = "proptest"))]
2587    fn check_bandpass_property(
2588        test_name: &str,
2589        kernel: Kernel,
2590    ) -> Result<(), Box<dyn std::error::Error>> {
2591        skip_if_unsupported!(kernel, test_name);
2592        Ok(())
2593    }
2594
2595    generate_all_bandpass_tests!(
2596        check_bandpass_partial_params,
2597        check_bandpass_accuracy,
2598        check_bandpass_default_candles,
2599        check_bandpass_zero_period,
2600        check_bandpass_period_exceeds_length,
2601        check_bandpass_very_small_dataset,
2602        check_bandpass_reinput,
2603        check_bandpass_nan_handling,
2604        check_bandpass_no_poison,
2605        check_bandpass_streaming,
2606        check_bandpass_property
2607    );
2608    fn check_batch_default_row(
2609        test: &str,
2610        kernel: Kernel,
2611    ) -> Result<(), Box<dyn std::error::Error>> {
2612        skip_if_unsupported!(kernel, test);
2613
2614        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2615        let c = read_candles_from_csv(file)?;
2616
2617        let output = BandPassBatchBuilder::new()
2618            .kernel(kernel)
2619            .apply_candles(&c, "close")?;
2620
2621        let def = BandPassParams::default();
2622        let row_idx = output.row_for_params(&def).expect("default row missing");
2623        let (bp_slice, _bpn_slice, _sig_slice, _trg_slice) =
2624            output.row_slices(row_idx).expect("row missing");
2625
2626        assert_eq!(bp_slice.len(), c.close.len());
2627
2628        let expected = [
2629            -236.23678021132827,
2630            -247.4846395608195,
2631            -242.3788746078502,
2632            -212.89589193350128,
2633            -179.97293838509464,
2634        ];
2635        let start = bp_slice.len() - 5;
2636        for (i, &v) in bp_slice[start..].iter().enumerate() {
2637            assert!(
2638                (v - expected[i]).abs() < 1e-1,
2639                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2640            );
2641        }
2642        Ok(())
2643    }
2644
2645    macro_rules! gen_batch_tests {
2646        ($fn_name:ident) => {
2647            paste::paste! {
2648                #[test] fn [<$fn_name _scalar>]()      {
2649                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2650                }
2651                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2652                #[test] fn [<$fn_name _avx2>]()        {
2653                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2654                }
2655                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2656                #[test] fn [<$fn_name _avx512>]()      {
2657                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2658                }
2659                #[test] fn [<$fn_name _auto_detect>]() {
2660                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2661                }
2662            }
2663        };
2664    }
2665
2666    #[cfg(debug_assertions)]
2667    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2668        skip_if_unsupported!(kernel, test);
2669
2670        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2671        let c = read_candles_from_csv(file)?;
2672
2673        let output = BandPassBatchBuilder::new()
2674            .kernel(kernel)
2675            .period_range(5, 50, 5)
2676            .bandwidth_range(0.1, 0.9, 0.1)
2677            .apply_candles(&c, "close")?;
2678
2679        for row_idx in 0..output.rows {
2680            let params = &output.combos[row_idx];
2681            let (bp_row, bpn_row, sig_row, trg_row) = output.row_slices(row_idx).unwrap();
2682
2683            for (col_idx, &val) in bp_row.iter().enumerate() {
2684                if val.is_nan() {
2685                    continue;
2686                }
2687
2688                let bits = val.to_bits();
2689
2690                if bits == 0x11111111_11111111 {
2691                    panic!(
2692                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2693                    test, val, bits, row_idx, col_idx, params
2694                );
2695                }
2696
2697                if bits == 0x22222222_22222222 {
2698                    panic!(
2699                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2700                    test, val, bits, row_idx, col_idx, params
2701                );
2702                }
2703
2704                if bits == 0x33333333_33333333 {
2705                    panic!(
2706                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2707                    test, val, bits, row_idx, col_idx, params
2708                );
2709                }
2710            }
2711
2712            for (col_idx, &val) in bpn_row.iter().enumerate() {
2713                if val.is_nan() {
2714                    continue;
2715                }
2716
2717                let bits = val.to_bits();
2718
2719                if bits == 0x11111111_11111111 {
2720                    panic!(
2721                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2722                    test, val, bits, row_idx, col_idx, params
2723                );
2724                }
2725
2726                if bits == 0x22222222_22222222 {
2727                    panic!(
2728                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2729                    test, val, bits, row_idx, col_idx, params
2730                );
2731                }
2732
2733                if bits == 0x33333333_33333333 {
2734                    panic!(
2735                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2736                    test, val, bits, row_idx, col_idx, params
2737                );
2738                }
2739            }
2740
2741            for (col_idx, &val) in sig_row.iter().enumerate() {
2742                if val.is_nan() {
2743                    continue;
2744                }
2745
2746                let bits = val.to_bits();
2747
2748                if bits == 0x11111111_11111111 {
2749                    panic!(
2750                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2751                    test, val, bits, row_idx, col_idx, params
2752                );
2753                }
2754
2755                if bits == 0x22222222_22222222 {
2756                    panic!(
2757                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2758                    test, val, bits, row_idx, col_idx, params
2759                );
2760                }
2761
2762                if bits == 0x33333333_33333333 {
2763                    panic!(
2764                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2765                    test, val, bits, row_idx, col_idx, params
2766                );
2767                }
2768            }
2769
2770            for (col_idx, &val) in trg_row.iter().enumerate() {
2771                if val.is_nan() {
2772                    continue;
2773                }
2774
2775                let bits = val.to_bits();
2776
2777                if bits == 0x11111111_11111111 {
2778                    panic!(
2779                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2780                    test, val, bits, row_idx, col_idx, params
2781                );
2782                }
2783
2784                if bits == 0x22222222_22222222 {
2785                    panic!(
2786                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2787                    test, val, bits, row_idx, col_idx, params
2788                );
2789                }
2790
2791                if bits == 0x33333333_33333333 {
2792                    panic!(
2793                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2794                    test, val, bits, row_idx, col_idx, params
2795                );
2796                }
2797            }
2798        }
2799
2800        Ok(())
2801    }
2802
2803    #[cfg(not(debug_assertions))]
2804    fn check_batch_no_poison(
2805        _test: &str,
2806        _kernel: Kernel,
2807    ) -> Result<(), Box<dyn std::error::Error>> {
2808        Ok(())
2809    }
2810
2811    gen_batch_tests!(check_batch_default_row);
2812    gen_batch_tests!(check_batch_no_poison);
2813}
2814
2815#[cfg(feature = "python")]
2816#[pyfunction(name = "bandpass")]
2817#[pyo3(signature = (data, period=20, bandwidth=0.3, kernel=None))]
2818pub fn bandpass_py<'py>(
2819    py: Python<'py>,
2820    data: numpy::PyReadonlyArray1<'py, f64>,
2821    period: usize,
2822    bandwidth: f64,
2823    kernel: Option<&str>,
2824) -> PyResult<Bound<'py, PyDict>> {
2825    use numpy::PyArrayMethods;
2826    let slice_in = data.as_slice()?;
2827    let kern = validate_kernel(kernel, false)?;
2828    let params = BandPassParams {
2829        period: Some(period),
2830        bandwidth: Some(bandwidth),
2831    };
2832    let inp = BandPassInput::from_slice(slice_in, params);
2833
2834    let out = py
2835        .allow_threads(|| bandpass_with_kernel(&inp, kern))
2836        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2837
2838    let dict = PyDict::new(py);
2839    dict.set_item("bp", out.bp.into_pyarray(py))?;
2840    dict.set_item("bp_normalized", out.bp_normalized.into_pyarray(py))?;
2841    dict.set_item("signal", out.signal.into_pyarray(py))?;
2842    dict.set_item("trigger", out.trigger.into_pyarray(py))?;
2843
2844    Ok(dict.into())
2845}
2846
2847#[cfg(feature = "python")]
2848#[pyclass(name = "BandPassStream")]
2849pub struct BandPassStreamPy {
2850    stream: BandPassStream,
2851}
2852
2853#[cfg(feature = "python")]
2854#[pymethods]
2855impl BandPassStreamPy {
2856    #[new]
2857    fn new(period: usize, bandwidth: f64) -> PyResult<Self> {
2858        let params = BandPassParams {
2859            period: Some(period),
2860            bandwidth: Some(bandwidth),
2861        };
2862        let stream =
2863            BandPassStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2864        Ok(BandPassStreamPy { stream })
2865    }
2866
2867    fn update(&mut self, value: f64) -> f64 {
2868        self.stream.update(value)
2869    }
2870}
2871
2872#[cfg(feature = "python")]
2873#[pyfunction(name = "bandpass_batch")]
2874#[pyo3(signature = (data, period_range, bandwidth_range, kernel=None))]
2875pub fn bandpass_batch_py<'py>(
2876    py: Python<'py>,
2877    data: numpy::PyReadonlyArray1<'py, f64>,
2878    period_range: (usize, usize, usize),
2879    bandwidth_range: (f64, f64, f64),
2880    kernel: Option<&str>,
2881) -> PyResult<Bound<'py, PyDict>> {
2882    use numpy::{PyArray1, PyArrayMethods};
2883    let slice_in = data.as_slice()?;
2884    let sweep = BandPassBatchRange {
2885        period: period_range,
2886        bandwidth: bandwidth_range,
2887    };
2888
2889    let kern = validate_kernel(kernel, true)?;
2890    let output = py
2891        .allow_threads(|| bandpass_batch_with_kernel(slice_in, &sweep, kern))
2892        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2893
2894    let rows = output.rows;
2895    let cols = output.cols;
2896
2897    unsafe {
2898        let total = rows
2899            .checked_mul(cols)
2900            .ok_or_else(|| PyValueError::new_err("bandpass_batch: rows*cols overflow"))?;
2901        let bp_arr = PyArray1::<f64>::new(py, [total], false);
2902        let bpn_arr = PyArray1::<f64>::new(py, [total], false);
2903        let sig_arr = PyArray1::<f64>::new(py, [total], false);
2904        let trg_arr = PyArray1::<f64>::new(py, [total], false);
2905
2906        bp_arr.as_slice_mut()?.copy_from_slice(&output.bp);
2907        bpn_arr
2908            .as_slice_mut()?
2909            .copy_from_slice(&output.bp_normalized);
2910        sig_arr.as_slice_mut()?.copy_from_slice(&output.signal);
2911        trg_arr.as_slice_mut()?.copy_from_slice(&output.trigger);
2912
2913        let d = PyDict::new(py);
2914        d.set_item("bp", bp_arr.reshape((rows, cols))?)?;
2915        d.set_item("bp_normalized", bpn_arr.reshape((rows, cols))?)?;
2916        d.set_item("signal", sig_arr.reshape((rows, cols))?)?;
2917        d.set_item("trigger", trg_arr.reshape((rows, cols))?)?;
2918        d.set_item(
2919            "periods",
2920            output
2921                .combos
2922                .iter()
2923                .map(|p| p.period.unwrap() as u64)
2924                .collect::<Vec<_>>()
2925                .into_pyarray(py),
2926        )?;
2927        d.set_item(
2928            "bandwidths",
2929            output
2930                .combos
2931                .iter()
2932                .map(|p| p.bandwidth.unwrap())
2933                .collect::<Vec<_>>()
2934                .into_pyarray(py),
2935        )?;
2936        Ok(d)
2937    }
2938}
2939
2940#[cfg(all(feature = "python", feature = "cuda"))]
2941#[pyfunction(name = "bandpass_cuda_batch_dev")]
2942#[pyo3(signature = (close_f32, period_range, bandwidth_range, device_id=0))]
2943pub fn bandpass_cuda_batch_dev_py<'py>(
2944    py: Python<'py>,
2945    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2946    period_range: (usize, usize, usize),
2947    bandwidth_range: (f64, f64, f64),
2948    device_id: usize,
2949) -> PyResult<Bound<'py, PyDict>> {
2950    use numpy::IntoPyArray;
2951    if !cuda_available() {
2952        return Err(PyValueError::new_err("CUDA not available"));
2953    }
2954    let slice = close_f32.as_slice()?;
2955    let sweep = BandPassBatchRange {
2956        period: period_range,
2957        bandwidth: bandwidth_range,
2958    };
2959    let (outputs, combos, dev_id, ctx) = py.allow_threads(|| -> PyResult<_> {
2960        let cuda =
2961            CudaBandpass::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2962        let dev_id = cuda.device_id();
2963        let ctx = cuda.context_arc();
2964        let res = cuda
2965            .bandpass_batch_dev(slice, &sweep)
2966            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2967        cuda.stream()
2968            .synchronize()
2969            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2970        Ok((res.outputs, res.combos, dev_id, ctx))
2971    })?;
2972
2973    let d = PyDict::new(py);
2974    d.set_item(
2975        "bp",
2976        Py::new(
2977            py,
2978            BandPassDeviceArrayF32Py {
2979                inner: outputs.first,
2980                _ctx: ctx.clone(),
2981                device_id: dev_id,
2982                stream: 0,
2983            },
2984        )?,
2985    )?;
2986    d.set_item(
2987        "bp_normalized",
2988        Py::new(
2989            py,
2990            BandPassDeviceArrayF32Py {
2991                inner: outputs.second,
2992                _ctx: ctx.clone(),
2993                device_id: dev_id,
2994                stream: 0,
2995            },
2996        )?,
2997    )?;
2998    d.set_item(
2999        "signal",
3000        Py::new(
3001            py,
3002            BandPassDeviceArrayF32Py {
3003                inner: outputs.third,
3004                _ctx: ctx.clone(),
3005                device_id: dev_id,
3006                stream: 0,
3007            },
3008        )?,
3009    )?;
3010    d.set_item(
3011        "trigger",
3012        Py::new(
3013            py,
3014            BandPassDeviceArrayF32Py {
3015                inner: outputs.fourth,
3016                _ctx: ctx,
3017                device_id: dev_id,
3018                stream: 0,
3019            },
3020        )?,
3021    )?;
3022
3023    let periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap()).collect();
3024    let bands: Vec<f64> = combos.iter().map(|p| p.bandwidth.unwrap()).collect();
3025    d.set_item("periods", periods.into_pyarray(py))?;
3026    d.set_item("bandwidths", bands.into_pyarray(py))?;
3027    d.set_item("rows", combos.len())?;
3028    d.set_item("cols", slice.len())?;
3029    Ok(d)
3030}
3031
3032#[cfg(all(feature = "python", feature = "cuda"))]
3033#[pyfunction(name = "bandpass_cuda_many_series_one_param_dev")]
3034#[pyo3(signature = (data_tm_f32, period, bandwidth, device_id=0))]
3035pub fn bandpass_cuda_many_series_one_param_dev_py<'py>(
3036    py: Python<'py>,
3037    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3038    period: usize,
3039    bandwidth: f64,
3040    device_id: usize,
3041) -> PyResult<Bound<'py, PyDict>> {
3042    if !cuda_available() {
3043        return Err(PyValueError::new_err("CUDA not available"));
3044    }
3045    let shape = data_tm_f32.shape();
3046    if shape.len() != 2 {
3047        return Err(PyValueError::new_err("expected 2D array"));
3048    }
3049    let rows = shape[0];
3050    let cols = shape[1];
3051    let flat = data_tm_f32.as_slice()?;
3052    let params = BandPassParams {
3053        period: Some(period),
3054        bandwidth: Some(bandwidth),
3055    };
3056
3057    let (outputs, dev_id, ctx) = py.allow_threads(|| -> PyResult<_> {
3058        let cuda =
3059            CudaBandpass::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3060        let dev_id = cuda.device_id();
3061        let ctx = cuda.context_arc();
3062        let out = cuda
3063            .bandpass_many_series_one_param_time_major_dev(flat, cols, rows, &params)
3064            .map_err(|e| PyValueError::new_err(e.to_string()))?;
3065        cuda.stream()
3066            .synchronize()
3067            .map_err(|e| PyValueError::new_err(e.to_string()))?;
3068        Ok((out, dev_id, ctx))
3069    })?;
3070
3071    let d = PyDict::new(py);
3072    d.set_item(
3073        "bp",
3074        Py::new(
3075            py,
3076            BandPassDeviceArrayF32Py {
3077                inner: outputs.first,
3078                _ctx: ctx.clone(),
3079                device_id: dev_id,
3080                stream: 0,
3081            },
3082        )?,
3083    )?;
3084    d.set_item(
3085        "bp_normalized",
3086        Py::new(
3087            py,
3088            BandPassDeviceArrayF32Py {
3089                inner: outputs.second,
3090                _ctx: ctx.clone(),
3091                device_id: dev_id,
3092                stream: 0,
3093            },
3094        )?,
3095    )?;
3096    d.set_item(
3097        "signal",
3098        Py::new(
3099            py,
3100            BandPassDeviceArrayF32Py {
3101                inner: outputs.third,
3102                _ctx: ctx.clone(),
3103                device_id: dev_id,
3104                stream: 0,
3105            },
3106        )?,
3107    )?;
3108    d.set_item(
3109        "trigger",
3110        Py::new(
3111            py,
3112            BandPassDeviceArrayF32Py {
3113                inner: outputs.fourth,
3114                _ctx: ctx,
3115                device_id: dev_id,
3116                stream: 0,
3117            },
3118        )?,
3119    )?;
3120    d.set_item("rows", rows)?;
3121    d.set_item("cols", cols)?;
3122    d.set_item("period", period)?;
3123    d.set_item("bandwidth", bandwidth)?;
3124    Ok(d)
3125}
3126
3127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3128#[derive(Serialize, Deserialize)]
3129pub struct BandPassJsResult {
3130    pub values: Vec<f64>,
3131    pub rows: usize,
3132    pub cols: usize,
3133}
3134
3135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3136#[wasm_bindgen(js_name = "bandpass_js")]
3137pub fn bandpass_js(data: &[f64], period: usize, bandwidth: f64) -> Result<JsValue, JsValue> {
3138    let input = BandPassInput::from_slice(
3139        data,
3140        BandPassParams {
3141            period: Some(period),
3142            bandwidth: Some(bandwidth),
3143        },
3144    );
3145    let out = bandpass_with_kernel(&input, detect_best_kernel())
3146        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3147    let cols = data.len();
3148    let mut values = Vec::with_capacity(4 * cols);
3149    values.extend_from_slice(&out.bp);
3150    values.extend_from_slice(&out.bp_normalized);
3151    values.extend_from_slice(&out.signal);
3152    values.extend_from_slice(&out.trigger);
3153    serde_wasm_bindgen::to_value(&BandPassJsResult {
3154        values,
3155        rows: 4,
3156        cols,
3157    })
3158    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3159}
3160
3161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3162#[derive(Serialize, Deserialize)]
3163pub struct BandPassBatchConfig {
3164    pub period_range: (usize, usize, usize),
3165    pub bandwidth_range: (f64, f64, f64),
3166}
3167
3168#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3169#[derive(Serialize, Deserialize)]
3170pub struct BandPassBatchJsOutput {
3171    pub values: Vec<f64>,
3172    pub rows: usize,
3173    pub cols: usize,
3174    pub combos: Vec<BandPassParams>,
3175}
3176
3177#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3178#[wasm_bindgen(js_name = "bandpass_batch")]
3179pub fn bandpass_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3180    let cfg: BandPassBatchConfig = serde_wasm_bindgen::from_value(config)
3181        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3182    let sweep = BandPassBatchRange {
3183        period: cfg.period_range,
3184        bandwidth: cfg.bandwidth_range,
3185    };
3186    let out = bandpass_batch_with_kernel(data, &sweep, Kernel::Auto)
3187        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3188
3189    let rows = out.rows;
3190    let cols = out.cols;
3191
3192    let total = rows
3193        .checked_mul(cols)
3194        .and_then(|v| v.checked_mul(4))
3195        .ok_or_else(|| JsValue::from_str("bandpass_batch_js: rows*cols overflow"))?;
3196    let mut values = Vec::with_capacity(total);
3197    values.extend_from_slice(&out.bp);
3198    values.extend_from_slice(&out.bp_normalized);
3199    values.extend_from_slice(&out.signal);
3200    values.extend_from_slice(&out.trigger);
3201
3202    let js_output = js_sys::Object::new();
3203    js_sys::Reflect::set(
3204        &js_output,
3205        &JsValue::from_str("values"),
3206        &serde_wasm_bindgen::to_value(&values).unwrap(),
3207    )?;
3208    js_sys::Reflect::set(
3209        &js_output,
3210        &JsValue::from_str("combos"),
3211        &JsValue::from_f64(out.combos.len() as f64),
3212    )?;
3213    js_sys::Reflect::set(
3214        &js_output,
3215        &JsValue::from_str("outputs"),
3216        &JsValue::from_f64(4.0),
3217    )?;
3218    js_sys::Reflect::set(
3219        &js_output,
3220        &JsValue::from_str("cols"),
3221        &JsValue::from_f64(cols as f64),
3222    )?;
3223    Ok(JsValue::from(js_output))
3224}
3225
3226#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3227#[wasm_bindgen(js_name = "bandpass_batch_metadata")]
3228pub fn bandpass_batch_metadata_js(
3229    period_start: usize,
3230    period_end: usize,
3231    period_step: usize,
3232    bandwidth_start: f64,
3233    bandwidth_end: f64,
3234    bandwidth_step: f64,
3235) -> Result<JsValue, JsValue> {
3236    let sweep = BandPassBatchRange {
3237        period: (period_start, period_end, period_step),
3238        bandwidth: (bandwidth_start, bandwidth_end, bandwidth_step),
3239    };
3240    let combos = expand_grid(&sweep);
3241
3242    let mut flat = Vec::with_capacity(combos.len() * 2);
3243    for combo in &combos {
3244        flat.push(combo.period.unwrap() as f64);
3245        flat.push(combo.bandwidth.unwrap());
3246    }
3247    serde_wasm_bindgen::to_value(&flat).map_err(|e| JsValue::from_str(&e.to_string()))
3248}
3249
3250#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3251#[wasm_bindgen]
3252pub fn bandpass_into(
3253    in_ptr: *const f64,
3254    len: usize,
3255    period: usize,
3256    bandwidth: f64,
3257    bp_ptr: *mut f64,
3258    bpn_ptr: *mut f64,
3259    sig_ptr: *mut f64,
3260    trg_ptr: *mut f64,
3261) -> Result<(), JsValue> {
3262    if in_ptr.is_null()
3263        || bp_ptr.is_null()
3264        || bpn_ptr.is_null()
3265        || sig_ptr.is_null()
3266        || trg_ptr.is_null()
3267    {
3268        return Err(JsValue::from_str("null pointer in bandpass_into"));
3269    }
3270
3271    if in_ptr == bp_ptr as *const f64
3272        || in_ptr == bpn_ptr as *const f64
3273        || in_ptr == sig_ptr as *const f64
3274        || in_ptr == trg_ptr as *const f64
3275    {
3276        return Err(JsValue::from_str(
3277            "input and output pointers must not alias",
3278        ));
3279    }
3280
3281    let out_ptrs = [bp_ptr, bpn_ptr, sig_ptr, trg_ptr];
3282    for i in 0..out_ptrs.len() {
3283        for j in i + 1..out_ptrs.len() {
3284            if out_ptrs[i] == out_ptrs[j] {
3285                return Err(JsValue::from_str(
3286                    "output pointers must not alias with each other",
3287                ));
3288            }
3289        }
3290    }
3291
3292    unsafe {
3293        let data = std::slice::from_raw_parts(in_ptr, len);
3294        let mut bp = std::slice::from_raw_parts_mut(bp_ptr, len);
3295        let mut bpn = std::slice::from_raw_parts_mut(bpn_ptr, len);
3296        let mut sig = std::slice::from_raw_parts_mut(sig_ptr, len);
3297        let mut trg = std::slice::from_raw_parts_mut(trg_ptr, len);
3298        let input = BandPassInput::from_slice(
3299            data,
3300            BandPassParams {
3301                period: Some(period),
3302                bandwidth: Some(bandwidth),
3303            },
3304        );
3305        bandpass_into_slice(
3306            &mut bp,
3307            &mut bpn,
3308            &mut sig,
3309            &mut trg,
3310            &input,
3311            detect_best_kernel(),
3312        )
3313        .map_err(|e| JsValue::from_str(&e.to_string()))
3314    }
3315}
3316
3317#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3318#[wasm_bindgen]
3319pub fn bandpass_alloc(len: usize) -> *mut f64 {
3320    let mut vec = Vec::<f64>::with_capacity(len);
3321    let ptr = vec.as_mut_ptr();
3322    std::mem::forget(vec);
3323    ptr
3324}
3325
3326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3327#[wasm_bindgen]
3328pub fn bandpass_free(ptr: *mut f64, len: usize) {
3329    if !ptr.is_null() {
3330        unsafe {
3331            let _ = Vec::from_raw_parts(ptr, len, len);
3332        }
3333    }
3334}