1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::bandpass_wrapper::CudaBandpass;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::cuda_available;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::moving_averages::DeviceArrayF32;
23use crate::indicators::highpass::{highpass, HighPassError, HighPassInput, HighPassParams};
24use crate::utilities::data_loader::{source_type, Candles};
25#[cfg(all(feature = "python", feature = "cuda"))]
26use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
27use crate::utilities::enums::Kernel;
28use crate::utilities::helpers::{
29 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
30 make_uninit_matrix,
31};
32#[cfg(feature = "python")]
33use crate::utilities::kernel_validation::validate_kernel;
34use aligned_vec::{AVec, CACHELINE_ALIGN};
35#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
36use core::arch::x86_64::*;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::context::Context;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use cust::memory::DeviceBuffer;
41#[cfg(not(target_arch = "wasm32"))]
42use rayon::prelude::*;
43use std::convert::AsRef;
44use std::f64::consts::PI;
45#[cfg(all(feature = "python", feature = "cuda"))]
46use std::sync::Arc;
47use thiserror::Error;
48
49impl<'a> AsRef<[f64]> for BandPassInput<'a> {
50 #[inline(always)]
51 fn as_ref(&self) -> &[f64] {
52 match &self.data {
53 BandPassData::Slice(slice) => slice,
54 BandPassData::Candles { candles, source } => source_type(candles, source),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum BandPassData<'a> {
61 Candles {
62 candles: &'a Candles,
63 source: &'a str,
64 },
65 Slice(&'a [f64]),
66}
67
68#[derive(Debug, Clone, Copy)]
69#[cfg_attr(
70 all(target_arch = "wasm32", feature = "wasm"),
71 derive(Serialize, Deserialize)
72)]
73pub struct BandPassParams {
74 pub period: Option<usize>,
75 pub bandwidth: Option<f64>,
76}
77
78impl Default for BandPassParams {
79 fn default() -> Self {
80 Self {
81 period: Some(20),
82 bandwidth: Some(0.3),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct BandPassInput<'a> {
89 pub data: BandPassData<'a>,
90 pub params: BandPassParams,
91}
92
93impl<'a> BandPassInput<'a> {
94 #[inline]
95 pub fn from_candles(c: &'a Candles, s: &'a str, p: BandPassParams) -> Self {
96 Self {
97 data: BandPassData::Candles {
98 candles: c,
99 source: s,
100 },
101 params: p,
102 }
103 }
104 #[inline]
105 pub fn from_slice(sl: &'a [f64], p: BandPassParams) -> Self {
106 Self {
107 data: BandPassData::Slice(sl),
108 params: p,
109 }
110 }
111 #[inline]
112 pub fn with_default_candles(c: &'a Candles) -> Self {
113 Self::from_candles(c, "close", BandPassParams::default())
114 }
115 #[inline]
116 pub fn get_period(&self) -> usize {
117 self.params.period.unwrap_or(20)
118 }
119 #[inline]
120 pub fn get_bandwidth(&self) -> f64 {
121 self.params.bandwidth.unwrap_or(0.3)
122 }
123}
124
125#[derive(Copy, Clone, Debug)]
126pub struct BandPassBuilder {
127 period: Option<usize>,
128 bandwidth: Option<f64>,
129 kernel: Kernel,
130}
131
132impl Default for BandPassBuilder {
133 fn default() -> Self {
134 Self {
135 period: None,
136 bandwidth: None,
137 kernel: Kernel::Auto,
138 }
139 }
140}
141
142impl BandPassBuilder {
143 #[inline(always)]
144 pub fn new() -> Self {
145 Self::default()
146 }
147 #[inline(always)]
148 pub fn period(mut self, n: usize) -> Self {
149 self.period = Some(n);
150 self
151 }
152 #[inline(always)]
153 pub fn bandwidth(mut self, b: f64) -> Self {
154 self.bandwidth = Some(b);
155 self
156 }
157 #[inline(always)]
158 pub fn kernel(mut self, k: Kernel) -> Self {
159 self.kernel = k;
160 self
161 }
162 #[inline(always)]
163 pub fn apply(self, c: &Candles) -> Result<BandPassOutput, BandPassError> {
164 let p = BandPassParams {
165 period: self.period,
166 bandwidth: self.bandwidth,
167 };
168 let i = BandPassInput::from_candles(c, "close", p);
169 bandpass_with_kernel(&i, self.kernel)
170 }
171 #[inline(always)]
172 pub fn apply_slice(self, d: &[f64]) -> Result<BandPassOutput, BandPassError> {
173 let p = BandPassParams {
174 period: self.period,
175 bandwidth: self.bandwidth,
176 };
177 let i = BandPassInput::from_slice(d, p);
178 bandpass_with_kernel(&i, self.kernel)
179 }
180 #[inline(always)]
181 pub fn into_stream(self) -> Result<BandPassStream, BandPassError> {
182 let p = BandPassParams {
183 period: self.period,
184 bandwidth: self.bandwidth,
185 };
186 BandPassStream::try_new(p)
187 }
188}
189
190#[derive(Debug, Error)]
191pub enum BandPassError {
192 #[error("bandpass: Input data slice is empty.")]
193 EmptyInputData,
194 #[error("bandpass: All values are NaN.")]
195 AllValuesNaN,
196 #[error("bandpass: Invalid period: period={period}, data_len={data_len}")]
197 InvalidPeriod { period: usize, data_len: usize },
198 #[error("bandpass: Not enough valid data: needed={needed}, valid={valid}")]
199 NotEnoughValidData { needed: usize, valid: usize },
200 #[error("bandpass: Invalid bandwidth={bandwidth}")]
201 InvalidBandwidth { bandwidth: f64 },
202 #[error("bandpass: hp_period too small ({hp_period})")]
203 HpPeriodTooSmall { hp_period: usize },
204 #[error("bandpass: trigger_period too small ({trigger_period})")]
205 TriggerPeriodTooSmall { trigger_period: usize },
206 #[error("bandpass: Output length mismatch: expected={expected}, got={got}")]
207 OutputLengthMismatch { expected: usize, got: usize },
208 #[error("bandpass: Invalid range: start={start}, end={end}, step={step}")]
209 InvalidRange {
210 start: usize,
211 end: usize,
212 step: usize,
213 },
214 #[error("bandpass: Invalid kernel for batch: {0:?}")]
215 InvalidKernelForBatch(Kernel),
216 #[error(transparent)]
217 HighPassError(#[from] HighPassError),
218}
219
220#[derive(Debug, Clone)]
221pub struct BandPassOutput {
222 pub bp: Vec<f64>,
223 pub bp_normalized: Vec<f64>,
224 pub signal: Vec<f64>,
225 pub trigger: Vec<f64>,
226}
227
228#[inline]
229pub fn bandpass(input: &BandPassInput) -> Result<BandPassOutput, BandPassError> {
230 bandpass_with_kernel(input, Kernel::Auto)
231}
232
233#[inline(always)]
234fn bandpass_prepare<'a>(
235 input: &'a BandPassInput,
236 kernel: Kernel,
237) -> Result<(&'a [f64], usize, usize, f64, usize, usize, Kernel), BandPassError> {
238 let data: &[f64] = input.as_ref();
239 let len = data.len();
240 let period = input.get_period();
241 let bandwidth = input.get_bandwidth();
242
243 if len == 0 {
244 return Err(BandPassError::EmptyInputData);
245 }
246 let first_valid = data
247 .iter()
248 .position(|x| x.is_finite())
249 .ok_or(BandPassError::AllValuesNaN)?;
250 if period == 0 || period > len {
251 return Err(BandPassError::InvalidPeriod {
252 period,
253 data_len: len,
254 });
255 }
256 if len - first_valid < period {
257 return Err(BandPassError::NotEnoughValidData {
258 needed: period,
259 valid: len - first_valid,
260 });
261 }
262 if !(0.0..=1.0).contains(&bandwidth) || !bandwidth.is_finite() {
263 return Err(BandPassError::InvalidBandwidth { bandwidth });
264 }
265
266 let hp_period = (4.0 * period as f64 / bandwidth).round() as usize;
267 if hp_period < 2 {
268 return Err(BandPassError::HpPeriodTooSmall { hp_period });
269 }
270 let trigger_period = ((period as f64 / bandwidth) / 1.5).round() as usize;
271 if trigger_period < 2 {
272 return Err(BandPassError::TriggerPeriodTooSmall { trigger_period });
273 }
274
275 let chosen = match kernel {
276 Kernel::Auto => detect_best_kernel(),
277 k => k,
278 };
279 Ok((
280 data,
281 len,
282 period,
283 bandwidth,
284 hp_period,
285 trigger_period,
286 chosen,
287 ))
288}
289
290pub fn bandpass_with_kernel(
291 input: &BandPassInput,
292 kernel: Kernel,
293) -> Result<BandPassOutput, BandPassError> {
294 let (data, len, period, bandwidth, hp_period, trigger_period, chosen) =
295 bandpass_prepare(input, kernel)?;
296
297 let mut hp_params = HighPassParams::default();
298 hp_params.period = Some(hp_period);
299 let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
300
301 let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(len);
302 let warmup_bp = first_valid_hp.saturating_add(2).max(2).min(len);
303
304 let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
305 let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
306 let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
307
308 let mut bp = alloc_with_nan_prefix(len, warmup_bp);
309
310 let bp_start = warmup_bp.saturating_sub(2);
311 unsafe {
312 match chosen {
313 Kernel::Scalar | Kernel::ScalarBatch => {
314 bandpass_scalar(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
315 }
316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317 Kernel::Avx2 | Kernel::Avx2Batch => {
318 bandpass_avx2(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
319 }
320 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
321 Kernel::Avx512 | Kernel::Avx512Batch => {
322 bandpass_avx512(&hp[bp_start..], period, alpha, beta, &mut bp[bp_start..])
323 }
324 _ => unreachable!(),
325 }
326 }
327
328 for v in &mut bp[..warmup_bp] {
329 *v = f64::NAN;
330 }
331
332 let mut bp_normalized = alloc_with_nan_prefix(len, warmup_bp);
333
334 let k = 0.991;
335 let mut peak_prev = 0.0f64;
336 for i in warmup_bp..len {
337 peak_prev *= k;
338 let abs_bp = bp[i].abs();
339 if abs_bp > peak_prev {
340 peak_prev = abs_bp;
341 }
342 bp_normalized[i] = if peak_prev != 0.0 {
343 bp[i] / peak_prev
344 } else {
345 0.0
346 };
347 }
348
349 let mut trigger = alloc_with_nan_prefix(len, warmup_bp);
350 if warmup_bp < len {
351 let mut trigger_params = HighPassParams::default();
352 trigger_params.period = Some(trigger_period);
353 let trig_inp = HighPassInput::from_slice(&bp_normalized[warmup_bp..], trigger_params);
354 crate::indicators::moving_averages::highpass::highpass_into_slice(
355 &mut trigger[warmup_bp..],
356 &trig_inp,
357 chosen,
358 )?;
359 }
360
361 let first_trig = trigger.iter().position(|x| !x.is_nan()).unwrap_or(len);
362 let warm_sig = warmup_bp.max(first_trig);
363 let mut signal = alloc_with_nan_prefix(len, warm_sig);
364 for i in warm_sig..len {
365 let bn = bp_normalized[i];
366 let tr = trigger[i];
367 signal[i] = if bn < tr {
368 1.0
369 } else if bn > tr {
370 -1.0
371 } else {
372 0.0
373 };
374 }
375
376 Ok(BandPassOutput {
377 bp,
378 bp_normalized,
379 signal,
380 trigger,
381 })
382}
383
384#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
385#[inline]
386pub fn bandpass_into(
387 input: &BandPassInput,
388 bp_out: &mut [f64],
389 bp_normalized_out: &mut [f64],
390 signal_out: &mut [f64],
391 trigger_out: &mut [f64],
392) -> Result<(), BandPassError> {
393 bandpass_into_slice(
394 bp_out,
395 bp_normalized_out,
396 signal_out,
397 trigger_out,
398 input,
399 Kernel::Auto,
400 )
401}
402
403#[inline]
404pub fn bandpass_into_slice(
405 bp_dst: &mut [f64],
406 bpn_dst: &mut [f64],
407 sig_dst: &mut [f64],
408 trig_dst: &mut [f64],
409 input: &BandPassInput,
410 kernel: Kernel,
411) -> Result<(), BandPassError> {
412 let (data, len, period, bandwidth, hp_period, trigger_period, chosen) =
413 bandpass_prepare(input, kernel)?;
414 if bp_dst.len() != len || bpn_dst.len() != len || sig_dst.len() != len || trig_dst.len() != len
415 {
416 return Err(BandPassError::OutputLengthMismatch {
417 expected: len,
418 got: *[bp_dst.len(), bpn_dst.len(), sig_dst.len(), trig_dst.len()]
419 .iter()
420 .min()
421 .unwrap_or(&0),
422 });
423 }
424
425 let mut hp_params = HighPassParams::default();
426 hp_params.period = Some(hp_period);
427 let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
428
429 let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(len);
430 let warm_bp = first_valid_hp.saturating_add(2).max(2).min(len);
431
432 let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
433 let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
434 let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
435
436 let bp_start = warm_bp.saturating_sub(2);
437 unsafe {
438 match chosen {
439 Kernel::Scalar | Kernel::ScalarBatch => bandpass_scalar(
440 &hp[bp_start..],
441 period,
442 alpha,
443 beta,
444 &mut bp_dst[bp_start..],
445 ),
446 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
447 Kernel::Avx2 | Kernel::Avx2Batch => bandpass_avx2(
448 &hp[bp_start..],
449 period,
450 alpha,
451 beta,
452 &mut bp_dst[bp_start..],
453 ),
454 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
455 Kernel::Avx512 | Kernel::Avx512Batch => bandpass_avx512(
456 &hp[bp_start..],
457 period,
458 alpha,
459 beta,
460 &mut bp_dst[bp_start..],
461 ),
462 _ => unreachable!(),
463 }
464 }
465
466 for v in &mut bp_dst[..warm_bp] {
467 *v = f64::NAN;
468 }
469
470 for v in &mut bpn_dst[..warm_bp] {
471 *v = f64::NAN;
472 }
473 let k = 0.991;
474 let mut peak = 0.0f64;
475 for i in warm_bp..len {
476 peak *= k;
477 let v = bp_dst[i];
478 let av = v.abs();
479 if av > peak {
480 peak = av;
481 }
482 bpn_dst[i] = if peak != 0.0 { v / peak } else { 0.0 };
483 }
484
485 for v in trig_dst.iter_mut() {
486 *v = f64::NAN;
487 }
488 if warm_bp < len {
489 let mut trigger_params = HighPassParams::default();
490 trigger_params.period = Some(trigger_period);
491 let trig_inp = HighPassInput::from_slice(&bpn_dst[warm_bp..], trigger_params);
492
493 crate::indicators::moving_averages::highpass::highpass_into_slice(
494 &mut trig_dst[warm_bp..],
495 &trig_inp,
496 chosen,
497 )?;
498 }
499
500 let first_trig = trig_dst.iter().position(|x| !x.is_nan()).unwrap_or(len);
501 let warm_sig = warm_bp.max(first_trig);
502 for v in &mut sig_dst[..warm_sig] {
503 *v = f64::NAN;
504 }
505 for i in warm_sig..len {
506 let bn = bpn_dst[i];
507 let tr = trig_dst[i];
508 sig_dst[i] = if bn < tr {
509 1.0
510 } else if bn > tr {
511 -1.0
512 } else {
513 0.0
514 };
515 }
516
517 Ok(())
518}
519
520#[inline(always)]
521pub fn bandpass_scalar(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
522 let len = hp.len();
523 if len == 0 {
524 return;
525 }
526
527 out[0] = hp[0];
528 if len == 1 {
529 return;
530 }
531 out[1] = hp[1];
532 if len == 2 {
533 return;
534 }
535
536 let a = 0.5 * (1.0 - alpha);
537 let c = beta * (1.0 + alpha);
538 let d = -alpha;
539
540 let mut y_im2 = out[0];
541 let mut y_im1 = out[1];
542
543 let mut i = 2usize;
544 while i + 3 < len {
545 let delta0 = hp[i] - hp[i - 2];
546 let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
547 out[i] = y0;
548
549 let delta1 = hp[i + 1] - hp[i - 1];
550 let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
551 out[i + 1] = y1;
552
553 let delta2 = hp[i + 2] - hp[i];
554 let y2 = d.mul_add(y0, c.mul_add(y1, a * delta2));
555 out[i + 2] = y2;
556
557 let delta3 = hp[i + 3] - hp[i + 1];
558 let y3 = d.mul_add(y1, c.mul_add(y2, a * delta3));
559 out[i + 3] = y3;
560
561 y_im2 = y2;
562 y_im1 = y3;
563 i += 4;
564 }
565
566 while i + 1 < len {
567 let delta0 = hp[i] - hp[i - 2];
568 let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
569 out[i] = y0;
570
571 let delta1 = hp[i + 1] - hp[i - 1];
572 let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
573 out[i + 1] = y1;
574
575 y_im2 = y0;
576 y_im1 = y1;
577 i += 2;
578 }
579
580 if i < len {
581 let delta = hp[i] - hp[i - 2];
582 out[i] = d.mul_add(y_im2, c.mul_add(y_im1, a * delta));
583 }
584}
585
586#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
587#[inline(always)]
588pub fn bandpass_avx2(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
589 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
590}
591
592#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
593#[inline(always)]
594pub fn bandpass_avx512(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
595 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
596}
597
598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
599#[inline(always)]
600pub fn bandpass_avx512_short(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
601 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
602}
603
604#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
605#[inline(always)]
606pub fn bandpass_avx512_long(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
607 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
608}
609
610#[derive(Debug, Clone)]
611pub struct BandPassStream {
612 period: usize,
613
614 c0: f64,
615 c1: f64,
616 c2: f64,
617
618 hp_stream: crate::indicators::highpass::HighPassStream,
619
620 hp_z1: f64,
621 hp_z2: f64,
622 y_z1: f64,
623 y_z2: f64,
624
625 hp_valid: u8,
626}
627
628impl BandPassStream {
629 #[inline]
630 pub fn try_new(params: BandPassParams) -> Result<Self, BandPassError> {
631 let period = params.period.unwrap_or(20);
632 if period < 2 {
633 return Err(BandPassError::InvalidPeriod {
634 period,
635 data_len: 0,
636 });
637 }
638 let bandwidth = params.bandwidth.unwrap_or(0.3);
639 if !(0.0..=1.0).contains(&bandwidth) || !bandwidth.is_finite() {
640 return Err(BandPassError::InvalidBandwidth { bandwidth });
641 }
642
643 let hp_period = (4.0 * period as f64 / bandwidth).round() as usize;
644 if hp_period < 2 {
645 return Err(BandPassError::HpPeriodTooSmall { hp_period });
646 }
647
648 let mut hp_params = HighPassParams::default();
649 hp_params.period = Some(hp_period);
650 let hp_stream = crate::indicators::highpass::HighPassStream::try_new(hp_params)?;
651
652 use std::f64::consts::PI;
653 let beta = (2.0 * PI / period as f64).cos();
654 let gamma = (2.0 * PI * bandwidth / period as f64).cos();
655
656 #[inline(always)]
657 fn alpha_from_gamma(gamma: f64) -> f64 {
658 let g2 = gamma * gamma;
659 let s = (1.0 - g2).sqrt();
660 (1.0 - s) / gamma
661 }
662 let alpha = alpha_from_gamma(gamma);
663
664 let c0 = 0.5 * (1.0 - alpha);
665 let c1 = beta * (1.0 + alpha);
666 let c2 = -alpha;
667
668 Ok(Self {
669 period,
670 c0,
671 c1,
672 c2,
673 hp_stream,
674 hp_z1: 0.0,
675 hp_z2: 0.0,
676 y_z1: 0.0,
677 y_z2: 0.0,
678 hp_valid: 0,
679 })
680 }
681
682 #[inline(always)]
683 pub fn update(&mut self, value: f64) -> f64 {
684 let hp = self.hp_stream.update(value);
685
686 if !hp.is_finite() {
687 return f64::NAN;
688 }
689
690 if self.hp_valid < 2 {
691 let y = hp;
692
693 self.hp_z2 = self.hp_z1;
694 self.hp_z1 = hp;
695
696 self.y_z2 = self.y_z1;
697 self.y_z1 = y;
698
699 self.hp_valid += 1;
700 return f64::NAN;
701 }
702
703 let delta = hp - self.hp_z2;
704 let y = self
705 .c2
706 .mul_add(self.y_z2, self.c1.mul_add(self.y_z1, self.c0 * delta));
707
708 self.hp_z2 = self.hp_z1;
709 self.hp_z1 = hp;
710
711 self.y_z2 = self.y_z1;
712 self.y_z1 = y;
713
714 y
715 }
716}
717
718#[derive(Clone, Debug)]
719pub struct BandPassBatchRange {
720 pub period: (usize, usize, usize),
721 pub bandwidth: (f64, f64, f64),
722}
723
724impl Default for BandPassBatchRange {
725 fn default() -> Self {
726 Self {
727 period: (20, 20, 0),
728 bandwidth: (0.3, 0.549, 0.001),
729 }
730 }
731}
732
733#[derive(Clone, Debug, Default)]
734pub struct BandPassBatchBuilder {
735 range: BandPassBatchRange,
736 kernel: Kernel,
737}
738
739impl BandPassBatchBuilder {
740 pub fn new() -> Self {
741 Self::default()
742 }
743 pub fn kernel(mut self, k: Kernel) -> Self {
744 self.kernel = k;
745 self
746 }
747 #[inline]
748 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
749 self.range.period = (start, end, step);
750 self
751 }
752 #[inline]
753 pub fn period_static(mut self, p: usize) -> Self {
754 self.range.period = (p, p, 0);
755 self
756 }
757 #[inline]
758 pub fn bandwidth_range(mut self, start: f64, end: f64, step: f64) -> Self {
759 self.range.bandwidth = (start, end, step);
760 self
761 }
762 #[inline]
763 pub fn bandwidth_static(mut self, b: f64) -> Self {
764 self.range.bandwidth = (b, b, 0.0);
765 self
766 }
767 pub fn apply_slice(self, data: &[f64]) -> Result<BandPassBatchOutput, BandPassError> {
768 bandpass_batch_with_kernel(data, &self.range, self.kernel)
769 }
770 pub fn with_default_slice(
771 data: &[f64],
772 k: Kernel,
773 ) -> Result<BandPassBatchOutput, BandPassError> {
774 BandPassBatchBuilder::new().kernel(k).apply_slice(data)
775 }
776 pub fn apply_candles(
777 self,
778 c: &Candles,
779 src: &str,
780 ) -> Result<BandPassBatchOutput, BandPassError> {
781 let slice = source_type(c, src);
782 self.apply_slice(slice)
783 }
784 pub fn with_default_candles(c: &Candles) -> Result<BandPassBatchOutput, BandPassError> {
785 BandPassBatchBuilder::new()
786 .kernel(Kernel::Auto)
787 .apply_candles(c, "close")
788 }
789}
790
791#[derive(Clone, Debug)]
792pub struct BandPassBatchOutput {
793 pub bp: Vec<f64>,
794 pub bp_normalized: Vec<f64>,
795 pub signal: Vec<f64>,
796 pub trigger: Vec<f64>,
797 pub combos: Vec<BandPassParams>,
798 pub rows: usize,
799 pub cols: usize,
800}
801
802impl BandPassBatchOutput {
803 #[inline]
804 pub fn row_for_params(&self, p: &BandPassParams) -> Option<usize> {
805 self.combos.iter().position(|c| {
806 c.period.unwrap_or(20) == p.period.unwrap_or(20)
807 && (c.bandwidth.unwrap_or(0.3) - p.bandwidth.unwrap_or(0.3)).abs() < 1e-12
808 })
809 }
810 #[inline]
811 pub fn row_slices(&self, row: usize) -> Option<(&[f64], &[f64], &[f64], &[f64])> {
812 if row >= self.rows {
813 return None;
814 }
815 let (r, c) = (row, self.cols);
816 Some((
817 &self.bp[r * self.cols..r * self.cols + c],
818 &self.bp_normalized[r * self.cols..r * self.cols + c],
819 &self.signal[r * self.cols..r * self.cols + c],
820 &self.trigger[r * self.cols..r * self.cols + c],
821 ))
822 }
823}
824
825#[inline(always)]
826fn expand_grid(r: &BandPassBatchRange) -> Vec<BandPassParams> {
827 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, BandPassError> {
828 if step == 0 || start == end {
829 return Ok(vec![start]);
830 }
831 let mut vals = Vec::new();
832 if start < end {
833 let mut v = start;
834 while v <= end {
835 vals.push(v);
836 match v.checked_add(step) {
837 Some(next) => {
838 if next == v {
839 break;
840 }
841 v = next;
842 }
843 None => break,
844 }
845 }
846 } else {
847 let mut v = start;
848 loop {
849 vals.push(v);
850 if v == end {
851 break;
852 }
853 let next = v.saturating_sub(step);
854 if next == v {
855 break;
856 }
857 v = next;
858 if v < end {
859 break;
860 }
861 }
862 }
863 if vals.is_empty() {
864 return Err(BandPassError::InvalidRange { start, end, step });
865 }
866 Ok(vals)
867 }
868 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, BandPassError> {
869 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
870 return Ok(vec![start]);
871 }
872 let mut vals = Vec::new();
873 if start <= end {
874 let mut x = start;
875 loop {
876 vals.push(x);
877 if x >= end {
878 break;
879 }
880 let next = x + step;
881 if !next.is_finite() || next == x {
882 break;
883 }
884 x = next;
885 if x > end + 1e-12 {
886 break;
887 }
888 }
889 } else {
890 let mut x = start;
891 loop {
892 vals.push(x);
893 if x <= end {
894 break;
895 }
896 let next = x - step.abs();
897 if !next.is_finite() || next == x {
898 break;
899 }
900 x = next;
901 if x < end - 1e-12 {
902 break;
903 }
904 }
905 }
906 if vals.is_empty() {
907 return Err(BandPassError::InvalidRange {
908 start: start as usize,
909 end: end as usize,
910 step: step.abs() as usize,
911 });
912 }
913 Ok(vals)
914 }
915 let periods = match axis_usize(r.period) {
916 Ok(v) => v,
917 Err(_) => return Vec::new(),
918 };
919 let bandwidths = match axis_f64(r.bandwidth) {
920 Ok(v) => v,
921 Err(_) => return Vec::new(),
922 };
923 let mut out = Vec::with_capacity(periods.len() * bandwidths.len());
924 for &p in &periods {
925 for &b in &bandwidths {
926 out.push(BandPassParams {
927 period: Some(p),
928 bandwidth: Some(b),
929 });
930 }
931 }
932 out
933}
934
935#[cfg(all(feature = "python", feature = "cuda"))]
936#[pyclass(
937 module = "ta_indicators.cuda",
938 name = "BandPassDeviceArrayF32",
939 unsendable
940)]
941pub struct BandPassDeviceArrayF32Py {
942 pub(crate) inner: DeviceArrayF32,
943 pub(crate) _ctx: Arc<Context>,
944 pub(crate) device_id: u32,
945 pub(crate) stream: usize,
946}
947
948#[cfg(all(feature = "python", feature = "cuda"))]
949#[pymethods]
950impl BandPassDeviceArrayF32Py {
951 #[getter]
952 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
953 let inner = &self.inner;
954 let d = PyDict::new(py);
955 d.set_item("shape", (inner.rows, inner.cols))?;
956 d.set_item("typestr", "<f4")?;
957 d.set_item(
958 "strides",
959 (
960 inner.cols * std::mem::size_of::<f32>(),
961 std::mem::size_of::<f32>(),
962 ),
963 )?;
964 d.set_item("data", (inner.device_ptr() as usize, false))?;
965
966 d.set_item("version", 3)?;
967 Ok(d)
968 }
969
970 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
971 Ok((2, self.device_id as i32))
972 }
973
974 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
975 fn __dlpack__<'py>(
976 &mut self,
977 py: Python<'py>,
978 stream: Option<pyo3::PyObject>,
979 max_version: Option<pyo3::PyObject>,
980 dl_device: Option<pyo3::PyObject>,
981 copy: Option<pyo3::PyObject>,
982 ) -> PyResult<PyObject> {
983 let (kdl, alloc_dev) = self.__dlpack_device__()?;
984 if let Some(dev_obj) = dl_device.as_ref() {
985 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
986 if dev_ty != kdl || dev_id != alloc_dev {
987 let wants_copy = copy
988 .as_ref()
989 .and_then(|c| c.extract::<bool>(py).ok())
990 .unwrap_or(false);
991 if wants_copy {
992 return Err(PyValueError::new_err(
993 "device copy not implemented for __dlpack__",
994 ));
995 } else {
996 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
997 }
998 }
999 }
1000 }
1001 let _ = stream;
1002
1003 let dummy =
1004 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1005 let inner = std::mem::replace(
1006 &mut self.inner,
1007 DeviceArrayF32 {
1008 buf: dummy,
1009 rows: 0,
1010 cols: 0,
1011 },
1012 );
1013
1014 let rows = inner.rows;
1015 let cols = inner.cols;
1016 let buf = inner.buf;
1017
1018 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1019
1020 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1021 }
1022}
1023
1024pub fn bandpass_batch_with_kernel(
1025 data: &[f64],
1026 sweep: &BandPassBatchRange,
1027 k: Kernel,
1028) -> Result<BandPassBatchOutput, BandPassError> {
1029 let kernel = match k {
1030 Kernel::Auto => detect_best_batch_kernel(),
1031 other if other.is_batch() => other,
1032 other => return Err(BandPassError::InvalidKernelForBatch(other)),
1033 };
1034 let simd = match kernel {
1035 Kernel::Avx512Batch => Kernel::Avx512,
1036 Kernel::Avx2Batch => Kernel::Avx2,
1037 Kernel::ScalarBatch => Kernel::Scalar,
1038 _ => unreachable!(),
1039 };
1040 bandpass_batch_par_slice(data, sweep, simd)
1041}
1042
1043pub fn bandpass_batch_slice(
1044 data: &[f64],
1045 sweep: &BandPassBatchRange,
1046 kern: Kernel,
1047) -> Result<BandPassBatchOutput, BandPassError> {
1048 bandpass_batch_inner(data, sweep, kern, false)
1049}
1050
1051pub fn bandpass_batch_par_slice(
1052 data: &[f64],
1053 sweep: &BandPassBatchRange,
1054 kern: Kernel,
1055) -> Result<BandPassBatchOutput, BandPassError> {
1056 bandpass_batch_inner(data, sweep, kern, true)
1057}
1058
1059fn bandpass_batch_inner(
1060 data: &[f64],
1061 sweep: &BandPassBatchRange,
1062 kern: Kernel,
1063 parallel: bool,
1064) -> Result<BandPassBatchOutput, BandPassError> {
1065 let combos = expand_grid(sweep);
1066 if combos.is_empty() {
1067 return Err(BandPassError::InvalidRange {
1068 start: sweep.period.0,
1069 end: sweep.period.1,
1070 step: sweep.period.2,
1071 });
1072 }
1073
1074 let cols = data.len();
1075 if cols == 0 {
1076 return Err(BandPassError::EmptyInputData);
1077 }
1078 let rows = combos.len();
1079 rows.checked_mul(cols).ok_or(BandPassError::InvalidRange {
1080 start: sweep.period.0,
1081 end: sweep.period.1,
1082 step: sweep.period.2,
1083 })?;
1084
1085 let mut bp_mu = make_uninit_matrix(rows, cols);
1086 let mut bpn_mu = make_uninit_matrix(rows, cols);
1087 let mut sig_mu = make_uninit_matrix(rows, cols);
1088 let mut trg_mu = make_uninit_matrix(rows, cols);
1089
1090 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1091 let warms_bp: Vec<usize> = combos
1092 .iter()
1093 .map(|p| {
1094 let period = p.period.unwrap();
1095 let bandwidth = p.bandwidth.unwrap();
1096 let hp_p = (4.0 * period as f64 / bandwidth).round() as usize;
1097 let warm_hp = first + hp_p - 1;
1098 warm_hp.max(2)
1099 })
1100 .collect();
1101 let warms_trg: Vec<usize> = combos
1102 .iter()
1103 .map(|p| {
1104 let period = p.period.unwrap();
1105 let bandwidth = p.bandwidth.unwrap();
1106 let hp_p = (4.0 * period as f64 / bandwidth).round() as usize;
1107 let trig_p = ((period as f64 / bandwidth) / 1.5).round() as usize;
1108 let warm_hp = first + hp_p - 1;
1109 let warm_bp = warm_hp.max(2);
1110 warm_bp + trig_p - 1
1111 })
1112 .collect();
1113
1114 init_matrix_prefixes(&mut bp_mu, cols, &warms_bp);
1115 init_matrix_prefixes(&mut bpn_mu, cols, &warms_bp);
1116 init_matrix_prefixes(&mut trg_mu, cols, &warms_trg);
1117 init_matrix_prefixes(&mut sig_mu, cols, &warms_trg);
1118
1119 let mut bp_guard = core::mem::ManuallyDrop::new(bp_mu);
1120 let mut bpn_guard = core::mem::ManuallyDrop::new(bpn_mu);
1121 let mut sig_guard = core::mem::ManuallyDrop::new(sig_mu);
1122 let mut trg_guard = core::mem::ManuallyDrop::new(trg_mu);
1123
1124 let bp_out: &mut [f64] = unsafe {
1125 core::slice::from_raw_parts_mut(bp_guard.as_mut_ptr() as *mut f64, bp_guard.len())
1126 };
1127 let bpn_out: &mut [f64] = unsafe {
1128 core::slice::from_raw_parts_mut(bpn_guard.as_mut_ptr() as *mut f64, bpn_guard.len())
1129 };
1130 let sig_out: &mut [f64] = unsafe {
1131 core::slice::from_raw_parts_mut(sig_guard.as_mut_ptr() as *mut f64, sig_guard.len())
1132 };
1133 let trg_out: &mut [f64] = unsafe {
1134 core::slice::from_raw_parts_mut(trg_guard.as_mut_ptr() as *mut f64, trg_guard.len())
1135 };
1136
1137 bandpass_batch_inner_into(
1138 data, &combos, kern, parallel, bp_out, bpn_out, sig_out, trg_out,
1139 )?;
1140
1141 let bp = unsafe {
1142 Vec::from_raw_parts(
1143 bp_guard.as_mut_ptr() as *mut f64,
1144 bp_guard.len(),
1145 bp_guard.capacity(),
1146 )
1147 };
1148 let bpn = unsafe {
1149 Vec::from_raw_parts(
1150 bpn_guard.as_mut_ptr() as *mut f64,
1151 bpn_guard.len(),
1152 bpn_guard.capacity(),
1153 )
1154 };
1155 let sig = unsafe {
1156 Vec::from_raw_parts(
1157 sig_guard.as_mut_ptr() as *mut f64,
1158 sig_guard.len(),
1159 sig_guard.capacity(),
1160 )
1161 };
1162 let trg = unsafe {
1163 Vec::from_raw_parts(
1164 trg_guard.as_mut_ptr() as *mut f64,
1165 trg_guard.len(),
1166 trg_guard.capacity(),
1167 )
1168 };
1169
1170 Ok(BandPassBatchOutput {
1171 bp,
1172 bp_normalized: bpn,
1173 signal: sig,
1174 trigger: trg,
1175 combos,
1176 rows,
1177 cols,
1178 })
1179}
1180
1181#[inline(always)]
1182fn bandpass_batch_inner_into(
1183 data: &[f64],
1184 combos: &[BandPassParams],
1185 kern: Kernel,
1186 parallel: bool,
1187 bp_out: &mut [f64],
1188 bpn_out: &mut [f64],
1189 sig_out: &mut [f64],
1190 trg_out: &mut [f64],
1191) -> Result<(), BandPassError> {
1192 let rows = combos.len();
1193 let cols = data.len();
1194 let chosen = match kern {
1195 Kernel::Auto => detect_best_batch_kernel(),
1196 k => k,
1197 };
1198 let simd = match chosen {
1199 Kernel::ScalarBatch => Kernel::Scalar,
1200 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1201 Kernel::Avx2Batch => Kernel::Avx2,
1202 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1203 Kernel::Avx512Batch => Kernel::Avx512,
1204 _ => Kernel::Scalar,
1205 };
1206
1207 struct RowMeta {
1208 period: usize,
1209 bandwidth: f64,
1210 hp_p: usize,
1211 trig_p: usize,
1212 ksel: Kernel,
1213 }
1214
1215 let mut metas = Vec::with_capacity(rows);
1216 for &p in combos.iter() {
1217 let period = p.period.unwrap();
1218 let bandwidth = p.bandwidth.unwrap();
1219 let (_d, _len, _per, _bw, hp_p, trig_p, ksel) =
1220 bandpass_prepare(&BandPassInput::from_slice(data, p), simd)?;
1221 metas.push(RowMeta {
1222 period,
1223 bandwidth,
1224 hp_p,
1225 trig_p,
1226 ksel,
1227 });
1228 }
1229
1230 use std::collections::HashMap;
1231 let mut hp_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1232 for meta in metas.iter() {
1233 if !hp_cache.contains_key(&meta.hp_p) {
1234 let mut hp_params = HighPassParams::default();
1235 hp_params.period = Some(meta.hp_p);
1236 let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
1237 hp_cache.insert(meta.hp_p, hp);
1238 }
1239 }
1240
1241 if parallel {
1242 #[cfg(not(target_arch = "wasm32"))]
1243 {
1244 use rayon::prelude::*;
1245 use std::sync::atomic::{AtomicPtr, Ordering};
1246
1247 let bp_ptr = AtomicPtr::new(bp_out.as_mut_ptr());
1248 let bpn_ptr = AtomicPtr::new(bpn_out.as_mut_ptr());
1249 let sig_ptr = AtomicPtr::new(sig_out.as_mut_ptr());
1250 let trg_ptr = AtomicPtr::new(trg_out.as_mut_ptr());
1251
1252 (0..rows)
1253 .into_par_iter()
1254 .try_for_each(|row| -> Result<(), BandPassError> {
1255 let m = &metas[row];
1256 let period = m.period;
1257 let bandwidth = m.bandwidth;
1258
1259 let bp_r = unsafe {
1260 std::slice::from_raw_parts_mut(
1261 bp_ptr.load(Ordering::Relaxed).add(row * cols),
1262 cols,
1263 )
1264 };
1265 let bpn_r = unsafe {
1266 std::slice::from_raw_parts_mut(
1267 bpn_ptr.load(Ordering::Relaxed).add(row * cols),
1268 cols,
1269 )
1270 };
1271 let sig_r = unsafe {
1272 std::slice::from_raw_parts_mut(
1273 sig_ptr.load(Ordering::Relaxed).add(row * cols),
1274 cols,
1275 )
1276 };
1277 let trg_r = unsafe {
1278 std::slice::from_raw_parts_mut(
1279 trg_ptr.load(Ordering::Relaxed).add(row * cols),
1280 cols,
1281 )
1282 };
1283 let trig_p = m.trig_p;
1284 let ksel = m.ksel;
1285 let hp = hp_cache.get(&m.hp_p).expect("hp cache missing");
1286
1287 let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1288 let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1289 let bp_start = warm_bp.saturating_sub(2);
1290
1291 let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1292 let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1293 let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1294
1295 unsafe {
1296 match ksel {
1297 Kernel::Scalar => bandpass_scalar(
1298 &hp[bp_start..],
1299 period,
1300 alpha,
1301 beta,
1302 &mut bp_r[bp_start..],
1303 ),
1304 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1305 Kernel::Avx2 => bandpass_avx2(
1306 &hp[bp_start..],
1307 period,
1308 alpha,
1309 beta,
1310 &mut bp_r[bp_start..],
1311 ),
1312 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1313 Kernel::Avx512 => bandpass_avx512(
1314 &hp[bp_start..],
1315 period,
1316 alpha,
1317 beta,
1318 &mut bp_r[bp_start..],
1319 ),
1320 _ => unreachable!(),
1321 }
1322 }
1323
1324 for v in &mut bp_r[..warm_bp] {
1325 *v = f64::NAN;
1326 }
1327
1328 let k = 0.991;
1329 let mut peak = 0.0f64;
1330 for i in warm_bp..cols {
1331 peak *= k;
1332 let v = bp_r[i];
1333 let av = v.abs();
1334 if av > peak {
1335 peak = av;
1336 }
1337 bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1338 }
1339
1340 for v in trg_r.iter_mut() {
1341 *v = f64::NAN;
1342 }
1343 if warm_bp < cols {
1344 let mut trig_params = HighPassParams::default();
1345 trig_params.period = Some(trig_p);
1346 let trig_inp = HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params);
1347 crate::indicators::moving_averages::highpass::highpass_into_slice(
1348 &mut trg_r[warm_bp..],
1349 &trig_inp,
1350 ksel,
1351 )?;
1352 }
1353
1354 let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1355 let warm_sig = warm_bp.max(first_tr);
1356 for v in &mut bpn_r[..warm_bp] {
1357 if !v.is_nan() {
1358 *v = f64::NAN;
1359 }
1360 }
1361 for v in &mut sig_r[..warm_sig] {
1362 *v = f64::NAN;
1363 }
1364 for i in warm_sig..cols {
1365 let bn = bpn_r[i];
1366 let tr = trg_r[i];
1367 sig_r[i] = if bn < tr {
1368 1.0
1369 } else if bn > tr {
1370 -1.0
1371 } else {
1372 0.0
1373 };
1374 }
1375
1376 Ok(())
1377 })?;
1378 }
1379 #[cfg(target_arch = "wasm32")]
1380 {
1381 for row in 0..rows {
1382 let m = &metas[row];
1383 let period = m.period;
1384 let bandwidth = m.bandwidth;
1385
1386 let bp_r = &mut bp_out[row * cols..(row + 1) * cols];
1387 let bpn_r = &mut bpn_out[row * cols..(row + 1) * cols];
1388 let sig_r = &mut sig_out[row * cols..(row + 1) * cols];
1389 let trg_r = &mut trg_out[row * cols..(row + 1) * cols];
1390 let trig_p = m.trig_p;
1391 let ksel = m.ksel;
1392 let hp = hp_cache.get(&m.hp_p).expect("hp cache missing");
1393
1394 let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1395 let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1396 let bp_start = warm_bp.saturating_sub(2);
1397
1398 let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1399 let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1400 let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1401
1402 unsafe {
1403 match ksel {
1404 Kernel::Scalar => bandpass_scalar(
1405 &hp[bp_start..],
1406 period,
1407 alpha,
1408 beta,
1409 &mut bp_r[bp_start..],
1410 ),
1411 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1412 Kernel::Avx2 => bandpass_avx2(
1413 &hp[bp_start..],
1414 period,
1415 alpha,
1416 beta,
1417 &mut bp_r[bp_start..],
1418 ),
1419 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1420 Kernel::Avx512 => bandpass_avx512(
1421 &hp[bp_start..],
1422 period,
1423 alpha,
1424 beta,
1425 &mut bp_r[bp_start..],
1426 ),
1427 _ => unreachable!(),
1428 }
1429 }
1430
1431 for v in &mut bp_r[..warm_bp] {
1432 *v = f64::NAN;
1433 }
1434
1435 let k = 0.991;
1436 let mut peak = 0.0f64;
1437 for i in warm_bp..cols {
1438 peak *= k;
1439 let v = bp_r[i];
1440 let av = v.abs();
1441 if av > peak {
1442 peak = av;
1443 }
1444 bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1445 }
1446
1447 for v in trg_r.iter_mut() {
1448 *v = f64::NAN;
1449 }
1450 if warm_bp < cols {
1451 let mut trig_params = HighPassParams::default();
1452 trig_params.period = Some(trig_p);
1453 let trig_inp = HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params);
1454 crate::indicators::moving_averages::highpass::highpass_into_slice(
1455 &mut trg_r[warm_bp..],
1456 &trig_inp,
1457 ksel,
1458 )?;
1459 }
1460
1461 let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1462 let warm_sig = warm_bp.max(first_tr);
1463 for v in &mut bpn_r[..warm_bp] {
1464 if !v.is_nan() {
1465 *v = f64::NAN;
1466 }
1467 }
1468 for v in &mut sig_r[..warm_sig] {
1469 *v = f64::NAN;
1470 }
1471 for i in warm_sig..cols {
1472 let bn = bpn_r[i];
1473 let tr = trg_r[i];
1474 sig_r[i] = if bn < tr {
1475 1.0
1476 } else if bn > tr {
1477 -1.0
1478 } else {
1479 0.0
1480 };
1481 }
1482 }
1483 }
1484 } else {
1485 for row in 0..rows {
1486 let p = combos[row];
1487 let period = p.period.unwrap();
1488 let bandwidth = p.bandwidth.unwrap();
1489
1490 let bp_r = &mut bp_out[row * cols..(row + 1) * cols];
1491 let bpn_r = &mut bpn_out[row * cols..(row + 1) * cols];
1492 let sig_r = &mut sig_out[row * cols..(row + 1) * cols];
1493 let trg_r = &mut trg_out[row * cols..(row + 1) * cols];
1494
1495 let (_d, _len, _per, _bw, hp_p, trig_p, ksel) =
1496 bandpass_prepare(&BandPassInput::from_slice(data, p), simd)?;
1497
1498 let mut hp_params = HighPassParams::default();
1499 hp_params.period = Some(hp_p);
1500 let hp = highpass(&HighPassInput::from_slice(data, hp_params))?.values;
1501
1502 let first_valid_hp = hp.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1503 let warm_bp = first_valid_hp.saturating_add(2).max(2).min(cols);
1504 let bp_start = warm_bp.saturating_sub(2);
1505
1506 let beta = (2.0 * std::f64::consts::PI / period as f64).cos();
1507 let gamma = (2.0 * std::f64::consts::PI * bandwidth / period as f64).cos();
1508 let alpha = 1.0 / gamma - ((1.0 / (gamma * gamma)) - 1.0).sqrt();
1509
1510 unsafe {
1511 match ksel {
1512 Kernel::Scalar => {
1513 bandpass_scalar(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1514 }
1515 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1516 Kernel::Avx2 => {
1517 bandpass_avx2(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1518 }
1519 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1520 Kernel::Avx512 => {
1521 bandpass_avx512(&hp[bp_start..], period, alpha, beta, &mut bp_r[bp_start..])
1522 }
1523 _ => unreachable!(),
1524 }
1525 }
1526
1527 for v in &mut bp_r[..warm_bp] {
1528 *v = f64::NAN;
1529 }
1530
1531 let k = 0.991;
1532 let mut peak = 0.0f64;
1533 for i in warm_bp..cols {
1534 peak *= k;
1535 let v = bp_r[i];
1536 let av = v.abs();
1537 if av > peak {
1538 peak = av;
1539 }
1540 bpn_r[i] = if peak != 0.0 { v / peak } else { 0.0 };
1541 }
1542
1543 for v in trg_r.iter_mut() {
1544 *v = f64::NAN;
1545 }
1546 if warm_bp < cols {
1547 let mut trig_params = HighPassParams::default();
1548 trig_params.period = Some(trig_p);
1549 let trig_vec =
1550 highpass(&HighPassInput::from_slice(&bpn_r[warm_bp..], trig_params))?.values;
1551 trg_r[warm_bp..].copy_from_slice(&trig_vec);
1552 }
1553
1554 let first_tr = trg_r.iter().position(|x| !x.is_nan()).unwrap_or(cols);
1555 let warm_sig = warm_bp.max(first_tr);
1556 for v in &mut bpn_r[..warm_bp] {
1557 if !v.is_nan() {
1558 *v = f64::NAN;
1559 }
1560 }
1561 for v in &mut sig_r[..warm_sig] {
1562 *v = f64::NAN;
1563 }
1564 for i in warm_sig..cols {
1565 let bn = bpn_r[i];
1566 let tr = trg_r[i];
1567 sig_r[i] = if bn < tr {
1568 1.0
1569 } else if bn > tr {
1570 -1.0
1571 } else {
1572 0.0
1573 };
1574 }
1575 }
1576 }
1577
1578 Ok(())
1579}
1580
1581#[inline(always)]
1582pub fn bandpass_row_scalar(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1583 bandpass_scalar(hp, period, alpha, beta, out)
1584}
1585
1586#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1587#[inline(always)]
1588pub fn bandpass_row_avx2(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1589 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1590}
1591
1592#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1593#[inline(always)]
1594pub fn bandpass_row_avx512(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1595 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1596}
1597
1598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1599#[inline(always)]
1600pub fn bandpass_row_avx512_short(
1601 hp: &[f64],
1602 period: usize,
1603 alpha: f64,
1604 beta: f64,
1605 out: &mut [f64],
1606) {
1607 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1608}
1609
1610#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1611#[inline(always)]
1612pub fn bandpass_row_avx512_long(hp: &[f64], period: usize, alpha: f64, beta: f64, out: &mut [f64]) {
1613 unsafe { bandpass_scalar_unchecked(hp, alpha, beta, out) }
1614}
1615
1616#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1617#[inline(always)]
1618unsafe fn bandpass_scalar_unchecked(hp: &[f64], alpha: f64, beta: f64, out: &mut [f64]) {
1619 debug_assert!(out.len() >= hp.len());
1620 let len = hp.len();
1621 if len == 0 {
1622 return;
1623 }
1624
1625 let out_ptr = out.as_mut_ptr();
1626 let hp_ptr = hp.as_ptr();
1627
1628 *out_ptr.add(0) = *hp_ptr.add(0);
1629 if len == 1 {
1630 return;
1631 }
1632 *out_ptr.add(1) = *hp_ptr.add(1);
1633 if len == 2 {
1634 return;
1635 }
1636
1637 let a = 0.5 * (1.0 - alpha);
1638 let c = beta * (1.0 + alpha);
1639 let d = -alpha;
1640
1641 let mut y_im2 = *out_ptr.add(0);
1642 let mut y_im1 = *out_ptr.add(1);
1643
1644 let mut i = 2usize;
1645 while i + 3 < len {
1646 let delta0 = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1647 let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
1648 *out_ptr.add(i) = y0;
1649
1650 let delta1 = *hp_ptr.add(i + 1) - *hp_ptr.add(i - 1);
1651 let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
1652 *out_ptr.add(i + 1) = y1;
1653
1654 let delta2 = *hp_ptr.add(i + 2) - *hp_ptr.add(i);
1655 let y2 = d.mul_add(y0, c.mul_add(y1, a * delta2));
1656 *out_ptr.add(i + 2) = y2;
1657
1658 let delta3 = *hp_ptr.add(i + 3) - *hp_ptr.add(i + 1);
1659 let y3 = d.mul_add(y1, c.mul_add(y2, a * delta3));
1660 *out_ptr.add(i + 3) = y3;
1661
1662 y_im2 = y2;
1663 y_im1 = y3;
1664 i += 4;
1665 }
1666
1667 while i + 1 < len {
1668 let delta0 = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1669 let y0 = d.mul_add(y_im2, c.mul_add(y_im1, a * delta0));
1670 *out_ptr.add(i) = y0;
1671
1672 let delta1 = *hp_ptr.add(i + 1) - *hp_ptr.add(i - 1);
1673 let y1 = d.mul_add(y_im1, c.mul_add(y0, a * delta1));
1674 *out_ptr.add(i + 1) = y1;
1675
1676 y_im2 = y0;
1677 y_im1 = y1;
1678 i += 2;
1679 }
1680
1681 if i < len {
1682 let delta = *hp_ptr.add(i) - *hp_ptr.add(i - 2);
1683 let y = d.mul_add(y_im2, c.mul_add(y_im1, a * delta));
1684 *out_ptr.add(i) = y;
1685 }
1686}
1687
1688#[inline(always)]
1689fn expand_grid_for_bandpass(r: &BandPassBatchRange) -> Vec<BandPassParams> {
1690 expand_grid(r)
1691}
1692
1693#[cfg(test)]
1694mod tests {
1695 use super::*;
1696 use crate::skip_if_unsupported;
1697 use crate::utilities::data_loader::read_candles_from_csv;
1698
1699 #[test]
1700 fn test_bandpass_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1701 let len = 512usize;
1702 let mut data = Vec::with_capacity(len);
1703 for i in 0..len {
1704 let t = i as f64 * 0.07;
1705
1706 data.push((t).sin() * 0.8 + (0.33 * t).sin() * 0.2 + 0.001 * i as f64);
1707 }
1708
1709 let input = BandPassInput::from_slice(&data, BandPassParams::default());
1710
1711 let base = bandpass(&input)?;
1712
1713 let mut bp = vec![0.0; len];
1714 let mut bpn = vec![0.0; len];
1715 let mut sig = vec![0.0; len];
1716 let mut trg = vec![0.0; len];
1717
1718 #[allow(unused_mut)]
1719 let mut _ok = bandpass_into(&input, &mut bp, &mut bpn, &mut sig, &mut trg)?;
1720
1721 assert_eq!(bp.len(), base.bp.len());
1722 assert_eq!(bpn.len(), base.bp_normalized.len());
1723 assert_eq!(sig.len(), base.signal.len());
1724 assert_eq!(trg.len(), base.trigger.len());
1725
1726 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1727 (a.is_nan() && b.is_nan()) || (a == b)
1728 }
1729
1730 for i in 0..len {
1731 assert!(
1732 eq_or_both_nan(bp[i], base.bp[i]),
1733 "bp mismatch at {}: {} vs {}",
1734 i,
1735 bp[i],
1736 base.bp[i]
1737 );
1738 assert!(
1739 eq_or_both_nan(bpn[i], base.bp_normalized[i]),
1740 "bpn mismatch at {}: {} vs {}",
1741 i,
1742 bpn[i],
1743 base.bp_normalized[i]
1744 );
1745 assert!(
1746 eq_or_both_nan(sig[i], base.signal[i]),
1747 "sig mismatch at {}: {} vs {}",
1748 i,
1749 sig[i],
1750 base.signal[i]
1751 );
1752 assert!(
1753 eq_or_both_nan(trg[i], base.trigger[i]),
1754 "trg mismatch at {}: {} vs {}",
1755 i,
1756 trg[i],
1757 base.trigger[i]
1758 );
1759 }
1760
1761 Ok(())
1762 }
1763
1764 fn check_bandpass_partial_params(
1765 test_name: &str,
1766 kernel: Kernel,
1767 ) -> Result<(), Box<dyn std::error::Error>> {
1768 skip_if_unsupported!(kernel, test_name);
1769 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1770 let candles = read_candles_from_csv(file_path)?;
1771 let default_params = BandPassParams::default();
1772 let input = BandPassInput::from_candles(&candles, "close", default_params);
1773 let output = bandpass_with_kernel(&input, kernel)?;
1774 assert_eq!(output.bp.len(), candles.close.len());
1775 Ok(())
1776 }
1777 fn check_bandpass_accuracy(
1778 test_name: &str,
1779 kernel: Kernel,
1780 ) -> Result<(), Box<dyn std::error::Error>> {
1781 skip_if_unsupported!(kernel, test_name);
1782 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1783 let candles = read_candles_from_csv(file_path)?;
1784 let input = BandPassInput::with_default_candles(&candles);
1785 let result = bandpass_with_kernel(&input, kernel)?;
1786 let expected_bp_last_five = [
1787 -236.23678021132827,
1788 -247.4846395608195,
1789 -242.3788746078502,
1790 -212.89589193350128,
1791 -179.97293838509464,
1792 ];
1793 let expected_bp_normalized_last_five = [
1794 -0.4399672555578846,
1795 -0.4651011734720517,
1796 -0.4596426251402882,
1797 -0.40739824942488945,
1798 -0.3475245023284841,
1799 ];
1800 let expected_signal_last_five = [-1.0, 1.0, 1.0, 1.0, 1.0];
1801 let expected_trigger_last_five = [
1802 -0.4746908356434579,
1803 -0.4353877348116954,
1804 -0.3727126131420441,
1805 -0.2746336628365846,
1806 -0.18240018384226137,
1807 ];
1808 let start = result.bp.len().saturating_sub(5);
1809 assert!(result.bp.len() >= 5);
1810 assert!(result.bp_normalized.len() >= 5);
1811 assert!(result.signal.len() >= 5);
1812 assert!(result.trigger.len() >= 5);
1813
1814 let first_bp = result
1815 .bp
1816 .iter()
1817 .position(|x| !x.is_nan())
1818 .unwrap_or(result.bp.len());
1819 let first_bpn = result
1820 .bp_normalized
1821 .iter()
1822 .position(|x| !x.is_nan())
1823 .unwrap_or(result.bp_normalized.len());
1824 let first_sig = result
1825 .signal
1826 .iter()
1827 .position(|x| !x.is_nan())
1828 .unwrap_or(result.signal.len());
1829 let first_trig = result
1830 .trigger
1831 .iter()
1832 .position(|x| !x.is_nan())
1833 .unwrap_or(result.trigger.len());
1834
1835 if first_sig >= start {
1836 panic!(
1837 "Signal values are all NaN in the last 5 indices. Debug info:\n\
1838 first_bp: {}, first_bpn: {}, first_sig: {}, first_trig: {}, start: {}, len: {}\n\
1839 bp[start]: {:?}, bpn[start]: {:?}, trig[start]: {:?}",
1840 first_bp,
1841 first_bpn,
1842 first_sig,
1843 first_trig,
1844 start,
1845 result.bp.len(),
1846 result.bp[start],
1847 result.bp_normalized[start],
1848 result.trigger[start]
1849 );
1850 }
1851 for (i, &value) in result.bp[start..].iter().enumerate() {
1852 assert!(
1853 (value - expected_bp_last_five[i]).abs() < 1e-1,
1854 "BP value mismatch at index {}: expected {}, got {}",
1855 i,
1856 expected_bp_last_five[i],
1857 value
1858 );
1859 }
1860 for (i, &value) in result.bp_normalized[start..].iter().enumerate() {
1861 assert!(
1862 (value - expected_bp_normalized_last_five[i]).abs() < 1e-1,
1863 "BP Normalized value mismatch at index {}: expected {}, got {}",
1864 i,
1865 expected_bp_normalized_last_five[i],
1866 value
1867 );
1868 }
1869 for (i, &value) in result.signal[start..].iter().enumerate() {
1870 if value.is_nan() {
1871 continue;
1872 }
1873 assert!(
1874 (value - expected_signal_last_five[i]).abs() < 1e-1,
1875 "Signal value mismatch at index {}: expected {}, got {}",
1876 i,
1877 expected_signal_last_five[i],
1878 value
1879 );
1880 }
1881 for (i, &value) in result.trigger[start..].iter().enumerate() {
1882 assert!(
1883 (value - expected_trigger_last_five[i]).abs() < 1e-1,
1884 "Trigger value mismatch at index {}: expected {}, got {}",
1885 i,
1886 expected_trigger_last_five[i],
1887 value
1888 );
1889 }
1890
1891 for val in &result.bp {
1892 if !val.is_nan() {
1893 assert!(val.is_finite(), "bp value not finite: {}", val);
1894 }
1895 }
1896 for val in &result.bp_normalized {
1897 if !val.is_nan() {
1898 assert!(val.is_finite(), "bp_normalized value not finite: {}", val);
1899 }
1900 }
1901 for val in &result.signal {
1902 if !val.is_nan() {
1903 assert!(val.is_finite(), "signal value not finite: {}", val);
1904 }
1905 }
1906 for val in &result.trigger {
1907 if !val.is_nan() {
1908 assert!(val.is_finite(), "trigger value not finite: {}", val);
1909 }
1910 }
1911 Ok(())
1912 }
1913 fn check_bandpass_default_candles(
1914 test_name: &str,
1915 kernel: Kernel,
1916 ) -> Result<(), Box<dyn std::error::Error>> {
1917 skip_if_unsupported!(kernel, test_name);
1918 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1919 let candles = read_candles_from_csv(file_path)?;
1920 let input = BandPassInput::with_default_candles(&candles);
1921 match input.data {
1922 BandPassData::Candles { source, .. } => assert_eq!(source, "close"),
1923 _ => panic!("Expected BandPassData::Candles"),
1924 }
1925 let output = bandpass_with_kernel(&input, kernel)?;
1926 assert_eq!(output.bp.len(), candles.close.len());
1927 Ok(())
1928 }
1929 fn check_bandpass_zero_period(
1930 test_name: &str,
1931 kernel: Kernel,
1932 ) -> Result<(), Box<dyn std::error::Error>> {
1933 skip_if_unsupported!(kernel, test_name);
1934 let input_data = [10.0, 20.0, 30.0];
1935 let params = BandPassParams {
1936 period: Some(0),
1937 bandwidth: Some(0.3),
1938 };
1939 let input = BandPassInput::from_slice(&input_data, params);
1940 let res = bandpass_with_kernel(&input, kernel);
1941 assert!(res.is_err());
1942 Ok(())
1943 }
1944 fn check_bandpass_period_exceeds_length(
1945 test_name: &str,
1946 kernel: Kernel,
1947 ) -> Result<(), Box<dyn std::error::Error>> {
1948 skip_if_unsupported!(kernel, test_name);
1949 let data_small = [10.0, 20.0, 30.0];
1950 let params = BandPassParams {
1951 period: Some(10),
1952 bandwidth: Some(0.3),
1953 };
1954 let input = BandPassInput::from_slice(&data_small, params);
1955 let res = bandpass_with_kernel(&input, kernel);
1956 assert!(res.is_err());
1957 Ok(())
1958 }
1959 fn check_bandpass_very_small_dataset(
1960 test_name: &str,
1961 kernel: Kernel,
1962 ) -> Result<(), Box<dyn std::error::Error>> {
1963 skip_if_unsupported!(kernel, test_name);
1964 let single_point = [42.0];
1965 let params = BandPassParams {
1966 period: Some(20),
1967 bandwidth: Some(0.3),
1968 };
1969 let input = BandPassInput::from_slice(&single_point, params);
1970 let res = bandpass_with_kernel(&input, kernel);
1971 assert!(res.is_err());
1972 Ok(())
1973 }
1974 fn check_bandpass_reinput(
1975 test_name: &str,
1976 kernel: Kernel,
1977 ) -> Result<(), Box<dyn std::error::Error>> {
1978 skip_if_unsupported!(kernel, test_name);
1979 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1980 let candles = read_candles_from_csv(file_path)?;
1981 let first_params = BandPassParams {
1982 period: Some(20),
1983 bandwidth: Some(0.3),
1984 };
1985 let first_input = BandPassInput::from_candles(&candles, "close", first_params);
1986 let first_result = bandpass_with_kernel(&first_input, kernel)?;
1987 let second_params = BandPassParams {
1988 period: Some(30),
1989 bandwidth: Some(0.5),
1990 };
1991 let second_input = BandPassInput::from_slice(&first_result.bp, second_params);
1992 let second_result = bandpass_with_kernel(&second_input, kernel)?;
1993 assert_eq!(second_result.bp.len(), first_result.bp.len());
1994 Ok(())
1995 }
1996 fn check_bandpass_nan_handling(
1997 test_name: &str,
1998 kernel: Kernel,
1999 ) -> Result<(), Box<dyn std::error::Error>> {
2000 skip_if_unsupported!(kernel, test_name);
2001 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2002 let candles = read_candles_from_csv(file_path)?;
2003 let input = BandPassInput::from_candles(
2004 &candles,
2005 "close",
2006 BandPassParams {
2007 period: Some(20),
2008 bandwidth: Some(0.3),
2009 },
2010 );
2011 let res = bandpass_with_kernel(&input, kernel)?;
2012 assert_eq!(res.bp.len(), candles.close.len());
2013
2014 let first_bp = res
2015 .bp
2016 .iter()
2017 .position(|x| !x.is_nan())
2018 .unwrap_or(res.bp.len());
2019 let first_bpn = res
2020 .bp_normalized
2021 .iter()
2022 .position(|x| !x.is_nan())
2023 .unwrap_or(res.bp_normalized.len());
2024 let first_sig = res
2025 .signal
2026 .iter()
2027 .position(|x| !x.is_nan())
2028 .unwrap_or(res.signal.len());
2029 let first_trig = res
2030 .trigger
2031 .iter()
2032 .position(|x| !x.is_nan())
2033 .unwrap_or(res.trigger.len());
2034
2035 let warmup_end = first_bp.max(first_bpn).max(first_sig).max(first_trig);
2036 if warmup_end < res.bp.len() {
2037 for i in warmup_end..res.bp.len() {
2038 assert!(!res.bp[i].is_nan(), "bp[{}] is NaN after warmup", i);
2039 assert!(
2040 !res.bp_normalized[i].is_nan(),
2041 "bp_normalized[{}] is NaN after warmup",
2042 i
2043 );
2044 assert!(!res.signal[i].is_nan(), "signal[{}] is NaN after warmup", i);
2045 assert!(
2046 !res.trigger[i].is_nan(),
2047 "trigger[{}] is NaN after warmup",
2048 i
2049 );
2050 }
2051 }
2052 Ok(())
2053 }
2054
2055 macro_rules! generate_all_bandpass_tests {
2056 ($($test_fn:ident),*) => {
2057 paste::paste! {
2058 $(
2059 #[test]
2060 fn [<$test_fn _scalar_f64>]() {
2061 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2062 }
2063 )*
2064 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2065 $(
2066 #[test]
2067 fn [<$test_fn _avx2_f64>]() {
2068 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2069 }
2070 #[test]
2071 fn [<$test_fn _avx512_f64>]() {
2072 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2073 }
2074 )*
2075 }
2076 }
2077 }
2078
2079 #[cfg(debug_assertions)]
2080 fn check_bandpass_no_poison(
2081 test_name: &str,
2082 kernel: Kernel,
2083 ) -> Result<(), Box<dyn std::error::Error>> {
2084 skip_if_unsupported!(kernel, test_name);
2085
2086 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2087 let candles = read_candles_from_csv(file_path)?;
2088
2089 let test_params = vec![
2090 BandPassParams {
2091 period: Some(10),
2092 bandwidth: Some(0.2),
2093 },
2094 BandPassParams {
2095 period: Some(20),
2096 bandwidth: Some(0.3),
2097 },
2098 BandPassParams {
2099 period: Some(30),
2100 bandwidth: Some(0.4),
2101 },
2102 BandPassParams {
2103 period: Some(50),
2104 bandwidth: Some(0.5),
2105 },
2106 BandPassParams {
2107 period: Some(5),
2108 bandwidth: Some(0.1),
2109 },
2110 BandPassParams {
2111 period: Some(100),
2112 bandwidth: Some(0.8),
2113 },
2114 ];
2115
2116 for params in test_params {
2117 let input = BandPassInput::from_candles(&candles, "close", params.clone());
2118 let output = bandpass_with_kernel(&input, kernel)?;
2119
2120 for (i, &val) in output.bp.iter().enumerate() {
2121 if val.is_nan() {
2122 continue;
2123 }
2124
2125 let bits = val.to_bits();
2126
2127 if bits == 0x11111111_11111111 {
2128 panic!(
2129 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2130 test_name, val, bits, i, params
2131 );
2132 }
2133
2134 if bits == 0x22222222_22222222 {
2135 panic!(
2136 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2137 test_name, val, bits, i, params
2138 );
2139 }
2140
2141 if bits == 0x33333333_33333333 {
2142 panic!(
2143 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp at index {} with params {:?}",
2144 test_name, val, bits, i, params
2145 );
2146 }
2147 }
2148
2149 for (i, &val) in output.bp_normalized.iter().enumerate() {
2150 if val.is_nan() {
2151 continue;
2152 }
2153
2154 let bits = val.to_bits();
2155
2156 if bits == 0x11111111_11111111 {
2157 panic!(
2158 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2159 test_name, val, bits, i, params
2160 );
2161 }
2162
2163 if bits == 0x22222222_22222222 {
2164 panic!(
2165 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2166 test_name, val, bits, i, params
2167 );
2168 }
2169
2170 if bits == 0x33333333_33333333 {
2171 panic!(
2172 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp_normalized at index {} with params {:?}",
2173 test_name, val, bits, i, params
2174 );
2175 }
2176 }
2177
2178 for (i, &val) in output.signal.iter().enumerate() {
2179 if val.is_nan() {
2180 continue;
2181 }
2182
2183 let bits = val.to_bits();
2184
2185 if bits == 0x11111111_11111111 {
2186 panic!(
2187 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2188 test_name, val, bits, i, params
2189 );
2190 }
2191
2192 if bits == 0x22222222_22222222 {
2193 panic!(
2194 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2195 test_name, val, bits, i, params
2196 );
2197 }
2198
2199 if bits == 0x33333333_33333333 {
2200 panic!(
2201 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in signal at index {} with params {:?}",
2202 test_name, val, bits, i, params
2203 );
2204 }
2205 }
2206
2207 for (i, &val) in output.trigger.iter().enumerate() {
2208 if val.is_nan() {
2209 continue;
2210 }
2211
2212 let bits = val.to_bits();
2213
2214 if bits == 0x11111111_11111111 {
2215 panic!(
2216 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2217 test_name, val, bits, i, params
2218 );
2219 }
2220
2221 if bits == 0x22222222_22222222 {
2222 panic!(
2223 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2224 test_name, val, bits, i, params
2225 );
2226 }
2227
2228 if bits == 0x33333333_33333333 {
2229 panic!(
2230 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in trigger at index {} with params {:?}",
2231 test_name, val, bits, i, params
2232 );
2233 }
2234 }
2235 }
2236
2237 Ok(())
2238 }
2239
2240 #[cfg(not(debug_assertions))]
2241 fn check_bandpass_no_poison(
2242 _test_name: &str,
2243 _kernel: Kernel,
2244 ) -> Result<(), Box<dyn std::error::Error>> {
2245 Ok(())
2246 }
2247
2248 fn check_bandpass_streaming(
2249 test_name: &str,
2250 kernel: Kernel,
2251 ) -> Result<(), Box<dyn std::error::Error>> {
2252 skip_if_unsupported!(kernel, test_name);
2253
2254 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2255 let candles = read_candles_from_csv(file_path)?;
2256
2257 let period = 20;
2258 let bandwidth = 0.3;
2259
2260 let input = BandPassInput::from_candles(
2261 &candles,
2262 "close",
2263 BandPassParams {
2264 period: Some(period),
2265 bandwidth: Some(bandwidth),
2266 },
2267 );
2268 let batch_output = bandpass_with_kernel(&input, kernel)?;
2269
2270 let mut stream = BandPassStream::try_new(BandPassParams {
2271 period: Some(period),
2272 bandwidth: Some(bandwidth),
2273 })?;
2274
2275 let mut stream_values = Vec::with_capacity(candles.close.len());
2276 for &price in &candles.close {
2277 let bp_val = stream.update(price);
2278 stream_values.push(bp_val);
2279 }
2280
2281 assert_eq!(batch_output.bp.len(), stream_values.len());
2282 for (i, (&b, &s)) in batch_output.bp.iter().zip(stream_values.iter()).enumerate() {
2283 if b.is_nan() {
2284 continue;
2285 }
2286 let diff = (b - s).abs();
2287 let tol = 1e-10 * b.abs().max(1.0);
2288 assert!(
2289 diff < tol,
2290 "[{}] Streaming vs batch mismatch at index {}: batch={}, stream={}, diff={}",
2291 test_name,
2292 i,
2293 b,
2294 s,
2295 diff
2296 );
2297 }
2298
2299 Ok(())
2300 }
2301
2302 #[cfg(feature = "proptest")]
2303 #[allow(clippy::float_cmp)]
2304 fn check_bandpass_property(
2305 test_name: &str,
2306 kernel: Kernel,
2307 ) -> Result<(), Box<dyn std::error::Error>> {
2308 use proptest::prelude::*;
2309 skip_if_unsupported!(kernel, test_name);
2310
2311 let strat = (2usize..=50).prop_flat_map(|period| {
2312 (
2313 prop::collection::vec(
2314 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2315 period..400,
2316 ),
2317 Just(period),
2318 (0.1f64..=0.9f64),
2319 )
2320 });
2321
2322 proptest::test_runner::TestRunner::default().run(&strat, |(data, period, bandwidth)| {
2323 let params = BandPassParams {
2324 period: Some(period),
2325 bandwidth: Some(bandwidth),
2326 };
2327 let input = BandPassInput::from_slice(&data, params);
2328
2329 let result = bandpass_with_kernel(&input, kernel)?;
2330
2331 let ref_result = bandpass_with_kernel(&input, Kernel::Scalar)?;
2332
2333 prop_assert_eq!(result.bp.len(), data.len());
2334 prop_assert_eq!(result.bp_normalized.len(), data.len());
2335 prop_assert_eq!(result.signal.len(), data.len());
2336 prop_assert_eq!(result.trigger.len(), data.len());
2337
2338 let hp_period = ((4.0 * period as f64) / bandwidth).round() as usize;
2339 let expected_warmup = hp_period.saturating_sub(1).max(2);
2340
2341 if data.len() >= expected_warmup {
2342 for i in 0..(expected_warmup.saturating_sub(1)).min(data.len()) {
2343 prop_assert!(
2344 result.bp[i].is_nan(),
2345 "bp[{}] should be NaN during warmup but got {}",
2346 i,
2347 result.bp[i]
2348 );
2349 }
2350 }
2351
2352 for i in 0..data.len() {
2353 if !result.bp[i].is_nan() {
2354 prop_assert!(
2355 result.bp[i].is_finite(),
2356 "bp[{}] not finite: {}",
2357 i,
2358 result.bp[i]
2359 );
2360 }
2361 if !result.bp_normalized[i].is_nan() {
2362 prop_assert!(
2363 result.bp_normalized[i].is_finite(),
2364 "bp_normalized[{}] not finite: {}",
2365 i,
2366 result.bp_normalized[i]
2367 );
2368 }
2369 if !result.signal[i].is_nan() {
2370 prop_assert!(
2371 result.signal[i].is_finite(),
2372 "signal[{}] not finite: {}",
2373 i,
2374 result.signal[i]
2375 );
2376 }
2377 if !result.trigger[i].is_nan() {
2378 prop_assert!(
2379 result.trigger[i].is_finite(),
2380 "trigger[{}] not finite: {}",
2381 i,
2382 result.trigger[i]
2383 );
2384 }
2385 }
2386
2387 for i in 0..data.len() {
2388 let sig = result.signal[i];
2389 if !sig.is_nan() {
2390 prop_assert!(
2391 sig == -1.0 || sig == 0.0 || sig == 1.0,
2392 "signal[{}] = {} not in {{-1, 0, 1}}",
2393 i,
2394 sig
2395 );
2396 }
2397 }
2398
2399 for i in 0..data.len() {
2400 let norm = result.bp_normalized[i];
2401 if !norm.is_nan() {
2402 prop_assert!(
2403 norm >= -1.0 - 1e-9 && norm <= 1.0 + 1e-9,
2404 "bp_normalized[{}] = {} not in [-1, 1]",
2405 i,
2406 norm
2407 );
2408 }
2409 }
2410
2411 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9)
2412 && data.len() > expected_warmup + 20
2413 {
2414 let check_start = (expected_warmup + 10).min(data.len() - 1);
2415 for i in check_start..data.len() {
2416 if !result.bp[i].is_nan() {
2417 let tolerance = 1e-6 * data[0].abs().max(1.0);
2418 prop_assert!(
2419 result.bp[i].abs() < tolerance,
2420 "bp[{}] = {} not near 0 for constant data (tolerance={})",
2421 i,
2422 result.bp[i],
2423 tolerance
2424 );
2425 }
2426 }
2427 }
2428
2429 for i in 0..data.len() {
2430 let bn = result.bp_normalized[i];
2431 let tr = result.trigger[i];
2432 let sig = result.signal[i];
2433
2434 if !bn.is_nan() && !tr.is_nan() && !sig.is_nan() {
2435 if (bn - tr).abs() < 1e-12 {
2436 prop_assert_eq!(
2437 sig,
2438 0.0,
2439 "signal[{}] should be 0 when bp_normalized≈trigger",
2440 i
2441 );
2442 } else if bn < tr {
2443 prop_assert_eq!(
2444 sig,
2445 1.0,
2446 "signal[{}] should be 1 when bp_normalized<trigger ({} < {})",
2447 i,
2448 bn,
2449 tr
2450 );
2451 } else {
2452 prop_assert_eq!(
2453 sig,
2454 -1.0,
2455 "signal[{}] should be -1 when bp_normalized>trigger ({} > {})",
2456 i,
2457 bn,
2458 tr
2459 );
2460 }
2461 }
2462 }
2463
2464 if data.len() > 50 {
2465 let last_quarter_start = (data.len() * 3) / 4;
2466 let mut max_abs_bp = 0.0f64;
2467 let mut max_abs_input = 0.0f64;
2468
2469 for i in last_quarter_start..data.len() {
2470 if !result.bp[i].is_nan() {
2471 max_abs_bp = max_abs_bp.max(result.bp[i].abs());
2472 }
2473 max_abs_input = max_abs_input.max(data[i].abs());
2474 }
2475
2476 if max_abs_input > 0.0 {
2477 let amplification = max_abs_bp / max_abs_input;
2478 prop_assert!(
2479 amplification < 100.0,
2480 "Recursive filter may be unstable: amplification={} at period={}, bandwidth={}",
2481 amplification, period, bandwidth
2482 );
2483 }
2484 }
2485
2486 if period == 2 {
2487 let non_nan_count = result.bp.iter().filter(|&&x| !x.is_nan()).count();
2488 prop_assert!(
2489 non_nan_count >= data.len().saturating_sub(expected_warmup),
2490 "period=2 should produce mostly valid output after warmup"
2491 );
2492 }
2493
2494 for i in 0..data.len() {
2495 let bp = result.bp[i];
2496 let bp_ref = ref_result.bp[i];
2497 let bp_norm = result.bp_normalized[i];
2498 let bp_norm_ref = ref_result.bp_normalized[i];
2499 let sig = result.signal[i];
2500 let sig_ref = ref_result.signal[i];
2501 let trig = result.trigger[i];
2502 let trig_ref = ref_result.trigger[i];
2503
2504 if !bp.is_finite() || !bp_ref.is_finite() {
2505 prop_assert_eq!(
2506 bp.to_bits(),
2507 bp_ref.to_bits(),
2508 "bp finite/NaN mismatch at {}",
2509 i
2510 );
2511 } else {
2512 let ulp_diff = bp.to_bits().abs_diff(bp_ref.to_bits());
2513 prop_assert!(
2514 (bp - bp_ref).abs() <= 1e-9 || ulp_diff <= 4,
2515 "bp mismatch at {}: {} vs {} (ULP={})",
2516 i,
2517 bp,
2518 bp_ref,
2519 ulp_diff
2520 );
2521 }
2522
2523 if !bp_norm.is_finite() || !bp_norm_ref.is_finite() {
2524 prop_assert_eq!(
2525 bp_norm.to_bits(),
2526 bp_norm_ref.to_bits(),
2527 "bp_normalized finite/NaN mismatch at {}",
2528 i
2529 );
2530 } else {
2531 let ulp_diff = bp_norm.to_bits().abs_diff(bp_norm_ref.to_bits());
2532 prop_assert!(
2533 (bp_norm - bp_norm_ref).abs() <= 1e-9 || ulp_diff <= 4,
2534 "bp_normalized mismatch at {}: {} vs {} (ULP={})",
2535 i,
2536 bp_norm,
2537 bp_norm_ref,
2538 ulp_diff
2539 );
2540 }
2541
2542 if !sig.is_nan() && !sig_ref.is_nan() {
2543 prop_assert_eq!(
2544 sig,
2545 sig_ref,
2546 "signal mismatch at {}: {} vs {}",
2547 i,
2548 sig,
2549 sig_ref
2550 );
2551 } else {
2552 prop_assert_eq!(
2553 sig.is_nan(),
2554 sig_ref.is_nan(),
2555 "signal NaN mismatch at {}",
2556 i
2557 );
2558 }
2559
2560 if !trig.is_finite() || !trig_ref.is_finite() {
2561 prop_assert_eq!(
2562 trig.to_bits(),
2563 trig_ref.to_bits(),
2564 "trigger finite/NaN mismatch at {}",
2565 i
2566 );
2567 } else {
2568 let ulp_diff = trig.to_bits().abs_diff(trig_ref.to_bits());
2569 prop_assert!(
2570 (trig - trig_ref).abs() <= 1e-9 || ulp_diff <= 4,
2571 "trigger mismatch at {}: {} vs {} (ULP={})",
2572 i,
2573 trig,
2574 trig_ref,
2575 ulp_diff
2576 );
2577 }
2578 }
2579
2580 Ok(())
2581 })?;
2582
2583 Ok(())
2584 }
2585
2586 #[cfg(not(feature = "proptest"))]
2587 fn check_bandpass_property(
2588 test_name: &str,
2589 kernel: Kernel,
2590 ) -> Result<(), Box<dyn std::error::Error>> {
2591 skip_if_unsupported!(kernel, test_name);
2592 Ok(())
2593 }
2594
2595 generate_all_bandpass_tests!(
2596 check_bandpass_partial_params,
2597 check_bandpass_accuracy,
2598 check_bandpass_default_candles,
2599 check_bandpass_zero_period,
2600 check_bandpass_period_exceeds_length,
2601 check_bandpass_very_small_dataset,
2602 check_bandpass_reinput,
2603 check_bandpass_nan_handling,
2604 check_bandpass_no_poison,
2605 check_bandpass_streaming,
2606 check_bandpass_property
2607 );
2608 fn check_batch_default_row(
2609 test: &str,
2610 kernel: Kernel,
2611 ) -> Result<(), Box<dyn std::error::Error>> {
2612 skip_if_unsupported!(kernel, test);
2613
2614 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2615 let c = read_candles_from_csv(file)?;
2616
2617 let output = BandPassBatchBuilder::new()
2618 .kernel(kernel)
2619 .apply_candles(&c, "close")?;
2620
2621 let def = BandPassParams::default();
2622 let row_idx = output.row_for_params(&def).expect("default row missing");
2623 let (bp_slice, _bpn_slice, _sig_slice, _trg_slice) =
2624 output.row_slices(row_idx).expect("row missing");
2625
2626 assert_eq!(bp_slice.len(), c.close.len());
2627
2628 let expected = [
2629 -236.23678021132827,
2630 -247.4846395608195,
2631 -242.3788746078502,
2632 -212.89589193350128,
2633 -179.97293838509464,
2634 ];
2635 let start = bp_slice.len() - 5;
2636 for (i, &v) in bp_slice[start..].iter().enumerate() {
2637 assert!(
2638 (v - expected[i]).abs() < 1e-1,
2639 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2640 );
2641 }
2642 Ok(())
2643 }
2644
2645 macro_rules! gen_batch_tests {
2646 ($fn_name:ident) => {
2647 paste::paste! {
2648 #[test] fn [<$fn_name _scalar>]() {
2649 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2650 }
2651 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2652 #[test] fn [<$fn_name _avx2>]() {
2653 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2654 }
2655 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2656 #[test] fn [<$fn_name _avx512>]() {
2657 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2658 }
2659 #[test] fn [<$fn_name _auto_detect>]() {
2660 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2661 }
2662 }
2663 };
2664 }
2665
2666 #[cfg(debug_assertions)]
2667 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2668 skip_if_unsupported!(kernel, test);
2669
2670 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2671 let c = read_candles_from_csv(file)?;
2672
2673 let output = BandPassBatchBuilder::new()
2674 .kernel(kernel)
2675 .period_range(5, 50, 5)
2676 .bandwidth_range(0.1, 0.9, 0.1)
2677 .apply_candles(&c, "close")?;
2678
2679 for row_idx in 0..output.rows {
2680 let params = &output.combos[row_idx];
2681 let (bp_row, bpn_row, sig_row, trg_row) = output.row_slices(row_idx).unwrap();
2682
2683 for (col_idx, &val) in bp_row.iter().enumerate() {
2684 if val.is_nan() {
2685 continue;
2686 }
2687
2688 let bits = val.to_bits();
2689
2690 if bits == 0x11111111_11111111 {
2691 panic!(
2692 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2693 test, val, bits, row_idx, col_idx, params
2694 );
2695 }
2696
2697 if bits == 0x22222222_22222222 {
2698 panic!(
2699 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2700 test, val, bits, row_idx, col_idx, params
2701 );
2702 }
2703
2704 if bits == 0x33333333_33333333 {
2705 panic!(
2706 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp at row {} col {} with params {:?}",
2707 test, val, bits, row_idx, col_idx, params
2708 );
2709 }
2710 }
2711
2712 for (col_idx, &val) in bpn_row.iter().enumerate() {
2713 if val.is_nan() {
2714 continue;
2715 }
2716
2717 let bits = val.to_bits();
2718
2719 if bits == 0x11111111_11111111 {
2720 panic!(
2721 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2722 test, val, bits, row_idx, col_idx, params
2723 );
2724 }
2725
2726 if bits == 0x22222222_22222222 {
2727 panic!(
2728 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2729 test, val, bits, row_idx, col_idx, params
2730 );
2731 }
2732
2733 if bits == 0x33333333_33333333 {
2734 panic!(
2735 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in bp_normalized at row {} col {} with params {:?}",
2736 test, val, bits, row_idx, col_idx, params
2737 );
2738 }
2739 }
2740
2741 for (col_idx, &val) in sig_row.iter().enumerate() {
2742 if val.is_nan() {
2743 continue;
2744 }
2745
2746 let bits = val.to_bits();
2747
2748 if bits == 0x11111111_11111111 {
2749 panic!(
2750 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2751 test, val, bits, row_idx, col_idx, params
2752 );
2753 }
2754
2755 if bits == 0x22222222_22222222 {
2756 panic!(
2757 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2758 test, val, bits, row_idx, col_idx, params
2759 );
2760 }
2761
2762 if bits == 0x33333333_33333333 {
2763 panic!(
2764 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in signal at row {} col {} with params {:?}",
2765 test, val, bits, row_idx, col_idx, params
2766 );
2767 }
2768 }
2769
2770 for (col_idx, &val) in trg_row.iter().enumerate() {
2771 if val.is_nan() {
2772 continue;
2773 }
2774
2775 let bits = val.to_bits();
2776
2777 if bits == 0x11111111_11111111 {
2778 panic!(
2779 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2780 test, val, bits, row_idx, col_idx, params
2781 );
2782 }
2783
2784 if bits == 0x22222222_22222222 {
2785 panic!(
2786 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2787 test, val, bits, row_idx, col_idx, params
2788 );
2789 }
2790
2791 if bits == 0x33333333_33333333 {
2792 panic!(
2793 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in trigger at row {} col {} with params {:?}",
2794 test, val, bits, row_idx, col_idx, params
2795 );
2796 }
2797 }
2798 }
2799
2800 Ok(())
2801 }
2802
2803 #[cfg(not(debug_assertions))]
2804 fn check_batch_no_poison(
2805 _test: &str,
2806 _kernel: Kernel,
2807 ) -> Result<(), Box<dyn std::error::Error>> {
2808 Ok(())
2809 }
2810
2811 gen_batch_tests!(check_batch_default_row);
2812 gen_batch_tests!(check_batch_no_poison);
2813}
2814
2815#[cfg(feature = "python")]
2816#[pyfunction(name = "bandpass")]
2817#[pyo3(signature = (data, period=20, bandwidth=0.3, kernel=None))]
2818pub fn bandpass_py<'py>(
2819 py: Python<'py>,
2820 data: numpy::PyReadonlyArray1<'py, f64>,
2821 period: usize,
2822 bandwidth: f64,
2823 kernel: Option<&str>,
2824) -> PyResult<Bound<'py, PyDict>> {
2825 use numpy::PyArrayMethods;
2826 let slice_in = data.as_slice()?;
2827 let kern = validate_kernel(kernel, false)?;
2828 let params = BandPassParams {
2829 period: Some(period),
2830 bandwidth: Some(bandwidth),
2831 };
2832 let inp = BandPassInput::from_slice(slice_in, params);
2833
2834 let out = py
2835 .allow_threads(|| bandpass_with_kernel(&inp, kern))
2836 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2837
2838 let dict = PyDict::new(py);
2839 dict.set_item("bp", out.bp.into_pyarray(py))?;
2840 dict.set_item("bp_normalized", out.bp_normalized.into_pyarray(py))?;
2841 dict.set_item("signal", out.signal.into_pyarray(py))?;
2842 dict.set_item("trigger", out.trigger.into_pyarray(py))?;
2843
2844 Ok(dict.into())
2845}
2846
2847#[cfg(feature = "python")]
2848#[pyclass(name = "BandPassStream")]
2849pub struct BandPassStreamPy {
2850 stream: BandPassStream,
2851}
2852
2853#[cfg(feature = "python")]
2854#[pymethods]
2855impl BandPassStreamPy {
2856 #[new]
2857 fn new(period: usize, bandwidth: f64) -> PyResult<Self> {
2858 let params = BandPassParams {
2859 period: Some(period),
2860 bandwidth: Some(bandwidth),
2861 };
2862 let stream =
2863 BandPassStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2864 Ok(BandPassStreamPy { stream })
2865 }
2866
2867 fn update(&mut self, value: f64) -> f64 {
2868 self.stream.update(value)
2869 }
2870}
2871
2872#[cfg(feature = "python")]
2873#[pyfunction(name = "bandpass_batch")]
2874#[pyo3(signature = (data, period_range, bandwidth_range, kernel=None))]
2875pub fn bandpass_batch_py<'py>(
2876 py: Python<'py>,
2877 data: numpy::PyReadonlyArray1<'py, f64>,
2878 period_range: (usize, usize, usize),
2879 bandwidth_range: (f64, f64, f64),
2880 kernel: Option<&str>,
2881) -> PyResult<Bound<'py, PyDict>> {
2882 use numpy::{PyArray1, PyArrayMethods};
2883 let slice_in = data.as_slice()?;
2884 let sweep = BandPassBatchRange {
2885 period: period_range,
2886 bandwidth: bandwidth_range,
2887 };
2888
2889 let kern = validate_kernel(kernel, true)?;
2890 let output = py
2891 .allow_threads(|| bandpass_batch_with_kernel(slice_in, &sweep, kern))
2892 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2893
2894 let rows = output.rows;
2895 let cols = output.cols;
2896
2897 unsafe {
2898 let total = rows
2899 .checked_mul(cols)
2900 .ok_or_else(|| PyValueError::new_err("bandpass_batch: rows*cols overflow"))?;
2901 let bp_arr = PyArray1::<f64>::new(py, [total], false);
2902 let bpn_arr = PyArray1::<f64>::new(py, [total], false);
2903 let sig_arr = PyArray1::<f64>::new(py, [total], false);
2904 let trg_arr = PyArray1::<f64>::new(py, [total], false);
2905
2906 bp_arr.as_slice_mut()?.copy_from_slice(&output.bp);
2907 bpn_arr
2908 .as_slice_mut()?
2909 .copy_from_slice(&output.bp_normalized);
2910 sig_arr.as_slice_mut()?.copy_from_slice(&output.signal);
2911 trg_arr.as_slice_mut()?.copy_from_slice(&output.trigger);
2912
2913 let d = PyDict::new(py);
2914 d.set_item("bp", bp_arr.reshape((rows, cols))?)?;
2915 d.set_item("bp_normalized", bpn_arr.reshape((rows, cols))?)?;
2916 d.set_item("signal", sig_arr.reshape((rows, cols))?)?;
2917 d.set_item("trigger", trg_arr.reshape((rows, cols))?)?;
2918 d.set_item(
2919 "periods",
2920 output
2921 .combos
2922 .iter()
2923 .map(|p| p.period.unwrap() as u64)
2924 .collect::<Vec<_>>()
2925 .into_pyarray(py),
2926 )?;
2927 d.set_item(
2928 "bandwidths",
2929 output
2930 .combos
2931 .iter()
2932 .map(|p| p.bandwidth.unwrap())
2933 .collect::<Vec<_>>()
2934 .into_pyarray(py),
2935 )?;
2936 Ok(d)
2937 }
2938}
2939
2940#[cfg(all(feature = "python", feature = "cuda"))]
2941#[pyfunction(name = "bandpass_cuda_batch_dev")]
2942#[pyo3(signature = (close_f32, period_range, bandwidth_range, device_id=0))]
2943pub fn bandpass_cuda_batch_dev_py<'py>(
2944 py: Python<'py>,
2945 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2946 period_range: (usize, usize, usize),
2947 bandwidth_range: (f64, f64, f64),
2948 device_id: usize,
2949) -> PyResult<Bound<'py, PyDict>> {
2950 use numpy::IntoPyArray;
2951 if !cuda_available() {
2952 return Err(PyValueError::new_err("CUDA not available"));
2953 }
2954 let slice = close_f32.as_slice()?;
2955 let sweep = BandPassBatchRange {
2956 period: period_range,
2957 bandwidth: bandwidth_range,
2958 };
2959 let (outputs, combos, dev_id, ctx) = py.allow_threads(|| -> PyResult<_> {
2960 let cuda =
2961 CudaBandpass::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2962 let dev_id = cuda.device_id();
2963 let ctx = cuda.context_arc();
2964 let res = cuda
2965 .bandpass_batch_dev(slice, &sweep)
2966 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2967 cuda.stream()
2968 .synchronize()
2969 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2970 Ok((res.outputs, res.combos, dev_id, ctx))
2971 })?;
2972
2973 let d = PyDict::new(py);
2974 d.set_item(
2975 "bp",
2976 Py::new(
2977 py,
2978 BandPassDeviceArrayF32Py {
2979 inner: outputs.first,
2980 _ctx: ctx.clone(),
2981 device_id: dev_id,
2982 stream: 0,
2983 },
2984 )?,
2985 )?;
2986 d.set_item(
2987 "bp_normalized",
2988 Py::new(
2989 py,
2990 BandPassDeviceArrayF32Py {
2991 inner: outputs.second,
2992 _ctx: ctx.clone(),
2993 device_id: dev_id,
2994 stream: 0,
2995 },
2996 )?,
2997 )?;
2998 d.set_item(
2999 "signal",
3000 Py::new(
3001 py,
3002 BandPassDeviceArrayF32Py {
3003 inner: outputs.third,
3004 _ctx: ctx.clone(),
3005 device_id: dev_id,
3006 stream: 0,
3007 },
3008 )?,
3009 )?;
3010 d.set_item(
3011 "trigger",
3012 Py::new(
3013 py,
3014 BandPassDeviceArrayF32Py {
3015 inner: outputs.fourth,
3016 _ctx: ctx,
3017 device_id: dev_id,
3018 stream: 0,
3019 },
3020 )?,
3021 )?;
3022
3023 let periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap()).collect();
3024 let bands: Vec<f64> = combos.iter().map(|p| p.bandwidth.unwrap()).collect();
3025 d.set_item("periods", periods.into_pyarray(py))?;
3026 d.set_item("bandwidths", bands.into_pyarray(py))?;
3027 d.set_item("rows", combos.len())?;
3028 d.set_item("cols", slice.len())?;
3029 Ok(d)
3030}
3031
3032#[cfg(all(feature = "python", feature = "cuda"))]
3033#[pyfunction(name = "bandpass_cuda_many_series_one_param_dev")]
3034#[pyo3(signature = (data_tm_f32, period, bandwidth, device_id=0))]
3035pub fn bandpass_cuda_many_series_one_param_dev_py<'py>(
3036 py: Python<'py>,
3037 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3038 period: usize,
3039 bandwidth: f64,
3040 device_id: usize,
3041) -> PyResult<Bound<'py, PyDict>> {
3042 if !cuda_available() {
3043 return Err(PyValueError::new_err("CUDA not available"));
3044 }
3045 let shape = data_tm_f32.shape();
3046 if shape.len() != 2 {
3047 return Err(PyValueError::new_err("expected 2D array"));
3048 }
3049 let rows = shape[0];
3050 let cols = shape[1];
3051 let flat = data_tm_f32.as_slice()?;
3052 let params = BandPassParams {
3053 period: Some(period),
3054 bandwidth: Some(bandwidth),
3055 };
3056
3057 let (outputs, dev_id, ctx) = py.allow_threads(|| -> PyResult<_> {
3058 let cuda =
3059 CudaBandpass::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3060 let dev_id = cuda.device_id();
3061 let ctx = cuda.context_arc();
3062 let out = cuda
3063 .bandpass_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
3064 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3065 cuda.stream()
3066 .synchronize()
3067 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3068 Ok((out, dev_id, ctx))
3069 })?;
3070
3071 let d = PyDict::new(py);
3072 d.set_item(
3073 "bp",
3074 Py::new(
3075 py,
3076 BandPassDeviceArrayF32Py {
3077 inner: outputs.first,
3078 _ctx: ctx.clone(),
3079 device_id: dev_id,
3080 stream: 0,
3081 },
3082 )?,
3083 )?;
3084 d.set_item(
3085 "bp_normalized",
3086 Py::new(
3087 py,
3088 BandPassDeviceArrayF32Py {
3089 inner: outputs.second,
3090 _ctx: ctx.clone(),
3091 device_id: dev_id,
3092 stream: 0,
3093 },
3094 )?,
3095 )?;
3096 d.set_item(
3097 "signal",
3098 Py::new(
3099 py,
3100 BandPassDeviceArrayF32Py {
3101 inner: outputs.third,
3102 _ctx: ctx.clone(),
3103 device_id: dev_id,
3104 stream: 0,
3105 },
3106 )?,
3107 )?;
3108 d.set_item(
3109 "trigger",
3110 Py::new(
3111 py,
3112 BandPassDeviceArrayF32Py {
3113 inner: outputs.fourth,
3114 _ctx: ctx,
3115 device_id: dev_id,
3116 stream: 0,
3117 },
3118 )?,
3119 )?;
3120 d.set_item("rows", rows)?;
3121 d.set_item("cols", cols)?;
3122 d.set_item("period", period)?;
3123 d.set_item("bandwidth", bandwidth)?;
3124 Ok(d)
3125}
3126
3127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3128#[derive(Serialize, Deserialize)]
3129pub struct BandPassJsResult {
3130 pub values: Vec<f64>,
3131 pub rows: usize,
3132 pub cols: usize,
3133}
3134
3135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3136#[wasm_bindgen(js_name = "bandpass_js")]
3137pub fn bandpass_js(data: &[f64], period: usize, bandwidth: f64) -> Result<JsValue, JsValue> {
3138 let input = BandPassInput::from_slice(
3139 data,
3140 BandPassParams {
3141 period: Some(period),
3142 bandwidth: Some(bandwidth),
3143 },
3144 );
3145 let out = bandpass_with_kernel(&input, detect_best_kernel())
3146 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3147 let cols = data.len();
3148 let mut values = Vec::with_capacity(4 * cols);
3149 values.extend_from_slice(&out.bp);
3150 values.extend_from_slice(&out.bp_normalized);
3151 values.extend_from_slice(&out.signal);
3152 values.extend_from_slice(&out.trigger);
3153 serde_wasm_bindgen::to_value(&BandPassJsResult {
3154 values,
3155 rows: 4,
3156 cols,
3157 })
3158 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3159}
3160
3161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3162#[derive(Serialize, Deserialize)]
3163pub struct BandPassBatchConfig {
3164 pub period_range: (usize, usize, usize),
3165 pub bandwidth_range: (f64, f64, f64),
3166}
3167
3168#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3169#[derive(Serialize, Deserialize)]
3170pub struct BandPassBatchJsOutput {
3171 pub values: Vec<f64>,
3172 pub rows: usize,
3173 pub cols: usize,
3174 pub combos: Vec<BandPassParams>,
3175}
3176
3177#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3178#[wasm_bindgen(js_name = "bandpass_batch")]
3179pub fn bandpass_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3180 let cfg: BandPassBatchConfig = serde_wasm_bindgen::from_value(config)
3181 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3182 let sweep = BandPassBatchRange {
3183 period: cfg.period_range,
3184 bandwidth: cfg.bandwidth_range,
3185 };
3186 let out = bandpass_batch_with_kernel(data, &sweep, Kernel::Auto)
3187 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3188
3189 let rows = out.rows;
3190 let cols = out.cols;
3191
3192 let total = rows
3193 .checked_mul(cols)
3194 .and_then(|v| v.checked_mul(4))
3195 .ok_or_else(|| JsValue::from_str("bandpass_batch_js: rows*cols overflow"))?;
3196 let mut values = Vec::with_capacity(total);
3197 values.extend_from_slice(&out.bp);
3198 values.extend_from_slice(&out.bp_normalized);
3199 values.extend_from_slice(&out.signal);
3200 values.extend_from_slice(&out.trigger);
3201
3202 let js_output = js_sys::Object::new();
3203 js_sys::Reflect::set(
3204 &js_output,
3205 &JsValue::from_str("values"),
3206 &serde_wasm_bindgen::to_value(&values).unwrap(),
3207 )?;
3208 js_sys::Reflect::set(
3209 &js_output,
3210 &JsValue::from_str("combos"),
3211 &JsValue::from_f64(out.combos.len() as f64),
3212 )?;
3213 js_sys::Reflect::set(
3214 &js_output,
3215 &JsValue::from_str("outputs"),
3216 &JsValue::from_f64(4.0),
3217 )?;
3218 js_sys::Reflect::set(
3219 &js_output,
3220 &JsValue::from_str("cols"),
3221 &JsValue::from_f64(cols as f64),
3222 )?;
3223 Ok(JsValue::from(js_output))
3224}
3225
3226#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3227#[wasm_bindgen(js_name = "bandpass_batch_metadata")]
3228pub fn bandpass_batch_metadata_js(
3229 period_start: usize,
3230 period_end: usize,
3231 period_step: usize,
3232 bandwidth_start: f64,
3233 bandwidth_end: f64,
3234 bandwidth_step: f64,
3235) -> Result<JsValue, JsValue> {
3236 let sweep = BandPassBatchRange {
3237 period: (period_start, period_end, period_step),
3238 bandwidth: (bandwidth_start, bandwidth_end, bandwidth_step),
3239 };
3240 let combos = expand_grid(&sweep);
3241
3242 let mut flat = Vec::with_capacity(combos.len() * 2);
3243 for combo in &combos {
3244 flat.push(combo.period.unwrap() as f64);
3245 flat.push(combo.bandwidth.unwrap());
3246 }
3247 serde_wasm_bindgen::to_value(&flat).map_err(|e| JsValue::from_str(&e.to_string()))
3248}
3249
3250#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3251#[wasm_bindgen]
3252pub fn bandpass_into(
3253 in_ptr: *const f64,
3254 len: usize,
3255 period: usize,
3256 bandwidth: f64,
3257 bp_ptr: *mut f64,
3258 bpn_ptr: *mut f64,
3259 sig_ptr: *mut f64,
3260 trg_ptr: *mut f64,
3261) -> Result<(), JsValue> {
3262 if in_ptr.is_null()
3263 || bp_ptr.is_null()
3264 || bpn_ptr.is_null()
3265 || sig_ptr.is_null()
3266 || trg_ptr.is_null()
3267 {
3268 return Err(JsValue::from_str("null pointer in bandpass_into"));
3269 }
3270
3271 if in_ptr == bp_ptr as *const f64
3272 || in_ptr == bpn_ptr as *const f64
3273 || in_ptr == sig_ptr as *const f64
3274 || in_ptr == trg_ptr as *const f64
3275 {
3276 return Err(JsValue::from_str(
3277 "input and output pointers must not alias",
3278 ));
3279 }
3280
3281 let out_ptrs = [bp_ptr, bpn_ptr, sig_ptr, trg_ptr];
3282 for i in 0..out_ptrs.len() {
3283 for j in i + 1..out_ptrs.len() {
3284 if out_ptrs[i] == out_ptrs[j] {
3285 return Err(JsValue::from_str(
3286 "output pointers must not alias with each other",
3287 ));
3288 }
3289 }
3290 }
3291
3292 unsafe {
3293 let data = std::slice::from_raw_parts(in_ptr, len);
3294 let mut bp = std::slice::from_raw_parts_mut(bp_ptr, len);
3295 let mut bpn = std::slice::from_raw_parts_mut(bpn_ptr, len);
3296 let mut sig = std::slice::from_raw_parts_mut(sig_ptr, len);
3297 let mut trg = std::slice::from_raw_parts_mut(trg_ptr, len);
3298 let input = BandPassInput::from_slice(
3299 data,
3300 BandPassParams {
3301 period: Some(period),
3302 bandwidth: Some(bandwidth),
3303 },
3304 );
3305 bandpass_into_slice(
3306 &mut bp,
3307 &mut bpn,
3308 &mut sig,
3309 &mut trg,
3310 &input,
3311 detect_best_kernel(),
3312 )
3313 .map_err(|e| JsValue::from_str(&e.to_string()))
3314 }
3315}
3316
3317#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3318#[wasm_bindgen]
3319pub fn bandpass_alloc(len: usize) -> *mut f64 {
3320 let mut vec = Vec::<f64>::with_capacity(len);
3321 let ptr = vec.as_mut_ptr();
3322 std::mem::forget(vec);
3323 ptr
3324}
3325
3326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3327#[wasm_bindgen]
3328pub fn bandpass_free(ptr: *mut f64, len: usize) {
3329 if !ptr.is_null() {
3330 unsafe {
3331 let _ = Vec::from_raw_parts(ptr, len, len);
3332 }
3333 }
3334}