1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::collections::VecDeque;
13use std::convert::AsRef;
14use std::mem::ManuallyDrop;
15use thiserror::Error;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::moving_averages::DeviceArrayF32;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use cust::context::Context;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use cust::memory::DeviceBuffer;
27#[cfg(feature = "python")]
28use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
29#[cfg(feature = "python")]
30use pyo3::exceptions::{PyTypeError, PyValueError};
31#[cfg(feature = "python")]
32use pyo3::prelude::*;
33#[cfg(feature = "python")]
34use pyo3::types::PyDict;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use std::sync::Arc;
37
38#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
39use serde::{Deserialize, Serialize};
40#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
41use wasm_bindgen::prelude::*;
42
43#[derive(Debug, Clone)]
44pub enum ChandeData<'a> {
45 Candles {
46 candles: &'a Candles,
47 },
48 Slices {
49 high: &'a [f64],
50 low: &'a [f64],
51 close: &'a [f64],
52 },
53}
54
55#[derive(Debug, Clone)]
56pub struct ChandeOutput {
57 pub values: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62 all(target_arch = "wasm32", feature = "wasm"),
63 derive(Serialize, Deserialize)
64)]
65pub struct ChandeParams {
66 pub period: Option<usize>,
67 pub mult: Option<f64>,
68 pub direction: Option<String>,
69}
70
71impl Default for ChandeParams {
72 fn default() -> Self {
73 Self {
74 period: Some(22),
75 mult: Some(3.0),
76 direction: Some("long".into()),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
82pub struct ChandeInput<'a> {
83 pub data: ChandeData<'a>,
84 pub params: ChandeParams,
85}
86
87impl<'a> ChandeInput<'a> {
88 #[inline]
89 pub fn from_candles(c: &'a Candles, p: ChandeParams) -> Self {
90 Self {
91 data: ChandeData::Candles { candles: c },
92 params: p,
93 }
94 }
95 #[inline]
96 pub fn from_slices(high: &'a [f64], low: &'a [f64], close: &'a [f64], p: ChandeParams) -> Self {
97 Self {
98 data: ChandeData::Slices { high, low, close },
99 params: p,
100 }
101 }
102 #[inline]
103 pub fn with_default_candles(c: &'a Candles) -> Self {
104 Self::from_candles(c, ChandeParams::default())
105 }
106 #[inline]
107 pub fn get_period(&self) -> usize {
108 self.params.period.unwrap_or(22)
109 }
110 #[inline]
111 pub fn get_mult(&self) -> f64 {
112 self.params.mult.unwrap_or(3.0)
113 }
114 #[inline]
115 pub fn get_direction(&self) -> &str {
116 self.params.direction.as_deref().unwrap_or("long")
117 }
118 #[inline]
119 pub fn borrow_slices(&self) -> (&[f64], &[f64], &[f64]) {
120 match &self.data {
121 ChandeData::Candles { candles } => (
122 source_type(candles, "high"),
123 source_type(candles, "low"),
124 source_type(candles, "close"),
125 ),
126 ChandeData::Slices { high, low, close } => (high, low, close),
127 }
128 }
129}
130
131#[derive(Clone, Debug)]
132pub struct ChandeBuilder {
133 period: Option<usize>,
134 mult: Option<f64>,
135 direction: Option<String>,
136 kernel: Kernel,
137}
138
139impl Default for ChandeBuilder {
140 fn default() -> Self {
141 Self {
142 period: None,
143 mult: None,
144 direction: None,
145 kernel: Kernel::Auto,
146 }
147 }
148}
149impl ChandeBuilder {
150 #[inline(always)]
151 pub fn new() -> Self {
152 Self::default()
153 }
154 #[inline(always)]
155 pub fn period(mut self, n: usize) -> Self {
156 self.period = Some(n);
157 self
158 }
159 #[inline(always)]
160 pub fn mult(mut self, m: f64) -> Self {
161 self.mult = Some(m);
162 self
163 }
164 #[inline(always)]
165 pub fn direction<S: Into<String>>(mut self, d: S) -> Self {
166 self.direction = Some(d.into());
167 self
168 }
169 #[inline(always)]
170 pub fn kernel(mut self, k: Kernel) -> Self {
171 self.kernel = k;
172 self
173 }
174
175 #[inline(always)]
176 pub fn apply(self, c: &Candles) -> Result<ChandeOutput, ChandeError> {
177 let p = ChandeParams {
178 period: self.period,
179 mult: self.mult,
180 direction: self.direction,
181 };
182 let i = ChandeInput::from_candles(c, p);
183 chande_with_kernel(&i, self.kernel)
184 }
185
186 #[inline(always)]
187 pub fn apply_slices(
188 self,
189 high: &[f64],
190 low: &[f64],
191 close: &[f64],
192 ) -> Result<ChandeOutput, ChandeError> {
193 let p = ChandeParams {
194 period: self.period,
195 mult: self.mult,
196 direction: self.direction,
197 };
198 let i = ChandeInput::from_slices(high, low, close, p);
199 chande_with_kernel(&i, self.kernel)
200 }
201
202 #[inline(always)]
203 pub fn into_stream(self) -> Result<ChandeStream, ChandeError> {
204 let p = ChandeParams {
205 period: self.period,
206 mult: self.mult,
207 direction: self.direction,
208 };
209 ChandeStream::try_new(p)
210 }
211}
212
213#[derive(Debug, Error)]
214pub enum ChandeError {
215 #[error("chande: Input series are empty.")]
216 EmptyInputData,
217 #[error("chande: All values are NaN.")]
218 AllValuesNaN,
219 #[error("chande: Invalid period: period={period}, data_len={data_len}")]
220 InvalidPeriod { period: usize, data_len: usize },
221 #[error("chande: not enough valid data: needed={needed}, valid={valid}")]
222 NotEnoughValidData { needed: usize, valid: usize },
223 #[error("chande: input length mismatch: high={h}, low={l}, close={c}")]
224 DataLengthMismatch { h: usize, l: usize, c: usize },
225 #[error("chande: Invalid direction: {direction}")]
226 InvalidDirection { direction: String },
227 #[error("chande: output length mismatch: expected={expected}, got={got}")]
228 OutputLengthMismatch { expected: usize, got: usize },
229 #[error("chande: invalid range: start={start}, end={end}, step={step}")]
230 InvalidRange {
231 start: isize,
232 end: isize,
233 step: isize,
234 },
235 #[error("chande: invalid kernel for batch: {0:?}")]
236 InvalidKernelForBatch(Kernel),
237 #[error("chande: invalid input: {0}")]
238 InvalidInput(String),
239}
240
241#[inline]
242fn first_valid3(h: &[f64], l: &[f64], c: &[f64]) -> Option<usize> {
243 let n = h.len().min(l.len()).min(c.len());
244 (0..n).find(|&i| !h[i].is_nan() && !l[i].is_nan() && !c[i].is_nan())
245}
246
247#[inline]
248pub fn chande(input: &ChandeInput) -> Result<ChandeOutput, ChandeError> {
249 chande_with_kernel(input, Kernel::Auto)
250}
251
252pub fn chande_with_kernel(
253 input: &ChandeInput,
254 kernel: Kernel,
255) -> Result<ChandeOutput, ChandeError> {
256 let (high, low, close) = input.borrow_slices();
257 if high.is_empty() || low.is_empty() || close.is_empty() {
258 return Err(ChandeError::EmptyInputData);
259 }
260 if !(high.len() == low.len() && low.len() == close.len()) {
261 return Err(ChandeError::DataLengthMismatch {
262 h: high.len(),
263 l: low.len(),
264 c: close.len(),
265 });
266 }
267
268 let len = high.len();
269 let first = first_valid3(high, low, close).ok_or(ChandeError::AllValuesNaN)?;
270 let period = input.get_period();
271 let mult = input.get_mult();
272 let dir = {
273 let d = input.get_direction();
274 if d.eq_ignore_ascii_case("long") {
275 "long"
276 } else if d.eq_ignore_ascii_case("short") {
277 "short"
278 } else {
279 return Err(ChandeError::InvalidDirection {
280 direction: d.to_string(),
281 });
282 }
283 };
284 if period == 0 || period > len {
285 return Err(ChandeError::InvalidPeriod {
286 period,
287 data_len: len,
288 });
289 }
290 if len - first < period {
291 return Err(ChandeError::NotEnoughValidData {
292 needed: period,
293 valid: len - first,
294 });
295 }
296
297 let chosen = match kernel {
298 Kernel::Auto => detect_best_kernel(),
299 other => other,
300 };
301
302 let chosen = match (
303 chosen,
304 cfg!(all(feature = "nightly-avx", target_arch = "x86_64")),
305 ) {
306 (Kernel::Avx512 | Kernel::Avx512Batch, false)
307 | (Kernel::Avx2 | Kernel::Avx2Batch, false) => Kernel::Scalar,
308 (k, _) => k,
309 };
310
311 let warmup = first + period - 1;
312 let mut out = alloc_with_nan_prefix(len, warmup);
313 unsafe {
314 match chosen {
315 Kernel::Scalar | Kernel::ScalarBatch => {
316 chande_scalar(high, low, close, period, mult, dir, first, &mut out)
317 }
318 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
319 Kernel::Avx2 | Kernel::Avx2Batch => {
320 chande_avx2(high, low, close, period, mult, dir, first, &mut out)
321 }
322 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
323 Kernel::Avx512 | Kernel::Avx512Batch => {
324 chande_avx512(high, low, close, period, mult, dir, first, &mut out)
325 }
326 _ => unreachable!(),
327 }
328 }
329 Ok(ChandeOutput { values: out })
330}
331
332#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
333#[inline]
334pub fn chande_into(input: &ChandeInput, out: &mut [f64]) -> Result<(), ChandeError> {
335 chande_into_slice(out, input, Kernel::Auto)
336}
337
338#[inline]
339pub fn chande_compute_into(
340 high: &[f64],
341 low: &[f64],
342 close: &[f64],
343 period: usize,
344 mult: f64,
345 direction: &str,
346 kernel: Kernel,
347 out: &mut [f64],
348) -> Result<(), ChandeError> {
349 if high.is_empty() || low.is_empty() || close.is_empty() {
350 return Err(ChandeError::EmptyInputData);
351 }
352 if !(high.len() == low.len() && low.len() == close.len()) {
353 return Err(ChandeError::DataLengthMismatch {
354 h: high.len(),
355 l: low.len(),
356 c: close.len(),
357 });
358 }
359 if out.len() != high.len() {
360 return Err(ChandeError::OutputLengthMismatch {
361 expected: high.len(),
362 got: out.len(),
363 });
364 }
365 let len = high.len();
366 let first = first_valid3(high, low, close).ok_or(ChandeError::AllValuesNaN)?;
367 if period == 0 || period > len {
368 return Err(ChandeError::InvalidPeriod {
369 period,
370 data_len: len,
371 });
372 }
373 if len - first < period {
374 return Err(ChandeError::NotEnoughValidData {
375 needed: period,
376 valid: len - first,
377 });
378 }
379 let dir = if direction.eq_ignore_ascii_case("long") {
380 "long"
381 } else if direction.eq_ignore_ascii_case("short") {
382 "short"
383 } else {
384 return Err(ChandeError::InvalidDirection {
385 direction: direction.to_string(),
386 });
387 };
388
389 let warmup = first + period - 1;
390 let warmup_end = warmup.min(out.len());
391 for v in &mut out[..warmup_end] {
392 *v = f64::NAN;
393 }
394
395 let chosen = match kernel {
396 Kernel::Auto => detect_best_kernel(),
397 k => k,
398 };
399
400 let chosen = match (
401 chosen,
402 cfg!(all(feature = "nightly-avx", target_arch = "x86_64")),
403 ) {
404 (Kernel::Avx512 | Kernel::Avx512Batch, false)
405 | (Kernel::Avx2 | Kernel::Avx2Batch, false) => Kernel::Scalar,
406 (k, _) => k,
407 };
408
409 unsafe {
410 match chosen {
411 Kernel::Scalar | Kernel::ScalarBatch => {
412 chande_scalar(high, low, close, period, mult, dir, first, out)
413 }
414 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415 Kernel::Avx2 | Kernel::Avx2Batch => {
416 chande_avx2(high, low, close, period, mult, dir, first, out)
417 }
418 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419 Kernel::Avx512 | Kernel::Avx512Batch => {
420 chande_avx512(high, low, close, period, mult, dir, first, out)
421 }
422 _ => unreachable!(),
423 }
424 }
425 Ok(())
426}
427
428#[inline]
429pub fn chande_into_slice(
430 dst: &mut [f64],
431 input: &ChandeInput,
432 kern: Kernel,
433) -> Result<(), ChandeError> {
434 let (high, low, close) = input.borrow_slices();
435 let p = input.get_period();
436 let m = input.get_mult();
437 let d = input.get_direction();
438 chande_compute_into(high, low, close, p, m, d, kern, dst)
439}
440
441#[inline]
442pub fn chande_scalar(
443 high: &[f64],
444 low: &[f64],
445 close: &[f64],
446 period: usize,
447 mult: f64,
448 dir: &str,
449 first: usize,
450 out: &mut [f64],
451) {
452 let len = high.len();
453 if first >= len {
454 return;
455 }
456
457 let alpha = 1.0 / period as f64;
458 let warmup = first + period - 1;
459
460 let mut sum_tr = 0.0f64;
461 let mut rma = 0.0f64;
462 let mut prev_close = close[first];
463
464 use std::collections::VecDeque;
465
466 if dir == "long" {
467 let mut dq: VecDeque<usize> = VecDeque::with_capacity(period);
468 for i in first..len {
469 let hi = high[i];
470 let lo = low[i];
471 let tr = if i == first {
472 hi - lo
473 } else {
474 let hl = hi - lo;
475 let hc = (hi - prev_close).abs();
476 let lc = (lo - prev_close).abs();
477 hl.max(hc).max(lc)
478 };
479
480 if i >= warmup {
481 let window_start = i + 1 - period;
482 while let Some(&j) = dq.front() {
483 if j < window_start {
484 dq.pop_front();
485 } else {
486 break;
487 }
488 }
489 }
490
491 while let Some(&j) = dq.back() {
492 if high[j] <= hi {
493 dq.pop_back();
494 } else {
495 break;
496 }
497 }
498 dq.push_back(i);
499
500 if i < warmup {
501 sum_tr += tr;
502 } else if i == warmup {
503 sum_tr += tr;
504 rma = sum_tr / period as f64;
505
506 let max_h = high[*dq.front().expect("deque nonempty at warmup")];
507 out[i] = (-rma).mul_add(mult, max_h);
508 } else {
509 rma = alpha.mul_add(tr - rma, rma);
510 let max_h = high[*dq.front().expect("deque nonempty in steady state")];
511 out[i] = (-rma).mul_add(mult, max_h);
512 }
513
514 prev_close = close[i];
515 }
516 } else {
517 let mut dq: VecDeque<usize> = VecDeque::with_capacity(period);
518 for i in first..len {
519 let hi = high[i];
520 let lo = low[i];
521 let tr = if i == first {
522 hi - lo
523 } else {
524 let hl = hi - lo;
525 let hc = (hi - prev_close).abs();
526 let lc = (lo - prev_close).abs();
527 hl.max(hc).max(lc)
528 };
529
530 if i >= warmup {
531 let window_start = i + 1 - period;
532 while let Some(&j) = dq.front() {
533 if j < window_start {
534 dq.pop_front();
535 } else {
536 break;
537 }
538 }
539 }
540
541 while let Some(&j) = dq.back() {
542 if low[j] >= lo {
543 dq.pop_back();
544 } else {
545 break;
546 }
547 }
548 dq.push_back(i);
549
550 if i < warmup {
551 sum_tr += tr;
552 } else if i == warmup {
553 sum_tr += tr;
554 rma = sum_tr / period as f64;
555
556 let min_l = low[*dq.front().expect("deque nonempty at warmup")];
557 out[i] = rma.mul_add(mult, min_l);
558 } else {
559 rma = alpha.mul_add(tr - rma, rma);
560 let min_l = low[*dq.front().expect("deque nonempty in steady state")];
561 out[i] = rma.mul_add(mult, min_l);
562 }
563
564 prev_close = close[i];
565 }
566 }
567}
568
569#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
570#[inline]
571pub fn chande_avx2(
572 high: &[f64],
573 low: &[f64],
574 close: &[f64],
575 period: usize,
576 mult: f64,
577 dir: &str,
578 first: usize,
579 out: &mut [f64],
580) {
581 unsafe { chande_fast_unchecked(high, low, close, period, mult, dir, first, out) }
582}
583
584#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
585#[inline]
586pub fn chande_avx512(
587 high: &[f64],
588 low: &[f64],
589 close: &[f64],
590 period: usize,
591 mult: f64,
592 dir: &str,
593 first: usize,
594 out: &mut [f64],
595) {
596 unsafe { chande_fast_unchecked(high, low, close, period, mult, dir, first, out) }
597}
598
599#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
600#[inline]
601pub unsafe fn chande_avx512_short(
602 high: &[f64],
603 low: &[f64],
604 close: &[f64],
605 period: usize,
606 mult: f64,
607 dir: &str,
608 first: usize,
609 out: &mut [f64],
610) {
611 chande_fast_unchecked(high, low, close, period, mult, dir, first, out)
612}
613
614#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
615#[inline]
616pub unsafe fn chande_avx512_long(
617 high: &[f64],
618 low: &[f64],
619 close: &[f64],
620 period: usize,
621 mult: f64,
622 dir: &str,
623 first: usize,
624 out: &mut [f64],
625) {
626 chande_fast_unchecked(high, low, close, period, mult, dir, first, out)
627}
628
629#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
630#[inline(always)]
631unsafe fn chande_fast_unchecked(
632 high: &[f64],
633 low: &[f64],
634 close: &[f64],
635 period: usize,
636 mult: f64,
637 dir: &str,
638 first: usize,
639 out: &mut [f64],
640) {
641 use std::collections::VecDeque;
642 let len = high.len();
643 if first >= len {
644 return;
645 }
646 let alpha = 1.0 / period as f64;
647 let warmup = first + period - 1;
648
649 let hp = high.as_ptr();
650 let lp = low.as_ptr();
651 let cp = close.as_ptr();
652 let op = out.as_mut_ptr();
653
654 let mut prev_close = *cp.add(first);
655 let mut sum_tr = 0.0f64;
656 let mut rma = 0.0f64;
657
658 if dir == "long" {
659 let mut dq: VecDeque<usize> = VecDeque::with_capacity(period);
660 for i in first..len {
661 let hi = *hp.add(i);
662 let lo = *lp.add(i);
663 let hl = hi - lo;
664 let tr = if i == first {
665 hl
666 } else {
667 let hc = (hi - prev_close).abs();
668 let lc = (lo - prev_close).abs();
669 let t = if hl >= hc { hl } else { hc };
670 if t >= lc {
671 t
672 } else {
673 lc
674 }
675 };
676
677 if i >= warmup {
678 let window_start = i + 1 - period;
679 while let Some(&j) = dq.front() {
680 if j < window_start {
681 dq.pop_front();
682 } else {
683 break;
684 }
685 }
686 }
687 while let Some(&j) = dq.back() {
688 if *hp.add(j) <= hi {
689 dq.pop_back();
690 } else {
691 break;
692 }
693 }
694 dq.push_back(i);
695
696 if i < warmup {
697 sum_tr += tr;
698 } else if i == warmup {
699 sum_tr += tr;
700 rma = sum_tr / period as f64;
701 let max_h = *hp.add(*dq.front().unwrap());
702 *op.add(i) = (-rma).mul_add(mult, max_h);
703 } else {
704 rma = alpha.mul_add(tr - rma, rma);
705 let max_h = *hp.add(*dq.front().unwrap());
706 *op.add(i) = (-rma).mul_add(mult, max_h);
707 }
708 prev_close = *cp.add(i);
709 }
710 } else {
711 let mut dq: VecDeque<usize> = VecDeque::with_capacity(period);
712 for i in first..len {
713 let hi = *hp.add(i);
714 let lo = *lp.add(i);
715 let hl = hi - lo;
716 let tr = if i == first {
717 hl
718 } else {
719 let hc = (hi - prev_close).abs();
720 let lc = (lo - prev_close).abs();
721 let t = if hl >= hc { hl } else { hc };
722 if t >= lc {
723 t
724 } else {
725 lc
726 }
727 };
728
729 if i >= warmup {
730 let window_start = i + 1 - period;
731 while let Some(&j) = dq.front() {
732 if j < window_start {
733 dq.pop_front();
734 } else {
735 break;
736 }
737 }
738 }
739 while let Some(&j) = dq.back() {
740 if *lp.add(j) >= lo {
741 dq.pop_back();
742 } else {
743 break;
744 }
745 }
746 dq.push_back(i);
747
748 if i < warmup {
749 sum_tr += tr;
750 } else if i == warmup {
751 sum_tr += tr;
752 rma = sum_tr / period as f64;
753 let min_l = *lp.add(*dq.front().unwrap());
754 *op.add(i) = rma.mul_add(mult, min_l);
755 } else {
756 rma = alpha.mul_add(tr - rma, rma);
757 let min_l = *lp.add(*dq.front().unwrap());
758 *op.add(i) = rma.mul_add(mult, min_l);
759 }
760 prev_close = *cp.add(i);
761 }
762 }
763}
764
765#[derive(Debug, Clone)]
766pub struct ChandeStream {
767 period: usize,
768 mult: f64,
769 direction: String,
770 is_long: bool,
771
772 alpha: f64,
773
774 atr: f64,
775 close_prev: f64,
776 t: usize,
777 warm: usize,
778 filled: bool,
779
780 max_deque: std::collections::VecDeque<(f64, usize)>,
781 min_deque: std::collections::VecDeque<(f64, usize)>,
782}
783
784impl ChandeStream {
785 pub fn try_new(params: ChandeParams) -> Result<Self, ChandeError> {
786 let period = params.period.unwrap_or(22);
787 let mult = params.mult.unwrap_or(3.0);
788 let direction = params
789 .direction
790 .unwrap_or_else(|| "long".into())
791 .to_lowercase();
792
793 if period == 0 {
794 return Err(ChandeError::InvalidPeriod {
795 period,
796 data_len: 0,
797 });
798 }
799 if direction != "long" && direction != "short" {
800 return Err(ChandeError::InvalidDirection { direction });
801 }
802
803 let is_long = direction == "long";
804 Ok(Self {
805 period,
806 mult,
807 direction,
808 is_long,
809 alpha: 1.0 / period as f64,
810 atr: 0.0,
811 close_prev: f64::NAN,
812 t: 0,
813 warm: 0,
814 filled: false,
815 max_deque: std::collections::VecDeque::with_capacity(period),
816 min_deque: std::collections::VecDeque::with_capacity(period),
817 })
818 }
819
820 #[inline(always)]
821 fn evict_old(&mut self) {
822 let window_start = self.t.saturating_sub(self.period - 1);
823 if self.is_long {
824 while let Some(&(_, idx)) = self.max_deque.front() {
825 if idx < window_start {
826 self.max_deque.pop_front();
827 } else {
828 break;
829 }
830 }
831 } else {
832 while let Some(&(_, idx)) = self.min_deque.front() {
833 if idx < window_start {
834 self.min_deque.pop_front();
835 } else {
836 break;
837 }
838 }
839 }
840 }
841
842 #[inline(always)]
843 fn push_max(&mut self, v: f64) {
844 while let Some(&(back, _)) = self.max_deque.back() {
845 if back <= v {
846 self.max_deque.pop_back();
847 } else {
848 break;
849 }
850 }
851 self.max_deque.push_back((v, self.t));
852 }
853
854 #[inline(always)]
855 fn push_min(&mut self, v: f64) {
856 while let Some(&(back, _)) = self.min_deque.back() {
857 if back >= v {
858 self.min_deque.pop_back();
859 } else {
860 break;
861 }
862 }
863 self.min_deque.push_back((v, self.t));
864 }
865
866 #[inline(always)]
867 fn tr(&self, high: f64, low: f64) -> f64 {
868 if self.warm == 0 {
869 high - low
870 } else {
871 let max_h = if high > self.close_prev {
872 high
873 } else {
874 self.close_prev
875 };
876 let min_l = if low < self.close_prev {
877 low
878 } else {
879 self.close_prev
880 };
881 max_h - min_l
882 }
883 }
884
885 #[inline(always)]
886 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
887 let tr = self.tr(high, low);
888
889 if !self.filled {
890 if self.is_long {
891 self.push_max(high);
892 } else {
893 self.push_min(low);
894 }
895 self.atr += tr;
896 self.warm += 1;
897
898 let now_ready = self.warm == self.period;
899 if now_ready {
900 self.atr *= self.alpha;
901 self.filled = true;
902 }
903
904 self.close_prev = close;
905 self.t = self.t.wrapping_add(1);
906
907 if !now_ready {
908 return None;
909 }
910
911 if self.is_long {
912 let m = self.max_deque.front().unwrap().0;
913 Some((-self.atr).mul_add(self.mult, m))
914 } else {
915 let m = self.min_deque.front().unwrap().0;
916 Some(self.atr.mul_add(self.mult, m))
917 }
918 } else {
919 self.evict_old();
920 if self.is_long {
921 self.push_max(high);
922 } else {
923 self.push_min(low);
924 }
925
926 self.atr = self.alpha.mul_add(tr - self.atr, self.atr);
927
928 self.close_prev = close;
929 self.t = self.t.wrapping_add(1);
930
931 if self.is_long {
932 let m = self.max_deque.front().unwrap().0;
933 Some((-self.atr).mul_add(self.mult, m))
934 } else {
935 let m = self.min_deque.front().unwrap().0;
936 Some(self.atr.mul_add(self.mult, m))
937 }
938 }
939 }
940}
941
942#[derive(Clone, Debug)]
943pub struct ChandeBatchRange {
944 pub period: (usize, usize, usize),
945 pub mult: (f64, f64, f64),
946}
947
948impl Default for ChandeBatchRange {
949 fn default() -> Self {
950 Self {
951 period: (22, 22, 0),
952 mult: (3.0, 3.249, 0.001),
953 }
954 }
955}
956
957#[derive(Clone, Debug, Default)]
958pub struct ChandeBatchBuilder {
959 range: ChandeBatchRange,
960 direction: String,
961 kernel: Kernel,
962}
963
964impl ChandeBatchBuilder {
965 pub fn new() -> Self {
966 Self {
967 range: ChandeBatchRange::default(),
968 direction: "long".into(),
969 kernel: Kernel::Auto,
970 }
971 }
972 pub fn kernel(mut self, k: Kernel) -> Self {
973 self.kernel = k;
974 self
975 }
976 pub fn direction<S: Into<String>>(mut self, d: S) -> Self {
977 self.direction = d.into();
978 self
979 }
980
981 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
982 self.range.period = (start, end, step);
983 self
984 }
985 pub fn period_static(mut self, p: usize) -> Self {
986 self.range.period = (p, p, 0);
987 self
988 }
989 pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
990 self.range.mult = (start, end, step);
991 self
992 }
993 pub fn mult_static(mut self, m: f64) -> Self {
994 self.range.mult = (m, m, 0.0);
995 self
996 }
997
998 pub fn apply_candles(self, c: &Candles) -> Result<ChandeBatchOutput, ChandeError> {
999 let high = source_type(c, "high");
1000 let low = source_type(c, "low");
1001 let close = source_type(c, "close");
1002 chande_batch_with_kernel(high, low, close, &self.range, &self.direction, self.kernel)
1003 }
1004
1005 pub fn apply_slices(
1006 self,
1007 high: &[f64],
1008 low: &[f64],
1009 close: &[f64],
1010 ) -> Result<ChandeBatchOutput, ChandeError> {
1011 chande_batch_with_kernel(high, low, close, &self.range, &self.direction, self.kernel)
1012 }
1013}
1014
1015pub fn chande_batch_with_kernel(
1016 high: &[f64],
1017 low: &[f64],
1018 close: &[f64],
1019 sweep: &ChandeBatchRange,
1020 direction: &str,
1021 k: Kernel,
1022) -> Result<ChandeBatchOutput, ChandeError> {
1023 let kernel = match k {
1024 Kernel::Auto => detect_best_batch_kernel(),
1025 other if other.is_batch() => other,
1026 other => {
1027 return Err(ChandeError::InvalidKernelForBatch(other));
1028 }
1029 };
1030 let simd = match kernel {
1031 Kernel::Avx512Batch => Kernel::Avx512,
1032 Kernel::Avx2Batch => Kernel::Avx2,
1033 Kernel::ScalarBatch => Kernel::Scalar,
1034 _ => unreachable!(),
1035 };
1036 chande_batch_par_slice(high, low, close, sweep, direction, simd)
1037}
1038
1039#[derive(Clone, Debug)]
1040pub struct ChandeBatchOutput {
1041 pub values: Vec<f64>,
1042 pub combos: Vec<ChandeParams>,
1043 pub rows: usize,
1044 pub cols: usize,
1045}
1046impl ChandeBatchOutput {
1047 pub fn row_for_params(&self, p: &ChandeParams) -> Option<usize> {
1048 self.combos.iter().position(|c| {
1049 c.period.unwrap_or(22) == p.period.unwrap_or(22)
1050 && (c.mult.unwrap_or(3.0) - p.mult.unwrap_or(3.0)).abs() < 1e-12
1051 && c.direction.as_deref().unwrap_or("long")
1052 == p.direction.as_deref().unwrap_or("long")
1053 })
1054 }
1055 pub fn values_for(&self, p: &ChandeParams) -> Option<&[f64]> {
1056 self.row_for_params(p).map(|row| {
1057 let start = row * self.cols;
1058 &self.values[start..start + self.cols]
1059 })
1060 }
1061}
1062
1063#[inline(always)]
1064fn expand_grid(r: &ChandeBatchRange, dir: &str) -> Result<Vec<ChandeParams>, ChandeError> {
1065 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ChandeError> {
1066 if step == 0 || start == end {
1067 return Ok(vec![start]);
1068 }
1069
1070 if start < end {
1071 if step == 0 {
1072 return Ok(vec![start]);
1073 }
1074 Ok((start..=end).step_by(step).collect())
1075 } else {
1076 let step_i = step as isize;
1077 if step_i == 0 {
1078 return Ok(vec![start]);
1079 }
1080 let mut vals = Vec::new();
1081 let mut x = start as isize;
1082 let end_i = end as isize;
1083 while x >= end_i {
1084 vals.push(x as usize);
1085 x = x.saturating_sub(step_i);
1086 if step_i <= 0 {
1087 break;
1088 }
1089 }
1090 if vals.is_empty() {
1091 return Err(ChandeError::InvalidRange {
1092 start: start as isize,
1093 end: end as isize,
1094 step: step as isize,
1095 });
1096 }
1097 Ok(vals)
1098 }
1099 }
1100 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, ChandeError> {
1101 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1102 return Ok(vec![start]);
1103 }
1104 let mut v = Vec::new();
1105 if start < end {
1106 let mut x = start;
1107 while x <= end + 1e-12 {
1108 v.push(x);
1109 x += step;
1110 }
1111 } else {
1112 let mut x = start;
1113 let st = -step.abs();
1114 while x >= end - 1e-12 {
1115 v.push(x);
1116 x += st;
1117 }
1118 }
1119 if v.is_empty() {
1120 return Err(ChandeError::InvalidRange {
1121 start: start as isize,
1122 end: end as isize,
1123 step: step as isize,
1124 });
1125 }
1126 Ok(v)
1127 }
1128 let periods = axis_usize(r.period)?;
1129 let mults = axis_f64(r.mult)?;
1130
1131 let cap = periods
1132 .len()
1133 .checked_mul(mults.len())
1134 .ok_or(ChandeError::InvalidRange {
1135 start: 0,
1136 end: 0,
1137 step: 0,
1138 })?;
1139 let mut out = Vec::with_capacity(cap);
1140 for &p in &periods {
1141 for &m in &mults {
1142 out.push(ChandeParams {
1143 period: Some(p),
1144 mult: Some(m),
1145 direction: Some(dir.to_string()),
1146 });
1147 }
1148 }
1149 Ok(out)
1150}
1151
1152#[inline(always)]
1153pub fn chande_batch_slice(
1154 high: &[f64],
1155 low: &[f64],
1156 close: &[f64],
1157 sweep: &ChandeBatchRange,
1158 dir: &str,
1159 kern: Kernel,
1160) -> Result<ChandeBatchOutput, ChandeError> {
1161 chande_batch_inner(high, low, close, sweep, dir, kern, false)
1162}
1163
1164#[inline(always)]
1165pub fn chande_batch_par_slice(
1166 high: &[f64],
1167 low: &[f64],
1168 close: &[f64],
1169 sweep: &ChandeBatchRange,
1170 dir: &str,
1171 kern: Kernel,
1172) -> Result<ChandeBatchOutput, ChandeError> {
1173 chande_batch_inner(high, low, close, sweep, dir, kern, true)
1174}
1175
1176#[inline(always)]
1177fn chande_batch_inner(
1178 high: &[f64],
1179 low: &[f64],
1180 close: &[f64],
1181 sweep: &ChandeBatchRange,
1182 dir: &str,
1183 kern: Kernel,
1184 parallel: bool,
1185) -> Result<ChandeBatchOutput, ChandeError> {
1186 if high.is_empty() || low.is_empty() || close.is_empty() {
1187 return Err(ChandeError::EmptyInputData);
1188 }
1189 if !(high.len() == low.len() && low.len() == close.len()) {
1190 return Err(ChandeError::DataLengthMismatch {
1191 h: high.len(),
1192 l: low.len(),
1193 c: close.len(),
1194 });
1195 }
1196
1197 let combos = expand_grid(sweep, dir)?;
1198 if combos.is_empty() {
1199 return Err(ChandeError::InvalidRange {
1200 start: 0,
1201 end: 0,
1202 step: 0,
1203 });
1204 }
1205 let first = first_valid3(high, low, close).ok_or(ChandeError::AllValuesNaN)?;
1206 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1207 if high.len() - first < max_p {
1208 return Err(ChandeError::NotEnoughValidData {
1209 needed: max_p,
1210 valid: high.len() - first,
1211 });
1212 }
1213 let rows = combos.len();
1214 let cols = high.len();
1215
1216 let _total = rows
1217 .checked_mul(cols)
1218 .ok_or(ChandeError::InvalidInput("rows*cols overflow".into()))?;
1219
1220 let warmup_periods: Vec<usize> = combos
1221 .iter()
1222 .map(|c| first + c.period.unwrap() - 1)
1223 .collect();
1224
1225 let mut buf_mu = make_uninit_matrix(rows, cols);
1226 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1227
1228 let mut buf_guard = ManuallyDrop::new(buf_mu);
1229 let values_slice: &mut [f64] = unsafe {
1230 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1231 };
1232
1233 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1234 let period = combos[row].period.unwrap();
1235 let mult = combos[row].mult.unwrap();
1236 let direction = combos[row].direction.as_deref().unwrap();
1237 match kern {
1238 Kernel::Scalar => {
1239 chande_row_scalar(high, low, close, first, period, mult, direction, out_row)
1240 }
1241 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1242 Kernel::Avx2 => {
1243 chande_row_avx2(high, low, close, first, period, mult, direction, out_row)
1244 }
1245 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1246 Kernel::Avx512 => {
1247 chande_row_avx512(high, low, close, first, period, mult, direction, out_row)
1248 }
1249 _ => unreachable!(),
1250 }
1251 };
1252 if parallel {
1253 #[cfg(not(target_arch = "wasm32"))]
1254 {
1255 values_slice
1256 .par_chunks_mut(cols)
1257 .enumerate()
1258 .for_each(|(row, slice)| do_row(row, slice));
1259 }
1260
1261 #[cfg(target_arch = "wasm32")]
1262 {
1263 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1264 do_row(row, slice);
1265 }
1266 }
1267 } else {
1268 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1269 do_row(row, slice);
1270 }
1271 }
1272
1273 let values = unsafe {
1274 Vec::from_raw_parts(
1275 buf_guard.as_mut_ptr() as *mut f64,
1276 buf_guard.len(),
1277 buf_guard.capacity(),
1278 )
1279 };
1280
1281 Ok(ChandeBatchOutput {
1282 values,
1283 combos,
1284 rows,
1285 cols,
1286 })
1287}
1288
1289#[inline(always)]
1290fn chande_batch_inner_into(
1291 high: &[f64],
1292 low: &[f64],
1293 close: &[f64],
1294 sweep: &ChandeBatchRange,
1295 dir: &str,
1296 kern: Kernel,
1297 parallel: bool,
1298 out: &mut [f64],
1299) -> Result<Vec<ChandeParams>, ChandeError> {
1300 if high.is_empty() || low.is_empty() || close.is_empty() {
1301 return Err(ChandeError::EmptyInputData);
1302 }
1303 if !(high.len() == low.len() && low.len() == close.len()) {
1304 return Err(ChandeError::DataLengthMismatch {
1305 h: high.len(),
1306 l: low.len(),
1307 c: close.len(),
1308 });
1309 }
1310
1311 let combos = expand_grid(sweep, dir)?;
1312 if combos.is_empty() {
1313 return Err(ChandeError::InvalidRange {
1314 start: 0,
1315 end: 0,
1316 step: 0,
1317 });
1318 }
1319
1320 let first = first_valid3(high, low, close).ok_or(ChandeError::AllValuesNaN)?;
1321
1322 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1323 if high.len() - first < max_p {
1324 return Err(ChandeError::NotEnoughValidData {
1325 needed: max_p,
1326 valid: high.len() - first,
1327 });
1328 }
1329
1330 let cols = high.len();
1331
1332 let expected = combos
1333 .len()
1334 .checked_mul(cols)
1335 .ok_or_else(|| ChandeError::InvalidInput("rows*cols overflow".into()))?;
1336 if out.len() != expected {
1337 return Err(ChandeError::OutputLengthMismatch {
1338 expected,
1339 got: out.len(),
1340 });
1341 }
1342
1343 let actual_kern = match kern {
1344 Kernel::Auto => detect_best_batch_kernel(),
1345 k => k,
1346 };
1347
1348 for (row, combo) in combos.iter().enumerate() {
1349 let warmup = first + combo.period.unwrap() - 1;
1350 let row_start = row * cols;
1351 for i in 0..warmup.min(cols) {
1352 out[row_start + i] = f64::NAN;
1353 }
1354 }
1355
1356 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1357 let period = combos[row].period.unwrap();
1358 let mult = combos[row].mult.unwrap();
1359 let direction = combos[row].direction.as_deref().unwrap();
1360 match actual_kern {
1361 Kernel::Scalar | Kernel::ScalarBatch => {
1362 chande_row_scalar(high, low, close, first, period, mult, direction, out_row)
1363 }
1364 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1365 Kernel::Avx2 | Kernel::Avx2Batch => {
1366 chande_row_avx2(high, low, close, first, period, mult, direction, out_row)
1367 }
1368 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1369 Kernel::Avx512 | Kernel::Avx512Batch => {
1370 chande_row_avx512(high, low, close, first, period, mult, direction, out_row)
1371 }
1372 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1373 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1374 chande_row_scalar(high, low, close, first, period, mult, direction, out_row)
1375 }
1376 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1377 }
1378 };
1379
1380 if parallel {
1381 #[cfg(not(target_arch = "wasm32"))]
1382 {
1383 out.par_chunks_mut(cols)
1384 .enumerate()
1385 .for_each(|(row, slice)| do_row(row, slice));
1386 }
1387
1388 #[cfg(target_arch = "wasm32")]
1389 {
1390 for (row, slice) in out.chunks_mut(cols).enumerate() {
1391 do_row(row, slice);
1392 }
1393 }
1394 } else {
1395 for (row, slice) in out.chunks_mut(cols).enumerate() {
1396 do_row(row, slice);
1397 }
1398 }
1399
1400 Ok(combos)
1401}
1402
1403#[inline(always)]
1404unsafe fn chande_row_scalar(
1405 high: &[f64],
1406 low: &[f64],
1407 close: &[f64],
1408 first: usize,
1409 period: usize,
1410 mult: f64,
1411 dir: &str,
1412 out: &mut [f64],
1413) {
1414 chande_scalar(high, low, close, period, mult, dir, first, out);
1415}
1416#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1417#[inline(always)]
1418unsafe fn chande_row_avx2(
1419 high: &[f64],
1420 low: &[f64],
1421 close: &[f64],
1422 first: usize,
1423 period: usize,
1424 mult: f64,
1425 dir: &str,
1426 out: &mut [f64],
1427) {
1428 chande_fast_unchecked(high, low, close, period, mult, dir, first, out)
1429}
1430#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1431#[inline(always)]
1432unsafe fn chande_row_avx512(
1433 high: &[f64],
1434 low: &[f64],
1435 close: &[f64],
1436 first: usize,
1437 period: usize,
1438 mult: f64,
1439 dir: &str,
1440 out: &mut [f64],
1441) {
1442 if period <= 32 {
1443 chande_row_avx512_short(high, low, close, first, period, mult, dir, out)
1444 } else {
1445 chande_row_avx512_long(high, low, close, first, period, mult, dir, out)
1446 }
1447}
1448#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1449#[inline(always)]
1450unsafe fn chande_row_avx512_short(
1451 high: &[f64],
1452 low: &[f64],
1453 close: &[f64],
1454 first: usize,
1455 period: usize,
1456 mult: f64,
1457 dir: &str,
1458 out: &mut [f64],
1459) {
1460 chande_fast_unchecked(high, low, close, period, mult, dir, first, out)
1461}
1462#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1463#[inline(always)]
1464unsafe fn chande_row_avx512_long(
1465 high: &[f64],
1466 low: &[f64],
1467 close: &[f64],
1468 first: usize,
1469 period: usize,
1470 mult: f64,
1471 dir: &str,
1472 out: &mut [f64],
1473) {
1474 chande_fast_unchecked(high, low, close, period, mult, dir, first, out)
1475}
1476#[cfg(test)]
1477mod tests {
1478 use super::*;
1479 use crate::skip_if_unsupported;
1480 use crate::utilities::data_loader::read_candles_from_csv;
1481
1482 #[test]
1483 fn test_chande_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1484 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1485 let candles = read_candles_from_csv(file_path)?;
1486
1487 let input = ChandeInput::with_default_candles(&candles);
1488
1489 let baseline = chande(&input)?;
1490
1491 let mut out = vec![0.0f64; candles.close.len()];
1492 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1493 {
1494 chande_into(&input, &mut out)?;
1495 }
1496 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1497 {
1498 chande_into_slice(&mut out, &input, Kernel::Auto)?;
1499 }
1500
1501 assert_eq!(baseline.values.len(), out.len());
1502
1503 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1504 (a.is_nan() && b.is_nan()) || (a == b)
1505 }
1506
1507 for i in 0..out.len() {
1508 assert!(
1509 eq_or_both_nan(baseline.values[i], out[i]),
1510 "Mismatch at index {}: got {}, expected {}",
1511 i,
1512 out[i],
1513 baseline.values[i]
1514 );
1515 }
1516 Ok(())
1517 }
1518
1519 fn check_chande_partial_params(
1520 test_name: &str,
1521 kernel: Kernel,
1522 ) -> Result<(), Box<dyn std::error::Error>> {
1523 skip_if_unsupported!(kernel, test_name);
1524 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1525 let candles = read_candles_from_csv(file_path)?;
1526
1527 let default_params = ChandeParams {
1528 period: None,
1529 mult: None,
1530 direction: None,
1531 };
1532 let input = ChandeInput::from_candles(&candles, default_params);
1533 let output = chande_with_kernel(&input, kernel)?;
1534 assert_eq!(output.values.len(), candles.close.len());
1535
1536 Ok(())
1537 }
1538
1539 fn check_chande_accuracy(
1540 test_name: &str,
1541 kernel: Kernel,
1542 ) -> Result<(), Box<dyn std::error::Error>> {
1543 skip_if_unsupported!(kernel, test_name);
1544 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1545 let candles = read_candles_from_csv(file_path)?;
1546 let close_prices = &candles.close;
1547
1548 let params = ChandeParams {
1549 period: Some(22),
1550 mult: Some(3.0),
1551 direction: Some("long".into()),
1552 };
1553 let input = ChandeInput::from_candles(&candles, params);
1554 let chande_result = chande_with_kernel(&input, kernel)?;
1555
1556 assert_eq!(chande_result.values.len(), close_prices.len());
1557
1558 let expected_last_five = [
1559 59444.14115983658,
1560 58576.49837984401,
1561 58649.1120898511,
1562 58724.56154031242,
1563 58713.39965211639,
1564 ];
1565
1566 assert!(chande_result.values.len() >= 5);
1567 let start_idx = chande_result.values.len() - 5;
1568 let actual_last_five = &chande_result.values[start_idx..];
1569 for (i, &val) in actual_last_five.iter().enumerate() {
1570 let exp = expected_last_five[i];
1571 assert!(
1572 (val - exp).abs() < 1e-4,
1573 "[{}] Chande Exits mismatch at index {}: expected {}, got {}",
1574 test_name,
1575 i,
1576 exp,
1577 val
1578 );
1579 }
1580
1581 let period = 22;
1582 for i in 0..(period - 1) {
1583 assert!(
1584 chande_result.values[i].is_nan(),
1585 "Expected leading NaN at index {}",
1586 i
1587 );
1588 }
1589
1590 let default_input = ChandeInput::with_default_candles(&candles);
1591 let default_output = chande_with_kernel(&default_input, kernel)?;
1592 assert_eq!(default_output.values.len(), close_prices.len());
1593 Ok(())
1594 }
1595
1596 fn check_chande_zero_period(
1597 test_name: &str,
1598 kernel: Kernel,
1599 ) -> Result<(), Box<dyn std::error::Error>> {
1600 skip_if_unsupported!(kernel, test_name);
1601 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1602 let candles = read_candles_from_csv(file_path)?;
1603
1604 let params = ChandeParams {
1605 period: Some(0),
1606 mult: Some(3.0),
1607 direction: Some("long".into()),
1608 };
1609 let input = ChandeInput::from_candles(&candles, params);
1610
1611 let res = chande_with_kernel(&input, kernel);
1612 assert!(
1613 res.is_err(),
1614 "[{}] Chande should fail with zero period",
1615 test_name
1616 );
1617 Ok(())
1618 }
1619
1620 fn check_chande_period_exceeds_length(
1621 test_name: &str,
1622 kernel: Kernel,
1623 ) -> Result<(), Box<dyn std::error::Error>> {
1624 skip_if_unsupported!(kernel, test_name);
1625 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1626 let candles = read_candles_from_csv(file_path)?;
1627
1628 let params = ChandeParams {
1629 period: Some(99999),
1630 mult: Some(3.0),
1631 direction: Some("long".into()),
1632 };
1633 let input = ChandeInput::from_candles(&candles, params);
1634
1635 let res = chande_with_kernel(&input, kernel);
1636 assert!(
1637 res.is_err(),
1638 "[{}] Chande should fail with period exceeding length",
1639 test_name
1640 );
1641 Ok(())
1642 }
1643
1644 fn check_chande_bad_direction(
1645 test_name: &str,
1646 kernel: Kernel,
1647 ) -> Result<(), Box<dyn std::error::Error>> {
1648 skip_if_unsupported!(kernel, test_name);
1649 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1650 let candles = read_candles_from_csv(file_path)?;
1651
1652 let params = ChandeParams {
1653 period: Some(22),
1654 mult: Some(3.0),
1655 direction: Some("bad".into()),
1656 };
1657 let input = ChandeInput::from_candles(&candles, params);
1658
1659 let res = chande_with_kernel(&input, kernel);
1660 assert!(
1661 res.is_err(),
1662 "[{}] Chande should fail with bad direction",
1663 test_name
1664 );
1665 Ok(())
1666 }
1667
1668 fn check_chande_nan_handling(
1669 test_name: &str,
1670 kernel: Kernel,
1671 ) -> Result<(), Box<dyn std::error::Error>> {
1672 skip_if_unsupported!(kernel, test_name);
1673 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1674 let candles = read_candles_from_csv(file_path)?;
1675
1676 let params = ChandeParams {
1677 period: Some(22),
1678 mult: Some(3.0),
1679 direction: Some("long".into()),
1680 };
1681 let input = ChandeInput::from_candles(&candles, params);
1682 let result = chande_with_kernel(&input, kernel)?;
1683
1684 if result.values.len() > 240 {
1685 for i in 240..result.values.len() {
1686 assert!(
1687 !result.values[i].is_nan(),
1688 "[{}] Unexpected NaN at index {}",
1689 test_name,
1690 i
1691 );
1692 }
1693 }
1694 Ok(())
1695 }
1696
1697 fn check_chande_streaming(
1698 test_name: &str,
1699 kernel: Kernel,
1700 ) -> Result<(), Box<dyn std::error::Error>> {
1701 skip_if_unsupported!(kernel, test_name);
1702 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1703 let candles = read_candles_from_csv(file_path)?;
1704
1705 let params = ChandeParams {
1706 period: Some(22),
1707 mult: Some(3.0),
1708 direction: Some("long".into()),
1709 };
1710 let input = ChandeInput::from_candles(&candles, params.clone());
1711 let batch_output = chande_with_kernel(&input, kernel)?.values;
1712
1713 let mut stream = ChandeStream::try_new(params)?;
1714 let mut stream_values = Vec::with_capacity(candles.close.len());
1715 for ((&h, &l), &c) in candles.high.iter().zip(&candles.low).zip(&candles.close) {
1716 match stream.update(h, l, c) {
1717 Some(chande_val) => stream_values.push(chande_val),
1718 None => stream_values.push(f64::NAN),
1719 }
1720 }
1721 assert_eq!(batch_output.len(), stream_values.len());
1722 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1723 if b.is_nan() && s.is_nan() {
1724 continue;
1725 }
1726 let diff = (b - s).abs();
1727 assert!(
1728 diff < 1e-8,
1729 "[{}] Chande streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1730 test_name,
1731 i,
1732 b,
1733 s,
1734 diff
1735 );
1736 }
1737 Ok(())
1738 }
1739
1740 #[cfg(debug_assertions)]
1741 fn check_chande_no_poison(
1742 test_name: &str,
1743 kernel: Kernel,
1744 ) -> Result<(), Box<dyn std::error::Error>> {
1745 skip_if_unsupported!(kernel, test_name);
1746 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1747 let candles = read_candles_from_csv(file_path)?;
1748
1749 let param_combinations = vec![
1750 ChandeParams {
1751 period: Some(10),
1752 mult: Some(2.0),
1753 direction: Some("long".into()),
1754 },
1755 ChandeParams {
1756 period: Some(22),
1757 mult: Some(3.0),
1758 direction: Some("short".into()),
1759 },
1760 ChandeParams {
1761 period: Some(50),
1762 mult: Some(5.0),
1763 direction: Some("long".into()),
1764 },
1765 ];
1766
1767 for params in param_combinations {
1768 let input = ChandeInput::from_candles(&candles, params.clone());
1769 let output = chande_with_kernel(&input, kernel)?;
1770
1771 for (i, &val) in output.values.iter().enumerate() {
1772 if val.is_nan() {
1773 continue;
1774 }
1775
1776 let bits = val.to_bits();
1777
1778 if bits == 0x11111111_11111111 {
1779 panic!(
1780 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params: period={}, mult={}, direction={}",
1781 test_name, val, bits, i,
1782 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
1783 );
1784 }
1785
1786 if bits == 0x22222222_22222222 {
1787 panic!(
1788 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params: period={}, mult={}, direction={}",
1789 test_name, val, bits, i,
1790 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
1791 );
1792 }
1793
1794 if bits == 0x33333333_33333333 {
1795 panic!(
1796 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params: period={}, mult={}, direction={}",
1797 test_name, val, bits, i,
1798 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
1799 );
1800 }
1801 }
1802 }
1803
1804 Ok(())
1805 }
1806
1807 #[cfg(not(debug_assertions))]
1808 fn check_chande_no_poison(
1809 _test_name: &str,
1810 _kernel: Kernel,
1811 ) -> Result<(), Box<dyn std::error::Error>> {
1812 Ok(())
1813 }
1814
1815 macro_rules! generate_all_chande_tests {
1816 ($($test_fn:ident),*) => {
1817 paste::paste! {
1818 $( #[test] fn [<$test_fn _scalar_f64>]() {
1819 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1820 })*
1821 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1822 $( #[test] fn [<$test_fn _avx2_f64>]() {
1823 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1824 })*
1825 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1826 $( #[test] fn [<$test_fn _avx512_f64>]() {
1827 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1828 })*
1829 }
1830 }
1831 }
1832
1833 #[cfg(feature = "proptest")]
1834 #[allow(clippy::float_cmp)]
1835 fn check_chande_property(
1836 test_name: &str,
1837 kernel: Kernel,
1838 ) -> Result<(), Box<dyn std::error::Error>> {
1839 use proptest::prelude::*;
1840 skip_if_unsupported!(kernel, test_name);
1841
1842 let strat = (1usize..=100).prop_flat_map(|period| {
1843 (
1844 prop::collection::vec(
1845 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1846 period..400,
1847 )
1848 .prop_flat_map(move |close| {
1849 let len = close.len();
1850 (
1851 Just(close.clone()),
1852 prop::collection::vec(0.0f64..1000.0f64, len),
1853 prop::collection::vec(0.0f64..1000.0f64, len),
1854 )
1855 .prop_map(move |(c, high_spread, low_spread)| {
1856 let high: Vec<f64> = c
1857 .iter()
1858 .zip(&high_spread)
1859 .map(|(&close_val, &spread)| close_val + spread)
1860 .collect();
1861 let low: Vec<f64> = c
1862 .iter()
1863 .zip(&low_spread)
1864 .map(|(&close_val, &spread)| close_val - spread)
1865 .collect();
1866 (high, low, c.clone())
1867 })
1868 }),
1869 Just(period),
1870 0.1f64..10.0f64,
1871 prop::bool::ANY,
1872 )
1873 });
1874
1875 proptest::test_runner::TestRunner::default()
1876 .run(&strat, |((high, low, close), period, mult, is_long)| {
1877 let direction = if is_long { "long" } else { "short" };
1878
1879 let candles = Candles {
1880 high: high.clone(),
1881 low: low.clone(),
1882 close: close.clone(),
1883 timestamp: vec![],
1884 open: vec![],
1885 volume: vec![],
1886 hl2: vec![],
1887 hlc3: vec![],
1888 ohlc4: vec![],
1889 hlcc4: vec![],
1890 };
1891
1892 let params = ChandeParams {
1893 period: Some(period),
1894 mult: Some(mult),
1895 direction: Some(direction.to_string()),
1896 };
1897
1898 let input = ChandeInput::from_candles(&candles, params);
1899
1900 let result = chande_with_kernel(&input, kernel);
1901
1902 prop_assert!(result.is_ok(), "Chande should succeed for valid inputs");
1903 let output = result.unwrap();
1904
1905 prop_assert_eq!(
1906 output.values.len(),
1907 high.len(),
1908 "Output length should match input"
1909 );
1910
1911 let first_valid = close.iter().position(|&x| !x.is_nan()).unwrap_or(0);
1912 let warmup_period = first_valid + period - 1;
1913
1914 for i in 0..warmup_period.min(output.values.len()) {
1915 prop_assert!(
1916 output.values[i].is_nan(),
1917 "Expected NaN during warmup at index {}",
1918 i
1919 );
1920 }
1921
1922 if warmup_period < output.values.len() {
1923 for i in warmup_period..output.values.len() {
1924 let val = output.values[i];
1925 prop_assert!(
1926 val.is_finite(),
1927 "Expected finite value after warmup at index {}, got {}",
1928 i,
1929 val
1930 );
1931 }
1932 }
1933
1934 for i in warmup_period..output.values.len() {
1935 let start_idx = i + 1 - period;
1936 let period_high = high[start_idx..=i].iter().cloned().fold(f64::MIN, f64::max);
1937 let period_low = low[start_idx..=i].iter().cloned().fold(f64::MAX, f64::min);
1938 let val = output.values[i];
1939
1940 if is_long {
1941 prop_assert!(
1942 val <= period_high + 1e-6,
1943 "Long exit {} should be <= period high {} at index {}",
1944 val,
1945 period_high,
1946 i
1947 );
1948 } else {
1949 prop_assert!(
1950 val >= period_low - 1e-6,
1951 "Short exit {} should be >= period low {} at index {}",
1952 val,
1953 period_low,
1954 i
1955 );
1956 }
1957 }
1958
1959 let ref_output = chande_with_kernel(&input, Kernel::Scalar).unwrap();
1960 for i in 0..output.values.len() {
1961 let val = output.values[i];
1962 let ref_val = ref_output.values[i];
1963
1964 if !val.is_finite() || !ref_val.is_finite() {
1965 prop_assert_eq!(
1966 val.to_bits(),
1967 ref_val.to_bits(),
1968 "NaN/Inf mismatch at index {}: {} vs {}",
1969 i,
1970 val,
1971 ref_val
1972 );
1973 continue;
1974 }
1975
1976 let val_bits = val.to_bits();
1977 let ref_bits = ref_val.to_bits();
1978 let ulp_diff = val_bits.abs_diff(ref_bits);
1979
1980 prop_assert!(
1981 (val - ref_val).abs() <= 1e-9 || ulp_diff <= 4,
1982 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1983 i,
1984 val,
1985 ref_val,
1986 ulp_diff
1987 );
1988 }
1989
1990 if period == 1 && warmup_period < output.values.len() {
1991 for i in warmup_period..output.values.len() {
1992 let val = output.values[i];
1993 prop_assert!(
1994 val.is_finite(),
1995 "Period=1 should produce finite values at index {}",
1996 i
1997 );
1998
1999 if is_long {
2000 prop_assert!(
2001 val <= high[i] + 1e-6,
2002 "Period=1 long exit {} should be <= high {} at index {}",
2003 val,
2004 high[i],
2005 i
2006 );
2007 } else {
2008 prop_assert!(
2009 val >= low[i] - 1e-6,
2010 "Period=1 short exit {} should be >= low {} at index {}",
2011 val,
2012 low[i],
2013 i
2014 );
2015 }
2016 }
2017 }
2018
2019 let all_same_close = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
2020 let all_same_high = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
2021 let all_same_low = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
2022
2023 if all_same_close
2024 && all_same_high
2025 && all_same_low
2026 && warmup_period + 10 < output.values.len()
2027 {
2028 let stable_start = warmup_period + period;
2029 if stable_start + 2 < output.values.len() {
2030 for i in stable_start..output.values.len() - 1 {
2031 prop_assert!(
2032 (output.values[i] - output.values[i + 1]).abs() <= 1e-6,
2033 "Constant data should produce stable output at index {}: {} vs {}",
2034 i,
2035 output.values[i],
2036 output.values[i + 1]
2037 );
2038 }
2039 }
2040 }
2041
2042 Ok(())
2043 })
2044 .unwrap();
2045
2046 Ok(())
2047 }
2048
2049 generate_all_chande_tests!(
2050 check_chande_partial_params,
2051 check_chande_accuracy,
2052 check_chande_zero_period,
2053 check_chande_period_exceeds_length,
2054 check_chande_bad_direction,
2055 check_chande_nan_handling,
2056 check_chande_streaming,
2057 check_chande_no_poison
2058 );
2059
2060 #[cfg(feature = "proptest")]
2061 generate_all_chande_tests!(check_chande_property);
2062
2063 fn check_batch_default_row(
2064 test: &str,
2065 kernel: Kernel,
2066 ) -> Result<(), Box<dyn std::error::Error>> {
2067 skip_if_unsupported!(kernel, test);
2068 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2069 let c = read_candles_from_csv(file)?;
2070 let output = ChandeBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2071
2072 let def = ChandeParams::default();
2073 let row = output.values_for(&def).expect("default row missing");
2074 assert_eq!(row.len(), c.close.len());
2075
2076 let expected = [
2077 59444.14115983658,
2078 58576.49837984401,
2079 58649.1120898511,
2080 58724.56154031242,
2081 58713.39965211639,
2082 ];
2083 let start = row.len() - 5;
2084 for (i, &v) in row[start..].iter().enumerate() {
2085 assert!(
2086 (v - expected[i]).abs() < 1e-4,
2087 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2088 );
2089 }
2090 Ok(())
2091 }
2092
2093 #[cfg(debug_assertions)]
2094 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2095 skip_if_unsupported!(kernel, test);
2096
2097 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2098 let c = read_candles_from_csv(file)?;
2099
2100 let output = ChandeBatchBuilder::new()
2101 .kernel(kernel)
2102 .period_range(10, 30, 10)
2103 .mult_range(2.0, 5.0, 1.5)
2104 .direction("long")
2105 .apply_candles(&c)?;
2106
2107 for (idx, &val) in output.values.iter().enumerate() {
2108 if val.is_nan() {
2109 continue;
2110 }
2111
2112 let bits = val.to_bits();
2113 let row = idx / output.cols;
2114 let col = idx % output.cols;
2115 let params = &output.combos[row];
2116
2117 if bits == 0x11111111_11111111 {
2118 panic!(
2119 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2120 test, val, bits, row, col, idx,
2121 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2122 );
2123 }
2124
2125 if bits == 0x22222222_22222222 {
2126 panic!(
2127 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2128 test, val, bits, row, col, idx,
2129 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2130 );
2131 }
2132
2133 if bits == 0x33333333_33333333 {
2134 panic!(
2135 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2136 test, val, bits, row, col, idx,
2137 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2138 );
2139 }
2140 }
2141
2142 let output_short = ChandeBatchBuilder::new()
2143 .kernel(kernel)
2144 .period_range(15, 45, 15)
2145 .mult_range(1.0, 4.0, 1.5)
2146 .direction("short")
2147 .apply_candles(&c)?;
2148
2149 for (idx, &val) in output_short.values.iter().enumerate() {
2150 if val.is_nan() {
2151 continue;
2152 }
2153
2154 let bits = val.to_bits();
2155 let row = idx / output_short.cols;
2156 let col = idx % output_short.cols;
2157 let params = &output_short.combos[row];
2158
2159 if bits == 0x11111111_11111111 {
2160 panic!(
2161 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2162 test, val, bits, row, col, idx,
2163 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2164 );
2165 }
2166
2167 if bits == 0x22222222_22222222 {
2168 panic!(
2169 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2170 test, val, bits, row, col, idx,
2171 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2172 );
2173 }
2174
2175 if bits == 0x33333333_33333333 {
2176 panic!(
2177 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with params: period={}, mult={}, direction={}",
2178 test, val, bits, row, col, idx,
2179 params.period.unwrap(), params.mult.unwrap(), params.direction.as_ref().unwrap()
2180 );
2181 }
2182 }
2183
2184 Ok(())
2185 }
2186
2187 #[cfg(not(debug_assertions))]
2188 fn check_batch_no_poison(
2189 _test: &str,
2190 _kernel: Kernel,
2191 ) -> Result<(), Box<dyn std::error::Error>> {
2192 Ok(())
2193 }
2194
2195 macro_rules! gen_batch_tests {
2196 ($fn_name:ident) => {
2197 paste::paste! {
2198 #[test] fn [<$fn_name _scalar>]() {
2199 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2200 }
2201 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2202 #[test] fn [<$fn_name _avx2>]() {
2203 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2204 }
2205 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2206 #[test] fn [<$fn_name _avx512>]() {
2207 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2208 }
2209 #[test] fn [<$fn_name _auto_detect>]() {
2210 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2211 }
2212 }
2213 };
2214 }
2215 gen_batch_tests!(check_batch_default_row);
2216 gen_batch_tests!(check_batch_no_poison);
2217}
2218
2219#[cfg(feature = "python")]
2220#[pyfunction(name = "chande")]
2221#[pyo3(signature = (high, low, close, period, mult, direction, kernel=None))]
2222pub fn chande_py<'py>(
2223 py: Python<'py>,
2224 high: PyReadonlyArray1<'py, f64>,
2225 low: PyReadonlyArray1<'py, f64>,
2226 close: PyReadonlyArray1<'py, f64>,
2227 period: usize,
2228 mult: f64,
2229 direction: &str,
2230 kernel: Option<&str>,
2231) -> PyResult<Bound<'py, PyArray1<f64>>> {
2232 let h = high.as_slice()?;
2233 let l = low.as_slice()?;
2234 let c = close.as_slice()?;
2235 let kern = validate_kernel(kernel, false)?;
2236 let params = ChandeParams {
2237 period: Some(period),
2238 mult: Some(mult),
2239 direction: Some(direction.to_string()),
2240 };
2241 let input = ChandeInput::from_slices(h, l, c, params);
2242
2243 let result_vec = py
2244 .allow_threads(|| chande_with_kernel(&input, kern).map(|o| o.values))
2245 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2246
2247 Ok(result_vec.into_pyarray(py))
2248}
2249
2250#[cfg(feature = "python")]
2251#[pyclass(name = "ChandeStream")]
2252pub struct ChandeStreamPy {
2253 stream: ChandeStream,
2254}
2255
2256#[cfg(feature = "python")]
2257#[pymethods]
2258impl ChandeStreamPy {
2259 #[new]
2260 fn new(period: usize, mult: f64, direction: &str) -> PyResult<Self> {
2261 let params = ChandeParams {
2262 period: Some(period),
2263 mult: Some(mult),
2264 direction: Some(direction.to_string()),
2265 };
2266 let stream =
2267 ChandeStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2268 Ok(ChandeStreamPy { stream })
2269 }
2270
2271 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2272 self.stream.update(high, low, close)
2273 }
2274}
2275
2276#[cfg(feature = "python")]
2277#[pyfunction(name = "chande_batch")]
2278#[pyo3(signature = (high, low, close, period_range, mult_range, direction, kernel=None))]
2279pub fn chande_batch_py<'py>(
2280 py: Python<'py>,
2281 high: PyReadonlyArray1<'py, f64>,
2282 low: PyReadonlyArray1<'py, f64>,
2283 close: PyReadonlyArray1<'py, f64>,
2284 period_range: (usize, usize, usize),
2285 mult_range: (f64, f64, f64),
2286 direction: &str,
2287 kernel: Option<&str>,
2288) -> PyResult<Bound<'py, PyDict>> {
2289 let h = high.as_slice()?;
2290 let l = low.as_slice()?;
2291 let c = close.as_slice()?;
2292
2293 let sweep = ChandeBatchRange {
2294 period: period_range,
2295 mult: mult_range,
2296 };
2297 let combos =
2298 expand_grid(&sweep, direction).map_err(|e| PyValueError::new_err(e.to_string()))?;
2299 let rows = combos.len();
2300 let cols = h.len();
2301 let total = rows
2302 .checked_mul(cols)
2303 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2304
2305 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2306 let slice_out = unsafe { out_arr.as_slice_mut()? };
2307
2308 let kern = validate_kernel(kernel, true)?;
2309 py.allow_threads(|| {
2310 let simd = match kern {
2311 Kernel::Auto => detect_best_batch_kernel(),
2312 k => k,
2313 };
2314
2315 let simd = match simd {
2316 Kernel::Avx512Batch => Kernel::Avx512,
2317 Kernel::Avx2Batch => Kernel::Avx2,
2318 Kernel::ScalarBatch => Kernel::Scalar,
2319 _ => simd,
2320 };
2321 chande_batch_inner_into(h, l, c, &sweep, direction, simd, true, slice_out)
2322 })
2323 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2324
2325 let dict = PyDict::new(py);
2326 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2327 dict.set_item(
2328 "periods",
2329 combos
2330 .iter()
2331 .map(|p| p.period.unwrap() as u64)
2332 .collect::<Vec<_>>()
2333 .into_pyarray(py),
2334 )?;
2335 dict.set_item(
2336 "mults",
2337 combos
2338 .iter()
2339 .map(|p| p.mult.unwrap())
2340 .collect::<Vec<_>>()
2341 .into_pyarray(py),
2342 )?;
2343 dict.set_item(
2344 "directions",
2345 combos
2346 .iter()
2347 .map(|p| p.direction.as_deref().unwrap())
2348 .collect::<Vec<_>>(),
2349 )?;
2350 Ok(dict)
2351}
2352
2353#[cfg(all(feature = "python", feature = "cuda"))]
2354#[pyclass(
2355 module = "ta_indicators.cuda",
2356 name = "DeviceArrayF32Chande",
2357 unsendable
2358)]
2359pub struct DeviceArrayF32ChandePy {
2360 pub(crate) inner: DeviceArrayF32,
2361 _ctx_guard: Arc<Context>,
2362 _device_id: u32,
2363}
2364
2365#[cfg(all(feature = "python", feature = "cuda"))]
2366#[pymethods]
2367impl DeviceArrayF32ChandePy {
2368 #[new]
2369 fn py_new() -> PyResult<Self> {
2370 Err(PyTypeError::new_err(
2371 "DeviceArrayF32Chande cannot be created directly; use chande_cuda_* factories",
2372 ))
2373 }
2374
2375 #[getter]
2376 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2377 let d = PyDict::new(py);
2378 let itemsize = std::mem::size_of::<f32>();
2379 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2380 d.set_item("typestr", "<f4")?;
2381 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2382 let ptr_val: usize = self.inner.buf.as_device_ptr().as_raw() as usize;
2383 d.set_item("data", (ptr_val, false))?;
2384 d.set_item("version", 3)?;
2385 Ok(d)
2386 }
2387
2388 fn __dlpack_device__(&self) -> (i32, i32) {
2389 (2, self._device_id as i32)
2390 }
2391
2392 #[pyo3(signature = (_stream=None, max_version=None, _dl_device=None, _copy=None))]
2393 fn __dlpack__<'py>(
2394 &mut self,
2395 py: Python<'py>,
2396 _stream: Option<pyo3::PyObject>,
2397 max_version: Option<pyo3::PyObject>,
2398 _dl_device: Option<pyo3::PyObject>,
2399 _copy: Option<pyo3::PyObject>,
2400 ) -> PyResult<PyObject> {
2401 use cust::memory::DeviceBuffer;
2402
2403 let (kdl, alloc_dev) = self.__dlpack_device__();
2404 if let Some(dev_obj) = _dl_device.as_ref() {
2405 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2406 if dev_ty != kdl || dev_id != alloc_dev {
2407 let wants_copy = _copy
2408 .as_ref()
2409 .and_then(|c| c.extract::<bool>(py).ok())
2410 .unwrap_or(false);
2411 if wants_copy {
2412 return Err(PyValueError::new_err(
2413 "device copy not implemented for __dlpack__",
2414 ));
2415 } else {
2416 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2417 }
2418 }
2419 }
2420 }
2421 let _ = _stream;
2422
2423 let dummy =
2424 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2425 let inner = std::mem::replace(
2426 &mut self.inner,
2427 DeviceArrayF32 {
2428 buf: dummy,
2429 rows: 0,
2430 cols: 0,
2431 },
2432 );
2433
2434 let rows = inner.rows;
2435 let cols = inner.cols;
2436 let buf = inner.buf;
2437
2438 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2439
2440 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2441 }
2442}
2443
2444#[cfg(all(feature = "python", feature = "cuda"))]
2445impl DeviceArrayF32ChandePy {
2446 pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2447 Self {
2448 inner,
2449 _ctx_guard: ctx_guard,
2450 _device_id: device_id,
2451 }
2452 }
2453}
2454
2455#[cfg(all(feature = "python", feature = "cuda"))]
2456#[pyfunction(name = "chande_cuda_batch_dev")]
2457#[pyo3(signature = (high, low, close, period_range, mult_range, direction, device_id=0))]
2458pub fn chande_cuda_batch_dev_py(
2459 py: Python<'_>,
2460 high: PyReadonlyArray1<'_, f32>,
2461 low: PyReadonlyArray1<'_, f32>,
2462 close: PyReadonlyArray1<'_, f32>,
2463 period_range: (usize, usize, usize),
2464 mult_range: (f64, f64, f64),
2465 direction: &str,
2466 device_id: usize,
2467) -> PyResult<DeviceArrayF32ChandePy> {
2468 use crate::cuda::cuda_available;
2469 use crate::cuda::CudaChande;
2470
2471 if !cuda_available() {
2472 return Err(PyValueError::new_err("CUDA not available"));
2473 }
2474
2475 let high_slice = high.as_slice()?;
2476 let low_slice = low.as_slice()?;
2477 let close_slice = close.as_slice()?;
2478 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2479 return Err(PyValueError::new_err("mismatched input lengths"));
2480 }
2481
2482 let sweep = ChandeBatchRange {
2483 period: period_range,
2484 mult: mult_range,
2485 };
2486
2487 let (inner, ctx, dev_id) = py.allow_threads(|| {
2488 let mut cuda =
2489 CudaChande::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2490 let ctx = cuda.context_arc();
2491 let dev_id = cuda.device_id();
2492 let dev_arr = cuda
2493 .chande_batch_dev(high_slice, low_slice, close_slice, &sweep, direction)
2494 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2495 Ok::<_, pyo3::PyErr>((dev_arr, ctx, dev_id))
2496 })?;
2497
2498 Ok(DeviceArrayF32ChandePy::new(inner, ctx, dev_id))
2499}
2500
2501#[cfg(all(feature = "python", feature = "cuda"))]
2502#[pyfunction(name = "chande_cuda_many_series_one_param_dev")]
2503#[pyo3(signature = (high_tm, low_tm, close_tm, cols, rows, period, mult, direction, device_id=0))]
2504pub fn chande_cuda_many_series_one_param_dev_py(
2505 py: Python<'_>,
2506 high_tm: PyReadonlyArray1<'_, f32>,
2507 low_tm: PyReadonlyArray1<'_, f32>,
2508 close_tm: PyReadonlyArray1<'_, f32>,
2509 cols: usize,
2510 rows: usize,
2511 period: usize,
2512 mult: f32,
2513 direction: &str,
2514 device_id: usize,
2515) -> PyResult<DeviceArrayF32ChandePy> {
2516 use crate::cuda::cuda_available;
2517 use crate::cuda::CudaChande;
2518
2519 if !cuda_available() {
2520 return Err(PyValueError::new_err("CUDA not available"));
2521 }
2522
2523 let high_slice = high_tm.as_slice()?;
2524 let low_slice = low_tm.as_slice()?;
2525 let close_slice = close_tm.as_slice()?;
2526 let expected = cols
2527 .checked_mul(rows)
2528 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2529 if high_slice.len() != expected || low_slice.len() != expected || close_slice.len() != expected
2530 {
2531 return Err(PyValueError::new_err("time-major input length mismatch"));
2532 }
2533
2534 let (inner, ctx, dev_id) = py.allow_threads(|| {
2535 let cuda = CudaChande::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2536 let ctx = cuda.context_arc();
2537 let dev_id = cuda.device_id();
2538 let arr = cuda
2539 .chande_many_series_one_param_time_major_dev(
2540 high_slice,
2541 low_slice,
2542 close_slice,
2543 cols,
2544 rows,
2545 period,
2546 mult,
2547 direction,
2548 )
2549 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2550 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2551 })?;
2552
2553 Ok(DeviceArrayF32ChandePy::new(inner, ctx, dev_id))
2554}
2555
2556#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2557#[wasm_bindgen]
2558pub fn chande_js(
2559 high: &[f64],
2560 low: &[f64],
2561 close: &[f64],
2562 period: usize,
2563 mult: f64,
2564 direction: &str,
2565) -> Result<Vec<f64>, JsValue> {
2566 let params = ChandeParams {
2567 period: Some(period),
2568 mult: Some(mult),
2569 direction: Some(direction.to_string()),
2570 };
2571 let input = ChandeInput::from_slices(high, low, close, params);
2572 let mut out = vec![0.0; high.len()];
2573 chande_into_slice(&mut out, &input, detect_best_kernel())
2574 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2575 Ok(out)
2576}
2577
2578#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2579#[derive(Serialize, Deserialize)]
2580pub struct ChandeBatchConfig {
2581 pub period_range: (usize, usize, usize),
2582 pub mult_range: (f64, f64, f64),
2583 pub direction: String,
2584}
2585
2586#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2587#[derive(Serialize, Deserialize)]
2588pub struct ChandeBatchJsOutput {
2589 pub values: Vec<f64>,
2590 pub combos: Vec<ChandeParams>,
2591 pub rows: usize,
2592 pub cols: usize,
2593}
2594
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597pub fn chande_batch_js(
2598 high: &[f64],
2599 low: &[f64],
2600 close: &[f64],
2601 period_start: usize,
2602 period_end: usize,
2603 period_step: usize,
2604 mult_start: f64,
2605 mult_end: f64,
2606 mult_step: f64,
2607 direction: &str,
2608) -> Result<JsValue, JsValue> {
2609 use wasm_bindgen::prelude::*;
2610
2611 let sweep = ChandeBatchRange {
2612 period: (period_start, period_end, period_step),
2613 mult: (mult_start, mult_end, mult_step),
2614 };
2615
2616 let simd = detect_best_batch_kernel().to_non_batch();
2617
2618 let out = chande_batch_inner(high, low, close, &sweep, direction, simd, false)
2619 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2620
2621 let js_obj = js_sys::Object::new();
2622
2623 let values_arr = js_sys::Float64Array::new_with_length(out.values.len() as u32);
2624 values_arr.copy_from(&out.values);
2625 js_sys::Reflect::set(&js_obj, &JsValue::from_str("values"), &values_arr.into())?;
2626
2627 let periods: Vec<f64> = out
2628 .combos
2629 .iter()
2630 .map(|c| c.period.unwrap() as f64)
2631 .collect();
2632 let mults: Vec<f64> = out.combos.iter().map(|c| c.mult.unwrap()).collect();
2633 let directions: Vec<String> = out
2634 .combos
2635 .iter()
2636 .map(|c| c.direction.as_ref().unwrap().clone())
2637 .collect();
2638
2639 let periods_arr = js_sys::Float64Array::new_with_length(periods.len() as u32);
2640 periods_arr.copy_from(&periods);
2641 js_sys::Reflect::set(&js_obj, &JsValue::from_str("periods"), &periods_arr.into())?;
2642
2643 let mults_arr = js_sys::Float64Array::new_with_length(mults.len() as u32);
2644 mults_arr.copy_from(&mults);
2645 js_sys::Reflect::set(&js_obj, &JsValue::from_str("mults"), &mults_arr.into())?;
2646
2647 let dirs_arr = js_sys::Array::new();
2648 for dir in &directions {
2649 dirs_arr.push(&JsValue::from_str(dir));
2650 }
2651 js_sys::Reflect::set(&js_obj, &JsValue::from_str("directions"), &dirs_arr.into())?;
2652
2653 js_sys::Reflect::set(
2654 &js_obj,
2655 &JsValue::from_str("rows"),
2656 &JsValue::from_f64(out.rows as f64),
2657 )?;
2658 js_sys::Reflect::set(
2659 &js_obj,
2660 &JsValue::from_str("cols"),
2661 &JsValue::from_f64(out.cols as f64),
2662 )?;
2663
2664 Ok(js_obj.into())
2665}
2666
2667#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2668#[wasm_bindgen(js_name = chande_batch)]
2669pub fn chande_batch_unified_js(
2670 high: &[f64],
2671 low: &[f64],
2672 close: &[f64],
2673 config: JsValue,
2674) -> Result<JsValue, JsValue> {
2675 let cfg: ChandeBatchConfig = serde_wasm_bindgen::from_value(config)
2676 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2677 let sweep = ChandeBatchRange {
2678 period: cfg.period_range,
2679 mult: cfg.mult_range,
2680 };
2681 let simd = detect_best_batch_kernel().to_non_batch();
2682 let out = chande_batch_inner(high, low, close, &sweep, &cfg.direction, simd, false)
2683 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2684 let js = ChandeBatchJsOutput {
2685 values: out.values,
2686 combos: out.combos,
2687 rows: out.rows,
2688 cols: out.cols,
2689 };
2690 serde_wasm_bindgen::to_value(&js)
2691 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2692}
2693
2694#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2695#[wasm_bindgen]
2696pub fn chande_alloc(len: usize) -> *mut f64 {
2697 let mut v: Vec<f64> = Vec::with_capacity(len);
2698 let p = v.as_mut_ptr();
2699 std::mem::forget(v);
2700 p
2701}
2702#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2703#[wasm_bindgen]
2704pub fn chande_free(ptr: *mut f64, len: usize) {
2705 unsafe {
2706 let _ = Vec::from_raw_parts(ptr, len, len);
2707 }
2708}
2709
2710#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2711#[wasm_bindgen]
2712pub fn chande_into(
2713 h_ptr: *const f64,
2714 l_ptr: *const f64,
2715 c_ptr: *const f64,
2716 out_ptr: *mut f64,
2717 len: usize,
2718 period: usize,
2719 mult: f64,
2720 direction: &str,
2721) -> Result<(), JsValue> {
2722 if [
2723 h_ptr as usize,
2724 l_ptr as usize,
2725 c_ptr as usize,
2726 out_ptr as usize,
2727 ]
2728 .iter()
2729 .any(|&p| p == 0)
2730 {
2731 return Err(JsValue::from_str("null pointer passed to chande_into"));
2732 }
2733 unsafe {
2734 let h = std::slice::from_raw_parts(h_ptr, len);
2735 let l = std::slice::from_raw_parts(l_ptr, len);
2736 let c = std::slice::from_raw_parts(c_ptr, len);
2737 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2738
2739 if out_ptr as *const f64 == h_ptr
2740 || out_ptr as *const f64 == l_ptr
2741 || out_ptr as *const f64 == c_ptr
2742 {
2743 let mut tmp = vec![0.0; len];
2744 let params = ChandeParams {
2745 period: Some(period),
2746 mult: Some(mult),
2747 direction: Some(direction.to_string()),
2748 };
2749 let input = ChandeInput::from_slices(h, l, c, params);
2750 chande_into_slice(&mut tmp, &input, detect_best_kernel())
2751 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2752 out.copy_from_slice(&tmp);
2753 } else {
2754 let params = ChandeParams {
2755 period: Some(period),
2756 mult: Some(mult),
2757 direction: Some(direction.to_string()),
2758 };
2759 let input = ChandeInput::from_slices(h, l, c, params);
2760 chande_into_slice(out, &input, detect_best_kernel())
2761 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2762 }
2763 Ok(())
2764 }
2765}
2766
2767#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2768#[wasm_bindgen]
2769pub fn chande_batch_into(
2770 h_ptr: *const f64,
2771 l_ptr: *const f64,
2772 c_ptr: *const f64,
2773 out_ptr: *mut f64,
2774 len: usize,
2775 p_start: usize,
2776 p_end: usize,
2777 p_step: usize,
2778 m_start: f64,
2779 m_end: f64,
2780 m_step: f64,
2781 direction: &str,
2782) -> Result<usize, JsValue> {
2783 if [
2784 h_ptr as usize,
2785 l_ptr as usize,
2786 c_ptr as usize,
2787 out_ptr as usize,
2788 ]
2789 .iter()
2790 .any(|&p| p == 0)
2791 {
2792 return Err(JsValue::from_str(
2793 "null pointer passed to chande_batch_into",
2794 ));
2795 }
2796 unsafe {
2797 let h = std::slice::from_raw_parts(h_ptr, len);
2798 let l = std::slice::from_raw_parts(l_ptr, len);
2799 let c = std::slice::from_raw_parts(c_ptr, len);
2800 let sweep = ChandeBatchRange {
2801 period: (p_start, p_end, p_step),
2802 mult: (m_start, m_end, m_step),
2803 };
2804 let combos =
2805 expand_grid(&sweep, direction).map_err(|e| JsValue::from_str(&e.to_string()))?;
2806 let rows = combos.len();
2807 let cols = len;
2808 let total = rows
2809 .checked_mul(cols)
2810 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2811 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2812
2813 let simd = match detect_best_batch_kernel() {
2814 Kernel::Avx512Batch => Kernel::Avx512,
2815 Kernel::Avx2Batch => Kernel::Avx2,
2816 _ => Kernel::Scalar,
2817 };
2818 chande_batch_inner_into(h, l, c, &sweep, direction, simd, false, out)
2819 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2820 Ok(rows)
2821 }
2822}