1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, CandleFieldFlags, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::convert::AsRef;
32use std::error::Error;
33use std::mem::MaybeUninit;
34use thiserror::Error;
35
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::alphatrend_wrapper::CudaAlphaTrend;
38use crate::indicators::mfi::{mfi_with_kernel, MfiInput, MfiParams};
39use crate::indicators::rsi::{rsi_with_kernel, RsiInput, RsiParams};
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use cust::memory::DeviceBuffer;
46#[cfg(all(feature = "python", feature = "cuda"))]
47use std::sync::Arc;
48
49impl<'a> AsRef<[f64]> for AlphaTrendInput<'a> {
50 #[inline(always)]
51 fn as_ref(&self) -> &[f64] {
52 match &self.data {
53 AlphaTrendData::Slices { close, .. } => close,
54 AlphaTrendData::Candles { candles, .. } => &candles.close,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum AlphaTrendData<'a> {
61 Candles {
62 candles: &'a Candles,
63 },
64 Slices {
65 open: &'a [f64],
66 high: &'a [f64],
67 low: &'a [f64],
68 close: &'a [f64],
69 volume: &'a [f64],
70 },
71}
72
73#[derive(Debug, Clone)]
74pub struct AlphaTrendOutput {
75 pub k1: Vec<f64>,
76 pub k2: Vec<f64>,
77}
78
79#[derive(Debug, Clone)]
80#[cfg_attr(
81 all(target_arch = "wasm32", feature = "wasm"),
82 derive(Serialize, Deserialize)
83)]
84pub struct AlphaTrendParams {
85 pub coeff: Option<f64>,
86 pub period: Option<usize>,
87 pub no_volume: Option<bool>,
88}
89
90impl Default for AlphaTrendParams {
91 fn default() -> Self {
92 Self {
93 coeff: Some(1.0),
94 period: Some(14),
95 no_volume: Some(false),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct AlphaTrendInput<'a> {
102 pub data: AlphaTrendData<'a>,
103 pub params: AlphaTrendParams,
104}
105
106impl<'a> AlphaTrendInput<'a> {
107 #[inline]
108 pub fn from_candles(c: &'a Candles, p: AlphaTrendParams) -> Self {
109 Self {
110 data: AlphaTrendData::Candles { candles: c },
111 params: p,
112 }
113 }
114
115 #[inline]
116 pub fn from_slices(
117 open: &'a [f64],
118 high: &'a [f64],
119 low: &'a [f64],
120 close: &'a [f64],
121 volume: &'a [f64],
122 p: AlphaTrendParams,
123 ) -> Self {
124 Self {
125 data: AlphaTrendData::Slices {
126 open,
127 high,
128 low,
129 close,
130 volume,
131 },
132 params: p,
133 }
134 }
135
136 #[inline]
137 pub fn with_default_candles(c: &'a Candles) -> Self {
138 Self::from_candles(c, AlphaTrendParams::default())
139 }
140
141 #[inline]
142 pub fn get_coeff(&self) -> f64 {
143 self.params.coeff.unwrap_or(1.0)
144 }
145
146 #[inline]
147 pub fn get_period(&self) -> usize {
148 self.params.period.unwrap_or(14)
149 }
150
151 #[inline]
152 pub fn get_no_volume(&self) -> bool {
153 self.params.no_volume.unwrap_or(false)
154 }
155}
156
157#[derive(Copy, Clone, Debug)]
158pub struct AlphaTrendBuilder {
159 coeff: Option<f64>,
160 period: Option<usize>,
161 no_volume: Option<bool>,
162 kernel: Kernel,
163}
164
165impl Default for AlphaTrendBuilder {
166 fn default() -> Self {
167 Self {
168 coeff: None,
169 period: None,
170 no_volume: None,
171 kernel: Kernel::Auto,
172 }
173 }
174}
175
176impl AlphaTrendBuilder {
177 #[inline(always)]
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 #[inline(always)]
183 pub fn coeff(mut self, val: f64) -> Self {
184 self.coeff = Some(val);
185 self
186 }
187
188 #[inline(always)]
189 pub fn period(mut self, val: usize) -> Self {
190 self.period = Some(val);
191 self
192 }
193
194 #[inline(always)]
195 pub fn no_volume(mut self, val: bool) -> Self {
196 self.no_volume = Some(val);
197 self
198 }
199
200 #[inline(always)]
201 pub fn kernel(mut self, k: Kernel) -> Self {
202 self.kernel = k;
203 self
204 }
205
206 #[inline(always)]
207 pub fn apply(self, c: &Candles) -> Result<AlphaTrendOutput, AlphaTrendError> {
208 let p = AlphaTrendParams {
209 coeff: self.coeff,
210 period: self.period,
211 no_volume: self.no_volume,
212 };
213 let i = AlphaTrendInput::from_candles(c, p);
214 alphatrend_with_kernel(&i, self.kernel)
215 }
216
217 #[inline(always)]
218 pub fn apply_slice(
219 self,
220 open: &[f64],
221 high: &[f64],
222 low: &[f64],
223 close: &[f64],
224 volume: &[f64],
225 ) -> Result<AlphaTrendOutput, AlphaTrendError> {
226 let p = AlphaTrendParams {
227 coeff: self.coeff,
228 period: self.period,
229 no_volume: self.no_volume,
230 };
231 let i = AlphaTrendInput::from_slices(open, high, low, close, volume, p);
232 alphatrend_with_kernel(&i, self.kernel)
233 }
234
235 #[inline(always)]
236 pub fn into_stream(self) -> Result<AlphaTrendStream, AlphaTrendError> {
237 let p = AlphaTrendParams {
238 coeff: self.coeff,
239 period: self.period,
240 no_volume: self.no_volume,
241 };
242 AlphaTrendStream::try_new(p)
243 }
244}
245
246#[derive(Debug, Error)]
247pub enum AlphaTrendError {
248 #[error("alphatrend: Input data slice is empty.")]
249 EmptyInputData,
250
251 #[error("alphatrend: All values are NaN.")]
252 AllValuesNaN,
253
254 #[error("alphatrend: Invalid period: period = {period}, data length = {data_len}")]
255 InvalidPeriod { period: usize, data_len: usize },
256
257 #[error("alphatrend: Not enough valid data: needed = {needed}, valid = {valid}")]
258 NotEnoughValidData { needed: usize, valid: usize },
259
260 #[error("alphatrend: Inconsistent data lengths")]
261 InconsistentDataLengths,
262
263 #[error("alphatrend: Output length mismatch: expected = {expected}, got = {got}")]
264 OutputLengthMismatch { expected: usize, got: usize },
265
266 #[error("alphatrend: Invalid coefficient: {coeff}")]
267 InvalidCoeff { coeff: f64 },
268
269 #[error("alphatrend: RSI calculation failed: {msg}")]
270 RsiError { msg: String },
271
272 #[error("alphatrend: MFI calculation failed: {msg}")]
273 MfiError { msg: String },
274
275 #[error("alphatrend: Invalid range (usize): start={start} end={end} step={step}")]
276 InvalidRange {
277 start: usize,
278 end: usize,
279 step: usize,
280 },
281
282 #[error("alphatrend: Invalid range (f64): start={start} end={end} step={step}")]
283 InvalidRangeF64 { start: f64, end: f64, step: f64 },
284
285 #[error("alphatrend: Invalid kernel for batch path: {0:?}")]
286 InvalidKernelForBatch(Kernel),
287
288 #[error("alphatrend: invalid input: {0}")]
289 InvalidInput(String),
290}
291
292#[inline]
293pub fn alphatrend(input: &AlphaTrendInput) -> Result<AlphaTrendOutput, AlphaTrendError> {
294 alphatrend_with_kernel(input, Kernel::Auto)
295}
296
297pub fn alphatrend_with_kernel(
298 input: &AlphaTrendInput,
299 kernel: Kernel,
300) -> Result<AlphaTrendOutput, AlphaTrendError> {
301 let (open, high, low, close, volume, coeff, period, no_volume, first, chosen) =
302 alphatrend_prepare(input, kernel)?;
303
304 let len = close.len();
305 let warm = first + period - 1;
306
307 let mut k1 = alloc_with_nan_prefix(len, warm);
308 let mut k2 = alloc_with_nan_prefix(len, warm + 2);
309
310 alphatrend_compute_into(
311 open, high, low, close, volume, coeff, period, no_volume, first, chosen, &mut k1, &mut k2,
312 )?;
313
314 Ok(AlphaTrendOutput { k1, k2 })
315}
316
317#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
318#[inline]
319pub fn alphatrend_into(
320 input: &AlphaTrendInput,
321 out_k1: &mut [f64],
322 out_k2: &mut [f64],
323) -> Result<(), AlphaTrendError> {
324 alphatrend_into_slices(out_k1, out_k2, input, Kernel::Auto)
325}
326
327#[inline]
328pub fn alphatrend_into_slices(
329 dst_k1: &mut [f64],
330 dst_k2: &mut [f64],
331 input: &AlphaTrendInput,
332 kern: Kernel,
333) -> Result<(), AlphaTrendError> {
334 let (open, high, low, close, volume, coeff, period, no_volume, first, chosen) =
335 alphatrend_prepare(input, kern)?;
336
337 if dst_k1.len() != close.len() {
338 return Err(AlphaTrendError::OutputLengthMismatch {
339 expected: close.len(),
340 got: dst_k1.len(),
341 });
342 }
343 if dst_k2.len() != close.len() {
344 return Err(AlphaTrendError::OutputLengthMismatch {
345 expected: close.len(),
346 got: dst_k2.len(),
347 });
348 }
349
350 let warm = first + period - 1;
351 let k1_warm_end = warm.min(dst_k1.len());
352 let k2_warm_end = (warm + 2).min(dst_k2.len());
353 for v in &mut dst_k1[..k1_warm_end] {
354 *v = f64::NAN;
355 }
356 for v in &mut dst_k2[..k2_warm_end] {
357 *v = f64::NAN;
358 }
359
360 alphatrend_compute_into(
361 open, high, low, close, volume, coeff, period, no_volume, first, chosen, dst_k1, dst_k2,
362 )?;
363
364 Ok(())
365}
366
367#[inline(always)]
368fn alphatrend_prepare<'a>(
369 input: &'a AlphaTrendInput,
370 kernel: Kernel,
371) -> Result<
372 (
373 &'a [f64],
374 &'a [f64],
375 &'a [f64],
376 &'a [f64],
377 &'a [f64],
378 f64,
379 usize,
380 bool,
381 usize,
382 Kernel,
383 ),
384 AlphaTrendError,
385> {
386 let (open, high, low, close, volume) = match &input.data {
387 AlphaTrendData::Candles { candles } => (
388 &candles.open[..],
389 &candles.high[..],
390 &candles.low[..],
391 &candles.close[..],
392 &candles.volume[..],
393 ),
394 AlphaTrendData::Slices {
395 open,
396 high,
397 low,
398 close,
399 volume,
400 } => (*open, *high, *low, *close, *volume),
401 };
402
403 let len = close.len();
404
405 if len == 0 {
406 return Err(AlphaTrendError::EmptyInputData);
407 }
408
409 if open.len() != len || high.len() != len || low.len() != len || volume.len() != len {
410 return Err(AlphaTrendError::InconsistentDataLengths);
411 }
412
413 let first = close
414 .iter()
415 .position(|x| !x.is_nan())
416 .ok_or(AlphaTrendError::AllValuesNaN)?;
417
418 let coeff = input.get_coeff();
419 let period = input.get_period();
420 let no_volume = input.get_no_volume();
421
422 if period == 0 || period > len {
423 return Err(AlphaTrendError::InvalidPeriod {
424 period,
425 data_len: len,
426 });
427 }
428
429 if len - first < period {
430 return Err(AlphaTrendError::NotEnoughValidData {
431 needed: period,
432 valid: len - first,
433 });
434 }
435
436 if coeff <= 0.0 || !coeff.is_finite() {
437 return Err(AlphaTrendError::InvalidCoeff { coeff });
438 }
439
440 let chosen = match kernel {
441 Kernel::Auto => Kernel::Scalar,
442 k => k,
443 };
444
445 Ok((
446 open, high, low, close, volume, coeff, period, no_volume, first, chosen,
447 ))
448}
449
450#[inline(always)]
451fn alphatrend_compute_into(
452 open: &[f64],
453 high: &[f64],
454 low: &[f64],
455 close: &[f64],
456 volume: &[f64],
457 coeff: f64,
458 period: usize,
459 no_volume: bool,
460 first: usize,
461 kernel: Kernel,
462 out_k1: &mut [f64],
463 out_k2: &mut [f64],
464) -> Result<(), AlphaTrendError> {
465 unsafe {
466 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
467 {
468 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
469 return alphatrend_simd128(
470 open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
471 );
472 }
473 }
474
475 match kernel {
476 Kernel::Scalar | Kernel::ScalarBatch => alphatrend_scalar(
477 open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
478 ),
479 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
480 Kernel::Avx2 | Kernel::Avx2Batch => alphatrend_avx2(
481 open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
482 ),
483 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
484 Kernel::Avx512 | Kernel::Avx512Batch => alphatrend_avx512(
485 open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
486 ),
487 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
488 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
489 alphatrend_scalar(
490 open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
491 )
492 }
493 _ => unreachable!(),
494 }
495 }
496}
497
498#[inline]
499pub fn alphatrend_scalar(
500 _open: &[f64],
501 high: &[f64],
502 low: &[f64],
503 close: &[f64],
504 volume: &[f64],
505 coeff: f64,
506 period: usize,
507 no_volume: bool,
508 first_val: usize,
509 out_k1: &mut [f64],
510 out_k2: &mut [f64],
511) -> Result<(), AlphaTrendError> {
512 let len = close.len();
513 let warmup = first_val + period - 1;
514
515 let mut tr_mu = make_uninit_matrix(1, len);
516 let tr: &mut [f64] =
517 unsafe { core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len) };
518
519 if first_val < len {
520 tr[first_val] = high[first_val] - low[first_val];
521 }
522 for i in (first_val + 1)..len {
523 let hl = high[i] - low[i];
524 let hc = (high[i] - close[i - 1]).abs();
525 let lc = (low[i] - close[i - 1]).abs();
526 tr[i] = hl.max(hc).max(lc);
527 }
528
529 let momentum_values: Vec<f64> = if no_volume {
530 let rsi_params = RsiParams {
531 period: Some(period),
532 };
533 let rsi_input = RsiInput::from_slice(close, rsi_params);
534 rsi_with_kernel(&rsi_input, Kernel::Scalar)
535 .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
536 .values
537 } else {
538 let mut hlc3_mu = make_uninit_matrix(1, len);
539 let hlc3: &mut [f64] =
540 unsafe { core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len) };
541 for i in 0..len {
542 hlc3[i] = (high[i] + low[i] + close[i]) / 3.0;
543 }
544 let mfi_params = MfiParams {
545 period: Some(period),
546 };
547 let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
548 mfi_with_kernel(&mfi_input, Kernel::Scalar)
549 .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
550 .values
551 };
552
553 if warmup < len {
554 let mut sum = 0.0f64;
555 for j in first_val..=warmup {
556 sum += tr[j];
557 }
558
559 let mut prev_alpha = f64::NAN;
560 let mut prev1 = f64::NAN;
561 let mut prev2 = f64::NAN;
562
563 for i in warmup..len {
564 let a = sum / period as f64;
565
566 let up_t = low[i] - a * coeff;
567 let down_t = high[i] + a * coeff;
568 let m_check = momentum_values[i] >= 50.0;
569
570 let cur = if i == warmup {
571 if m_check {
572 up_t
573 } else {
574 down_t
575 }
576 } else if m_check {
577 if up_t < prev_alpha {
578 prev_alpha
579 } else {
580 up_t
581 }
582 } else {
583 if down_t > prev_alpha {
584 prev_alpha
585 } else {
586 down_t
587 }
588 };
589
590 out_k1[i] = cur;
591 if i >= warmup + 2 {
592 out_k2[i] = prev2;
593 }
594
595 prev2 = prev1;
596 prev1 = cur;
597 prev_alpha = cur;
598
599 if i + 1 < len {
600 sum += tr[i + 1] - tr[i + 1 - period];
601 }
602 }
603 }
604
605 for v in &mut out_k1[..warmup.min(len)] {
606 *v = f64::NAN;
607 }
608 for v in &mut out_k2[..(warmup + 2).min(len)] {
609 *v = f64::NAN;
610 }
611
612 Ok(())
613}
614
615#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
616#[target_feature(enable = "avx2,fma")]
617unsafe fn alphatrend_avx2(
618 _open: &[f64],
619 high: &[f64],
620 low: &[f64],
621 close: &[f64],
622 volume: &[f64],
623 coeff: f64,
624 period: usize,
625 no_volume: bool,
626 first_val: usize,
627 out_k1: &mut [f64],
628 out_k2: &mut [f64],
629) -> Result<(), AlphaTrendError> {
630 use core::arch::x86_64::*;
631
632 #[inline(always)]
633 unsafe fn mm256_abs_pd(x: __m256d) -> __m256d {
634 let sign = _mm256_set1_pd(-0.0);
635 _mm256_andnot_pd(sign, x)
636 }
637
638 let len = close.len();
639 let warmup = first_val + period - 1;
640 let p_f = period as f64;
641
642 let mut tr_mu = make_uninit_matrix(1, len);
643 let tr: &mut [f64] = core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len);
644
645 if first_val < len {
646 *tr.get_unchecked_mut(first_val) =
647 *high.get_unchecked(first_val) - *low.get_unchecked(first_val);
648 }
649
650 let mut i = first_val + 1;
651 while i + 4 <= len {
652 let hv = _mm256_loadu_pd(high.as_ptr().add(i));
653 let lv = _mm256_loadu_pd(low.as_ptr().add(i));
654 let pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
655
656 let hl = _mm256_sub_pd(hv, lv);
657 let hc = mm256_abs_pd(_mm256_sub_pd(hv, pc));
658 let lc = mm256_abs_pd(_mm256_sub_pd(lv, pc));
659
660 let m1 = _mm256_max_pd(hl, hc);
661 let m = _mm256_max_pd(m1, lc);
662 _mm256_storeu_pd(tr.as_mut_ptr().add(i), m);
663 i += 4;
664 }
665 while i < len {
666 let hi = *high.get_unchecked(i);
667 let lo = *low.get_unchecked(i);
668 let pc = *close.get_unchecked(i - 1);
669 let hl = hi - lo;
670 let hc = (hi - pc).abs();
671 let lc = (lo - pc).abs();
672 let m = if hl >= hc { hl } else { hc };
673 *tr.get_unchecked_mut(i) = if m >= lc { m } else { lc };
674 i += 1;
675 }
676
677 let momentum_values: Vec<f64> = if no_volume {
678 let rsi_params = RsiParams {
679 period: Some(period),
680 };
681 let rsi_input = RsiInput::from_slice(close, rsi_params);
682 rsi_with_kernel(&rsi_input, Kernel::Avx2)
683 .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
684 .values
685 } else {
686 let mut hlc3_mu = make_uninit_matrix(1, len);
687 let hlc3: &mut [f64] =
688 core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len);
689
690 let inv3 = _mm256_set1_pd(1.0 / 3.0);
691 let mut j = 0usize;
692 while j + 4 <= len {
693 let hv = _mm256_loadu_pd(high.as_ptr().add(j));
694 let lv = _mm256_loadu_pd(low.as_ptr().add(j));
695 let cv = _mm256_loadu_pd(close.as_ptr().add(j));
696 let s = _mm256_add_pd(_mm256_add_pd(hv, lv), cv);
697 let h3 = _mm256_mul_pd(s, inv3);
698 _mm256_storeu_pd(hlc3.as_mut_ptr().add(j), h3);
699 j += 4;
700 }
701 while j < len {
702 *hlc3.get_unchecked_mut(j) =
703 (*high.get_unchecked(j) + *low.get_unchecked(j) + *close.get_unchecked(j))
704 * (1.0 / 3.0);
705 j += 1;
706 }
707
708 let mfi_params = MfiParams {
709 period: Some(period),
710 };
711 let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
712 mfi_with_kernel(&mfi_input, Kernel::Avx2)
713 .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
714 .values
715 };
716
717 let mut sum = 0.0f64;
718 {
719 let mut j = first_val;
720 while j <= warmup {
721 sum += *tr.get_unchecked(j);
722 j += 1;
723 }
724 }
725
726 #[inline(always)]
727 fn fast_max(a: f64, b: f64) -> f64 {
728 if a >= b {
729 a
730 } else {
731 b
732 }
733 }
734 #[inline(always)]
735 fn fast_min(a: f64, b: f64) -> f64 {
736 if a <= b {
737 a
738 } else {
739 b
740 }
741 }
742
743 let mut prev2 = f64::NAN;
744 let mut prev1 = f64::NAN;
745 let mut prev_alpha = f64::NAN;
746
747 let mut k = warmup;
748 while k < len {
749 let a = sum / p_f;
750 let hi = *high.get_unchecked(k);
751 let lo = *low.get_unchecked(k);
752 let up = (-coeff).mul_add(a, lo);
753 let dn = coeff.mul_add(a, hi);
754 let m_ge_50 = *momentum_values.get_unchecked(k) >= 50.0;
755
756 let alpha = if k == warmup {
757 if m_ge_50 {
758 up
759 } else {
760 dn
761 }
762 } else if m_ge_50 {
763 fast_max(up, prev_alpha)
764 } else {
765 fast_min(dn, prev_alpha)
766 };
767
768 *out_k1.get_unchecked_mut(k) = alpha;
769 if k >= warmup + 2 {
770 *out_k2.get_unchecked_mut(k) = prev2;
771 }
772
773 prev2 = prev1;
774 prev1 = alpha;
775 prev_alpha = alpha;
776
777 let nxt = k + 1;
778 if nxt < len {
779 sum += *tr.get_unchecked(nxt) - *tr.get_unchecked(nxt - period);
780 }
781 k += 1;
782 }
783
784 for v in &mut out_k1[..warmup.min(len)] {
785 *v = f64::NAN;
786 }
787 for v in &mut out_k2[..(warmup + 2).min(len)] {
788 *v = f64::NAN;
789 }
790
791 Ok(())
792}
793
794#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
795#[target_feature(enable = "avx512f,fma")]
796unsafe fn alphatrend_avx512(
797 _open: &[f64],
798 high: &[f64],
799 low: &[f64],
800 close: &[f64],
801 volume: &[f64],
802 coeff: f64,
803 period: usize,
804 no_volume: bool,
805 first_val: usize,
806 out_k1: &mut [f64],
807 out_k2: &mut [f64],
808) -> Result<(), AlphaTrendError> {
809 use core::arch::x86_64::*;
810
811 #[inline(always)]
812 unsafe fn mm512_abs_pd(x: __m512d) -> __m512d {
813 let sign = _mm512_set1_pd(-0.0);
814 _mm512_andnot_pd(sign, x)
815 }
816
817 let len = close.len();
818 let warmup = first_val + period - 1;
819 let p_f = period as f64;
820
821 let mut tr_mu = make_uninit_matrix(1, len);
822 let tr: &mut [f64] = core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len);
823
824 if first_val < len {
825 *tr.get_unchecked_mut(first_val) =
826 *high.get_unchecked(first_val) - *low.get_unchecked(first_val);
827 }
828
829 let mut i = first_val + 1;
830 while i + 8 <= len {
831 let hv = _mm512_loadu_pd(high.as_ptr().add(i));
832 let lv = _mm512_loadu_pd(low.as_ptr().add(i));
833 let pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
834
835 let hl = _mm512_sub_pd(hv, lv);
836 let hc = mm512_abs_pd(_mm512_sub_pd(hv, pc));
837 let lc = mm512_abs_pd(_mm512_sub_pd(lv, pc));
838
839 let m1 = _mm512_max_pd(hl, hc);
840 let m = _mm512_max_pd(m1, lc);
841 _mm512_storeu_pd(tr.as_mut_ptr().add(i), m);
842 i += 8;
843 }
844 while i < len {
845 let hi = *high.get_unchecked(i);
846 let lo = *low.get_unchecked(i);
847 let pc = *close.get_unchecked(i - 1);
848 let hl = hi - lo;
849 let hc = (hi - pc).abs();
850 let lc = (lo - pc).abs();
851 let m = if hl >= hc { hl } else { hc };
852 *tr.get_unchecked_mut(i) = if m >= lc { m } else { lc };
853 i += 1;
854 }
855
856 let momentum_values: Vec<f64> = if no_volume {
857 let rsi_params = RsiParams {
858 period: Some(period),
859 };
860 let rsi_input = RsiInput::from_slice(close, rsi_params);
861 rsi_with_kernel(&rsi_input, Kernel::Avx512)
862 .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
863 .values
864 } else {
865 let mut hlc3_mu = make_uninit_matrix(1, len);
866 let hlc3: &mut [f64] =
867 core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len);
868
869 let inv3 = _mm512_set1_pd(1.0 / 3.0);
870 let mut j = 0usize;
871 while j + 8 <= len {
872 let hv = _mm512_loadu_pd(high.as_ptr().add(j));
873 let lv = _mm512_loadu_pd(low.as_ptr().add(j));
874 let cv = _mm512_loadu_pd(close.as_ptr().add(j));
875 let s = _mm512_add_pd(_mm512_add_pd(hv, lv), cv);
876 let h3 = _mm512_mul_pd(s, inv3);
877 _mm512_storeu_pd(hlc3.as_mut_ptr().add(j), h3);
878 j += 8;
879 }
880 while j < len {
881 *hlc3.get_unchecked_mut(j) =
882 (*high.get_unchecked(j) + *low.get_unchecked(j) + *close.get_unchecked(j))
883 * (1.0 / 3.0);
884 j += 1;
885 }
886
887 let mfi_params = MfiParams {
888 period: Some(period),
889 };
890 let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
891 mfi_with_kernel(&mfi_input, Kernel::Avx512)
892 .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
893 .values
894 };
895
896 let mut sum = 0.0f64;
897 {
898 let mut j = first_val;
899 while j <= warmup {
900 sum += *tr.get_unchecked(j);
901 j += 1;
902 }
903 }
904
905 #[inline(always)]
906 fn fast_max(a: f64, b: f64) -> f64 {
907 if a >= b {
908 a
909 } else {
910 b
911 }
912 }
913 #[inline(always)]
914 fn fast_min(a: f64, b: f64) -> f64 {
915 if a <= b {
916 a
917 } else {
918 b
919 }
920 }
921
922 let mut prev2 = f64::NAN;
923 let mut prev1 = f64::NAN;
924 let mut prev_alpha = f64::NAN;
925
926 let mut k = warmup;
927 while k < len {
928 let a = sum / p_f;
929 let hi = *high.get_unchecked(k);
930 let lo = *low.get_unchecked(k);
931 let up = (-coeff).mul_add(a, lo);
932 let dn = coeff.mul_add(a, hi);
933 let m_ge_50 = *momentum_values.get_unchecked(k) >= 50.0;
934
935 let alpha = if k == warmup {
936 if m_ge_50 {
937 up
938 } else {
939 dn
940 }
941 } else if m_ge_50 {
942 fast_max(up, prev_alpha)
943 } else {
944 fast_min(dn, prev_alpha)
945 };
946
947 *out_k1.get_unchecked_mut(k) = alpha;
948 if k >= warmup + 2 {
949 *out_k2.get_unchecked_mut(k) = prev2;
950 }
951
952 prev2 = prev1;
953 prev1 = alpha;
954 prev_alpha = alpha;
955
956 let nxt = k + 1;
957 if nxt < len {
958 sum += *tr.get_unchecked(nxt) - *tr.get_unchecked(nxt - period);
959 }
960 k += 1;
961 }
962
963 for v in &mut out_k1[..warmup.min(len)] {
964 *v = f64::NAN;
965 }
966 for v in &mut out_k2[..(warmup + 2).min(len)] {
967 *v = f64::NAN;
968 }
969
970 Ok(())
971}
972
973#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
974#[inline]
975unsafe fn alphatrend_simd128(
976 open: &[f64],
977 high: &[f64],
978 low: &[f64],
979 close: &[f64],
980 volume: &[f64],
981 coeff: f64,
982 period: usize,
983 no_volume: bool,
984 first_val: usize,
985 out_k1: &mut [f64],
986 out_k2: &mut [f64],
987) -> Result<(), AlphaTrendError> {
988 use core::arch::wasm32::*;
989
990 alphatrend_scalar(
991 open, high, low, close, volume, coeff, period, no_volume, first_val, out_k1, out_k2,
992 )
993}
994
995#[derive(Debug, Clone)]
996pub struct AlphaTrendStream {
997 coeff: f64,
998 period: usize,
999 inv_period: f64,
1000 no_volume: bool,
1001
1002 tr_ring: Vec<f64>,
1003 tr_sum: f64,
1004 tr_idx: usize,
1005 tr_filled: usize,
1006
1007 rsi_seeded: bool,
1008 rsi_init_gains: f64,
1009 rsi_init_losses: f64,
1010 rsi_count: usize,
1011 rsi_avg_gain: f64,
1012 rsi_avg_loss: f64,
1013
1014 mfi_pos_ring: Vec<f64>,
1015 mfi_neg_ring: Vec<f64>,
1016 mfi_pos_sum: f64,
1017 mfi_neg_sum: f64,
1018 mfi_idx: usize,
1019 mfi_filled: usize,
1020 prev_tp: f64,
1021
1022 prev_close: f64,
1023 have_prev: bool,
1024
1025 prev_alpha: f64,
1026 prev1: f64,
1027 prev2: f64,
1028 alpha_count: usize,
1029}
1030
1031impl AlphaTrendStream {
1032 pub fn try_new(params: AlphaTrendParams) -> Result<Self, AlphaTrendError> {
1033 let coeff = params.coeff.unwrap_or(1.0);
1034 let period = params.period.unwrap_or(14);
1035 let no_volume = params.no_volume.unwrap_or(false);
1036
1037 if period == 0 {
1038 return Err(AlphaTrendError::InvalidPeriod {
1039 period,
1040 data_len: 0,
1041 });
1042 }
1043 if coeff <= 0.0 || !coeff.is_finite() {
1044 return Err(AlphaTrendError::InvalidCoeff { coeff });
1045 }
1046
1047 Ok(Self {
1048 coeff,
1049 period,
1050 inv_period: 1.0 / (period as f64),
1051 no_volume,
1052
1053 tr_ring: vec![0.0; period],
1054 tr_sum: 0.0,
1055 tr_idx: 0,
1056 tr_filled: 0,
1057
1058 rsi_seeded: false,
1059 rsi_init_gains: 0.0,
1060 rsi_init_losses: 0.0,
1061 rsi_count: 0,
1062 rsi_avg_gain: 0.0,
1063 rsi_avg_loss: 0.0,
1064
1065 mfi_pos_ring: vec![0.0; period],
1066 mfi_neg_ring: vec![0.0; period],
1067 mfi_pos_sum: 0.0,
1068 mfi_neg_sum: 0.0,
1069 mfi_idx: 0,
1070 mfi_filled: 0,
1071 prev_tp: f64::NAN,
1072
1073 prev_close: f64::NAN,
1074 have_prev: false,
1075
1076 prev_alpha: f64::NAN,
1077 prev1: f64::NAN,
1078 prev2: f64::NAN,
1079 alpha_count: 0,
1080 })
1081 }
1082
1083 #[inline]
1084 pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1085 if !(high.is_finite() && low.is_finite() && close.is_finite() && volume.is_finite()) {
1086 return None;
1087 }
1088 if high < low {
1089 return None;
1090 }
1091
1092 let tr = if self.have_prev {
1093 let hl = high - low;
1094 let hc = (high - self.prev_close).abs();
1095 let lc = (low - self.prev_close).abs();
1096 if hl >= hc {
1097 if hl >= lc {
1098 hl
1099 } else {
1100 lc
1101 }
1102 } else {
1103 if hc >= lc {
1104 hc
1105 } else {
1106 lc
1107 }
1108 }
1109 } else {
1110 high - low
1111 };
1112
1113 if self.tr_filled < self.period {
1114 self.tr_ring[self.tr_idx] = tr;
1115 self.tr_sum += tr;
1116 self.tr_filled += 1;
1117 self.tr_idx = (self.tr_idx + 1) % self.period;
1118 } else {
1119 let old = self.tr_ring[self.tr_idx];
1120 self.tr_ring[self.tr_idx] = tr;
1121 self.tr_sum += tr - old;
1122 self.tr_idx = (self.tr_idx + 1) % self.period;
1123 }
1124 let atr_ready = self.tr_filled == self.period;
1125 let atr = if atr_ready {
1126 self.tr_sum * self.inv_period
1127 } else {
1128 f64::NAN
1129 };
1130
1131 let mut m_ge_50 = false;
1132
1133 if self.no_volume {
1134 let (gain, loss) = if self.have_prev {
1135 let d = close - self.prev_close;
1136 if d >= 0.0 {
1137 (d, 0.0)
1138 } else {
1139 (0.0, -d)
1140 }
1141 } else {
1142 (0.0, 0.0)
1143 };
1144
1145 if !self.rsi_seeded {
1146 self.rsi_init_gains += gain;
1147 self.rsi_init_losses += loss;
1148 self.rsi_count += 1;
1149 if self.rsi_count >= self.period {
1150 self.rsi_avg_gain = self.rsi_init_gains * self.inv_period;
1151 self.rsi_avg_loss = self.rsi_init_losses * self.inv_period;
1152 self.rsi_seeded = true;
1153 }
1154 } else {
1155 let n1 = (self.period as f64) - 1.0;
1156 self.rsi_avg_gain = (self.rsi_avg_gain * n1 + gain) * self.inv_period;
1157 self.rsi_avg_loss = (self.rsi_avg_loss * n1 + loss) * self.inv_period;
1158 }
1159
1160 if self.rsi_seeded {
1161 if self.rsi_avg_loss == 0.0 {
1162 m_ge_50 = self.rsi_avg_gain >= 0.0;
1163 } else if self.rsi_avg_gain == 0.0 {
1164 m_ge_50 = false;
1165 } else {
1166 m_ge_50 = self.rsi_avg_gain >= self.rsi_avg_loss;
1167 }
1168 } else {
1169 m_ge_50 = false;
1170 }
1171 } else {
1172 let tp = (high + low + close) / 3.0;
1173 if self.have_prev {
1174 let mf = (tp * volume).max(0.0);
1175 let (pos, neg) = if tp > self.prev_tp {
1176 (mf, 0.0)
1177 } else if tp < self.prev_tp {
1178 (0.0, mf)
1179 } else {
1180 (0.0, 0.0)
1181 };
1182
1183 if self.mfi_filled < self.period {
1184 self.mfi_pos_sum += pos;
1185 self.mfi_neg_sum += neg;
1186 self.mfi_pos_ring[self.mfi_idx] = pos;
1187 self.mfi_neg_ring[self.mfi_idx] = neg;
1188 self.mfi_idx = (self.mfi_idx + 1) % self.period;
1189 self.mfi_filled += 1;
1190 } else {
1191 let op = self.mfi_pos_ring[self.mfi_idx];
1192 let on = self.mfi_neg_ring[self.mfi_idx];
1193 self.mfi_pos_ring[self.mfi_idx] = pos;
1194 self.mfi_neg_ring[self.mfi_idx] = neg;
1195 self.mfi_pos_sum += pos - op;
1196 self.mfi_neg_sum += neg - on;
1197 self.mfi_idx = (self.mfi_idx + 1) % self.period;
1198 }
1199 }
1200
1201 if self.mfi_filled == self.period {
1202 if self.mfi_neg_sum == 0.0 {
1203 m_ge_50 = self.mfi_pos_sum >= 0.0;
1204 } else if self.mfi_pos_sum == 0.0 {
1205 m_ge_50 = false;
1206 } else {
1207 m_ge_50 = self.mfi_pos_sum >= self.mfi_neg_sum;
1208 }
1209 } else {
1210 m_ge_50 = false;
1211 }
1212 self.prev_tp = tp;
1213 }
1214
1215 let mut emitted = false;
1216 let mut cur = f64::NAN;
1217 let mut k2_out = f64::NAN;
1218
1219 if atr_ready {
1220 let up = (-self.coeff).mul_add(atr, low);
1221 let dn = self.coeff.mul_add(atr, high);
1222
1223 cur = if self.alpha_count == 0 {
1224 if m_ge_50 {
1225 up
1226 } else {
1227 dn
1228 }
1229 } else if m_ge_50 {
1230 if up < self.prev_alpha {
1231 self.prev_alpha
1232 } else {
1233 up
1234 }
1235 } else {
1236 if dn > self.prev_alpha {
1237 self.prev_alpha
1238 } else {
1239 dn
1240 }
1241 };
1242
1243 let k2_emit = self.prev2;
1244 self.prev2 = self.prev1;
1245 self.prev1 = cur;
1246 self.prev_alpha = cur;
1247 self.alpha_count += 1;
1248 emitted = true;
1249 k2_out = k2_emit;
1250 }
1251
1252 self.prev_close = close;
1253 self.have_prev = true;
1254
1255 if emitted && self.alpha_count >= 3 {
1256 Some((cur, k2_out))
1257 } else {
1258 None
1259 }
1260 }
1261
1262 #[inline(always)]
1263 pub fn get_warmup_period(&self) -> usize {
1264 self.period - 1
1265 }
1266}
1267
1268#[cfg(feature = "python")]
1269#[pyfunction(name = "alphatrend")]
1270#[pyo3(signature = (open, high, low, close, volume, coeff=1.0, period=14, no_volume=false, kernel=None))]
1271pub fn alphatrend_py<'py>(
1272 py: Python<'py>,
1273 open: PyReadonlyArray1<'py, f64>,
1274 high: PyReadonlyArray1<'py, f64>,
1275 low: PyReadonlyArray1<'py, f64>,
1276 close: PyReadonlyArray1<'py, f64>,
1277 volume: PyReadonlyArray1<'py, f64>,
1278 coeff: f64,
1279 period: usize,
1280 no_volume: bool,
1281 kernel: Option<&str>,
1282) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1283 let open_slice = open.as_slice()?;
1284 let high_slice = high.as_slice()?;
1285 let low_slice = low.as_slice()?;
1286 let close_slice = close.as_slice()?;
1287 let volume_slice = volume.as_slice()?;
1288
1289 let kern = validate_kernel(kernel, false)?;
1290 let params = AlphaTrendParams {
1291 coeff: Some(coeff),
1292 period: Some(period),
1293 no_volume: Some(no_volume),
1294 };
1295 let input = AlphaTrendInput::from_slices(
1296 open_slice,
1297 high_slice,
1298 low_slice,
1299 close_slice,
1300 volume_slice,
1301 params,
1302 );
1303
1304 let result = py
1305 .allow_threads(|| alphatrend_with_kernel(&input, kern))
1306 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1307
1308 Ok((result.k1.into_pyarray(py), result.k2.into_pyarray(py)))
1309}
1310
1311#[cfg(feature = "python")]
1312#[pyclass(name = "AlphaTrendStream")]
1313pub struct AlphaTrendStreamPy {
1314 stream: AlphaTrendStream,
1315}
1316
1317#[cfg(feature = "python")]
1318#[pymethods]
1319impl AlphaTrendStreamPy {
1320 #[new]
1321 fn new(coeff: f64, period: usize, no_volume: bool) -> PyResult<Self> {
1322 let params = AlphaTrendParams {
1323 coeff: Some(coeff),
1324 period: Some(period),
1325 no_volume: Some(no_volume),
1326 };
1327 let stream =
1328 AlphaTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1329 Ok(AlphaTrendStreamPy { stream })
1330 }
1331
1332 fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1333 self.stream.update(high, low, close, volume)
1334 }
1335}
1336
1337#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1338#[derive(Serialize, Deserialize)]
1339pub struct AlphaTrendJsOutput {
1340 pub values: Vec<f64>,
1341 pub rows: usize,
1342 pub cols: usize,
1343}
1344
1345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1346#[wasm_bindgen]
1347pub fn alphatrend_js(
1348 open: &[f64],
1349 high: &[f64],
1350 low: &[f64],
1351 close: &[f64],
1352 volume: &[f64],
1353 coeff: f64,
1354 period: usize,
1355 no_volume: bool,
1356) -> Result<JsValue, JsValue> {
1357 let params = AlphaTrendParams {
1358 coeff: Some(coeff),
1359 period: Some(period),
1360 no_volume: Some(no_volume),
1361 };
1362 let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
1363
1364 let mut k1 = vec![0.0; close.len()];
1365 let mut k2 = vec![0.0; close.len()];
1366
1367 alphatrend_into_slices(&mut k1, &mut k2, &input, Kernel::Auto)
1368 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1369
1370 let mut values = Vec::with_capacity(k1.len() * 2);
1371 values.extend_from_slice(&k1);
1372 values.extend_from_slice(&k2);
1373
1374 let out = AlphaTrendJsOutput {
1375 values,
1376 rows: 2,
1377 cols: close.len(),
1378 };
1379 serde_wasm_bindgen::to_value(&out)
1380 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1381}
1382
1383#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1384#[wasm_bindgen]
1385pub fn alphatrend_alloc_flat(n: usize) -> *mut f64 {
1386 let mut v = Vec::<f64>::with_capacity(2 * n);
1387 let p = v.as_mut_ptr();
1388 core::mem::forget(v);
1389 p
1390}
1391
1392#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1393#[wasm_bindgen]
1394pub fn alphatrend_free_flat(ptr: *mut f64, n: usize) {
1395 unsafe {
1396 let _ = Vec::from_raw_parts(ptr, 2 * n, 2 * n);
1397 }
1398}
1399
1400#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1401#[wasm_bindgen]
1402pub fn alphatrend_into_flat(
1403 open_ptr: *const f64,
1404 high_ptr: *const f64,
1405 low_ptr: *const f64,
1406 close_ptr: *const f64,
1407 volume_ptr: *const f64,
1408 out_flat_ptr: *mut f64,
1409 len: usize,
1410 coeff: f64,
1411 period: usize,
1412 no_volume: bool,
1413) -> Result<(), JsValue> {
1414 if [open_ptr, high_ptr, low_ptr, close_ptr, volume_ptr]
1415 .iter()
1416 .any(|&p| p.is_null())
1417 || out_flat_ptr.is_null()
1418 {
1419 return Err(JsValue::from_str("null pointer"));
1420 }
1421 unsafe {
1422 let (open, high, low, close, volume) = (
1423 core::slice::from_raw_parts(open_ptr, len),
1424 core::slice::from_raw_parts(high_ptr, len),
1425 core::slice::from_raw_parts(low_ptr, len),
1426 core::slice::from_raw_parts(close_ptr, len),
1427 core::slice::from_raw_parts(volume_ptr, len),
1428 );
1429 let (k1, k2) = (
1430 core::slice::from_raw_parts_mut(out_flat_ptr, len),
1431 core::slice::from_raw_parts_mut(out_flat_ptr.add(len), len),
1432 );
1433 let params = AlphaTrendParams {
1434 coeff: Some(coeff),
1435 period: Some(period),
1436 no_volume: Some(no_volume),
1437 };
1438 let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
1439 alphatrend_into_slices(k1, k2, &input, Kernel::Auto)
1440 .map_err(|e| JsValue::from_str(&e.to_string()))
1441 }
1442}
1443
1444#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1445#[wasm_bindgen]
1446#[deprecated(note = "Use alphatrend_alloc_flat/alphatrend_into_flat")]
1447pub fn alphatrend_alloc(_len: usize) -> *mut f64 {
1448 core::ptr::null_mut()
1449}
1450
1451#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1452#[wasm_bindgen]
1453#[deprecated(note = "Use alphatrend_free_flat")]
1454pub fn alphatrend_free(_ptr: *mut f64, _len: usize) {}
1455
1456#[derive(Clone, Debug)]
1457pub struct AlphaTrendBatchRange {
1458 pub coeff: (f64, f64, f64),
1459 pub period: (usize, usize, usize),
1460 pub no_volume: bool,
1461}
1462
1463impl Default for AlphaTrendBatchRange {
1464 fn default() -> Self {
1465 Self {
1466 coeff: (1.0, 1.0, 0.0),
1467 period: (14, 263, 1),
1468 no_volume: false,
1469 }
1470 }
1471}
1472
1473#[derive(Clone, Debug, Default)]
1474pub struct AlphaTrendBatchBuilder {
1475 range: AlphaTrendBatchRange,
1476 kernel: Kernel,
1477}
1478
1479impl AlphaTrendBatchBuilder {
1480 pub fn new() -> Self {
1481 Self::default()
1482 }
1483
1484 pub fn kernel(mut self, k: Kernel) -> Self {
1485 self.kernel = k;
1486 self
1487 }
1488
1489 #[inline]
1490 pub fn coeff_range(mut self, start: f64, end: f64, step: f64) -> Self {
1491 self.range.coeff = (start, end, step);
1492 self
1493 }
1494
1495 #[inline]
1496 pub fn coeff_static(mut self, val: f64) -> Self {
1497 self.range.coeff = (val, val, 0.0);
1498 self
1499 }
1500
1501 #[inline]
1502 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1503 self.range.period = (start, end, step);
1504 self
1505 }
1506
1507 #[inline]
1508 pub fn period_static(mut self, val: usize) -> Self {
1509 self.range.period = (val, val, 0);
1510 self
1511 }
1512
1513 #[inline]
1514 pub fn no_volume(mut self, val: bool) -> Self {
1515 self.range.no_volume = val;
1516 self
1517 }
1518
1519 pub fn apply_candles(self, c: &Candles) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1520 alphatrend_batch_with_kernel(c, &self.range, self.kernel)
1521 }
1522
1523 pub fn apply_slices(
1524 self,
1525 open: &[f64],
1526 high: &[f64],
1527 low: &[f64],
1528 close: &[f64],
1529 volume: &[f64],
1530 ) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1531 let len = close.len();
1532 if open.len() != len || high.len() != len || low.len() != len || volume.len() != len {
1533 return Err(AlphaTrendError::InconsistentDataLengths);
1534 }
1535
1536 let candles = Candles {
1537 timestamp: vec![0; len],
1538 open: open.to_vec(),
1539 high: high.to_vec(),
1540 low: low.to_vec(),
1541 close: close.to_vec(),
1542 volume: volume.to_vec(),
1543 fields: CandleFieldFlags {
1544 open: true,
1545 high: true,
1546 low: true,
1547 close: true,
1548 volume: true,
1549 },
1550 hl2: vec![],
1551 hlc3: vec![],
1552 ohlc4: vec![],
1553 hlcc4: vec![],
1554 };
1555
1556 alphatrend_batch_with_kernel(&candles, &self.range, self.kernel)
1557 }
1558
1559 pub fn with_default_candles(c: &Candles) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1560 AlphaTrendBatchBuilder::new()
1561 .kernel(Kernel::Auto)
1562 .apply_candles(c)
1563 }
1564
1565 pub fn with_default_slices(
1566 open: &[f64],
1567 high: &[f64],
1568 low: &[f64],
1569 close: &[f64],
1570 volume: &[f64],
1571 k: Kernel,
1572 ) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1573 AlphaTrendBatchBuilder::new()
1574 .kernel(k)
1575 .apply_slices(open, high, low, close, volume)
1576 }
1577}
1578
1579#[derive(Clone, Debug)]
1580pub struct AlphaTrendBatchOutput {
1581 pub values_k1: Vec<f64>,
1582 pub values_k2: Vec<f64>,
1583 pub combos: Vec<AlphaTrendParams>,
1584 pub rows: usize,
1585 pub cols: usize,
1586}
1587
1588impl AlphaTrendBatchOutput {
1589 pub fn row_for_params(&self, p: &AlphaTrendParams) -> Option<usize> {
1590 self.combos.iter().position(|c| {
1591 (c.coeff.unwrap_or(1.0) - p.coeff.unwrap_or(1.0)).abs() < 1e-12
1592 && c.period.unwrap_or(14) == p.period.unwrap_or(14)
1593 && c.no_volume.unwrap_or(false) == p.no_volume.unwrap_or(false)
1594 })
1595 }
1596
1597 pub fn values_for(&self, p: &AlphaTrendParams) -> Option<(&[f64], &[f64])> {
1598 self.row_for_params(p).map(|row| {
1599 let start = row * self.cols;
1600 let end = start + self.cols;
1601 (&self.values_k1[start..end], &self.values_k2[start..end])
1602 })
1603 }
1604}
1605
1606#[inline(always)]
1607fn expand_grid_alphatrend(
1608 r: &AlphaTrendBatchRange,
1609) -> Result<Vec<AlphaTrendParams>, AlphaTrendError> {
1610 fn axis_usize(
1611 (start, end, step): (usize, usize, usize),
1612 ) -> Result<Vec<usize>, AlphaTrendError> {
1613 if step == 0 || start == end {
1614 return Ok(vec![start]);
1615 }
1616 let mut v = Vec::new();
1617 if start < end {
1618 let mut cur = start;
1619 while cur <= end {
1620 v.push(cur);
1621 cur = cur.saturating_add(step);
1622 if cur == *v.last().unwrap() {
1623 break;
1624 }
1625 }
1626 } else {
1627 let mut cur = start;
1628 while cur >= end {
1629 v.push(cur);
1630 let next = cur.saturating_sub(step);
1631 if next == cur {
1632 break;
1633 }
1634 cur = next;
1635 if cur == 0 && end > 0 {
1636 break;
1637 }
1638 }
1639 }
1640 if v.is_empty() {
1641 return Err(AlphaTrendError::InvalidRange { start, end, step });
1642 }
1643 Ok(v)
1644 }
1645 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, AlphaTrendError> {
1646 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1647 return Ok(vec![start]);
1648 }
1649 let mut out = Vec::new();
1650 if start < end {
1651 let st = if step > 0.0 { step } else { -step };
1652 let mut x = start;
1653 while x <= end + 1e-12 {
1654 out.push(x);
1655 x += st;
1656 }
1657 } else {
1658 let st = if step > 0.0 { -step } else { step };
1659 if st.abs() < 1e-12 {
1660 return Ok(vec![start]);
1661 }
1662 let mut x = start;
1663 while x >= end - 1e-12 {
1664 out.push(x);
1665 x += st;
1666 }
1667 }
1668 if out.is_empty() {
1669 return Err(AlphaTrendError::InvalidRangeF64 { start, end, step });
1670 }
1671 Ok(out)
1672 }
1673
1674 let coeffs = axis_f64(r.coeff)?;
1675 let periods = axis_usize(r.period)?;
1676
1677 let mut out = Vec::with_capacity(coeffs.len().saturating_mul(periods.len()));
1678 for &c in &coeffs {
1679 for &p in &periods {
1680 out.push(AlphaTrendParams {
1681 coeff: Some(c),
1682 period: Some(p),
1683 no_volume: Some(r.no_volume),
1684 });
1685 }
1686 }
1687 Ok(out)
1688}
1689
1690pub fn alphatrend_batch_with_kernel(
1691 candles: &Candles,
1692 sweep: &AlphaTrendBatchRange,
1693 k: Kernel,
1694) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1695 let kernel = match k {
1696 Kernel::Auto => detect_best_batch_kernel(),
1697 other if other.is_batch() => other,
1698 other => return Err(AlphaTrendError::InvalidKernelForBatch(other)),
1699 };
1700
1701 let simd = match kernel {
1702 Kernel::Avx512Batch => Kernel::Avx512,
1703 Kernel::Avx2Batch => Kernel::Avx2,
1704 Kernel::ScalarBatch => Kernel::Scalar,
1705 _ => unreachable!(),
1706 };
1707
1708 alphatrend_batch_inner(candles, sweep, simd, true)
1709}
1710
1711#[inline(always)]
1712pub fn alphatrend_batch_slice(
1713 open: &[f64],
1714 high: &[f64],
1715 low: &[f64],
1716 close: &[f64],
1717 volume: &[f64],
1718 sweep: &AlphaTrendBatchRange,
1719 kern: Kernel,
1720) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1721 alphatrend_batch_inner_from_slices(open, high, low, close, volume, sweep, kern, false)
1722}
1723
1724#[inline(always)]
1725pub fn alphatrend_batch_par_slice(
1726 open: &[f64],
1727 high: &[f64],
1728 low: &[f64],
1729 close: &[f64],
1730 volume: &[f64],
1731 sweep: &AlphaTrendBatchRange,
1732 kern: Kernel,
1733) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1734 alphatrend_batch_inner_from_slices(open, high, low, close, volume, sweep, kern, true)
1735}
1736
1737#[inline(always)]
1738fn alphatrend_batch_inner_from_slices(
1739 open: &[f64],
1740 high: &[f64],
1741 low: &[f64],
1742 close: &[f64],
1743 volume: &[f64],
1744 sweep: &AlphaTrendBatchRange,
1745 kern: Kernel,
1746 parallel: bool,
1747) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1748 let combos = expand_grid_alphatrend(sweep)?;
1749 let cols = close.len();
1750 let rows = combos.len();
1751 let _elems = rows
1752 .checked_mul(cols)
1753 .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1754 if cols == 0 {
1755 return Err(AlphaTrendError::EmptyInputData);
1756 }
1757
1758 let first = close
1759 .iter()
1760 .position(|x| !x.is_nan())
1761 .ok_or(AlphaTrendError::AllValuesNaN)?;
1762 let warm_k1: Vec<usize> = combos
1763 .iter()
1764 .map(|p| first + p.period.unwrap_or(14) - 1)
1765 .collect();
1766 let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1767
1768 let mut k1_mu = make_uninit_matrix(rows, cols);
1769 let mut k2_mu = make_uninit_matrix(rows, cols);
1770 init_matrix_prefixes(&mut k1_mu, cols, &warm_k1);
1771 init_matrix_prefixes(&mut k2_mu, cols, &warm_k2);
1772
1773 let mut k1_guard = core::mem::ManuallyDrop::new(k1_mu);
1774 let mut k2_guard = core::mem::ManuallyDrop::new(k2_mu);
1775 let out_k1: &mut [f64] = unsafe {
1776 core::slice::from_raw_parts_mut(k1_guard.as_mut_ptr() as *mut f64, k1_guard.len())
1777 };
1778 let out_k2: &mut [f64] = unsafe {
1779 core::slice::from_raw_parts_mut(k2_guard.as_mut_ptr() as *mut f64, k2_guard.len())
1780 };
1781
1782 let actual = match kern {
1783 Kernel::Auto => detect_best_batch_kernel(),
1784 k => k,
1785 };
1786 alphatrend_batch_inner_into_slices(
1787 open, high, low, close, volume, sweep, actual, parallel, out_k1, out_k2,
1788 )?;
1789
1790 let values_k1 = unsafe {
1791 Vec::from_raw_parts(
1792 k1_guard.as_mut_ptr() as *mut f64,
1793 k1_guard.len(),
1794 k1_guard.capacity(),
1795 )
1796 };
1797 let values_k2 = unsafe {
1798 Vec::from_raw_parts(
1799 k2_guard.as_mut_ptr() as *mut f64,
1800 k2_guard.len(),
1801 k2_guard.capacity(),
1802 )
1803 };
1804 core::mem::forget(k1_guard);
1805 core::mem::forget(k2_guard);
1806
1807 Ok(AlphaTrendBatchOutput {
1808 values_k1,
1809 values_k2,
1810 combos,
1811 rows,
1812 cols,
1813 })
1814}
1815
1816#[inline(always)]
1817fn alphatrend_batch_inner(
1818 candles: &Candles,
1819 sweep: &AlphaTrendBatchRange,
1820 kern: Kernel,
1821 parallel: bool,
1822) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1823 let combos = expand_grid_alphatrend(sweep)?;
1824 let cols = candles.close.len();
1825 let rows = combos.len();
1826 let _elems = rows
1827 .checked_mul(cols)
1828 .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1829 if cols == 0 {
1830 return Err(AlphaTrendError::EmptyInputData);
1831 }
1832
1833 let mut k1_mu = make_uninit_matrix(rows, cols);
1834 let mut k2_mu = make_uninit_matrix(rows, cols);
1835
1836 let first = candles
1837 .close
1838 .iter()
1839 .position(|x| !x.is_nan())
1840 .ok_or(AlphaTrendError::AllValuesNaN)?;
1841 let warm_k1: Vec<usize> = combos
1842 .iter()
1843 .map(|p| first + p.period.unwrap_or(14) - 1)
1844 .collect();
1845 let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1846
1847 init_matrix_prefixes(&mut k1_mu, cols, &warm_k1);
1848 init_matrix_prefixes(&mut k2_mu, cols, &warm_k2);
1849
1850 let mut k1_guard = core::mem::ManuallyDrop::new(k1_mu);
1851 let mut k2_guard = core::mem::ManuallyDrop::new(k2_mu);
1852 let out_k1: &mut [f64] = unsafe {
1853 core::slice::from_raw_parts_mut(k1_guard.as_mut_ptr() as *mut f64, k1_guard.len())
1854 };
1855 let out_k2: &mut [f64] = unsafe {
1856 core::slice::from_raw_parts_mut(k2_guard.as_mut_ptr() as *mut f64, k2_guard.len())
1857 };
1858
1859 let actual = match kern {
1860 Kernel::Auto => detect_best_batch_kernel(),
1861 k => k,
1862 };
1863
1864 let do_row =
1865 |row: usize, k1_row: &mut [f64], k2_row: &mut [f64]| -> Result<(), AlphaTrendError> {
1866 let p = &combos[row];
1867 let input = AlphaTrendInput::from_candles(candles, p.clone());
1868 alphatrend_into_slices(k1_row, k2_row, &input, actual)
1869 };
1870
1871 #[cfg(not(target_arch = "wasm32"))]
1872 if parallel {
1873 use rayon::prelude::*;
1874
1875 out_k1
1876 .par_chunks_mut(cols)
1877 .zip(out_k2.par_chunks_mut(cols))
1878 .enumerate()
1879 .try_for_each(|(row, (k1r, k2r))| do_row(row, k1r, k2r))?;
1880 } else {
1881 for (row, (k1r, k2r)) in out_k1
1882 .chunks_mut(cols)
1883 .zip(out_k2.chunks_mut(cols))
1884 .enumerate()
1885 {
1886 do_row(row, k1r, k2r)?;
1887 }
1888 }
1889
1890 #[cfg(target_arch = "wasm32")]
1891 for (row, (k1r, k2r)) in out_k1
1892 .chunks_mut(cols)
1893 .zip(out_k2.chunks_mut(cols))
1894 .enumerate()
1895 {
1896 do_row(row, k1r, k2r)?;
1897 }
1898
1899 let values_k1 = unsafe {
1900 Vec::from_raw_parts(
1901 k1_guard.as_mut_ptr() as *mut f64,
1902 k1_guard.len(),
1903 k1_guard.capacity(),
1904 )
1905 };
1906 let values_k2 = unsafe {
1907 Vec::from_raw_parts(
1908 k2_guard.as_mut_ptr() as *mut f64,
1909 k2_guard.len(),
1910 k2_guard.capacity(),
1911 )
1912 };
1913
1914 Ok(AlphaTrendBatchOutput {
1915 values_k1,
1916 values_k2,
1917 combos,
1918 rows,
1919 cols,
1920 })
1921}
1922
1923#[inline(always)]
1924pub fn alphatrend_batch_inner_into_slices(
1925 open: &[f64],
1926 high: &[f64],
1927 low: &[f64],
1928 close: &[f64],
1929 volume: &[f64],
1930 sweep: &AlphaTrendBatchRange,
1931 kern: Kernel,
1932 parallel: bool,
1933 k1_slice: &mut [f64],
1934 k2_slice: &mut [f64],
1935) -> Result<(), AlphaTrendError> {
1936 let combos = expand_grid_alphatrend(sweep)?;
1937 let cols = close.len();
1938 let rows = combos.len();
1939
1940 if open.len() != cols || high.len() != cols || low.len() != cols || volume.len() != cols {
1941 return Err(AlphaTrendError::InconsistentDataLengths);
1942 }
1943
1944 if cols == 0 {
1945 return Err(AlphaTrendError::EmptyInputData);
1946 }
1947
1948 let total = rows
1949 .checked_mul(cols)
1950 .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1951 if k1_slice.len() != total {
1952 return Err(AlphaTrendError::OutputLengthMismatch {
1953 expected: total,
1954 got: k1_slice.len(),
1955 });
1956 }
1957 if k2_slice.len() != total {
1958 return Err(AlphaTrendError::OutputLengthMismatch {
1959 expected: total,
1960 got: k2_slice.len(),
1961 });
1962 }
1963
1964 let actual = match kern {
1965 Kernel::Auto => detect_best_batch_kernel(),
1966 k => k,
1967 };
1968 let simd_kernel = match actual {
1969 Kernel::Avx512Batch => Kernel::Avx512,
1970 Kernel::Avx2Batch => Kernel::Avx2,
1971 Kernel::ScalarBatch => Kernel::Scalar,
1972 _ => detect_best_kernel(),
1973 };
1974
1975 let first = close
1976 .iter()
1977 .position(|x| !x.is_nan())
1978 .ok_or(AlphaTrendError::AllValuesNaN)?;
1979
1980 let warm_k1: Vec<usize> = combos
1981 .iter()
1982 .map(|p| first + p.period.unwrap_or(14) - 1)
1983 .collect();
1984 let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1985
1986 let k1_uninit: &mut [MaybeUninit<f64>] = unsafe {
1987 core::slice::from_raw_parts_mut(
1988 k1_slice.as_mut_ptr() as *mut MaybeUninit<f64>,
1989 k1_slice.len(),
1990 )
1991 };
1992 let k2_uninit: &mut [MaybeUninit<f64>] = unsafe {
1993 core::slice::from_raw_parts_mut(
1994 k2_slice.as_mut_ptr() as *mut MaybeUninit<f64>,
1995 k2_slice.len(),
1996 )
1997 };
1998 init_matrix_prefixes(k1_uninit, cols, &warm_k1);
1999 init_matrix_prefixes(k2_uninit, cols, &warm_k2);
2000
2001 let mut tr_mu = make_uninit_matrix(1, cols);
2002 let tr: &mut [f64] =
2003 unsafe { core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, cols) };
2004 if first < cols {
2005 tr[first] = high[first] - low[first];
2006 }
2007 for i in (first + 1)..cols {
2008 let hl = high[i] - low[i];
2009 let hc = (high[i] - close[i - 1]).abs();
2010 let lc = (low[i] - close[i - 1]).abs();
2011 tr[i] = hl.max(hc).max(lc);
2012 }
2013
2014 let use_rsi = sweep.no_volume;
2015 let hlc3_opt: Option<Vec<f64>> = if use_rsi {
2016 None
2017 } else {
2018 let mut hlc3_mu = make_uninit_matrix(1, cols);
2019 let hlc3: &mut [f64] =
2020 unsafe { core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, cols) };
2021 for i in 0..cols {
2022 hlc3[i] = (high[i] + low[i] + close[i]) / 3.0;
2023 }
2024
2025 let v = unsafe {
2026 Vec::from_raw_parts(
2027 hlc3_mu.as_mut_ptr() as *mut f64,
2028 hlc3_mu.len(),
2029 hlc3_mu.capacity(),
2030 )
2031 };
2032 core::mem::forget(hlc3_mu);
2033 Some(v)
2034 };
2035
2036 use std::collections::HashMap;
2037 let mut unique_periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap_or(14)).collect();
2038 unique_periods.sort_unstable();
2039 unique_periods.dedup();
2040
2041 let mut momentum_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(unique_periods.len());
2042 for &p in &unique_periods {
2043 if p == 0 || p > cols {
2044 return Err(AlphaTrendError::InvalidPeriod {
2045 period: p,
2046 data_len: cols,
2047 });
2048 }
2049 if use_rsi {
2050 let rsi_params = RsiParams { period: Some(p) };
2051 let rsi_input = RsiInput::from_slice(close, rsi_params);
2052 let mv = rsi_with_kernel(&rsi_input, simd_kernel)
2053 .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
2054 .values;
2055 momentum_map.insert(p, mv);
2056 } else {
2057 let hlc3 = hlc3_opt.as_ref().expect("hlc3 precomputed");
2058 let mfi_params = MfiParams { period: Some(p) };
2059 let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
2060 let mv = mfi_with_kernel(&mfi_input, simd_kernel)
2061 .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
2062 .values;
2063 momentum_map.insert(p, mv);
2064 }
2065 }
2066
2067 let do_row =
2068 |row: usize, k1_row: &mut [f64], k2_row: &mut [f64]| -> Result<(), AlphaTrendError> {
2069 let params = &combos[row];
2070 let coeff = params.coeff.unwrap_or(1.0);
2071 if !coeff.is_finite() || coeff <= 0.0 {
2072 return Err(AlphaTrendError::InvalidCoeff { coeff });
2073 }
2074 let period = params.period.unwrap_or(14);
2075 if period == 0 || period > cols {
2076 return Err(AlphaTrendError::InvalidPeriod {
2077 period,
2078 data_len: cols,
2079 });
2080 }
2081 let warmup = first + period - 1;
2082 if warmup >= cols {
2083 return Ok(());
2084 }
2085
2086 let mom = momentum_map.get(&period).expect("momentum precomputed");
2087
2088 let mut sum = 0.0f64;
2089 for j in first..=warmup {
2090 sum += tr[j];
2091 }
2092
2093 let mut prev_alpha = f64::NAN;
2094 let mut prev1 = f64::NAN;
2095 let mut prev2 = f64::NAN;
2096
2097 for i in warmup..cols {
2098 let a = sum / period as f64;
2099 let up = low[i] - a * coeff;
2100 let dn = high[i] + a * coeff;
2101 let m_ge_50 = mom[i] >= 50.0;
2102
2103 let cur = if i == warmup {
2104 if m_ge_50 {
2105 up
2106 } else {
2107 dn
2108 }
2109 } else if m_ge_50 {
2110 if up < prev_alpha {
2111 prev_alpha
2112 } else {
2113 up
2114 }
2115 } else {
2116 if dn > prev_alpha {
2117 prev_alpha
2118 } else {
2119 dn
2120 }
2121 };
2122
2123 k1_row[i] = cur;
2124 if i >= warmup + 2 {
2125 k2_row[i] = prev2;
2126 }
2127
2128 prev2 = prev1;
2129 prev1 = cur;
2130 prev_alpha = cur;
2131
2132 if i + 1 < cols {
2133 sum += tr[i + 1] - tr[i + 1 - period];
2134 }
2135 }
2136 Ok(())
2137 };
2138
2139 #[cfg(not(target_arch = "wasm32"))]
2140 if parallel {
2141 use rayon::prelude::*;
2142 k1_slice
2143 .par_chunks_mut(cols)
2144 .zip(k2_slice.par_chunks_mut(cols))
2145 .enumerate()
2146 .try_for_each(|(row, (k1r, k2r))| do_row(row, k1r, k2r))?;
2147 } else {
2148 for (row, (k1r, k2r)) in k1_slice
2149 .chunks_mut(cols)
2150 .zip(k2_slice.chunks_mut(cols))
2151 .enumerate()
2152 {
2153 do_row(row, k1r, k2r)?;
2154 }
2155 }
2156
2157 #[cfg(target_arch = "wasm32")]
2158 for (row, (k1r, k2r)) in k1_slice
2159 .chunks_mut(cols)
2160 .zip(k2_slice.chunks_mut(cols))
2161 .enumerate()
2162 {
2163 do_row(row, k1r, k2r)?;
2164 }
2165
2166 Ok(())
2167}
2168
2169#[cfg(feature = "python")]
2170#[pyfunction(name = "alphatrend_batch")]
2171#[pyo3(signature = (open, high, low, close, volume, coeff_range, period_range, no_volume=false, kernel=None))]
2172pub fn alphatrend_batch_py<'py>(
2173 py: Python<'py>,
2174 open: PyReadonlyArray1<'py, f64>,
2175 high: PyReadonlyArray1<'py, f64>,
2176 low: PyReadonlyArray1<'py, f64>,
2177 close: PyReadonlyArray1<'py, f64>,
2178 volume: PyReadonlyArray1<'py, f64>,
2179 coeff_range: (f64, f64, f64),
2180 period_range: (usize, usize, usize),
2181 no_volume: bool,
2182 kernel: Option<&str>,
2183) -> PyResult<Bound<'py, PyDict>> {
2184 use numpy::PyArray1;
2185
2186 let (o, h, l, c, v) = (
2187 open.as_slice()?,
2188 high.as_slice()?,
2189 low.as_slice()?,
2190 close.as_slice()?,
2191 volume.as_slice()?,
2192 );
2193 let len = c.len();
2194 if o.len() != len || h.len() != len || l.len() != len || v.len() != len {
2195 return Err(PyValueError::new_err("Inconsistent data lengths"));
2196 }
2197
2198 let sweep = AlphaTrendBatchRange {
2199 coeff: coeff_range,
2200 period: period_range,
2201 no_volume,
2202 };
2203 let kern = validate_kernel(kernel, true)?;
2204
2205 let rows = {
2206 fn axis_usize((s, e, st): (usize, usize, usize)) -> usize {
2207 if st == 0 || s == e {
2208 1
2209 } else {
2210 (e - s) / st + 1
2211 }
2212 }
2213 fn axis_f64((s, e, st): (f64, f64, f64)) -> usize {
2214 if st.abs() < 1e-12 || (s - e).abs() < 1e-12 {
2215 1
2216 } else {
2217 ((e - s) / st).floor() as usize + 1
2218 }
2219 }
2220 axis_f64(coeff_range) * axis_usize(period_range)
2221 };
2222
2223 let out_k1 = unsafe { PyArray1::<f64>::new(py, [rows * len], false) };
2224 let out_k2 = unsafe { PyArray1::<f64>::new(py, [rows * len], false) };
2225 let k1_slice = unsafe { out_k1.as_slice_mut()? };
2226 let k2_slice = unsafe { out_k2.as_slice_mut()? };
2227
2228 py.allow_threads(|| {
2229 alphatrend_batch_inner_into_slices(o, h, l, c, v, &sweep, kern, true, k1_slice, k2_slice)
2230 })
2231 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2232
2233 let dict = PyDict::new(py);
2234 dict.set_item("k1", out_k1.reshape([rows, len])?)?;
2235 dict.set_item("k2", out_k2.reshape([rows, len])?)?;
2236 dict.set_item("rows", rows)?;
2237 dict.set_item("cols", len)?;
2238
2239 let combos =
2240 expand_grid_alphatrend(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2241 let combo_list = PyList::new(
2242 py,
2243 combos.iter().map(|c| {
2244 let d = PyDict::new(py);
2245 d.set_item("coeff", c.coeff.unwrap_or(1.0)).unwrap();
2246 d.set_item("period", c.period.unwrap_or(14)).unwrap();
2247 d.set_item("no_volume", c.no_volume.unwrap_or(false))
2248 .unwrap();
2249 d
2250 }),
2251 )?;
2252 dict.set_item("combos", combo_list)?;
2253
2254 Ok(dict.into())
2255}
2256
2257#[cfg(all(feature = "python", feature = "cuda"))]
2258#[pyfunction(name = "alphatrend_cuda_batch_dev")]
2259#[pyo3(signature = (high_f32, low_f32, close_f32, volume_f32, coeff_range, period_range, no_volume=false, device_id=0))]
2260pub fn alphatrend_cuda_batch_dev_py<'py>(
2261 py: Python<'py>,
2262 high_f32: PyReadonlyArray1<'py, f32>,
2263 low_f32: PyReadonlyArray1<'py, f32>,
2264 close_f32: PyReadonlyArray1<'py, f32>,
2265 volume_f32: PyReadonlyArray1<'py, f32>,
2266 coeff_range: (f64, f64, f64),
2267 period_range: (usize, usize, usize),
2268 no_volume: bool,
2269 device_id: usize,
2270) -> PyResult<Bound<'py, PyDict>> {
2271 use crate::cuda::cuda_available;
2272 use numpy::IntoPyArray;
2273 if !cuda_available() {
2274 return Err(PyValueError::new_err("CUDA not available"));
2275 }
2276 let (h, l, c, v) = (
2277 high_f32.as_slice()?,
2278 low_f32.as_slice()?,
2279 close_f32.as_slice()?,
2280 volume_f32.as_slice()?,
2281 );
2282 if h.len() != l.len() || h.len() != c.len() || h.len() != v.len() {
2283 return Err(PyValueError::new_err("Inconsistent data lengths"));
2284 }
2285 let sweep = AlphaTrendBatchRange {
2286 coeff: coeff_range,
2287 period: period_range,
2288 no_volume,
2289 };
2290 let (batch, coeffs_vec, periods_vec, ctx_guard, dev_id) = py.allow_threads(|| {
2291 let cuda =
2292 CudaAlphaTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2293 let out = cuda
2294 .alphatrend_batch_dev(h, l, c, v, &sweep)
2295 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2296 let coeffs: Vec<f64> = out.combos.iter().map(|p| p.coeff.unwrap_or(1.0)).collect();
2297 let periods: Vec<u64> = out
2298 .combos
2299 .iter()
2300 .map(|p| p.period.unwrap_or(14) as u64)
2301 .collect();
2302 Ok::<_, PyErr>((out, coeffs, periods, cuda.context_arc(), cuda.device_id()))
2303 })?;
2304
2305 let rows = batch.k1.rows;
2306 let cols = batch.k1.cols;
2307 let dict = PyDict::new(py);
2308
2309 let k1_py = AtDeviceArrayF32Py {
2310 buf: Some(batch.k1.buf),
2311 rows,
2312 cols,
2313 _ctx: ctx_guard.clone(),
2314 device_id: dev_id,
2315 };
2316 let k2_py = AtDeviceArrayF32Py {
2317 buf: Some(batch.k2.buf),
2318 rows,
2319 cols,
2320 _ctx: ctx_guard,
2321 device_id: dev_id,
2322 };
2323 dict.set_item("k1", Py::new(py, k1_py)?)?;
2324 dict.set_item("k2", Py::new(py, k2_py)?)?;
2325 dict.set_item("coeffs", coeffs_vec.into_pyarray(py))?;
2326 dict.set_item("periods", periods_vec.into_pyarray(py))?;
2327 dict.set_item("rows", rows)?;
2328 dict.set_item("cols", cols)?;
2329 Ok(dict)
2330}
2331
2332#[cfg(all(feature = "python", feature = "cuda"))]
2333#[pyfunction(name = "alphatrend_cuda_many_series_one_param_dev")]
2334#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, volume_tm_f32, cols, rows, coeff=1.0, period=14, no_volume=false, device_id=0))]
2335pub fn alphatrend_cuda_many_series_one_param_dev_py<'py>(
2336 py: Python<'py>,
2337 high_tm_f32: PyReadonlyArray1<'py, f32>,
2338 low_tm_f32: PyReadonlyArray1<'py, f32>,
2339 close_tm_f32: PyReadonlyArray1<'py, f32>,
2340 volume_tm_f32: PyReadonlyArray1<'py, f32>,
2341 cols: usize,
2342 rows: usize,
2343 coeff: f64,
2344 period: usize,
2345 no_volume: bool,
2346 device_id: usize,
2347) -> PyResult<(AtDeviceArrayF32Py, AtDeviceArrayF32Py)> {
2348 use crate::cuda::cuda_available;
2349 if !cuda_available() {
2350 return Err(PyValueError::new_err("CUDA not available"));
2351 }
2352 let (h, l, c, v) = (
2353 high_tm_f32.as_slice()?,
2354 low_tm_f32.as_slice()?,
2355 close_tm_f32.as_slice()?,
2356 volume_tm_f32.as_slice()?,
2357 );
2358 if h.len() != cols * rows
2359 || l.len() != cols * rows
2360 || c.len() != cols * rows
2361 || v.len() != cols * rows
2362 {
2363 return Err(PyValueError::new_err("Inconsistent time-major shapes"));
2364 }
2365 let (k1, k2, ctx_guard, dev_id) = py.allow_threads(|| {
2366 let cuda =
2367 CudaAlphaTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2368 let out = cuda
2369 .alphatrend_many_series_one_param_time_major_dev(
2370 h, l, c, v, cols, rows, coeff, period, no_volume,
2371 )
2372 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2373 Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
2374 })?;
2375 Ok((
2376 AtDeviceArrayF32Py {
2377 buf: Some(k1.buf),
2378 rows: k1.rows,
2379 cols: k1.cols,
2380 _ctx: ctx_guard.clone(),
2381 device_id: dev_id,
2382 },
2383 AtDeviceArrayF32Py {
2384 buf: Some(k2.buf),
2385 rows: k2.rows,
2386 cols: k2.cols,
2387 _ctx: ctx_guard,
2388 device_id: dev_id,
2389 },
2390 ))
2391}
2392
2393#[cfg(all(feature = "python", feature = "cuda"))]
2394#[pyclass(module = "ta_indicators.cuda", unsendable)]
2395pub struct AtDeviceArrayF32Py {
2396 pub(crate) buf: Option<DeviceBuffer<f32>>,
2397 pub(crate) rows: usize,
2398 pub(crate) cols: usize,
2399 pub(crate) _ctx: Arc<Context>,
2400 pub(crate) device_id: u32,
2401}
2402
2403#[cfg(all(feature = "python", feature = "cuda"))]
2404#[pymethods]
2405impl AtDeviceArrayF32Py {
2406 #[getter]
2407 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2408 let d = PyDict::new(py);
2409 d.set_item("shape", (self.rows, self.cols))?;
2410 d.set_item("typestr", "<f4")?;
2411 d.set_item(
2412 "strides",
2413 (
2414 self.cols * std::mem::size_of::<f32>(),
2415 std::mem::size_of::<f32>(),
2416 ),
2417 )?;
2418 let ptr = self
2419 .buf
2420 .as_ref()
2421 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
2422 .as_device_ptr()
2423 .as_raw() as usize;
2424 d.set_item("data", (ptr, false))?;
2425
2426 d.set_item("version", 3)?;
2427 Ok(d)
2428 }
2429
2430 fn __dlpack_device__(&self) -> (i32, i32) {
2431 (2, self.device_id as i32)
2432 }
2433
2434 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2435 fn __dlpack__<'py>(
2436 &mut self,
2437 py: Python<'py>,
2438 stream: Option<PyObject>,
2439 max_version: Option<PyObject>,
2440 dl_device: Option<PyObject>,
2441 copy: Option<PyObject>,
2442 ) -> PyResult<PyObject> {
2443 let (kdl, alloc_dev) = self.__dlpack_device__();
2444 if let Some(dev_obj) = dl_device.as_ref() {
2445 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2446 if dev_ty != kdl || dev_id != alloc_dev {
2447 let wants_copy = copy
2448 .as_ref()
2449 .and_then(|c| c.extract::<bool>(py).ok())
2450 .unwrap_or(false);
2451 if wants_copy {
2452 return Err(PyValueError::new_err(
2453 "device copy not implemented for __dlpack__",
2454 ));
2455 } else {
2456 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2457 }
2458 }
2459 }
2460 }
2461 let _ = stream;
2462
2463 let buf = self
2464 .buf
2465 .take()
2466 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2467
2468 let rows = self.rows;
2469 let cols = self.cols;
2470 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2471
2472 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2473 }
2474}
2475
2476#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2477#[derive(Serialize, Deserialize)]
2478pub struct AlphaTrendBatchJsOutput {
2479 pub values: Vec<f64>,
2480 pub combos: Vec<AlphaTrendParams>,
2481 pub rows: usize,
2482 pub cols: usize,
2483}
2484
2485#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2486#[wasm_bindgen(js_name = alphatrend_batch)]
2487pub fn alphatrend_batch_js(
2488 open: &[f64],
2489 high: &[f64],
2490 low: &[f64],
2491 close: &[f64],
2492 volume: &[f64],
2493 coeff_start: f64,
2494 coeff_end: f64,
2495 coeff_step: f64,
2496 period_start: usize,
2497 period_end: usize,
2498 period_step: usize,
2499 no_volume: bool,
2500) -> Result<JsValue, JsValue> {
2501 let sweep = AlphaTrendBatchRange {
2502 coeff: (coeff_start, coeff_end, coeff_step),
2503 period: (period_start, period_end, period_step),
2504 no_volume,
2505 };
2506 let combos = expand_grid_alphatrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2507 let rows = combos.len();
2508 let cols = close.len();
2509
2510 let total = rows
2511 .checked_mul(cols)
2512 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2513 let mut k1 = vec![f64::NAN; total];
2514 let mut k2 = vec![f64::NAN; total];
2515
2516 alphatrend_batch_inner_into_slices(
2517 open,
2518 high,
2519 low,
2520 close,
2521 volume,
2522 &sweep,
2523 detect_best_batch_kernel(),
2524 true,
2525 &mut k1,
2526 &mut k2,
2527 )
2528 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2529
2530 let total_values = rows
2531 .checked_mul(2)
2532 .and_then(|r2| r2.checked_mul(cols))
2533 .ok_or_else(|| JsValue::from_str("rows*2*cols overflow"))?;
2534 let mut values = Vec::with_capacity(total_values);
2535 for r in 0..rows {
2536 let base = r * cols;
2537 values.extend_from_slice(&k1[base..base + cols]);
2538 values.extend_from_slice(&k2[base..base + cols]);
2539 }
2540
2541 let js = AlphaTrendBatchJsOutput {
2542 values,
2543 combos,
2544 rows: rows * 2,
2545 cols,
2546 };
2547 serde_wasm_bindgen::to_value(&js)
2548 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2549}
2550
2551#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2552#[wasm_bindgen]
2553pub fn alphatrend_batch_into_flat(
2554 open_ptr: *const f64,
2555 high_ptr: *const f64,
2556 low_ptr: *const f64,
2557 close_ptr: *const f64,
2558 volume_ptr: *const f64,
2559 out_ptr: *mut f64,
2560 len: usize,
2561 coeff_start: f64,
2562 coeff_end: f64,
2563 coeff_step: f64,
2564 period_start: usize,
2565 period_end: usize,
2566 period_step: usize,
2567 no_volume: bool,
2568) -> Result<usize, JsValue> {
2569 if [open_ptr, high_ptr, low_ptr, close_ptr, volume_ptr]
2570 .iter()
2571 .any(|&p| p.is_null())
2572 || out_ptr.is_null()
2573 {
2574 return Err(JsValue::from_str("null pointer"));
2575 }
2576 unsafe {
2577 let (open, high, low, close, volume) = (
2578 core::slice::from_raw_parts(open_ptr, len),
2579 core::slice::from_raw_parts(high_ptr, len),
2580 core::slice::from_raw_parts(low_ptr, len),
2581 core::slice::from_raw_parts(close_ptr, len),
2582 core::slice::from_raw_parts(volume_ptr, len),
2583 );
2584 let sweep = AlphaTrendBatchRange {
2585 coeff: (coeff_start, coeff_end, coeff_step),
2586 period: (period_start, period_end, period_step),
2587 no_volume,
2588 };
2589 let combos =
2590 expand_grid_alphatrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2591 let rows = combos.len();
2592 let cols = len;
2593 let total = rows
2594 .checked_mul(cols)
2595 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2596
2597 let k1 = core::slice::from_raw_parts_mut(out_ptr, total);
2598 let k2 = core::slice::from_raw_parts_mut(out_ptr.add(total), total);
2599
2600 alphatrend_batch_inner_into_slices(
2601 open,
2602 high,
2603 low,
2604 close,
2605 volume,
2606 &sweep,
2607 detect_best_batch_kernel(),
2608 false,
2609 k1,
2610 k2,
2611 )
2612 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2613
2614 Ok(rows)
2615 }
2616}
2617
2618#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2619#[derive(Serialize, Deserialize)]
2620pub struct AlphaTrendBatchConfig {
2621 pub coeff_range: (f64, f64, f64),
2622 pub period_range: (usize, usize, usize),
2623 pub no_volume: bool,
2624}
2625
2626#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2627#[wasm_bindgen(js_name = alphatrend_batch_unified)]
2628pub fn alphatrend_batch_unified_js(
2629 open: &[f64],
2630 high: &[f64],
2631 low: &[f64],
2632 close: &[f64],
2633 volume: &[f64],
2634 config: JsValue,
2635) -> Result<JsValue, JsValue> {
2636 let config: AlphaTrendBatchConfig = serde_wasm_bindgen::from_value(config)
2637 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2638
2639 let sweep = AlphaTrendBatchRange {
2640 coeff: config.coeff_range,
2641 period: config.period_range,
2642 no_volume: config.no_volume,
2643 };
2644
2645 let output =
2646 alphatrend_batch_slice(open, high, low, close, volume, &sweep, detect_best_kernel())
2647 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2648
2649 let rows2 = output.rows * 2;
2650 let cols = output.cols;
2651 let total_values = rows2
2652 .checked_mul(cols)
2653 .ok_or_else(|| JsValue::from_str("rows2*cols overflow"))?;
2654 let mut values = Vec::with_capacity(total_values);
2655 for r in 0..output.rows {
2656 let base = r * cols;
2657 values.extend_from_slice(&output.values_k1[base..base + cols]);
2658 values.extend_from_slice(&output.values_k2[base..base + cols]);
2659 }
2660
2661 let js_output = AlphaTrendBatchJsOutput {
2662 values,
2663 combos: output.combos,
2664 rows: rows2,
2665 cols,
2666 };
2667
2668 serde_wasm_bindgen::to_value(&js_output)
2669 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2670}
2671
2672#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2673#[wasm_bindgen]
2674pub fn alphatrend_into(
2675 in_ptr: *const f64,
2676 out_k1_ptr: *mut f64,
2677 out_k2_ptr: *mut f64,
2678 len: usize,
2679 open_ptr: *const f64,
2680 high_ptr: *const f64,
2681 low_ptr: *const f64,
2682 volume_ptr: *const f64,
2683 coeff: f64,
2684 period: usize,
2685 no_volume: bool,
2686) -> Result<(), JsValue> {
2687 if in_ptr.is_null()
2688 || out_k1_ptr.is_null()
2689 || out_k2_ptr.is_null()
2690 || open_ptr.is_null()
2691 || high_ptr.is_null()
2692 || low_ptr.is_null()
2693 || volume_ptr.is_null()
2694 {
2695 return Err(JsValue::from_str("Null pointer passed to alphatrend_into"));
2696 }
2697
2698 unsafe {
2699 let open = std::slice::from_raw_parts(open_ptr, len);
2700 let high = std::slice::from_raw_parts(high_ptr, len);
2701 let low = std::slice::from_raw_parts(low_ptr, len);
2702 let close = std::slice::from_raw_parts(in_ptr, len);
2703 let volume = std::slice::from_raw_parts(volume_ptr, len);
2704
2705 let params = AlphaTrendParams {
2706 coeff: Some(coeff),
2707 period: Some(period),
2708 no_volume: Some(no_volume),
2709 };
2710 let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
2711
2712 let out_k1 = std::slice::from_raw_parts_mut(out_k1_ptr, len);
2713 let out_k2 = std::slice::from_raw_parts_mut(out_k2_ptr, len);
2714
2715 alphatrend_into_slices(out_k1, out_k2, &input, Kernel::Auto)
2716 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2717
2718 Ok(())
2719 }
2720}
2721
2722#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2723#[wasm_bindgen]
2724#[deprecated(
2725 since = "1.0.0",
2726 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2727)]
2728pub struct AlphaTrendContext {
2729 coeff: f64,
2730 period: usize,
2731 no_volume: bool,
2732 kernel: Kernel,
2733}
2734
2735#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2736#[wasm_bindgen]
2737#[allow(deprecated)]
2738impl AlphaTrendContext {
2739 #[wasm_bindgen(constructor)]
2740 #[deprecated(
2741 since = "1.0.0",
2742 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2743 )]
2744 pub fn new(coeff: f64, period: usize, no_volume: bool) -> Result<AlphaTrendContext, JsValue> {
2745 if period == 0 {
2746 return Err(JsValue::from_str("Invalid period: 0"));
2747 }
2748 if coeff <= 0.0 || !coeff.is_finite() {
2749 return Err(JsValue::from_str(&format!(
2750 "Invalid coefficient: {}",
2751 coeff
2752 )));
2753 }
2754
2755 Ok(AlphaTrendContext {
2756 coeff,
2757 period,
2758 no_volume,
2759 kernel: Kernel::Auto,
2760 })
2761 }
2762
2763 pub fn update_into(
2764 &self,
2765 open_ptr: *const f64,
2766 high_ptr: *const f64,
2767 low_ptr: *const f64,
2768 close_ptr: *const f64,
2769 volume_ptr: *const f64,
2770 out_k1_ptr: *mut f64,
2771 out_k2_ptr: *mut f64,
2772 len: usize,
2773 ) -> Result<(), JsValue> {
2774 if len < self.period {
2775 return Err(JsValue::from_str("Data length less than period"));
2776 }
2777
2778 unsafe {
2779 let open = std::slice::from_raw_parts(open_ptr, len);
2780 let high = std::slice::from_raw_parts(high_ptr, len);
2781 let low = std::slice::from_raw_parts(low_ptr, len);
2782 let close = std::slice::from_raw_parts(close_ptr, len);
2783 let volume = std::slice::from_raw_parts(volume_ptr, len);
2784 let out_k1 = std::slice::from_raw_parts_mut(out_k1_ptr, len);
2785 let out_k2 = std::slice::from_raw_parts_mut(out_k2_ptr, len);
2786
2787 let params = AlphaTrendParams {
2788 coeff: Some(self.coeff),
2789 period: Some(self.period),
2790 no_volume: Some(self.no_volume),
2791 };
2792 let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
2793
2794 if close_ptr == out_k1_ptr || close_ptr == out_k2_ptr {
2795 let mut temp_k1 = vec![0.0; len];
2796 let mut temp_k2 = vec![0.0; len];
2797
2798 alphatrend_into_slices(&mut temp_k1, &mut temp_k2, &input, self.kernel)
2799 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2800
2801 out_k1.copy_from_slice(&temp_k1);
2802 out_k2.copy_from_slice(&temp_k2);
2803 } else {
2804 alphatrend_into_slices(out_k1, out_k2, &input, self.kernel)
2805 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2806 }
2807 }
2808
2809 Ok(())
2810 }
2811
2812 pub fn get_warmup_period(&self) -> usize {
2813 self.period - 1
2814 }
2815}
2816
2817#[cfg(test)]
2818mod tests {
2819 use super::*;
2820 use crate::utilities::data_loader::read_candles_from_csv;
2821 use std::error::Error;
2822
2823 fn check_alphatrend_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2824 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2825 let candles = read_candles_from_csv(file_path)?;
2826
2827 let input = AlphaTrendInput::from_candles(&candles, AlphaTrendParams::default());
2828 let result = alphatrend_with_kernel(&input, kernel)?;
2829
2830 let expected_k1 = [
2831 60243.00,
2832 60243.00,
2833 60138.92857143,
2834 60088.42857143,
2835 59937.21428571,
2836 ];
2837
2838 let expected_k2 = [
2839 60542.42857143,
2840 60454.14285714,
2841 60243.00,
2842 60243.00,
2843 60138.92857143,
2844 ];
2845
2846 let start = result.k1.len().saturating_sub(5);
2847
2848 for (i, &val) in result.k1[start..].iter().enumerate() {
2849 let diff = (val - expected_k1[i]).abs();
2850 assert!(
2851 diff < 1e-6,
2852 "[{}] AlphaTrend K1 {:?} mismatch at idx {}: got {}, expected {} (diff: {})",
2853 test_name,
2854 kernel,
2855 i,
2856 val,
2857 expected_k1[i],
2858 diff
2859 );
2860 }
2861
2862 for (i, &val) in result.k2[start..].iter().enumerate() {
2863 let diff = (val - expected_k2[i]).abs();
2864 assert!(
2865 diff < 1e-6,
2866 "[{}] AlphaTrend K2 {:?} mismatch at idx {}: got {}, expected {} (diff: {})",
2867 test_name,
2868 kernel,
2869 i,
2870 val,
2871 expected_k2[i],
2872 diff
2873 );
2874 }
2875
2876 Ok(())
2877 }
2878
2879 fn check_alphatrend_partial_params(
2880 test_name: &str,
2881 kernel: Kernel,
2882 ) -> Result<(), Box<dyn Error>> {
2883 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2884 let candles = read_candles_from_csv(file_path)?;
2885
2886 let default_params = AlphaTrendParams {
2887 coeff: None,
2888 period: None,
2889 no_volume: None,
2890 };
2891 let input = AlphaTrendInput::from_candles(&candles, default_params);
2892 let output = alphatrend_with_kernel(&input, kernel)?;
2893 assert_eq!(output.k1.len(), candles.close.len());
2894 assert_eq!(output.k2.len(), candles.close.len());
2895
2896 Ok(())
2897 }
2898
2899 fn check_alphatrend_default_candles(
2900 test_name: &str,
2901 kernel: Kernel,
2902 ) -> Result<(), Box<dyn Error>> {
2903 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2904 let candles = read_candles_from_csv(file_path)?;
2905
2906 let input = AlphaTrendInput::with_default_candles(&candles);
2907 let output = alphatrend_with_kernel(&input, kernel)?;
2908 assert_eq!(output.k1.len(), candles.close.len());
2909 assert_eq!(output.k2.len(), candles.close.len());
2910
2911 Ok(())
2912 }
2913
2914 fn check_alphatrend_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2915 let open = vec![10.0, 20.0, 30.0];
2916 let high = vec![12.0, 22.0, 32.0];
2917 let low = vec![8.0, 18.0, 28.0];
2918 let close = vec![11.0, 21.0, 31.0];
2919 let volume = vec![100.0, 200.0, 300.0];
2920
2921 let params = AlphaTrendParams {
2922 coeff: Some(1.0),
2923 period: Some(0),
2924 no_volume: Some(false),
2925 };
2926 let input = AlphaTrendInput::from_slices(&open, &high, &low, &close, &volume, params);
2927 let res = alphatrend_with_kernel(&input, kernel);
2928 assert!(
2929 res.is_err(),
2930 "[{}] AlphaTrend should fail with zero period",
2931 test_name
2932 );
2933 Ok(())
2934 }
2935
2936 fn check_alphatrend_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2937 let empty: [f64; 0] = [];
2938 let params = AlphaTrendParams::default();
2939 let input = AlphaTrendInput::from_slices(&empty, &empty, &empty, &empty, &empty, params);
2940 let res = alphatrend_with_kernel(&input, kernel);
2941 assert!(
2942 res.is_err(),
2943 "[{}] AlphaTrend should fail with empty input",
2944 test_name
2945 );
2946 Ok(())
2947 }
2948
2949 fn check_alphatrend_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2950 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2951 let params = AlphaTrendParams::default();
2952 let input = AlphaTrendInput::from_slices(
2953 &nan_data, &nan_data, &nan_data, &nan_data, &nan_data, params,
2954 );
2955 let res = alphatrend_with_kernel(&input, kernel);
2956 assert!(
2957 res.is_err(),
2958 "[{}] AlphaTrend should fail with all NaN values",
2959 test_name
2960 );
2961 Ok(())
2962 }
2963
2964 fn check_alphatrend_period_exceeds_length(
2965 test_name: &str,
2966 kernel: Kernel,
2967 ) -> Result<(), Box<dyn Error>> {
2968 let data_small = [10.0, 20.0, 30.0];
2969 let params = AlphaTrendParams {
2970 coeff: Some(1.0),
2971 period: Some(10),
2972 no_volume: Some(false),
2973 };
2974 let input = AlphaTrendInput::from_slices(
2975 &data_small,
2976 &data_small,
2977 &data_small,
2978 &data_small,
2979 &data_small,
2980 params,
2981 );
2982 let res = alphatrend_with_kernel(&input, kernel);
2983 assert!(
2984 res.is_err(),
2985 "[{}] AlphaTrend should fail with period exceeding length",
2986 test_name
2987 );
2988 Ok(())
2989 }
2990
2991 fn check_alphatrend_very_small_dataset(
2992 test_name: &str,
2993 kernel: Kernel,
2994 ) -> Result<(), Box<dyn Error>> {
2995 let single_point = [42.0];
2996 let params = AlphaTrendParams {
2997 coeff: Some(1.0),
2998 period: Some(14),
2999 no_volume: Some(false),
3000 };
3001 let input = AlphaTrendInput::from_slices(
3002 &single_point,
3003 &single_point,
3004 &single_point,
3005 &single_point,
3006 &single_point,
3007 params,
3008 );
3009 let res = alphatrend_with_kernel(&input, kernel);
3010 assert!(
3011 res.is_err(),
3012 "[{}] AlphaTrend should fail with insufficient data",
3013 test_name
3014 );
3015 Ok(())
3016 }
3017
3018 fn check_alphatrend_invalid_coeff(
3019 test_name: &str,
3020 kernel: Kernel,
3021 ) -> Result<(), Box<dyn Error>> {
3022 let data = vec![1.0; 20];
3023 let params = AlphaTrendParams {
3024 coeff: Some(-1.0),
3025 period: Some(14),
3026 no_volume: Some(false),
3027 };
3028 let input = AlphaTrendInput::from_slices(&data, &data, &data, &data, &data, params);
3029 let res = alphatrend_with_kernel(&input, kernel);
3030 assert!(
3031 matches!(res, Err(AlphaTrendError::InvalidCoeff { .. })),
3032 "[{}] AlphaTrend should fail with invalid coefficient",
3033 test_name
3034 );
3035 Ok(())
3036 }
3037
3038 fn check_alphatrend_inconsistent_lengths(
3039 test_name: &str,
3040 kernel: Kernel,
3041 ) -> Result<(), Box<dyn Error>> {
3042 let open = vec![10.0, 20.0, 30.0];
3043 let high = vec![12.0, 22.0];
3044 let low = vec![8.0, 18.0, 28.0];
3045 let close = vec![11.0, 21.0, 31.0];
3046 let volume = vec![100.0, 200.0, 300.0];
3047
3048 let params = AlphaTrendParams::default();
3049 let input = AlphaTrendInput::from_slices(&open, &high, &low, &close, &volume, params);
3050 let res = alphatrend_with_kernel(&input, kernel);
3051 assert!(
3052 matches!(res, Err(AlphaTrendError::InconsistentDataLengths)),
3053 "[{}] AlphaTrend should fail with inconsistent data lengths",
3054 test_name
3055 );
3056 Ok(())
3057 }
3058
3059 fn check_alphatrend_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3060 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3061 let candles = read_candles_from_csv(file_path)?;
3062
3063 let first_params = AlphaTrendParams {
3064 coeff: Some(1.0),
3065 period: Some(14),
3066 no_volume: Some(false),
3067 };
3068 let first_input = AlphaTrendInput::from_candles(&candles, first_params);
3069 let first_result = alphatrend_with_kernel(&first_input, kernel)?;
3070
3071 let second_params = AlphaTrendParams {
3072 coeff: Some(1.0),
3073 period: Some(14),
3074 no_volume: Some(true),
3075 };
3076
3077 let k1 = &first_result.k1;
3078 let synthetic_high: Vec<f64> = k1
3079 .iter()
3080 .map(|&v| if v.is_nan() { v } else { v + 10.0 })
3081 .collect();
3082 let synthetic_low: Vec<f64> = k1
3083 .iter()
3084 .map(|&v| if v.is_nan() { v } else { v - 10.0 })
3085 .collect();
3086 let synthetic_volume = vec![1000.0; k1.len()];
3087
3088 let second_input = AlphaTrendInput::from_slices(
3089 k1,
3090 &synthetic_high,
3091 &synthetic_low,
3092 k1,
3093 &synthetic_volume,
3094 second_params,
3095 );
3096 let second_result = alphatrend_with_kernel(&second_input, kernel)?;
3097
3098 assert_eq!(second_result.k1.len(), first_result.k1.len());
3099 assert_eq!(second_result.k2.len(), first_result.k2.len());
3100
3101 Ok(())
3102 }
3103
3104 fn check_alphatrend_nan_handling(
3105 test_name: &str,
3106 kernel: Kernel,
3107 ) -> Result<(), Box<dyn Error>> {
3108 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3109 let candles = read_candles_from_csv(file_path)?;
3110
3111 let input = AlphaTrendInput::from_candles(
3112 &candles,
3113 AlphaTrendParams {
3114 coeff: Some(1.0),
3115 period: Some(14),
3116 no_volume: Some(false),
3117 },
3118 );
3119 let res = alphatrend_with_kernel(&input, kernel)?;
3120 assert_eq!(res.k1.len(), candles.close.len());
3121 assert_eq!(res.k2.len(), candles.close.len());
3122
3123 if res.k1.len() > 240 {
3124 for (i, &val) in res.k1[240..].iter().enumerate() {
3125 assert!(
3126 !val.is_nan(),
3127 "[{}] Found unexpected NaN in K1 at out-index {}",
3128 test_name,
3129 240 + i
3130 );
3131 }
3132 }
3133 Ok(())
3134 }
3135
3136 fn check_alphatrend_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3137 let params = AlphaTrendParams {
3138 coeff: Some(1.0),
3139 period: Some(14),
3140 no_volume: Some(false),
3141 };
3142
3143 let mut stream = AlphaTrendStream::try_new(params)?;
3144 let warmup = stream.get_warmup_period();
3145
3146 for i in 0..30 {
3147 let high = 100.0 + i as f64 + 2.0;
3148 let low = 100.0 + i as f64 - 2.0;
3149 let close = 100.0 + i as f64;
3150 let volume = 1000.0 + i as f64 * 10.0;
3151
3152 let result = stream.update(high, low, close, volume);
3153 if i + 1 >= warmup + 3 {
3154 let some = result.expect("streaming should emit after warmup+2");
3155 assert!(
3156 some.0.is_finite() && some.1.is_finite(),
3157 "[{}] Non-finite streaming outputs at i={}",
3158 test_name,
3159 i
3160 );
3161 } else {
3162 assert!(
3163 result.is_none(),
3164 "[{}] Should not emit before warmup+2 at i={}",
3165 test_name,
3166 i
3167 );
3168 }
3169 }
3170 Ok(())
3171 }
3172
3173 #[cfg(debug_assertions)]
3174 fn check_alphatrend_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3175 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3176 let candles = read_candles_from_csv(file_path)?;
3177
3178 let test_params = vec![
3179 AlphaTrendParams::default(),
3180 AlphaTrendParams {
3181 coeff: Some(0.5),
3182 period: Some(7),
3183 no_volume: Some(false),
3184 },
3185 AlphaTrendParams {
3186 coeff: Some(2.0),
3187 period: Some(21),
3188 no_volume: Some(true),
3189 },
3190 AlphaTrendParams {
3191 coeff: Some(1.5),
3192 period: Some(10),
3193 no_volume: Some(false),
3194 },
3195 ];
3196
3197 for (param_idx, params) in test_params.iter().enumerate() {
3198 let input = AlphaTrendInput::from_candles(&candles, params.clone());
3199 let output = alphatrend_with_kernel(&input, kernel)?;
3200
3201 for (i, &val) in output.k1.iter().chain(output.k2.iter()).enumerate() {
3202 if val.is_nan() {
3203 continue;
3204 }
3205
3206 let bits = val.to_bits();
3207
3208 if bits == 0x11111111_11111111 {
3209 panic!(
3210 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
3211 with params: coeff={}, period={}, no_volume={}",
3212 test_name,
3213 val,
3214 bits,
3215 i,
3216 params.coeff.unwrap_or(1.0),
3217 params.period.unwrap_or(14),
3218 params.no_volume.unwrap_or(false)
3219 );
3220 }
3221
3222 if bits == 0x22222222_22222222 {
3223 panic!(
3224 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
3225 with params: coeff={}, period={}, no_volume={}",
3226 test_name,
3227 val,
3228 bits,
3229 i,
3230 params.coeff.unwrap_or(1.0),
3231 params.period.unwrap_or(14),
3232 params.no_volume.unwrap_or(false)
3233 );
3234 }
3235
3236 if bits == 0x33333333_33333333 {
3237 panic!(
3238 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
3239 with params: coeff={}, period={}, no_volume={}",
3240 test_name,
3241 val,
3242 bits,
3243 i,
3244 params.coeff.unwrap_or(1.0),
3245 params.period.unwrap_or(14),
3246 params.no_volume.unwrap_or(false)
3247 );
3248 }
3249 }
3250 }
3251
3252 Ok(())
3253 }
3254
3255 #[cfg(not(debug_assertions))]
3256 fn check_alphatrend_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3257 Ok(())
3258 }
3259
3260 #[cfg(feature = "proptest")]
3261 #[allow(clippy::float_cmp)]
3262 fn check_alphatrend_property(
3263 test_name: &str,
3264 kernel: Kernel,
3265 ) -> Result<(), Box<dyn std::error::Error>> {
3266 use proptest::prelude::*;
3267
3268 let strat = (1usize..=50).prop_flat_map(|period| {
3269 (
3270 prop::collection::vec(
3271 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
3272 period..400,
3273 ),
3274 Just(period),
3275 0.1f64..5.0f64,
3276 any::<bool>(),
3277 )
3278 });
3279
3280 proptest::test_runner::TestRunner::default()
3281 .run(&strat, |(close_data, period, coeff, no_volume)| {
3282 let high: Vec<f64> = close_data.iter().map(|&c| c + 5.0).collect();
3283 let low: Vec<f64> = close_data.iter().map(|&c| c - 5.0).collect();
3284 let open = close_data.clone();
3285 let volume = vec![1000.0; close_data.len()];
3286
3287 let params = AlphaTrendParams {
3288 coeff: Some(coeff),
3289 period: Some(period),
3290 no_volume: Some(no_volume),
3291 };
3292 let input =
3293 AlphaTrendInput::from_slices(&open, &high, &low, &close_data, &volume, params);
3294
3295 let result = alphatrend_with_kernel(&input, kernel).unwrap();
3296 let ref_result = alphatrend_with_kernel(&input, Kernel::Scalar).unwrap();
3297
3298 for i in 0..close_data.len() {
3299 let y = result.k1[i];
3300 let r = ref_result.k1[i];
3301
3302 if !y.is_finite() || !r.is_finite() {
3303 prop_assert!(
3304 y.to_bits() == r.to_bits(),
3305 "K1 finite/NaN mismatch idx {i}: {y} vs {r}"
3306 );
3307 continue;
3308 }
3309
3310 let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
3311 prop_assert!(
3312 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3313 "K1 mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
3314 );
3315 }
3316
3317 for i in 0..close_data.len() {
3318 let y = result.k2[i];
3319 let r = ref_result.k2[i];
3320
3321 if !y.is_finite() || !r.is_finite() {
3322 prop_assert!(
3323 y.to_bits() == r.to_bits(),
3324 "K2 finite/NaN mismatch idx {i}: {y} vs {r}"
3325 );
3326 continue;
3327 }
3328
3329 let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
3330 prop_assert!(
3331 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3332 "K2 mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
3333 );
3334 }
3335
3336 Ok(())
3337 })
3338 .unwrap();
3339
3340 Ok(())
3341 }
3342
3343 macro_rules! generate_all_alphatrend_tests {
3344 ($($test_fn:ident),*) => {
3345 paste::paste! {
3346 $(
3347 #[test]
3348 fn [<$test_fn _scalar_f64>]() {
3349 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3350 }
3351 )*
3352 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3353 $(
3354 #[test]
3355 fn [<$test_fn _avx2_f64>]() {
3356 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3357 }
3358 #[test]
3359 fn [<$test_fn _avx512_f64>]() {
3360 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3361 }
3362 )*
3363 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3364 $(
3365 #[test]
3366 fn [<$test_fn _simd128_f64>]() {
3367 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
3368 }
3369 )*
3370 }
3371 }
3372 }
3373
3374 generate_all_alphatrend_tests!(
3375 check_alphatrend_accuracy,
3376 check_alphatrend_partial_params,
3377 check_alphatrend_default_candles,
3378 check_alphatrend_zero_period,
3379 check_alphatrend_empty_input,
3380 check_alphatrend_all_nan,
3381 check_alphatrend_period_exceeds_length,
3382 check_alphatrend_very_small_dataset,
3383 check_alphatrend_invalid_coeff,
3384 check_alphatrend_inconsistent_lengths,
3385 check_alphatrend_reinput,
3386 check_alphatrend_nan_handling,
3387 check_alphatrend_streaming,
3388 check_alphatrend_no_poison
3389 );
3390
3391 #[cfg(feature = "proptest")]
3392 generate_all_alphatrend_tests!(check_alphatrend_property);
3393
3394 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3395 #[test]
3396 fn test_alphatrend_into_matches_api() -> Result<(), Box<dyn Error>> {
3397 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3398 let candles = read_candles_from_csv(file_path)?;
3399
3400 let input = AlphaTrendInput::from_candles(&candles, AlphaTrendParams::default());
3401
3402 let baseline = alphatrend(&input)?;
3403
3404 let mut out_k1 = vec![0.0; candles.close.len()];
3405 let mut out_k2 = vec![0.0; candles.close.len()];
3406 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3407 {
3408 alphatrend_into(&input, &mut out_k1, &mut out_k2)?;
3409 }
3410
3411 fn eq_or_both_nan(a: f64, b: f64) -> bool {
3412 if a.is_nan() && b.is_nan() {
3413 true
3414 } else {
3415 a == b || (a - b).abs() <= 1e-12
3416 }
3417 }
3418
3419 assert_eq!(baseline.k1.len(), out_k1.len());
3420 assert_eq!(baseline.k2.len(), out_k2.len());
3421
3422 for i in 0..out_k1.len() {
3423 assert!(
3424 eq_or_both_nan(baseline.k1[i], out_k1[i]),
3425 "k1 mismatch at idx {i}: api={} into={}",
3426 baseline.k1[i],
3427 out_k1[i]
3428 );
3429 assert!(
3430 eq_or_both_nan(baseline.k2[i], out_k2[i]),
3431 "k2 mismatch at idx {i}: api={} into={}",
3432 baseline.k2[i],
3433 out_k2[i]
3434 );
3435 }
3436
3437 Ok(())
3438 }
3439
3440 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3441 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3442 let c = read_candles_from_csv(file)?;
3443
3444 let sweep = AlphaTrendBatchRange::default();
3445 let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3446
3447 let def = AlphaTrendParams::default();
3448 let row = output.row_for_params(&def).expect("default row missing");
3449
3450 assert_eq!(output.cols, c.close.len());
3451
3452 let k1_start = row * output.cols;
3453 let k2_start = row * output.cols;
3454
3455 let expected_k1 = [
3456 60243.00,
3457 60243.00,
3458 60138.92857143,
3459 60088.42857143,
3460 59937.21428571,
3461 ];
3462
3463 let start = output.cols - 5;
3464 for (i, &expected) in expected_k1.iter().enumerate() {
3465 let actual = output.values_k1[k1_start + start + i];
3466 assert!(
3467 (actual - expected).abs() < 1e-6,
3468 "[{test}] default-row K1 mismatch at idx {i}: {actual} vs {expected}"
3469 );
3470 }
3471 Ok(())
3472 }
3473
3474 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3475 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3476 let c = read_candles_from_csv(file)?;
3477
3478 let sweep = AlphaTrendBatchRange {
3479 coeff: (1.0, 2.0, 0.5),
3480 period: (10, 20, 5),
3481 no_volume: false,
3482 };
3483
3484 let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3485
3486 let coeff_count = ((2.0 - 1.0) / 0.5) as usize + 1;
3487 let period_count = ((20 - 10) / 5) as usize + 1;
3488 let expected_combos = coeff_count * period_count;
3489
3490 assert_eq!(output.combos.len(), expected_combos);
3491 assert_eq!(output.rows, expected_combos);
3492 assert_eq!(output.cols, c.close.len());
3493
3494 Ok(())
3495 }
3496
3497 #[cfg(debug_assertions)]
3498 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3499 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3500 let c = read_candles_from_csv(file)?;
3501
3502 let test_configs = vec![
3503 (1.0, 1.0, 0.0, 10, 15, 5, false),
3504 (0.5, 2.0, 0.5, 14, 14, 0, true),
3505 (1.5, 1.5, 0.0, 7, 21, 7, false),
3506 ];
3507
3508 for (cfg_idx, &(c_start, c_end, c_step, p_start, p_end, p_step, no_vol)) in
3509 test_configs.iter().enumerate()
3510 {
3511 let sweep = AlphaTrendBatchRange {
3512 coeff: (c_start, c_end, c_step),
3513 period: (p_start, p_end, p_step),
3514 no_volume: no_vol,
3515 };
3516
3517 let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3518
3519 for (idx, &val) in output
3520 .values_k1
3521 .iter()
3522 .chain(output.values_k2.iter())
3523 .enumerate()
3524 {
3525 if val.is_nan() {
3526 continue;
3527 }
3528
3529 let bits = val.to_bits();
3530 let row = idx / output.cols;
3531 let col = idx % output.cols;
3532
3533 if bits == 0x11111111_11111111
3534 || bits == 0x22222222_22222222
3535 || bits == 0x33333333_33333333
3536 {
3537 let combo = if row < output.combos.len() {
3538 &output.combos[row]
3539 } else {
3540 &output.combos[row - output.combos.len()]
3541 };
3542
3543 panic!(
3544 "[{}] Config {}: Found poison value {} (0x{:016X}) \
3545 at row {} col {} (flat index {}) with params: coeff={}, period={}, no_volume={}",
3546 test,
3547 cfg_idx,
3548 val,
3549 bits,
3550 row,
3551 col,
3552 idx,
3553 combo.coeff.unwrap_or(1.0),
3554 combo.period.unwrap_or(14),
3555 combo.no_volume.unwrap_or(false)
3556 );
3557 }
3558 }
3559 }
3560
3561 Ok(())
3562 }
3563
3564 #[cfg(not(debug_assertions))]
3565 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3566 Ok(())
3567 }
3568
3569 macro_rules! gen_batch_tests {
3570 ($fn_name:ident) => {
3571 paste::paste! {
3572 #[test] fn [<$fn_name _scalar>]() {
3573 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3574 }
3575 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3576 #[test] fn [<$fn_name _avx2>]() {
3577 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3578 }
3579 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3580 #[test] fn [<$fn_name _avx512>]() {
3581 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3582 }
3583 #[test] fn [<$fn_name _auto_detect>]() {
3584 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3585 }
3586 }
3587 };
3588 }
3589
3590 gen_batch_tests!(check_batch_default_row);
3591 gen_batch_tests!(check_batch_sweep);
3592 gen_batch_tests!(check_batch_no_poison);
3593}