1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyBufferError;
7#[cfg(feature = "python")]
8use pyo3::exceptions::PyValueError;
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11#[cfg(feature = "python")]
12use pyo3::types::PyDict;
13
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use serde::{Deserialize, Serialize};
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19use crate::utilities::data_loader::Candles;
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
23 make_uninit_matrix,
24};
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::collections::VecDeque;
30use std::convert::AsRef;
31use std::error::Error;
32use std::mem::ManuallyDrop;
33use thiserror::Error;
34
35#[cfg(all(feature = "python", feature = "cuda"))]
36use crate::cuda::cuda_available;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use crate::cuda::oscillators::chop_wrapper::DeviceArrayF32 as DeviceArrayF32Chop;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use crate::cuda::oscillators::CudaChop;
41#[cfg(all(feature = "python", feature = "cuda"))]
42use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
43
44#[derive(Debug, Clone)]
45pub enum ChopData<'a> {
46 Candles(&'a Candles),
47 Slice {
48 high: &'a [f64],
49 low: &'a [f64],
50 close: &'a [f64],
51 },
52}
53
54#[derive(Debug, Clone)]
55pub struct ChopOutput {
56 pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61 all(target_arch = "wasm32", feature = "wasm"),
62 derive(Serialize, Deserialize)
63)]
64pub struct ChopParams {
65 pub period: Option<usize>,
66 pub scalar: Option<f64>,
67 pub drift: Option<usize>,
68}
69impl Default for ChopParams {
70 fn default() -> Self {
71 Self {
72 period: Some(14),
73 scalar: Some(100.0),
74 drift: Some(1),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct ChopInput<'a> {
81 pub data: ChopData<'a>,
82 pub params: ChopParams,
83}
84
85impl<'a> ChopInput<'a> {
86 #[inline]
87 pub fn from_candles(candles: &'a Candles, params: ChopParams) -> Self {
88 Self {
89 data: ChopData::Candles(candles),
90 params,
91 }
92 }
93 #[inline]
94 pub fn from_slices(
95 high: &'a [f64],
96 low: &'a [f64],
97 close: &'a [f64],
98 params: ChopParams,
99 ) -> Self {
100 Self {
101 data: ChopData::Slice { high, low, close },
102 params,
103 }
104 }
105 #[inline]
106 pub fn with_default_candles(candles: &'a Candles) -> Self {
107 Self {
108 data: ChopData::Candles(candles),
109 params: ChopParams::default(),
110 }
111 }
112 #[inline]
113 pub fn get_period(&self) -> usize {
114 self.params.period.unwrap_or(14)
115 }
116 #[inline]
117 pub fn get_scalar(&self) -> f64 {
118 self.params.scalar.unwrap_or(100.0)
119 }
120 #[inline]
121 pub fn get_drift(&self) -> usize {
122 self.params.drift.unwrap_or(1)
123 }
124}
125
126impl<'a> AsRef<[f64]> for ChopInput<'a> {
127 #[inline(always)]
128 fn as_ref(&self) -> &[f64] {
129 match &self.data {
130 ChopData::Candles(candles) => candles.close.as_slice(),
131 ChopData::Slice { close, .. } => close,
132 }
133 }
134}
135
136#[derive(Copy, Clone, Debug)]
137pub struct ChopBuilder {
138 period: Option<usize>,
139 scalar: Option<f64>,
140 drift: Option<usize>,
141 kernel: Kernel,
142}
143impl Default for ChopBuilder {
144 fn default() -> Self {
145 Self {
146 period: None,
147 scalar: None,
148 drift: None,
149 kernel: Kernel::Auto,
150 }
151 }
152}
153impl ChopBuilder {
154 #[inline(always)]
155 pub fn new() -> Self {
156 Self::default()
157 }
158 #[inline(always)]
159 pub fn period(mut self, n: usize) -> Self {
160 self.period = Some(n);
161 self
162 }
163 #[inline(always)]
164 pub fn scalar(mut self, s: f64) -> Self {
165 self.scalar = Some(s);
166 self
167 }
168 #[inline(always)]
169 pub fn drift(mut self, d: usize) -> Self {
170 self.drift = Some(d);
171 self
172 }
173 #[inline(always)]
174 pub fn kernel(mut self, k: Kernel) -> Self {
175 self.kernel = k;
176 self
177 }
178 #[inline(always)]
179 pub fn apply(self, c: &Candles) -> Result<ChopOutput, ChopError> {
180 let params = ChopParams {
181 period: self.period,
182 scalar: self.scalar,
183 drift: self.drift,
184 };
185 let input = ChopInput::from_candles(c, params);
186 chop_with_kernel(&input, self.kernel)
187 }
188 #[inline(always)]
189 pub fn apply_slices(
190 self,
191 high: &[f64],
192 low: &[f64],
193 close: &[f64],
194 ) -> Result<ChopOutput, ChopError> {
195 let params = ChopParams {
196 period: self.period,
197 scalar: self.scalar,
198 drift: self.drift,
199 };
200 let input = ChopInput::from_slices(high, low, close, params);
201 chop_with_kernel(&input, self.kernel)
202 }
203 #[inline(always)]
204 pub fn into_stream(self) -> Result<ChopStream, ChopError> {
205 let params = ChopParams {
206 period: self.period,
207 scalar: self.scalar,
208 drift: self.drift,
209 };
210 ChopStream::try_new(params)
211 }
212}
213
214#[derive(Debug, Error)]
215pub enum ChopError {
216 #[error("chop: Empty data provided.")]
217 EmptyData,
218 #[error("chop: Invalid period: period={period}, data length={data_len}")]
219 InvalidPeriod { period: usize, data_len: usize },
220 #[error("chop: All relevant data (high/low/close) are NaN.")]
221 AllValuesNaN,
222 #[error("chop: Not enough valid data: needed={needed}, valid={valid}")]
223 NotEnoughValidData { needed: usize, valid: usize },
224 #[error("chop: output length mismatch: expected={expected}, got={got}")]
225 OutputLengthMismatch { expected: usize, got: usize },
226 #[error("chop: invalid kernel for batch: {0:?}")]
227 InvalidKernelForBatch(Kernel),
228 #[error("chop: invalid range: start={start}, end={end}, step={step}")]
229 InvalidRange {
230 start: usize,
231 end: usize,
232 step: usize,
233 },
234 #[error("chop: invalid input: {0}")]
235 InvalidInput(String),
236 #[error("chop: underlying function failed: {0}")]
237 UnderlyingFunctionFailed(String),
238}
239
240#[inline]
241pub fn chop(input: &ChopInput) -> Result<ChopOutput, ChopError> {
242 chop_with_kernel(input, Kernel::Auto)
243}
244
245pub fn chop_with_kernel(input: &ChopInput, kernel: Kernel) -> Result<ChopOutput, ChopError> {
246 let (high, low, close) = match &input.data {
247 ChopData::Candles(candles) => (
248 candles.high.as_slice(),
249 candles.low.as_slice(),
250 candles.close.as_slice(),
251 ),
252 ChopData::Slice { high, low, close } => (*high, *low, *close),
253 };
254
255 if !(high.len() == low.len() && low.len() == close.len()) {
256 return Err(ChopError::UnderlyingFunctionFailed(
257 "mismatched input lengths".to_string(),
258 ));
259 }
260
261 let len = close.len();
262 if len == 0 {
263 return Err(ChopError::EmptyData);
264 }
265
266 let period = input.get_period();
267 if period == 0 || period > len {
268 return Err(ChopError::InvalidPeriod {
269 period,
270 data_len: len,
271 });
272 }
273 let drift = input.get_drift();
274 if drift == 0 {
275 return Err(ChopError::UnderlyingFunctionFailed(
276 "Invalid drift=0 for ATR".to_string(),
277 ));
278 }
279 let scalar = input.get_scalar();
280
281 let first_valid_idx = match (0..len).find(|&i| {
282 let (h, l, c) = (high[i], low[i], close[i]);
283 !(h.is_nan() || l.is_nan() || c.is_nan())
284 }) {
285 Some(idx) => idx,
286 None => return Err(ChopError::AllValuesNaN),
287 };
288 if (len - first_valid_idx) < period {
289 return Err(ChopError::NotEnoughValidData {
290 needed: period,
291 valid: len - first_valid_idx,
292 });
293 }
294
295 let warmup_period = first_valid_idx + period - 1;
296 let mut out = alloc_with_nan_prefix(len, warmup_period);
297
298 let chosen = match kernel {
299 Kernel::Auto => detect_best_kernel(),
300 other => other,
301 };
302
303 unsafe {
304 match chosen {
305 Kernel::Scalar | Kernel::ScalarBatch => chop_scalar(
306 high,
307 low,
308 close,
309 period,
310 drift,
311 scalar,
312 first_valid_idx,
313 &mut out,
314 ),
315 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
316 Kernel::Avx2 | Kernel::Avx2Batch => chop_avx2(
317 high,
318 low,
319 close,
320 period,
321 drift,
322 scalar,
323 first_valid_idx,
324 &mut out,
325 ),
326 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327 Kernel::Avx512 | Kernel::Avx512Batch => chop_avx512(
328 high,
329 low,
330 close,
331 period,
332 drift,
333 scalar,
334 first_valid_idx,
335 &mut out,
336 ),
337 _ => unreachable!(),
338 }
339 }
340 Ok(ChopOutput { values: out })
341}
342
343#[inline]
344pub fn chop_into_slice(dst: &mut [f64], input: &ChopInput, kern: Kernel) -> Result<(), ChopError> {
345 let (high, low, close) = match &input.data {
346 ChopData::Candles(candles) => (
347 candles.high.as_slice(),
348 candles.low.as_slice(),
349 candles.close.as_slice(),
350 ),
351 ChopData::Slice { high, low, close } => (*high, *low, *close),
352 };
353
354 if !(high.len() == low.len() && low.len() == close.len()) {
355 return Err(ChopError::UnderlyingFunctionFailed(
356 "mismatched input lengths".to_string(),
357 ));
358 }
359
360 let len = close.len();
361 if len == 0 {
362 return Err(ChopError::EmptyData);
363 }
364
365 if dst.len() != len {
366 return Err(ChopError::OutputLengthMismatch {
367 expected: len,
368 got: dst.len(),
369 });
370 }
371
372 let period = input.get_period();
373 if period == 0 || period > len {
374 return Err(ChopError::InvalidPeriod {
375 period,
376 data_len: len,
377 });
378 }
379 let drift = input.get_drift();
380 if drift == 0 {
381 return Err(ChopError::UnderlyingFunctionFailed(
382 "Invalid drift=0 for ATR".to_string(),
383 ));
384 }
385 let scalar = input.get_scalar();
386
387 let first_valid_idx = match (0..len).find(|&i| {
388 let (h, l, c) = (high[i], low[i], close[i]);
389 !(h.is_nan() || l.is_nan() || c.is_nan())
390 }) {
391 Some(idx) => idx,
392 None => return Err(ChopError::AllValuesNaN),
393 };
394 if (len - first_valid_idx) < period {
395 return Err(ChopError::NotEnoughValidData {
396 needed: period,
397 valid: len - first_valid_idx,
398 });
399 }
400
401 let chosen = match kern {
402 Kernel::Auto => detect_best_kernel(),
403 other => other,
404 };
405
406 unsafe {
407 match chosen {
408 Kernel::Scalar | Kernel::ScalarBatch => chop_scalar(
409 high,
410 low,
411 close,
412 period,
413 drift,
414 scalar,
415 first_valid_idx,
416 dst,
417 ),
418 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419 Kernel::Avx2 | Kernel::Avx2Batch => chop_avx2(
420 high,
421 low,
422 close,
423 period,
424 drift,
425 scalar,
426 first_valid_idx,
427 dst,
428 ),
429 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
430 Kernel::Avx512 | Kernel::Avx512Batch => chop_avx512(
431 high,
432 low,
433 close,
434 period,
435 drift,
436 scalar,
437 first_valid_idx,
438 dst,
439 ),
440 _ => unreachable!(),
441 }
442 }
443
444 let warmup_end = first_valid_idx + period - 1;
445 for v in &mut dst[..warmup_end] {
446 *v = f64::NAN;
447 }
448
449 Ok(())
450}
451
452#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
453#[inline]
454pub fn chop_into(input: &ChopInput, out: &mut [f64]) -> Result<(), ChopError> {
455 let len = match &input.data {
456 ChopData::Candles(c) => c.close.len(),
457 ChopData::Slice { close, .. } => close.len(),
458 };
459 if out.len() != len {
460 return Err(ChopError::OutputLengthMismatch {
461 expected: len,
462 got: out.len(),
463 });
464 }
465 chop_into_slice(out, input, Kernel::Auto)
466}
467
468#[inline]
469pub unsafe fn chop_scalar(
470 high: &[f64],
471 low: &[f64],
472 close: &[f64],
473 period: usize,
474 drift: usize,
475 scalar: f64,
476 first_valid_idx: usize,
477 out: &mut [f64],
478) {
479 debug_assert!(high.len() == low.len() && low.len() == close.len());
480 let len = close.len();
481 if len == 0 {
482 return;
483 }
484
485 let alpha = 1.0 / (drift as f64);
486 let logp = (period as f64).log10();
487
488 let mut atr_ring = vec![0.0_f64; period];
489 let mut atr_ring_idx: usize = 0;
490 let mut rolling_sum_atr: f64 = 0.0;
491
492 let mut rma_atr = f64::NAN;
493 let mut sum_tr: f64 = 0.0;
494
495 let mut dq_high: VecDeque<usize> = VecDeque::with_capacity(period);
496 let mut dq_low: VecDeque<usize> = VecDeque::with_capacity(period);
497
498 let mut prev_close = close[first_valid_idx];
499
500 for i in first_valid_idx..len {
501 let hi = high[i];
502 let lo = low[i];
503 let hl = hi - lo;
504 let tr = if i == first_valid_idx {
505 sum_tr = hl;
506 hl
507 } else {
508 let hc = (hi - prev_close).abs();
509 let lc = (lo - prev_close).abs();
510 hl.max(hc).max(lc)
511 };
512
513 let rel = i - first_valid_idx;
514 if rel < drift {
515 if i != first_valid_idx {
516 sum_tr += tr;
517 }
518 if rel == drift - 1 {
519 rma_atr = sum_tr / drift as f64;
520 }
521 } else {
522 rma_atr += alpha * (tr - rma_atr);
523 }
524 prev_close = close[i];
525
526 let current_atr = if rel < drift {
527 if rel == drift - 1 {
528 rma_atr
529 } else {
530 f64::NAN
531 }
532 } else {
533 rma_atr
534 };
535
536 let oldest = atr_ring[atr_ring_idx];
537 rolling_sum_atr -= oldest;
538 let new_val = if current_atr.is_nan() {
539 0.0
540 } else {
541 current_atr
542 };
543 atr_ring[atr_ring_idx] = new_val;
544 rolling_sum_atr += new_val;
545 atr_ring_idx += 1;
546 if atr_ring_idx == period {
547 atr_ring_idx = 0;
548 }
549
550 let win_start = i.saturating_sub(period - 1);
551 while let Some(&front_idx) = dq_high.front() {
552 if front_idx < win_start {
553 dq_high.pop_front();
554 } else {
555 break;
556 }
557 }
558 while let Some(&front_idx) = dq_low.front() {
559 if front_idx < win_start {
560 dq_low.pop_front();
561 } else {
562 break;
563 }
564 }
565 while let Some(&back_idx) = dq_high.back() {
566 if high[back_idx] <= hi {
567 dq_high.pop_back();
568 } else {
569 break;
570 }
571 }
572 dq_high.push_back(i);
573 while let Some(&back_idx) = dq_low.back() {
574 if low[back_idx] >= lo {
575 dq_low.pop_back();
576 } else {
577 break;
578 }
579 }
580 dq_low.push_back(i);
581
582 if rel >= (period - 1) {
583 let hh_idx = *dq_high.front().unwrap();
584 let ll_idx = *dq_low.front().unwrap();
585 let range = high[hh_idx] - low[ll_idx];
586 if range > 0.0 && rolling_sum_atr > 0.0 {
587 out[i] = (scalar * (rolling_sum_atr.log10() - range.log10())) / logp;
588 } else {
589 out[i] = f64::NAN;
590 }
591 }
592 }
593}
594
595#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
596#[inline]
597pub unsafe fn chop_avx2(
598 high: &[f64],
599 low: &[f64],
600 close: &[f64],
601 period: usize,
602 drift: usize,
603 scalar: f64,
604 first_valid_idx: usize,
605 out: &mut [f64],
606) {
607 chop_scalar(
608 high,
609 low,
610 close,
611 period,
612 drift,
613 scalar,
614 first_valid_idx,
615 out,
616 )
617}
618
619#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
620#[inline]
621pub unsafe fn chop_avx512(
622 high: &[f64],
623 low: &[f64],
624 close: &[f64],
625 period: usize,
626 drift: usize,
627 scalar: f64,
628 first_valid_idx: usize,
629 out: &mut [f64],
630) {
631 chop_scalar(
632 high,
633 low,
634 close,
635 period,
636 drift,
637 scalar,
638 first_valid_idx,
639 out,
640 )
641}
642
643#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
644#[inline]
645pub unsafe fn chop_avx512_short(
646 high: &[f64],
647 low: &[f64],
648 close: &[f64],
649 period: usize,
650 drift: usize,
651 scalar: f64,
652 first_valid_idx: usize,
653 out: &mut [f64],
654) {
655 chop_avx512(
656 high,
657 low,
658 close,
659 period,
660 drift,
661 scalar,
662 first_valid_idx,
663 out,
664 )
665}
666#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
667#[inline]
668pub unsafe fn chop_avx512_long(
669 high: &[f64],
670 low: &[f64],
671 close: &[f64],
672 period: usize,
673 drift: usize,
674 scalar: f64,
675 first_valid_idx: usize,
676 out: &mut [f64],
677) {
678 chop_avx512(
679 high,
680 low,
681 close,
682 period,
683 drift,
684 scalar,
685 first_valid_idx,
686 out,
687 )
688}
689
690#[inline(always)]
691pub fn chop_batch_with_kernel(
692 high: &[f64],
693 low: &[f64],
694 close: &[f64],
695 sweep: &ChopBatchRange,
696 k: Kernel,
697) -> Result<ChopBatchOutput, ChopError> {
698 let kernel = match k {
699 Kernel::Auto => detect_best_batch_kernel(),
700 other if other.is_batch() => other,
701 other => return Err(ChopError::InvalidKernelForBatch(other)),
702 };
703 let simd = match kernel {
704 Kernel::Avx512Batch => Kernel::Avx512,
705 Kernel::Avx2Batch => Kernel::Avx2,
706 Kernel::ScalarBatch => Kernel::Scalar,
707 _ => unreachable!(),
708 };
709 chop_batch_par_slice(high, low, close, sweep, simd)
710}
711
712#[derive(Clone, Debug)]
713pub struct ChopBatchRange {
714 pub period: (usize, usize, usize),
715 pub scalar: (f64, f64, f64),
716 pub drift: (usize, usize, usize),
717}
718impl Default for ChopBatchRange {
719 fn default() -> Self {
720 Self {
721 period: (14, 14, 0),
722 scalar: (100.0, 124.9, 0.1),
723 drift: (1, 1, 0),
724 }
725 }
726}
727
728#[derive(Clone, Debug, Default)]
729pub struct ChopBatchBuilder {
730 range: ChopBatchRange,
731 kernel: Kernel,
732}
733impl ChopBatchBuilder {
734 pub fn new() -> Self {
735 Self::default()
736 }
737 pub fn kernel(mut self, k: Kernel) -> Self {
738 self.kernel = k;
739 self
740 }
741 #[inline]
742 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
743 self.range.period = (start, end, step);
744 self
745 }
746 #[inline]
747 pub fn period_static(mut self, p: usize) -> Self {
748 self.range.period = (p, p, 0);
749 self
750 }
751 #[inline]
752 pub fn scalar_range(mut self, start: f64, end: f64, step: f64) -> Self {
753 self.range.scalar = (start, end, step);
754 self
755 }
756 #[inline]
757 pub fn scalar_static(mut self, s: f64) -> Self {
758 self.range.scalar = (s, s, 0.0);
759 self
760 }
761 #[inline]
762 pub fn drift_range(mut self, start: usize, end: usize, step: usize) -> Self {
763 self.range.drift = (start, end, step);
764 self
765 }
766 #[inline]
767 pub fn drift_static(mut self, d: usize) -> Self {
768 self.range.drift = (d, d, 0);
769 self
770 }
771 pub fn apply_slices(
772 self,
773 high: &[f64],
774 low: &[f64],
775 close: &[f64],
776 ) -> Result<ChopBatchOutput, ChopError> {
777 chop_batch_with_kernel(high, low, close, &self.range, self.kernel)
778 }
779}
780
781#[derive(Clone, Debug)]
782pub struct ChopBatchOutput {
783 pub values: Vec<f64>,
784 pub combos: Vec<ChopParams>,
785 pub rows: usize,
786 pub cols: usize,
787}
788impl ChopBatchOutput {
789 pub fn row_for_params(&self, p: &ChopParams) -> Option<usize> {
790 self.combos.iter().position(|c| {
791 c.period.unwrap_or(14) == p.period.unwrap_or(14)
792 && (c.scalar.unwrap_or(100.0) - p.scalar.unwrap_or(100.0)).abs() < 1e-12
793 && c.drift.unwrap_or(1) == p.drift.unwrap_or(1)
794 })
795 }
796 pub fn values_for(&self, p: &ChopParams) -> Option<&[f64]> {
797 self.row_for_params(p).map(|row| {
798 let start = row * self.cols;
799 &self.values[start..start + self.cols]
800 })
801 }
802}
803
804#[inline(always)]
805fn expand_grid(r: &ChopBatchRange) -> Result<Vec<ChopParams>, ChopError> {
806 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ChopError> {
807 if step == 0 || start == end {
808 return Ok(vec![start]);
809 }
810 let mut out = Vec::new();
811 if start < end {
812 let mut v = start;
813 while v <= end {
814 out.push(v);
815 match v.checked_add(step) {
816 Some(next) => {
817 if next == v {
818 break;
819 }
820 v = next;
821 }
822 None => break,
823 }
824 }
825 } else {
826 let mut v = start;
827 while v >= end {
828 out.push(v);
829 if v < end + step {
830 break;
831 }
832 v -= step;
833 if v == 0 {
834 break;
835 }
836 }
837 }
838 if out.is_empty() {
839 return Err(ChopError::InvalidRange { start, end, step });
840 }
841 Ok(out)
842 }
843 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, ChopError> {
844 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
845 return Ok(vec![start]);
846 }
847 let mut v = Vec::new();
848 if start <= end && step > 0.0 {
849 let mut x = start;
850 while x <= end + 1e-12 {
851 v.push(x);
852 x += step;
853 }
854 } else if start >= end && step < 0.0 {
855 let mut x = start;
856 while x >= end - 1e-12 {
857 v.push(x);
858 x += step;
859 }
860 } else {
861 return Err(ChopError::InvalidInput(
862 "axis_f64 step direction invalid".into(),
863 ));
864 }
865 if v.is_empty() {
866 return Err(ChopError::InvalidRange {
867 start: start as usize,
868 end: end as usize,
869 step: step.abs() as usize,
870 });
871 }
872 Ok(v)
873 }
874 let periods = axis_usize(r.period)?;
875 let scalars = axis_f64(r.scalar)?;
876 let drifts = axis_usize(r.drift)?;
877 let cap = periods
878 .len()
879 .checked_mul(scalars.len())
880 .and_then(|x| x.checked_mul(drifts.len()))
881 .ok_or_else(|| ChopError::InvalidInput("rows*cols overflow".into()))?;
882 let mut out = Vec::with_capacity(cap);
883 for &p in &periods {
884 for &s in &scalars {
885 for &d in &drifts {
886 out.push(ChopParams {
887 period: Some(p),
888 scalar: Some(s),
889 drift: Some(d),
890 });
891 }
892 }
893 }
894 Ok(out)
895}
896
897#[inline(always)]
898pub fn chop_batch_slice(
899 high: &[f64],
900 low: &[f64],
901 close: &[f64],
902 sweep: &ChopBatchRange,
903 kern: Kernel,
904) -> Result<ChopBatchOutput, ChopError> {
905 chop_batch_inner(high, low, close, sweep, kern, false)
906}
907#[inline(always)]
908pub fn chop_batch_par_slice(
909 high: &[f64],
910 low: &[f64],
911 close: &[f64],
912 sweep: &ChopBatchRange,
913 kern: Kernel,
914) -> Result<ChopBatchOutput, ChopError> {
915 chop_batch_inner(high, low, close, sweep, kern, true)
916}
917#[inline(always)]
918fn chop_batch_inner(
919 high: &[f64],
920 low: &[f64],
921 close: &[f64],
922 sweep: &ChopBatchRange,
923 kern: Kernel,
924 parallel: bool,
925) -> Result<ChopBatchOutput, ChopError> {
926 let combos = expand_grid(sweep)?;
927
928 if !(high.len() == low.len() && low.len() == close.len()) {
929 return Err(ChopError::UnderlyingFunctionFailed(
930 "mismatched input lengths".to_string(),
931 ));
932 }
933
934 let len = close.len();
935 let first = (0..len)
936 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
937 .ok_or(ChopError::AllValuesNaN)?;
938 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
939 if len - first < max_p {
940 return Err(ChopError::NotEnoughValidData {
941 needed: max_p,
942 valid: len - first,
943 });
944 }
945
946 let rows = combos.len();
947 let cols = len;
948 rows.checked_mul(cols)
949 .ok_or_else(|| ChopError::InvalidInput("rows*cols overflow".into()))?;
950 let mut buf_mu = make_uninit_matrix(rows, cols);
951
952 let warm: Vec<usize> = combos
953 .iter()
954 .map(|c| first + c.period.unwrap() - 1)
955 .collect();
956 init_matrix_prefixes(&mut buf_mu, cols, &warm);
957
958 let mut buf_guard = ManuallyDrop::new(buf_mu);
959 let values: &mut [f64] = unsafe {
960 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
961 };
962 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
963 let ChopParams {
964 period,
965 scalar,
966 drift,
967 } = combos[row].clone();
968 let p = period.unwrap();
969 let s = scalar.unwrap();
970 let d = drift.unwrap();
971 match kern {
972 Kernel::Scalar => chop_row_scalar(high, low, close, first, p, d, s, out_row),
973 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
974 Kernel::Avx2 => chop_row_avx2(high, low, close, first, p, d, s, out_row),
975 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
976 Kernel::Avx512 => chop_row_avx512(high, low, close, first, p, d, s, out_row),
977 _ => unreachable!(),
978 }
979 };
980 if parallel {
981 #[cfg(not(target_arch = "wasm32"))]
982 {
983 values
984 .par_chunks_mut(cols)
985 .enumerate()
986 .for_each(|(row, slice)| do_row(row, slice));
987 }
988
989 #[cfg(target_arch = "wasm32")]
990 {
991 for (row, slice) in values.chunks_mut(cols).enumerate() {
992 do_row(row, slice);
993 }
994 }
995 } else {
996 for (row, slice) in values.chunks_mut(cols).enumerate() {
997 do_row(row, slice);
998 }
999 }
1000 let values = unsafe {
1001 Vec::from_raw_parts(
1002 buf_guard.as_mut_ptr() as *mut f64,
1003 buf_guard.len(),
1004 buf_guard.capacity(),
1005 )
1006 };
1007
1008 Ok(ChopBatchOutput {
1009 values,
1010 combos,
1011 rows,
1012 cols,
1013 })
1014}
1015
1016#[inline(always)]
1017unsafe fn chop_row_scalar(
1018 high: &[f64],
1019 low: &[f64],
1020 close: &[f64],
1021 first: usize,
1022 period: usize,
1023 drift: usize,
1024 scalar: f64,
1025 out: &mut [f64],
1026) {
1027 chop_scalar(high, low, close, period, drift, scalar, first, out)
1028}
1029
1030#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1031#[inline(always)]
1032unsafe fn chop_row_avx2(
1033 high: &[f64],
1034 low: &[f64],
1035 close: &[f64],
1036 first: usize,
1037 period: usize,
1038 drift: usize,
1039 scalar: f64,
1040 out: &mut [f64],
1041) {
1042 chop_avx2(high, low, close, period, drift, scalar, first, out)
1043}
1044
1045#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1046#[inline(always)]
1047pub unsafe fn chop_row_avx512(
1048 high: &[f64],
1049 low: &[f64],
1050 close: &[f64],
1051 first: usize,
1052 period: usize,
1053 drift: usize,
1054 scalar: f64,
1055 out: &mut [f64],
1056) {
1057 if period <= 32 {
1058 chop_row_avx512_short(high, low, close, first, period, drift, scalar, out)
1059 } else {
1060 chop_row_avx512_long(high, low, close, first, period, drift, scalar, out)
1061 }
1062}
1063#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1064#[inline(always)]
1065pub unsafe fn chop_row_avx512_short(
1066 high: &[f64],
1067 low: &[f64],
1068 close: &[f64],
1069 first: usize,
1070 period: usize,
1071 drift: usize,
1072 scalar: f64,
1073 out: &mut [f64],
1074) {
1075 chop_avx512(high, low, close, period, drift, scalar, first, out)
1076}
1077#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1078#[inline(always)]
1079pub unsafe fn chop_row_avx512_long(
1080 high: &[f64],
1081 low: &[f64],
1082 close: &[f64],
1083 first: usize,
1084 period: usize,
1085 drift: usize,
1086 scalar: f64,
1087 out: &mut [f64],
1088) {
1089 chop_avx512(high, low, close, period, drift, scalar, first, out)
1090}
1091
1092#[derive(Copy, Clone, Debug)]
1093struct Node {
1094 idx: u64,
1095 val: f64,
1096}
1097
1098#[derive(Debug, Clone)]
1099pub struct ChopStream {
1100 period: usize,
1101 drift: usize,
1102 scalar: f64,
1103
1104 inv_drift: f64,
1105 scale_ln: f64,
1106
1107 atr_ring: Vec<f64>,
1108 ring_idx: usize,
1109 rolling_sum_atr: f64,
1110
1111 dq_high: VecDeque<Node>,
1112 dq_low: VecDeque<Node>,
1113
1114 rma_atr: f64,
1115 sum_tr: f64,
1116 count: u64,
1117 prev_close: f64,
1118}
1119impl ChopStream {
1120 #[inline]
1121 pub fn try_new(params: ChopParams) -> Result<Self, ChopError> {
1122 let period = params.period.unwrap_or(14);
1123 if period == 0 {
1124 return Err(ChopError::InvalidPeriod {
1125 period,
1126 data_len: 0,
1127 });
1128 }
1129 let drift = params.drift.unwrap_or(1);
1130 if drift == 0 {
1131 return Err(ChopError::UnderlyingFunctionFailed(
1132 "Invalid drift=0 for ATR".to_string(),
1133 ));
1134 }
1135 let scalar = params.scalar.unwrap_or(100.0);
1136
1137 let inv_drift = 1.0 / (drift as f64);
1138 let scale_ln = scalar / (period as f64).ln();
1139
1140 Ok(Self {
1141 period,
1142 drift,
1143 scalar,
1144 inv_drift,
1145 scale_ln,
1146
1147 atr_ring: vec![0.0; period],
1148 ring_idx: 0,
1149 rolling_sum_atr: 0.0,
1150
1151 dq_high: VecDeque::with_capacity(period),
1152 dq_low: VecDeque::with_capacity(period),
1153
1154 rma_atr: f64::NAN,
1155 sum_tr: 0.0,
1156 count: 0,
1157 prev_close: f64::NAN,
1158 })
1159 }
1160
1161 #[inline]
1162 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1163 let idx_ring = self.ring_idx;
1164 self.ring_idx = (self.ring_idx + 1) % self.period;
1165 self.count = self.count.saturating_add(1);
1166 let this_idx = self.count - 1;
1167
1168 let tr = if self.count == 1 {
1169 self.prev_close = close;
1170 self.sum_tr = high - low;
1171 high - low
1172 } else {
1173 let hl = high - low;
1174 let hc = (high - self.prev_close).abs();
1175 let lc = (low - self.prev_close).abs();
1176 self.prev_close = close;
1177 hl.max(hc).max(lc)
1178 };
1179
1180 if (self.count as usize) <= self.drift {
1181 if self.count != 1 {
1182 self.sum_tr += tr;
1183 }
1184 if (self.count as usize) == self.drift {
1185 self.rma_atr = self.sum_tr * self.inv_drift;
1186 }
1187 } else {
1188 self.rma_atr += self.inv_drift * (tr - self.rma_atr);
1189 }
1190
1191 let current_atr = if (self.count as usize) < self.drift {
1192 f64::NAN
1193 } else {
1194 self.rma_atr
1195 };
1196
1197 let newest = if current_atr.is_nan() {
1198 0.0
1199 } else {
1200 current_atr
1201 };
1202 let oldest = self.atr_ring[idx_ring];
1203 self.atr_ring[idx_ring] = newest;
1204 self.rolling_sum_atr += newest - oldest;
1205
1206 let win_start = self.count.saturating_sub(self.period as u64);
1207
1208 while let Some(&front) = self.dq_high.front() {
1209 if front.idx < win_start {
1210 self.dq_high.pop_front();
1211 } else {
1212 break;
1213 }
1214 }
1215 while let Some(&front) = self.dq_low.front() {
1216 if front.idx < win_start {
1217 self.dq_low.pop_front();
1218 } else {
1219 break;
1220 }
1221 }
1222
1223 while let Some(&back) = self.dq_high.back() {
1224 if back.val <= high {
1225 self.dq_high.pop_back();
1226 } else {
1227 break;
1228 }
1229 }
1230 self.dq_high.push_back(Node {
1231 idx: this_idx,
1232 val: high,
1233 });
1234
1235 while let Some(&back) = self.dq_low.back() {
1236 if back.val >= low {
1237 self.dq_low.pop_back();
1238 } else {
1239 break;
1240 }
1241 }
1242 self.dq_low.push_back(Node {
1243 idx: this_idx,
1244 val: low,
1245 });
1246
1247 if self.count >= self.period as u64 {
1248 let range = self.dq_high.front().unwrap().val - self.dq_low.front().unwrap().val;
1249 if range > 0.0 && self.rolling_sum_atr > 0.0 {
1250 let ratio = self.rolling_sum_atr / range;
1251 let y = if (ratio - 1.0).abs() < 1e-8 {
1252 self.scale_ln * (ratio - 1.0).ln_1p()
1253 } else {
1254 self.scale_ln * ratio.ln()
1255 };
1256 Some(y)
1257 } else {
1258 Some(f64::NAN)
1259 }
1260 } else {
1261 None
1262 }
1263 }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268 use super::*;
1269 use crate::skip_if_unsupported;
1270 use crate::utilities::data_loader::read_candles_from_csv;
1271 use std::error::Error;
1272
1273 #[test]
1274 fn test_chop_into_matches_api() -> Result<(), Box<dyn Error>> {
1275 let n = 256usize;
1276 let mut high = Vec::with_capacity(n);
1277 let mut low = Vec::with_capacity(n);
1278 let mut close = Vec::with_capacity(n);
1279 for i in 0..n {
1280 let t = i as f64;
1281 let base = 100.0 + (t * 0.07).sin() * 2.0 + (t * 0.013).cos();
1282 let h0 = base + 1.0 + 0.15 * (t * 0.31).sin();
1283 let l0 = base - 1.0 - 0.12 * (t * 0.23).cos();
1284 let (lo, hi) = if l0 <= h0 { (l0, h0) } else { (h0, l0) };
1285 let mut c0 = 0.5 * (lo + hi) + 0.2 * (t * 0.17).sin();
1286 if c0 < lo {
1287 c0 = lo;
1288 }
1289 if c0 > hi {
1290 c0 = hi;
1291 }
1292 high.push(hi);
1293 low.push(lo);
1294 close.push(c0);
1295 }
1296
1297 let input = ChopInput::from_slices(&high, &low, &close, ChopParams::default());
1298
1299 let baseline = chop(&input)?.values;
1300
1301 let mut out = vec![0.0; n];
1302 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1303 {
1304 chop_into(&input, &mut out)?;
1305 }
1306 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1307 {
1308 chop_into_slice(&mut out, &input, Kernel::Auto)?;
1309 }
1310
1311 assert_eq!(baseline.len(), out.len());
1312 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
1313 if a.is_nan() || b.is_nan() {
1314 assert!(a.is_nan() && b.is_nan(), "NaN mismatch at index {}", i);
1315 } else {
1316 assert!(
1317 (a - b).abs() <= 1e-12,
1318 "Value mismatch at index {}: {} vs {}",
1319 i,
1320 a,
1321 b
1322 );
1323 }
1324 }
1325 Ok(())
1326 }
1327 fn check_chop_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1328 skip_if_unsupported!(kernel, test_name);
1329 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1330 let candles = read_candles_from_csv(file_path)?;
1331 let partial_params = ChopParams {
1332 period: Some(30),
1333 scalar: None,
1334 drift: None,
1335 };
1336 let input_partial = ChopInput::from_candles(&candles, partial_params);
1337 let output_partial = chop_with_kernel(&input_partial, kernel)?;
1338 assert_eq!(output_partial.values.len(), candles.close.len());
1339 Ok(())
1340 }
1341 fn check_chop_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1342 skip_if_unsupported!(kernel, test_name);
1343 let expected_final_5 = [
1344 49.98214330294626,
1345 48.90450693742312,
1346 46.63648608318844,
1347 46.19823574588033,
1348 56.22876423352909,
1349 ];
1350 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1351 let candles = read_candles_from_csv(file_path)?;
1352 let input = ChopInput::with_default_candles(&candles);
1353 let result = chop_with_kernel(&input, kernel)?;
1354 let start_idx = result.values.len() - 5;
1355 for (i, &exp) in expected_final_5.iter().enumerate() {
1356 let idx = start_idx + i;
1357 let got = result.values[idx];
1358 assert!(
1359 (got - exp).abs() < 1e-4,
1360 "[{}] CHOP at idx {}: got {}, expected {}",
1361 test_name,
1362 idx,
1363 got,
1364 exp
1365 );
1366 }
1367 Ok(())
1368 }
1369 fn check_chop_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1370 skip_if_unsupported!(kernel, test_name);
1371 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1372 let candles = read_candles_from_csv(file_path)?;
1373 let input = ChopInput::with_default_candles(&candles);
1374 match input.data {
1375 ChopData::Candles(_) => {}
1376 _ => panic!("Expected ChopData::Candles variant"),
1377 }
1378 let output = chop_with_kernel(&input, kernel)?;
1379 assert_eq!(output.values.len(), candles.close.len());
1380 Ok(())
1381 }
1382 fn check_chop_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1383 skip_if_unsupported!(kernel, test_name);
1384 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1385 let candles = read_candles_from_csv(file_path)?;
1386 let params = ChopParams {
1387 period: Some(0),
1388 ..Default::default()
1389 };
1390 let input = ChopInput::from_candles(&candles, params);
1391 let result = chop_with_kernel(&input, kernel);
1392 assert!(
1393 result.is_err(),
1394 "[{}] Expected error for zero period",
1395 test_name
1396 );
1397 Ok(())
1398 }
1399 fn check_chop_period_exceeds_length(
1400 test_name: &str,
1401 kernel: Kernel,
1402 ) -> Result<(), Box<dyn Error>> {
1403 skip_if_unsupported!(kernel, test_name);
1404 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1405 let candles = read_candles_from_csv(file_path)?;
1406 let params = ChopParams {
1407 period: Some(999999),
1408 ..Default::default()
1409 };
1410 let input = ChopInput::from_candles(&candles, params);
1411 let result = chop_with_kernel(&input, kernel);
1412 assert!(
1413 result.is_err(),
1414 "[{}] Expected error for huge period",
1415 test_name
1416 );
1417 Ok(())
1418 }
1419 fn check_chop_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1420 skip_if_unsupported!(kernel, test_name);
1421 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1422 let candles = read_candles_from_csv(file_path)?;
1423 let input = ChopInput::with_default_candles(&candles);
1424 let result = chop_with_kernel(&input, kernel)?;
1425 let check_index = 240;
1426 if result.values.len() > check_index {
1427 let all_nan = result.values[check_index..].iter().all(|&x| x.is_nan());
1428 assert!(
1429 !all_nan,
1430 "[{}] All CHOP values from index {} onward are NaN.",
1431 test_name, check_index
1432 );
1433 }
1434 Ok(())
1435 }
1436 fn check_chop_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1437 skip_if_unsupported!(kernel, test_name);
1438 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1439 let candles = read_candles_from_csv(file_path)?;
1440 let period = 14;
1441 let scalar = 100.0;
1442 let drift = 1;
1443 let input = ChopInput::from_candles(
1444 &candles,
1445 ChopParams {
1446 period: Some(period),
1447 scalar: Some(scalar),
1448 drift: Some(drift),
1449 },
1450 );
1451 let batch_output = chop_with_kernel(&input, kernel)?.values;
1452 let mut stream = ChopStream::try_new(ChopParams {
1453 period: Some(period),
1454 scalar: Some(scalar),
1455 drift: Some(drift),
1456 })?;
1457 let mut stream_values = Vec::with_capacity(candles.close.len());
1458 for i in 0..candles.close.len() {
1459 let res = stream.update(candles.high[i], candles.low[i], candles.close[i]);
1460 match res {
1461 Some(chop_val) => stream_values.push(chop_val),
1462 None => stream_values.push(f64::NAN),
1463 }
1464 }
1465 assert_eq!(batch_output.len(), stream_values.len());
1466 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1467 if b.is_nan() && s.is_nan() {
1468 continue;
1469 }
1470 let diff = (b - s).abs();
1471 assert!(
1472 diff < 1e-9,
1473 "[{}] CHOP streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1474 test_name,
1475 i,
1476 b,
1477 s,
1478 diff
1479 );
1480 }
1481 Ok(())
1482 }
1483
1484 #[cfg(debug_assertions)]
1485 fn check_chop_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1486 skip_if_unsupported!(kernel, test_name);
1487
1488 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1489 let candles = read_candles_from_csv(file_path)?;
1490
1491 let input = ChopInput::with_default_candles(&candles);
1492 let output = chop_with_kernel(&input, kernel)?;
1493
1494 for (i, &val) in output.values.iter().enumerate() {
1495 if val.is_nan() {
1496 continue;
1497 }
1498
1499 let bits = val.to_bits();
1500
1501 if bits == 0x11111111_11111111 {
1502 panic!(
1503 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1504 test_name, val, bits, i
1505 );
1506 }
1507
1508 if bits == 0x22222222_22222222 {
1509 panic!(
1510 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1511 test_name, val, bits, i
1512 );
1513 }
1514
1515 if bits == 0x33333333_33333333 {
1516 panic!(
1517 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1518 test_name, val, bits, i
1519 );
1520 }
1521 }
1522
1523 let param_combinations = vec![
1524 ChopParams {
1525 period: Some(10),
1526 scalar: Some(50.0),
1527 drift: Some(1),
1528 },
1529 ChopParams {
1530 period: Some(20),
1531 scalar: Some(100.0),
1532 drift: Some(2),
1533 },
1534 ChopParams {
1535 period: Some(30),
1536 scalar: Some(150.0),
1537 drift: Some(3),
1538 },
1539 ];
1540
1541 for params in param_combinations {
1542 let input = ChopInput::from_candles(&candles, params);
1543 let output = chop_with_kernel(&input, kernel)?;
1544
1545 for (i, &val) in output.values.iter().enumerate() {
1546 if val.is_nan() {
1547 continue;
1548 }
1549
1550 let bits = val.to_bits();
1551
1552 if bits == 0x11111111_11111111 {
1553 panic!(
1554 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params {:?}",
1555 test_name, val, bits, i, input.params
1556 );
1557 }
1558
1559 if bits == 0x22222222_22222222 {
1560 panic!(
1561 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params {:?}",
1562 test_name, val, bits, i, input.params
1563 );
1564 }
1565
1566 if bits == 0x33333333_33333333 {
1567 panic!(
1568 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params {:?}",
1569 test_name, val, bits, i, input.params
1570 );
1571 }
1572 }
1573 }
1574
1575 Ok(())
1576 }
1577
1578 #[cfg(not(debug_assertions))]
1579 fn check_chop_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1580 Ok(())
1581 }
1582
1583 macro_rules! generate_all_chop_tests {
1584 ($($test_fn:ident),*) => {
1585 paste::paste! {
1586 $(
1587 #[test]
1588 fn [<$test_fn _scalar_f64>]() {
1589 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1590 }
1591 )*
1592 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1593 $(
1594 #[test]
1595 fn [<$test_fn _avx2_f64>]() {
1596 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1597 }
1598 #[test]
1599 fn [<$test_fn _avx512_f64>]() {
1600 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1601 }
1602 )*
1603 }
1604 }
1605 }
1606 #[cfg(not(feature = "proptest"))]
1607 generate_all_chop_tests!(
1608 check_chop_partial_params,
1609 check_chop_accuracy,
1610 check_chop_default_candles,
1611 check_chop_zero_period,
1612 check_chop_period_exceeds_length,
1613 check_chop_nan_handling,
1614 check_chop_streaming,
1615 check_chop_no_poison
1616 );
1617
1618 #[cfg(feature = "proptest")]
1619 generate_all_chop_tests!(
1620 check_chop_partial_params,
1621 check_chop_accuracy,
1622 check_chop_default_candles,
1623 check_chop_zero_period,
1624 check_chop_period_exceeds_length,
1625 check_chop_nan_handling,
1626 check_chop_streaming,
1627 check_chop_no_poison,
1628 check_chop_property
1629 );
1630
1631 #[cfg(feature = "proptest")]
1632 fn check_chop_property(
1633 test_name: &str,
1634 kernel: Kernel,
1635 ) -> Result<(), Box<dyn std::error::Error>> {
1636 use proptest::prelude::*;
1637 skip_if_unsupported!(kernel, test_name);
1638
1639 let strat = (50usize..400).prop_flat_map(|size| {
1640 (
1641 10.0f64..1000.0f64,
1642 0.0f64..0.1f64,
1643 -0.02f64..0.02f64,
1644 prop::collection::vec((0.0f64..1.0, 0.0f64..1.0, 0.0f64..1.0, 0.0f64..1.0), size),
1645 0u8..5,
1646 Just(size),
1647 5usize..50,
1648 50.0f64..200.0f64,
1649 1usize..5,
1650 )
1651 });
1652
1653 proptest::test_runner::TestRunner::default()
1654 .run(&strat, |(base_price, volatility, trend, random_factors, market_type, size, period, scalar, drift)| {
1655
1656 let mut high_data = Vec::with_capacity(size);
1657 let mut low_data = Vec::with_capacity(size);
1658 let mut close_data = Vec::with_capacity(size);
1659 let mut open_data = Vec::with_capacity(size);
1660
1661 let mut current_price = base_price;
1662
1663 for i in 0..size {
1664 let (r1, r2, r3, r4) = random_factors[i];
1665 let range = current_price * volatility;
1666
1667
1668 let (open, high, low, close) = match market_type {
1669 0 => {
1670
1671 let open = current_price;
1672 let close = current_price + range * (0.5 + r1 * 0.5) + (trend * current_price);
1673 let high = close.max(open) + range * r2 * 0.3;
1674 let low = close.min(open) - range * r3 * 0.2;
1675
1676 let high_adjusted = high + range * r4 * 0.1;
1677 current_price = close;
1678 (open, high_adjusted, low, close)
1679 }
1680 1 => {
1681
1682 let open = current_price;
1683 let close = current_price - range * (0.5 + r1 * 0.5) - (trend.abs() * current_price);
1684 let high = close.max(open) + range * r2 * 0.2;
1685 let low = close.min(open) - range * r3 * 0.3;
1686
1687 let low_adjusted = low - range * r4 * 0.1;
1688 current_price = close;
1689 (open, high, low_adjusted, close)
1690 }
1691 2 => {
1692
1693 let open = current_price;
1694 let direction = if r1 > 0.5 { 1.0 } else { -1.0 };
1695 let close = current_price + direction * range * r2 * 0.5;
1696 let high = open.max(close) + range * r3 * 0.4;
1697 let low = open.min(close) - range * r4 * 0.4;
1698
1699 current_price = base_price * 0.15 + current_price * 0.85;
1700 (open, high, low, close)
1701 }
1702 3 => {
1703
1704 let open = current_price;
1705 let close = current_price + range * (r1 - 0.5) * 2.0;
1706 let high = open.max(close) + range * r2 * 1.2;
1707 let low = open.min(close) - range * r3 * 1.2;
1708
1709 let high_wick = high + range * r4 * 0.3;
1710 current_price = close;
1711 (open, high_wick, low, close)
1712 }
1713 4 | _ => {
1714
1715 let tiny_move = range * 0.01 * (r1 - 0.5);
1716 let open = current_price;
1717 let close = current_price + tiny_move;
1718
1719 if r2 < 0.1 {
1720
1721 let price = current_price;
1722 (price, price, price, price)
1723 } else {
1724
1725 let high = open.max(close) + range * 0.001 * r3;
1726 let low = open.min(close) - range * 0.001 * r4;
1727 current_price = close;
1728 (open, high, low, close)
1729 }
1730 }
1731 };
1732
1733
1734 let high_final = high.max(open).max(close);
1735 let low_final = low.min(open).min(close);
1736
1737
1738 debug_assert!(high_final >= low_final, "High must be >= Low");
1739 debug_assert!(high_final >= open && high_final >= close, "High must be >= Open and Close");
1740 debug_assert!(low_final <= open && low_final <= close, "Low must be <= Open and Close");
1741
1742 open_data.push(open);
1743 high_data.push(high_final);
1744 low_data.push(low_final);
1745 close_data.push(close);
1746 }
1747
1748
1749 let params = ChopParams {
1750 period: Some(period),
1751 scalar: Some(scalar),
1752 drift: Some(drift),
1753 };
1754 let input = ChopInput::from_slices(&high_data, &low_data, &close_data, params.clone());
1755
1756
1757 let result = chop_with_kernel(&input, kernel)?;
1758 let reference = chop_with_kernel(&input, Kernel::Scalar)?;
1759
1760
1761 let first_valid_idx = (0..size).find(|&i| {
1762 !(high_data[i].is_nan() || low_data[i].is_nan() || close_data[i].is_nan())
1763 }).unwrap_or(0);
1764 let warmup_period = first_valid_idx + period - 1;
1765
1766
1767 let mut valid_chop_values = Vec::new();
1768
1769
1770 for i in 0..size {
1771 let y = result.values[i];
1772 let r = reference.values[i];
1773
1774
1775 prop_assert!(
1776 y.is_nan() || y.is_finite(),
1777 "[{}] CHOP at index {} is not finite or NaN: {}",
1778 test_name, i, y
1779 );
1780
1781
1782 if i < warmup_period {
1783 prop_assert!(
1784 y.is_nan(),
1785 "[{}] CHOP at index {} should be NaN during warmup but got: {}",
1786 test_name, i, y
1787 );
1788 }
1789
1790
1791 if i >= warmup_period && !high_data[i].is_nan() && !low_data[i].is_nan() && !close_data[i].is_nan() {
1792
1793 let window_start = i.saturating_sub(period - 1);
1794 let window_valid = (window_start..=i).all(|j| {
1795 !high_data[j].is_nan() && !low_data[j].is_nan() && !close_data[j].is_nan()
1796 });
1797
1798 if window_valid {
1799
1800 let window_high_max = (window_start..=i).map(|j| high_data[j]).fold(f64::NEG_INFINITY, f64::max);
1801 let window_low_min = (window_start..=i).map(|j| low_data[j]).fold(f64::INFINITY, f64::min);
1802 let range = window_high_max - window_low_min;
1803
1804 if range > 1e-10 {
1805
1806 if !y.is_nan() {
1807
1808
1809
1810 let normalized_bound = scalar * 1.5;
1811 prop_assert!(
1812 y >= -normalized_bound && y <= normalized_bound,
1813 "[{}] CHOP at index {} out of reasonable bounds: {} (scalar={}, bounds=±{})",
1814 test_name, i, y, scalar, normalized_bound
1815 );
1816
1817
1818 valid_chop_values.push(y);
1819 }
1820 } else if range == 0.0 {
1821
1822 prop_assert!(
1823 y.is_nan(),
1824 "[{}] CHOP at index {} should be NaN when range=0 but got: {}",
1825 test_name, i, y
1826 );
1827 } else {
1828
1829
1830 prop_assert!(
1831 y.is_nan() || y.is_finite(),
1832 "[{}] CHOP at index {} should be finite or NaN with tiny range: {}",
1833 test_name, i, y
1834 );
1835 }
1836 }
1837 }
1838
1839
1840 if y.is_finite() && r.is_finite() {
1841 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1842 prop_assert!(
1843 (y - r).abs() <= 1e-9 || ulp_diff <= 10,
1844 "[{}] Kernel mismatch at index {}: {} vs {} (ULP diff={})",
1845 test_name, i, y, r, ulp_diff
1846 );
1847 } else if y.is_nan() != r.is_nan() {
1848 prop_assert!(
1849 false,
1850 "[{}] NaN mismatch at index {}: kernel={}, scalar={}",
1851 test_name, i, y.is_nan(), r.is_nan()
1852 );
1853 }
1854
1855
1856 if (high_data[i] - low_data[i]).abs() < 1e-10 && i >= warmup_period {
1857
1858
1859 prop_assert!(
1860 y.is_nan() || y.is_finite(),
1861 "[{}] CHOP at flat candle index {} is invalid: {}",
1862 test_name, i, y
1863 );
1864 }
1865
1866 }
1867
1868
1869 if valid_chop_values.len() > 20 {
1870 let avg_chop = valid_chop_values.iter().sum::<f64>() / valid_chop_values.len() as f64;
1871 let median_idx = valid_chop_values.len() / 2;
1872 let mut sorted_values = valid_chop_values.clone();
1873 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1874 let median_chop = sorted_values[median_idx];
1875
1876
1877 match market_type {
1878 0 | 1 => {
1879
1880
1881 prop_assert!(
1882 avg_chop.is_finite() && median_chop.is_finite(),
1883 "[{}] Trending market (type {}) has non-finite CHOP: avg={}, median={}",
1884 test_name, market_type, avg_chop, median_chop
1885 );
1886
1887 let threshold = scalar * 0.6;
1888 if avg_chop > threshold && median_chop > threshold {
1889
1890
1891 prop_assert!(true);
1892 }
1893 }
1894 2 => {
1895
1896
1897 prop_assert!(
1898 avg_chop.is_finite() && median_chop.is_finite(),
1899 "[{}] Choppy market has non-finite CHOP: avg={}, median={}",
1900 test_name, avg_chop, median_chop
1901 );
1902
1903 let threshold = scalar * 0.3;
1904 if avg_chop < threshold && median_chop < threshold {
1905
1906 prop_assert!(true);
1907 }
1908 }
1909 3 => {
1910
1911
1912 prop_assert!(
1913 avg_chop.is_finite(),
1914 "[{}] Volatile market has non-finite average CHOP: {}",
1915 test_name, avg_chop
1916 );
1917 }
1918 4 => {
1919
1920
1921 if avg_chop.is_finite() {
1922 prop_assert!(
1923 avg_chop >= -scalar && avg_chop <= scalar,
1924 "[{}] Flat market CHOP out of bounds: avg={}, scalar={}",
1925 test_name, avg_chop, scalar
1926 );
1927 }
1928 }
1929 _ => {}
1930 }
1931 }
1932
1933
1934 if size >= period * 3 {
1935
1936 let seg1_end = period * 2;
1937 let seg2_start = period;
1938 let seg2_end = period * 3;
1939
1940 if seg1_end < size && seg2_end < size {
1941 let seg1_values: Vec<f64> = result.values[period..seg1_end]
1942 .iter()
1943 .filter(|v| v.is_finite())
1944 .cloned()
1945 .collect();
1946 let seg2_values: Vec<f64> = result.values[seg2_start..seg2_end]
1947 .iter()
1948 .filter(|v| v.is_finite())
1949 .cloned()
1950 .collect();
1951
1952 if !seg1_values.is_empty() && !seg2_values.is_empty() {
1953 let seg1_avg = seg1_values.iter().sum::<f64>() / seg1_values.len() as f64;
1954 let seg2_avg = seg2_values.iter().sum::<f64>() / seg2_values.len() as f64;
1955
1956
1957
1958 if market_type == 4 && seg1_avg.abs() > 1e-6 && seg2_avg.abs() > 1e-6 {
1959 let diff_ratio = (seg1_avg - seg2_avg).abs() / seg1_avg.abs().max(seg2_avg.abs());
1960 prop_assert!(
1961 diff_ratio < 0.8,
1962 "[{}] Flat market segments have inconsistent CHOP: seg1_avg={}, seg2_avg={}, diff_ratio={}",
1963 test_name, seg1_avg, seg2_avg, diff_ratio
1964 );
1965 }
1966 }
1967 }
1968 }
1969
1970 Ok(())
1971 })
1972 .unwrap();
1973
1974 Ok(())
1975 }
1976
1977 #[cfg(test)]
1978
1979 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1980 skip_if_unsupported!(kernel, test);
1981
1982 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1983 let c = read_candles_from_csv(file)?;
1984
1985 let high = c.high.as_slice();
1986 let low = c.low.as_slice();
1987 let close = c.close.as_slice();
1988
1989 let output = ChopBatchBuilder::new()
1990 .kernel(kernel)
1991 .apply_slices(high, low, close)?;
1992
1993 let def = ChopParams::default();
1994 let row = output.values_for(&def).expect("default row missing");
1995 assert_eq!(row.len(), close.len());
1996
1997 let expected = [
1998 49.98214330294626,
1999 48.90450693742312,
2000 46.63648608318844,
2001 46.19823574588033,
2002 56.22876423352909,
2003 ];
2004 let start = row.len().saturating_sub(5);
2005 for (i, &v) in row[start..].iter().enumerate() {
2006 assert!(
2007 (v - expected[i]).abs() < 1e-4,
2008 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2009 );
2010 }
2011 Ok(())
2012 }
2013
2014 fn check_batch_param_row_lookup(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2015 skip_if_unsupported!(kernel, test);
2016 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2017 let c = read_candles_from_csv(file)?;
2018 let high = c.high.as_slice();
2019 let low = c.low.as_slice();
2020 let close = c.close.as_slice();
2021
2022 let builder = ChopBatchBuilder::new()
2023 .kernel(kernel)
2024 .period_range(14, 16, 1)
2025 .scalar_range(100.0, 102.0, 1.0)
2026 .drift_range(1, 2, 1);
2027
2028 let out = builder.apply_slices(high, low, close)?;
2029
2030 for p in 14..=16 {
2031 for s in [100.0, 101.0, 102.0] {
2032 for d in 1..=2 {
2033 let params = ChopParams {
2034 period: Some(p),
2035 scalar: Some(s),
2036 drift: Some(d),
2037 };
2038 let row = out.values_for(¶ms);
2039 assert!(
2040 row.is_some(),
2041 "[{test}] No row for params: period={p}, scalar={s}, drift={d}"
2042 );
2043 }
2044 }
2045 }
2046 Ok(())
2047 }
2048
2049 fn check_batch_huge_period(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2050 skip_if_unsupported!(kernel, test);
2051 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2052 let c = read_candles_from_csv(file)?;
2053 let high = c.high.as_slice();
2054 let low = c.low.as_slice();
2055 let close = c.close.as_slice();
2056
2057 let builder = ChopBatchBuilder::new()
2058 .kernel(kernel)
2059 .period_range(100_000, 100_001, 1);
2060 let result = builder.apply_slices(high, low, close);
2061 assert!(result.is_err(), "[{test}] Expected error for huge period");
2062 Ok(())
2063 }
2064
2065 #[cfg(debug_assertions)]
2066 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2067 skip_if_unsupported!(kernel, test);
2068
2069 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2070 let c = read_candles_from_csv(file)?;
2071
2072 let high = c.high.as_slice();
2073 let low = c.low.as_slice();
2074 let close = c.close.as_slice();
2075
2076 let output = ChopBatchBuilder::new()
2077 .kernel(kernel)
2078 .period_range(10, 30, 10)
2079 .scalar_range(50.0, 150.0, 50.0)
2080 .drift_range(1, 3, 1)
2081 .apply_slices(high, low, close)?;
2082
2083 for (idx, &val) in output.values.iter().enumerate() {
2084 if val.is_nan() {
2085 continue;
2086 }
2087
2088 let bits = val.to_bits();
2089 let row = idx / output.cols;
2090 let col = idx % output.cols;
2091
2092 if bits == 0x11111111_11111111 {
2093 panic!(
2094 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2095 test, val, bits, row, col, idx
2096 );
2097 }
2098
2099 if bits == 0x22222222_22222222 {
2100 panic!(
2101 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2102 test, val, bits, row, col, idx
2103 );
2104 }
2105
2106 if bits == 0x33333333_33333333 {
2107 panic!(
2108 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2109 test, val, bits, row, col, idx
2110 );
2111 }
2112 }
2113
2114 Ok(())
2115 }
2116
2117 #[cfg(not(debug_assertions))]
2118 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2119 Ok(())
2120 }
2121
2122 macro_rules! gen_batch_tests {
2123 ($fn_name:ident) => {
2124 paste::paste! {
2125 #[test] fn [<$fn_name _scalar>]() {
2126 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2127 }
2128 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2129 #[test] fn [<$fn_name _avx2>]() {
2130 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2131 }
2132 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2133 #[test] fn [<$fn_name _avx512>]() {
2134 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2135 }
2136 #[test] fn [<$fn_name _auto_detect>]() {
2137 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2138 }
2139 }
2140 };
2141 }
2142 gen_batch_tests!(check_batch_default_row);
2143 gen_batch_tests!(check_batch_param_row_lookup);
2144 gen_batch_tests!(check_batch_huge_period);
2145 gen_batch_tests!(check_batch_no_poison);
2146}
2147
2148#[inline(always)]
2149fn chop_batch_inner_into(
2150 high: &[f64],
2151 low: &[f64],
2152 close: &[f64],
2153 sweep: &ChopBatchRange,
2154 kern: Kernel,
2155 parallel: bool,
2156 out: &mut [f64],
2157) -> Result<Vec<ChopParams>, ChopError> {
2158 let combos = expand_grid(sweep)?;
2159
2160 if !(high.len() == low.len() && low.len() == close.len()) {
2161 return Err(ChopError::UnderlyingFunctionFailed(
2162 "mismatched input lengths".to_string(),
2163 ));
2164 }
2165
2166 let len = close.len();
2167 if len == 0 {
2168 return Err(ChopError::EmptyData);
2169 }
2170
2171 let first = (0..len)
2172 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
2173 .ok_or(ChopError::AllValuesNaN)?;
2174 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
2175 if len - first < max_p {
2176 return Err(ChopError::NotEnoughValidData {
2177 needed: max_p,
2178 valid: len - first,
2179 });
2180 }
2181
2182 let rows = combos.len();
2183 let cols = len;
2184 let expected_len = rows
2185 .checked_mul(cols)
2186 .ok_or_else(|| ChopError::InvalidInput("rows*cols overflow".into()))?;
2187 if out.len() != expected_len {
2188 return Err(ChopError::OutputLengthMismatch {
2189 expected: expected_len,
2190 got: out.len(),
2191 });
2192 }
2193
2194 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
2195 core::slice::from_raw_parts_mut(
2196 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
2197 out.len(),
2198 )
2199 };
2200
2201 let warm: Vec<usize> = combos
2202 .iter()
2203 .map(|c| first + c.period.unwrap() - 1)
2204 .collect();
2205 init_matrix_prefixes(out_mu, cols, &warm);
2206
2207 let actual = match kern {
2208 Kernel::Auto => detect_best_batch_kernel(),
2209 k => k,
2210 };
2211 let simd = match actual {
2212 Kernel::Avx512Batch => Kernel::Avx512,
2213 Kernel::Avx2Batch => Kernel::Avx2,
2214 Kernel::ScalarBatch => Kernel::Scalar,
2215 _ => actual,
2216 };
2217
2218 let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
2219 let ChopParams {
2220 period,
2221 scalar,
2222 drift,
2223 } = combos[row];
2224 let p = period.unwrap();
2225 let s = scalar.unwrap();
2226 let d = drift.unwrap();
2227
2228 let row_out: &mut [f64] =
2229 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
2230 match simd {
2231 Kernel::Scalar => chop_row_scalar(high, low, close, first, p, d, s, row_out),
2232 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2233 Kernel::Avx2 => chop_row_avx2(high, low, close, first, p, d, s, row_out),
2234 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2235 Kernel::Avx512 => chop_row_avx512(high, low, close, first, p, d, s, row_out),
2236 _ => unreachable!(),
2237 }
2238 };
2239
2240 if parallel {
2241 #[cfg(not(target_arch = "wasm32"))]
2242 {
2243 use rayon::prelude::*;
2244 out_mu
2245 .par_chunks_mut(cols)
2246 .enumerate()
2247 .for_each(|(r, sl)| do_row(r, sl));
2248 }
2249 #[cfg(target_arch = "wasm32")]
2250 {
2251 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
2252 do_row(r, sl);
2253 }
2254 }
2255 } else {
2256 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
2257 do_row(r, sl);
2258 }
2259 }
2260
2261 Ok(combos)
2262}
2263
2264#[cfg(feature = "python")]
2265#[pyfunction(name = "chop")]
2266#[pyo3(signature = (high, low, close, period, scalar, drift, kernel=None))]
2267pub fn chop_py<'py>(
2268 py: Python<'py>,
2269 high: PyReadonlyArray1<'py, f64>,
2270 low: PyReadonlyArray1<'py, f64>,
2271 close: PyReadonlyArray1<'py, f64>,
2272 period: usize,
2273 scalar: f64,
2274 drift: usize,
2275 kernel: Option<&str>,
2276) -> PyResult<Bound<'py, PyArray1<f64>>> {
2277 use numpy::PyArrayMethods;
2278 let h = high.as_slice()?;
2279 let l = low.as_slice()?;
2280 let c = close.as_slice()?;
2281 let kern = validate_kernel(kernel, false)?;
2282 let input = ChopInput::from_slices(
2283 h,
2284 l,
2285 c,
2286 ChopParams {
2287 period: Some(period),
2288 scalar: Some(scalar),
2289 drift: Some(drift),
2290 },
2291 );
2292 let vec_out: Vec<f64> = py
2293 .allow_threads(|| chop_with_kernel(&input, kern).map(|o| o.values))
2294 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2295 Ok(vec_out.into_pyarray(py))
2296}
2297
2298#[cfg(feature = "python")]
2299#[pyclass(name = "ChopStream")]
2300pub struct ChopStreamPy {
2301 stream: ChopStream,
2302}
2303
2304#[cfg(feature = "python")]
2305#[pymethods]
2306impl ChopStreamPy {
2307 #[new]
2308 fn new(period: usize, scalar: f64, drift: usize) -> PyResult<Self> {
2309 let s = ChopStream::try_new(ChopParams {
2310 period: Some(period),
2311 scalar: Some(scalar),
2312 drift: Some(drift),
2313 })
2314 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2315 Ok(Self { stream: s })
2316 }
2317 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2318 self.stream.update(high, low, close)
2319 }
2320}
2321
2322#[cfg(feature = "python")]
2323#[pyfunction(name = "chop_batch")]
2324#[pyo3(signature = (high, low, close, period_range, scalar_range, drift_range, kernel=None))]
2325pub fn chop_batch_py<'py>(
2326 py: Python<'py>,
2327 high: PyReadonlyArray1<'py, f64>,
2328 low: PyReadonlyArray1<'py, f64>,
2329 close: PyReadonlyArray1<'py, f64>,
2330 period_range: (usize, usize, usize),
2331 scalar_range: (f64, f64, f64),
2332 drift_range: (usize, usize, usize),
2333 kernel: Option<&str>,
2334) -> PyResult<Bound<'py, PyDict>> {
2335 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2336 let h = high.as_slice()?;
2337 let l = low.as_slice()?;
2338 let c = close.as_slice()?;
2339
2340 let sweep = ChopBatchRange {
2341 period: period_range,
2342 scalar: scalar_range,
2343 drift: drift_range,
2344 };
2345 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2346 let rows = combos.len();
2347 let cols = c.len();
2348 let total = rows
2349 .checked_mul(cols)
2350 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2351
2352 let arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2353 let out_slice = unsafe { arr.as_slice_mut()? };
2354
2355 let kern = validate_kernel(kernel, true)?;
2356 let _ = py
2357 .allow_threads(|| {
2358 let k = match kern {
2359 Kernel::Auto => detect_best_batch_kernel(),
2360 other => other,
2361 };
2362 let simd = match k {
2363 Kernel::Avx512Batch => Kernel::Avx512,
2364 Kernel::Avx2Batch => Kernel::Avx2,
2365 Kernel::ScalarBatch => Kernel::Scalar,
2366 _ => k,
2367 };
2368 chop_batch_inner_into(h, l, c, &sweep, simd, true, out_slice)
2369 })
2370 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2371
2372 let dict = PyDict::new(py);
2373 dict.set_item("values", arr.reshape((rows, cols))?)?;
2374 dict.set_item(
2375 "periods",
2376 combos
2377 .iter()
2378 .map(|p| p.period.unwrap() as u64)
2379 .collect::<Vec<_>>()
2380 .into_pyarray(py),
2381 )?;
2382 dict.set_item(
2383 "scalars",
2384 combos
2385 .iter()
2386 .map(|p| p.scalar.unwrap())
2387 .collect::<Vec<_>>()
2388 .into_pyarray(py),
2389 )?;
2390 dict.set_item(
2391 "drifts",
2392 combos
2393 .iter()
2394 .map(|p| p.drift.unwrap() as u64)
2395 .collect::<Vec<_>>()
2396 .into_pyarray(py),
2397 )?;
2398 Ok(dict)
2399}
2400
2401#[cfg(all(feature = "python", feature = "cuda"))]
2402#[pyfunction(name = "chop_cuda_batch_dev")]
2403#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, scalar_range, drift_range, device_id=0))]
2404pub fn chop_cuda_batch_dev_py(
2405 py: Python<'_>,
2406 high_f32: numpy::PyReadonlyArray1<'_, f32>,
2407 low_f32: numpy::PyReadonlyArray1<'_, f32>,
2408 close_f32: numpy::PyReadonlyArray1<'_, f32>,
2409 period_range: (usize, usize, usize),
2410 scalar_range: (f64, f64, f64),
2411 drift_range: (usize, usize, usize),
2412 device_id: usize,
2413) -> PyResult<ChopDeviceArrayF32Py> {
2414 if !cuda_available() {
2415 return Err(PyValueError::new_err("CUDA not available"));
2416 }
2417 let h = high_f32.as_slice()?;
2418 let l = low_f32.as_slice()?;
2419 let c = close_f32.as_slice()?;
2420 let sweep = ChopBatchRange {
2421 period: period_range,
2422 scalar: scalar_range,
2423 drift: drift_range,
2424 };
2425 let inner = py.allow_threads(|| {
2426 let cuda = CudaChop::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2427 let (arr, _combos) = cuda
2428 .chop_batch_dev(h, l, c, &sweep)
2429 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2430 Ok::<_, PyErr>(arr)
2431 })?;
2432 Ok(ChopDeviceArrayF32Py { inner: Some(inner) })
2433}
2434
2435#[cfg(all(feature = "python", feature = "cuda"))]
2436#[pyfunction(name = "chop_cuda_many_series_one_param_dev")]
2437#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, period, scalar=100.0, drift=1, device_id=0))]
2438pub fn chop_cuda_many_series_one_param_dev_py(
2439 py: Python<'_>,
2440 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2441 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2442 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2443 cols: usize,
2444 rows: usize,
2445 period: usize,
2446 scalar: f64,
2447 drift: usize,
2448 device_id: usize,
2449) -> PyResult<ChopDeviceArrayF32Py> {
2450 if !cuda_available() {
2451 return Err(PyValueError::new_err("CUDA not available"));
2452 }
2453 let h = high_tm_f32.as_slice()?;
2454 let l = low_tm_f32.as_slice()?;
2455 let c = close_tm_f32.as_slice()?;
2456 let params = ChopParams {
2457 period: Some(period),
2458 scalar: Some(scalar),
2459 drift: Some(drift),
2460 };
2461 let inner = py.allow_threads(|| {
2462 let cuda = CudaChop::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2463 cuda.chop_many_series_one_param_time_major_dev(h, l, c, cols, rows, ¶ms)
2464 .map_err(|e| PyValueError::new_err(e.to_string()))
2465 })?;
2466 Ok(ChopDeviceArrayF32Py { inner: Some(inner) })
2467}
2468
2469#[cfg(all(feature = "python", feature = "cuda"))]
2470#[pyclass(module = "ta_indicators.cuda", unsendable)]
2471pub struct ChopDeviceArrayF32Py {
2472 pub(crate) inner: Option<DeviceArrayF32Chop>,
2473}
2474
2475#[cfg(all(feature = "python", feature = "cuda"))]
2476#[pymethods]
2477impl ChopDeviceArrayF32Py {
2478 #[getter]
2479 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2480 let inner = self
2481 .inner
2482 .as_ref()
2483 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2484 let d = PyDict::new(py);
2485 let itemsize = std::mem::size_of::<f32>();
2486 let row_stride = inner
2487 .cols
2488 .checked_mul(itemsize)
2489 .ok_or_else(|| PyValueError::new_err("byte stride overflow"))?;
2490 d.set_item("shape", (inner.rows, inner.cols))?;
2491 d.set_item("typestr", "<f4")?;
2492 d.set_item("strides", (row_stride, itemsize))?;
2493 d.set_item("data", (inner.device_ptr() as usize, false))?;
2494
2495 d.set_item("version", 3)?;
2496 Ok(d)
2497 }
2498
2499 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2500 let inner = self
2501 .inner
2502 .as_ref()
2503 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2504 Ok((2, inner.device_id as i32))
2505 }
2506
2507 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2508 fn __dlpack__<'py>(
2509 &mut self,
2510 py: Python<'py>,
2511 stream: Option<PyObject>,
2512 max_version: Option<PyObject>,
2513 dl_device: Option<PyObject>,
2514 copy: Option<PyObject>,
2515 ) -> PyResult<PyObject> {
2516 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2517 if let Some(dev_obj) = dl_device.as_ref() {
2518 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2519 if dev_ty != kdl || dev_id != alloc_dev {
2520 let wants_copy = copy
2521 .as_ref()
2522 .and_then(|c| c.extract::<bool>(py).ok())
2523 .unwrap_or(false);
2524 if wants_copy {
2525 return Err(PyBufferError::new_err(
2526 "device copy not implemented for __dlpack__",
2527 ));
2528 } else {
2529 return Err(PyValueError::new_err(
2530 "dl_device mismatch for chop CUDA buffer",
2531 ));
2532 }
2533 }
2534 }
2535 }
2536 let _ = stream;
2537
2538 if let Some(copy_obj) = copy.as_ref() {
2539 let do_copy: bool = copy_obj.extract(py).unwrap_or(false);
2540 if do_copy {
2541 return Err(PyBufferError::new_err(
2542 "__dlpack__(copy=True) not supported for chop CUDA buffers",
2543 ));
2544 }
2545 }
2546
2547 let inner = self
2548 .inner
2549 .take()
2550 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2551
2552 let rows = inner.rows;
2553 let cols = inner.cols;
2554 let buf = inner.buf;
2555 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2556
2557 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2558 }
2559}
2560
2561#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2562#[wasm_bindgen]
2563pub fn chop_js(
2564 high: &[f64],
2565 low: &[f64],
2566 close: &[f64],
2567 period: usize,
2568 scalar: f64,
2569 drift: usize,
2570) -> Result<Vec<f64>, JsValue> {
2571 let input = ChopInput::from_slices(
2572 high,
2573 low,
2574 close,
2575 ChopParams {
2576 period: Some(period),
2577 scalar: Some(scalar),
2578 drift: Some(drift),
2579 },
2580 );
2581 let mut out = vec![0.0; close.len()];
2582 chop_into_slice(&mut out, &input, detect_best_kernel())
2583 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2584 Ok(out)
2585}
2586
2587#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2588#[wasm_bindgen]
2589pub fn chop_alloc(len: usize) -> *mut f64 {
2590 let mut v = Vec::<f64>::with_capacity(len);
2591 let ptr = v.as_mut_ptr();
2592 std::mem::forget(v);
2593 ptr
2594}
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597pub fn chop_free(ptr: *mut f64, len: usize) {
2598 if ptr.is_null() || len == 0 {
2599 return;
2600 }
2601 unsafe {
2602 let _ = Vec::from_raw_parts(ptr, len, len);
2603 }
2604}
2605
2606#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2607#[wasm_bindgen]
2608pub fn chop_into(
2609 high_ptr: *const f64,
2610 low_ptr: *const f64,
2611 close_ptr: *const f64,
2612 out_ptr: *mut f64,
2613 len: usize,
2614 period: usize,
2615 scalar: f64,
2616 drift: usize,
2617) -> Result<(), JsValue> {
2618 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2619 return Err(JsValue::from_str("Null pointer to chop_into"));
2620 }
2621 unsafe {
2622 let h = std::slice::from_raw_parts(high_ptr, len);
2623 let l = std::slice::from_raw_parts(low_ptr, len);
2624 let c = std::slice::from_raw_parts(close_ptr, len);
2625 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2626 let input = ChopInput::from_slices(
2627 h,
2628 l,
2629 c,
2630 ChopParams {
2631 period: Some(period),
2632 scalar: Some(scalar),
2633 drift: Some(drift),
2634 },
2635 );
2636 chop_into_slice(out, &input, detect_best_kernel())
2637 .map_err(|e| JsValue::from_str(&e.to_string()))
2638 }
2639}
2640
2641#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2642#[derive(Serialize, Deserialize)]
2643pub struct ChopBatchConfig {
2644 pub period_range: (usize, usize, usize),
2645 pub scalar_range: (f64, f64, f64),
2646 pub drift_range: (usize, usize, usize),
2647}
2648
2649#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2650#[derive(Serialize, Deserialize)]
2651pub struct ChopBatchJsOutput {
2652 pub values: Vec<f64>,
2653 pub combos: Vec<ChopParams>,
2654 pub rows: usize,
2655 pub cols: usize,
2656}
2657
2658#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2659#[wasm_bindgen(js_name = chop_batch)]
2660pub fn chop_batch_unified_js(
2661 high: &[f64],
2662 low: &[f64],
2663 close: &[f64],
2664 config: JsValue,
2665) -> Result<JsValue, JsValue> {
2666 let cfg: ChopBatchConfig = serde_wasm_bindgen::from_value(config)
2667 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2668 let sweep = ChopBatchRange {
2669 period: cfg.period_range,
2670 scalar: cfg.scalar_range,
2671 drift: cfg.drift_range,
2672 };
2673 let rows = expand_grid(&sweep)
2674 .map_err(|e| JsValue::from_str(&e.to_string()))?
2675 .len();
2676 let cols = close.len();
2677 let total = rows
2678 .checked_mul(cols)
2679 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2680 let mut values = vec![0.0f64; total];
2681
2682 let combos = chop_batch_inner_into(
2683 high,
2684 low,
2685 close,
2686 &sweep,
2687 detect_best_kernel(),
2688 false,
2689 &mut values,
2690 )
2691 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2692
2693 let js = ChopBatchJsOutput {
2694 values,
2695 combos,
2696 rows,
2697 cols,
2698 };
2699 serde_wasm_bindgen::to_value(&js)
2700 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2701}
2702
2703#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2704#[wasm_bindgen]
2705pub fn chop_batch_into(
2706 high_ptr: *const f64,
2707 low_ptr: *const f64,
2708 close_ptr: *const f64,
2709 out_ptr: *mut f64,
2710 len: usize,
2711 period_start: usize,
2712 period_end: usize,
2713 period_step: usize,
2714 scalar_start: f64,
2715 scalar_end: f64,
2716 scalar_step: f64,
2717 drift_start: usize,
2718 drift_end: usize,
2719 drift_step: usize,
2720) -> Result<usize, JsValue> {
2721 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2722 return Err(JsValue::from_str("null pointer to chop_batch_into"));
2723 }
2724 unsafe {
2725 let h = std::slice::from_raw_parts(high_ptr, len);
2726 let l = std::slice::from_raw_parts(low_ptr, len);
2727 let c = std::slice::from_raw_parts(close_ptr, len);
2728 let sweep = ChopBatchRange {
2729 period: (period_start, period_end, period_step),
2730 scalar: (scalar_start, scalar_end, scalar_step),
2731 drift: (drift_start, drift_end, drift_step),
2732 };
2733 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2734 let rows = combos.len();
2735 let total = rows
2736 .checked_mul(len)
2737 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2738 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2739 chop_batch_inner_into(h, l, c, &sweep, detect_best_kernel(), false, out)
2740 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2741 Ok(rows)
2742 }
2743}