1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::{ManuallyDrop, MaybeUninit};
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::cuda_available;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::kama_wrapper::DeviceArrayF32Kama;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::cuda::moving_averages::CudaKama;
22#[derive(Debug, Clone)]
23pub enum KamaData<'a> {
24 Candles {
25 candles: &'a Candles,
26 source: &'a str,
27 },
28 Slice(&'a [f64]),
29}
30
31#[derive(Debug, Clone)]
32pub struct KamaOutput {
33 pub values: Vec<f64>,
34}
35
36#[derive(Debug, Clone)]
37#[cfg_attr(
38 all(target_arch = "wasm32", feature = "wasm"),
39 derive(serde::Serialize, serde::Deserialize)
40)]
41pub struct KamaParams {
42 pub period: Option<usize>,
43}
44
45impl Default for KamaParams {
46 fn default() -> Self {
47 KamaParams { period: Some(30) }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub struct KamaInput<'a> {
53 pub data: KamaData<'a>,
54 pub params: KamaParams,
55}
56
57impl<'a> AsRef<[f64]> for KamaInput<'a> {
58 #[inline(always)]
59 fn as_ref(&self) -> &[f64] {
60 match &self.data {
61 KamaData::Slice(slice) => slice,
62 KamaData::Candles { candles, source } => source_type(candles, source),
63 }
64 }
65}
66
67impl<'a> KamaInput<'a> {
68 #[inline]
69 pub fn from_candles(c: &'a Candles, s: &'a str, p: KamaParams) -> Self {
70 Self {
71 data: KamaData::Candles {
72 candles: c,
73 source: s,
74 },
75 params: p,
76 }
77 }
78 #[inline]
79 pub fn from_slice(sl: &'a [f64], p: KamaParams) -> Self {
80 Self {
81 data: KamaData::Slice(sl),
82 params: p,
83 }
84 }
85 #[inline]
86 pub fn with_default_candles(c: &'a Candles) -> Self {
87 Self::from_candles(c, "close", KamaParams::default())
88 }
89 #[inline]
90 pub fn get_period(&self) -> usize {
91 self.params.period.unwrap_or(30)
92 }
93}
94
95#[derive(Copy, Clone, Debug)]
96pub struct KamaBuilder {
97 period: Option<usize>,
98 kernel: Kernel,
99}
100
101impl Default for KamaBuilder {
102 fn default() -> Self {
103 Self {
104 period: None,
105 kernel: Kernel::Auto,
106 }
107 }
108}
109
110impl KamaBuilder {
111 #[inline(always)]
112 pub fn new() -> Self {
113 Self::default()
114 }
115 #[inline(always)]
116 pub fn period(mut self, n: usize) -> Self {
117 self.period = Some(n);
118 self
119 }
120 #[inline(always)]
121 pub fn kernel(mut self, k: Kernel) -> Self {
122 self.kernel = k;
123 self
124 }
125 #[inline(always)]
126 pub fn apply(self, c: &Candles) -> Result<KamaOutput, KamaError> {
127 let p = KamaParams {
128 period: self.period,
129 };
130 let i = KamaInput::from_candles(c, "close", p);
131 kama_with_kernel(&i, self.kernel)
132 }
133 #[inline(always)]
134 pub fn apply_slice(self, d: &[f64]) -> Result<KamaOutput, KamaError> {
135 let p = KamaParams {
136 period: self.period,
137 };
138 let i = KamaInput::from_slice(d, p);
139 kama_with_kernel(&i, self.kernel)
140 }
141 #[inline(always)]
142 pub fn into_stream(self) -> Result<KamaStream, KamaError> {
143 let p = KamaParams {
144 period: self.period,
145 };
146 KamaStream::try_new(p)
147 }
148}
149
150#[derive(Debug, Error)]
151pub enum KamaError {
152 #[error("kama: Input data slice is empty.")]
153 EmptyInputData,
154 #[error("kama: All values are NaN.")]
155 AllValuesNaN,
156 #[error("kama: Invalid period: period = {period}, data length = {data_len}")]
157 InvalidPeriod { period: usize, data_len: usize },
158 #[error("kama: Not enough valid data: needed = {needed}, valid = {valid}")]
159 NotEnoughValidData { needed: usize, valid: usize },
160 #[error("kama: Output length mismatch: expected = {expected}, got = {got}")]
161 OutputLengthMismatch { expected: usize, got: usize },
162 #[error("kama: Invalid range: start={start}, end={end}, step={step}")]
163 InvalidRange {
164 start: usize,
165 end: usize,
166 step: usize,
167 },
168 #[error("kama: Invalid kernel for batch operation: {0:?}")]
169 InvalidKernelForBatch(Kernel),
170 #[error("kama: invalid input: {0}")]
171 InvalidInput(&'static str),
172}
173
174#[inline]
175pub fn kama(input: &KamaInput) -> Result<KamaOutput, KamaError> {
176 kama_with_kernel(input, Kernel::Auto)
177}
178
179#[inline(always)]
180fn kama_prepare<'a>(
181 input: &'a KamaInput,
182 kernel: Kernel,
183) -> Result<(&'a [f64], usize, usize, Kernel), KamaError> {
184 let data: &[f64] = input.as_ref();
185 let len = data.len();
186 if len == 0 {
187 return Err(KamaError::EmptyInputData);
188 }
189 let first = data
190 .iter()
191 .position(|x| !x.is_nan())
192 .ok_or(KamaError::AllValuesNaN)?;
193 let period = input.get_period();
194
195 if period == 0 || period > len {
196 return Err(KamaError::InvalidPeriod {
197 period,
198 data_len: len,
199 });
200 }
201 if (len - first) <= period {
202 return Err(KamaError::NotEnoughValidData {
203 needed: period + 1,
204 valid: len - first,
205 });
206 }
207
208 let chosen = match kernel {
209 Kernel::Auto => Kernel::Scalar,
210 other => other,
211 };
212
213 Ok((data, period, first, chosen))
214}
215
216#[inline(always)]
217fn kama_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
218 unsafe {
219 match kernel {
220 Kernel::Scalar | Kernel::ScalarBatch => kama_scalar(data, period, first, out),
221 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222 Kernel::Avx2 | Kernel::Avx2Batch => kama_avx2(data, period, first, out),
223 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
224 Kernel::Avx512 | Kernel::Avx512Batch => kama_avx512(data, period, first, out),
225 _ => kama_scalar(data, period, first, out),
226 }
227 }
228}
229
230pub fn kama_with_kernel(input: &KamaInput, kernel: Kernel) -> Result<KamaOutput, KamaError> {
231 let (data, period, first, chosen) = kama_prepare(input, kernel)?;
232
233 let warm = first + period;
234 let mut out = alloc_with_nan_prefix(data.len(), warm);
235
236 kama_compute_into(data, period, first, chosen, &mut out);
237
238 Ok(KamaOutput { values: out })
239}
240
241#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
242pub fn kama_into(input: &KamaInput, out: &mut [f64]) -> Result<(), KamaError> {
243 let (data, period, first, chosen) = kama_prepare(input, Kernel::Auto)?;
244
245 if out.len() != data.len() {
246 return Err(KamaError::OutputLengthMismatch {
247 expected: data.len(),
248 got: out.len(),
249 });
250 }
251
252 let warm = first + period;
253 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
254 let pref = warm.min(out.len());
255 for v in &mut out[..pref] {
256 *v = qnan;
257 }
258
259 kama_compute_into(data, period, first, chosen, out);
260
261 Ok(())
262}
263
264#[inline]
265pub fn kama_into_slice(dst: &mut [f64], input: &KamaInput, kern: Kernel) -> Result<(), KamaError> {
266 let (data, period, first, chosen) = kama_prepare(input, kern)?;
267
268 if dst.len() != data.len() {
269 return Err(KamaError::OutputLengthMismatch {
270 expected: data.len(),
271 got: dst.len(),
272 });
273 }
274
275 kama_compute_into(data, period, first, chosen, dst);
276
277 let warmup_end = first + period;
278 for v in &mut dst[..warmup_end] {
279 *v = f64::NAN;
280 }
281
282 Ok(())
283}
284
285#[inline(always)]
286pub fn kama_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
287 assert_eq!(
288 out.len(),
289 data.len(),
290 "`out` must be the same length as `data`"
291 );
292
293 let len = data.len();
294 let lookback = period.saturating_sub(1);
295
296 let const_max = 2.0 / (30.0 + 1.0);
297 let const_diff = (2.0 / (2.0 + 1.0)) - const_max;
298
299 let mut sum_roc1 = 0.0;
300 let today = first_valid;
301 unsafe {
302 let mut prev = *data.get_unchecked(today);
303 for i in 0..=lookback {
304 let next = *data.get_unchecked(today + i + 1);
305 sum_roc1 += (next - prev).abs();
306 prev = next;
307 }
308 }
309
310 let initial_idx = today + lookback + 1;
311 let mut kama = data[initial_idx];
312 out[initial_idx] = kama;
313
314 let mut trailing_idx = today;
315 let mut trailing_value = data[trailing_idx];
316
317 unsafe {
318 let dp = data.as_ptr();
319 let op = out.as_mut_ptr();
320 for i in (initial_idx + 1)..len {
321 let price_prev = *dp.add(i - 1);
322 let price = *dp.add(i);
323
324 let next_tail = *dp.add(trailing_idx + 1);
325 let old_diff = (next_tail - trailing_value).abs();
326 let new_diff = (price - price_prev).abs();
327 sum_roc1 += new_diff - old_diff;
328
329 trailing_value = next_tail;
330 trailing_idx += 1;
331
332 let direction = (price - next_tail).abs();
333 let er = if sum_roc1 == 0.0 {
334 0.0
335 } else {
336 direction / sum_roc1
337 };
338 let t = er.mul_add(const_diff, const_max);
339 let sc = t * t;
340
341 kama = (price - kama).mul_add(sc, kama);
342 *op.add(i) = kama;
343 }
344 }
345}
346
347#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
348#[target_feature(enable = "avx2,fma")]
349#[inline]
350pub unsafe fn kama_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
351 use core::arch::x86_64::*;
352
353 const ABS_MASK: i64 = 0x7FFF_FFFF_FFFF_FFFFu64 as i64;
354 debug_assert!(period >= 2 && period <= data.len());
355 debug_assert_eq!(data.len(), out.len());
356
357 let lookback = period - 1;
358 let mut sum_roc1: f64 = 0.0;
359 let base = data.as_ptr().add(first_valid);
360
361 if lookback >= 15 {
362 let mask_pd = _mm256_castsi256_pd(_mm256_set1_epi64x(ABS_MASK));
363 let mut acc0 = _mm256_setzero_pd();
364 let mut acc1 = _mm256_setzero_pd();
365 let mut idx = 0usize;
366
367 #[inline(always)]
368 unsafe fn abs_diff(ptr: *const f64, ofs: usize, m: __m256d) -> __m256d {
369 let a = _mm256_loadu_pd(ptr.add(ofs));
370 let b = _mm256_loadu_pd(ptr.add(ofs + 1));
371 _mm256_and_pd(_mm256_sub_pd(b, a), m)
372 }
373
374 while idx + 15 <= lookback {
375 acc0 = _mm256_add_pd(acc0, abs_diff(base, idx, mask_pd));
376 acc1 = _mm256_add_pd(acc1, abs_diff(base, idx + 4, mask_pd));
377 acc0 = _mm256_add_pd(acc0, abs_diff(base, idx + 8, mask_pd));
378 acc1 = _mm256_add_pd(acc1, abs_diff(base, idx + 12, mask_pd));
379 idx += 16;
380 }
381
382 let sumv = _mm256_add_pd(acc0, acc1);
383 let hi = _mm256_extractf128_pd::<1>(sumv);
384 let lo = _mm256_castpd256_pd128(sumv);
385 let pair = _mm_add_pd(lo, hi);
386 sum_roc1 = _mm_cvtsd_f64(pair) + _mm_cvtsd_f64(_mm_unpackhi_pd(pair, pair));
387
388 while idx <= lookback {
389 sum_roc1 += (*base.add(idx + 1) - *base.add(idx)).abs();
390 idx += 1;
391 }
392 } else {
393 for k in 0..=lookback {
394 sum_roc1 += (*base.add(k + 1) - *base.add(k)).abs();
395 }
396 }
397
398 let init_idx = first_valid + lookback + 1;
399 let mut kama = *data.get_unchecked(init_idx);
400 *out.get_unchecked_mut(init_idx) = kama;
401
402 let const_max = 2.0 / 31.0;
403 let const_diff = (2.0 / 3.0) - const_max;
404
405 let mut tail_idx = first_valid;
406 let mut tail_val = *data.get_unchecked(tail_idx);
407
408 for i in (init_idx + 1)..data.len() {
409 let price = *data.get_unchecked(i);
410 let new_diff = (price - *data.get_unchecked(i - 1)).abs();
411
412 let next_tail = *data.get_unchecked(tail_idx + 1);
413 let old_diff = (next_tail - tail_val).abs();
414 sum_roc1 += new_diff - old_diff;
415
416 tail_val = next_tail;
417 tail_idx += 1;
418
419 let direction = (price - next_tail).abs();
420 let er = if sum_roc1 == 0.0 {
421 0.0
422 } else {
423 direction / sum_roc1
424 };
425 let t = er.mul_add(const_diff, const_max);
426 let sc = t * t;
427
428 kama = (price - kama).mul_add(sc, kama);
429
430 *out.get_unchecked_mut(i) = kama;
431
432 let pf = core::cmp::min(i + 128, data.len() - 1);
433 _mm_prefetch(data.as_ptr().add(pf) as *const i8, _MM_HINT_T1);
434 }
435}
436
437#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
438#[target_feature(enable = "avx512f,avx512vl,fma")]
439#[inline]
440pub unsafe fn kama_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
441 use core::arch::x86_64::*;
442
443 const ABS_MASK: i64 = 0x7FFF_FFFF_FFFF_FFFFu64 as i64;
444
445 debug_assert!(period >= 2 && period <= data.len());
446 debug_assert_eq!(data.len(), out.len());
447
448 let lookback = period - 1;
449 let mut sum_roc1: f64 = 0.0;
450 let base = data.as_ptr().add(first_valid);
451
452 if lookback >= 31 {
453 let mask_pd = _mm512_castsi512_pd(_mm512_set1_epi64(ABS_MASK));
454 let mut acc0 = _mm512_setzero_pd();
455 let mut acc1 = _mm512_setzero_pd();
456 let mut acc2 = _mm512_setzero_pd();
457 let mut acc3 = _mm512_setzero_pd();
458
459 #[inline(always)]
460 unsafe fn abs_diff(ptr: *const f64, idx: usize, mask: __m512d) -> __m512d {
461 let a = _mm512_loadu_pd(ptr.add(idx));
462 let b = _mm512_loadu_pd(ptr.add(idx + 1));
463 _mm512_and_pd(_mm512_sub_pd(b, a), mask)
464 }
465
466 let mut j = 0usize;
467 while j + 31 <= lookback {
468 acc0 = _mm512_add_pd(acc0, abs_diff(base, j, mask_pd));
469 acc1 = _mm512_add_pd(acc1, abs_diff(base, j + 8, mask_pd));
470 acc2 = _mm512_add_pd(acc2, abs_diff(base, j + 16, mask_pd));
471 acc3 = _mm512_add_pd(acc3, abs_diff(base, j + 24, mask_pd));
472 j += 32;
473 }
474 let acc_all = _mm512_add_pd(_mm512_add_pd(acc0, acc1), _mm512_add_pd(acc2, acc3));
475 sum_roc1 = _mm512_reduce_add_pd(acc_all);
476
477 while j <= lookback {
478 sum_roc1 += (*base.add(j + 1) - *base.add(j)).abs();
479 j += 1;
480 }
481 } else {
482 for k in 0..=lookback {
483 sum_roc1 += (*base.add(k + 1) - *base.add(k)).abs();
484 }
485 }
486
487 let init_idx = first_valid + lookback + 1;
488 let mut kama = *data.get_unchecked(init_idx);
489 *out.get_unchecked_mut(init_idx) = kama;
490
491 let const_max = 2.0 / 31.0;
492 let const_diff = (2.0 / 3.0) - const_max;
493
494 let mut tail_idx = first_valid;
495 let mut tail_val = *data.get_unchecked(tail_idx);
496
497 for i in (init_idx + 1)..data.len() {
498 let price = *data.get_unchecked(i);
499 let new_diff = (price - *data.get_unchecked(i - 1)).abs();
500
501 let next_tail = *data.get_unchecked(tail_idx + 1);
502 let old_diff = (next_tail - tail_val).abs();
503 sum_roc1 += new_diff - old_diff;
504
505 tail_val = next_tail;
506 tail_idx += 1;
507
508 let direction = (price - next_tail).abs();
509 let er = if sum_roc1 == 0.0 {
510 0.0
511 } else {
512 direction / sum_roc1
513 };
514
515 let t = er.mul_add(const_diff, const_max);
516 let sc = t * t;
517
518 kama = (price - kama).mul_add(sc, kama);
519
520 *out.get_unchecked_mut(i) = kama;
521
522 let pf = core::cmp::min(i + 128, data.len() - 1);
523 _mm_prefetch(data.as_ptr().add(pf) as *const i8, _MM_HINT_T1);
524 }
525}
526
527#[derive(Debug, Clone)]
528pub struct KamaStream {
529 period: usize,
530
531 prices: Vec<f64>,
532 diffs: Vec<f64>,
533
534 head_p: usize,
535 head_d: usize,
536 count: usize,
537 seeded: bool,
538
539 prev_price: f64,
540 prev_kama: f64,
541 sum_roc1: f64,
542
543 const_max: f64,
544 const_diff: f64,
545}
546
547impl KamaStream {
548 #[inline]
549 pub fn try_new(params: KamaParams) -> Result<Self, KamaError> {
550 let period = params.period.unwrap_or(30);
551 if period == 0 {
552 return Err(KamaError::InvalidPeriod {
553 period,
554 data_len: 0,
555 });
556 }
557 Ok(Self {
558 period,
559 prices: vec![0.0; period],
560 diffs: vec![0.0; period],
561 head_p: 0,
562 head_d: 0,
563 count: 0,
564 seeded: false,
565 prev_price: 0.0,
566 prev_kama: 0.0,
567 sum_roc1: 0.0,
568 const_max: 2.0 / (30.0 + 1.0),
569 const_diff: (2.0 / (2.0 + 1.0)) - (2.0 / (30.0 + 1.0)),
570 })
571 }
572
573 #[inline(always)]
574 pub fn update(&mut self, value: f64) -> Option<f64> {
575 if self.count == 0 {
576 self.prices[0] = value;
577 self.prev_price = value;
578 self.count = 1;
579 return None;
580 }
581
582 let new_diff = (value - self.prev_price).abs();
583
584 if self.count < self.period {
585 self.diffs[self.count - 1] = new_diff;
586 self.sum_roc1 += new_diff;
587
588 self.prices[self.count] = value;
589 self.prev_price = value;
590 self.count += 1;
591 return None;
592 }
593
594 if !self.seeded {
595 self.sum_roc1 += new_diff;
596
597 self.prev_kama = value;
598
599 self.diffs[self.period - 1] = new_diff;
600
601 self.prices[self.head_p] = value;
602 self.head_p = if self.period == 1 { 0 } else { 1 };
603 self.head_d = 0;
604
605 self.prev_price = value;
606 self.seeded = true;
607 return Some(self.prev_kama);
608 }
609
610 let old_diff = self.diffs[self.head_d];
611 self.sum_roc1 += new_diff - old_diff;
612
613 let tail_price = self.prices[self.head_p];
614 let direction = (value - tail_price).abs();
615 let er = if self.sum_roc1 == 0.0 {
616 0.0
617 } else {
618 direction * (1.0 / self.sum_roc1)
619 };
620 let t = er.mul_add(self.const_diff, self.const_max);
621 let sc = t * t;
622
623 self.prev_kama = (value - self.prev_kama).mul_add(sc, self.prev_kama);
624
625 self.diffs[self.head_d] = new_diff;
626 self.head_d = if self.head_d + 1 == self.period {
627 0
628 } else {
629 self.head_d + 1
630 };
631
632 self.prices[self.head_p] = value;
633 self.head_p = if self.head_p + 1 == self.period {
634 0
635 } else {
636 self.head_p + 1
637 };
638
639 self.prev_price = value;
640 Some(self.prev_kama)
641 }
642}
643
644#[derive(Clone, Debug)]
645pub struct KamaBatchRange {
646 pub period: (usize, usize, usize),
647}
648
649impl Default for KamaBatchRange {
650 fn default() -> Self {
651 Self {
652 period: (30, 279, 1),
653 }
654 }
655}
656
657#[derive(Clone, Debug, Default)]
658pub struct KamaBatchBuilder {
659 range: KamaBatchRange,
660 kernel: Kernel,
661}
662
663impl KamaBatchBuilder {
664 pub fn new() -> Self {
665 Self::default()
666 }
667 pub fn kernel(mut self, k: Kernel) -> Self {
668 self.kernel = k;
669 self
670 }
671 #[inline]
672 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
673 self.range.period = (start, end, step);
674 self
675 }
676 #[inline]
677 pub fn period_static(mut self, p: usize) -> Self {
678 self.range.period = (p, p, 0);
679 self
680 }
681 pub fn apply_slice(self, data: &[f64]) -> Result<KamaBatchOutput, KamaError> {
682 kama_batch_with_kernel(data, &self.range, self.kernel)
683 }
684 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<KamaBatchOutput, KamaError> {
685 KamaBatchBuilder::new().kernel(k).apply_slice(data)
686 }
687 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<KamaBatchOutput, KamaError> {
688 let slice = source_type(c, src);
689 self.apply_slice(slice)
690 }
691 pub fn with_default_candles(c: &Candles) -> Result<KamaBatchOutput, KamaError> {
692 KamaBatchBuilder::new()
693 .kernel(Kernel::Auto)
694 .apply_candles(c, "close")
695 }
696}
697
698pub fn kama_batch_with_kernel(
699 data: &[f64],
700 sweep: &KamaBatchRange,
701 k: Kernel,
702) -> Result<KamaBatchOutput, KamaError> {
703 let kernel = match k {
704 Kernel::Auto => match detect_best_batch_kernel() {
705 Kernel::Avx512Batch => Kernel::Avx2Batch,
706 other => other,
707 },
708 other if other.is_batch() => other,
709 _ => return Err(KamaError::InvalidKernelForBatch(k)),
710 };
711 let simd = match kernel {
712 Kernel::Avx512Batch => Kernel::Avx512,
713 Kernel::Avx2Batch => Kernel::Avx2,
714 Kernel::ScalarBatch => Kernel::Scalar,
715 _ => unreachable!(),
716 };
717 kama_batch_par_slice(data, sweep, simd)
718}
719
720#[derive(Clone, Debug)]
721pub struct KamaBatchOutput {
722 pub values: Vec<f64>,
723 pub combos: Vec<KamaParams>,
724 pub rows: usize,
725 pub cols: usize,
726}
727
728impl KamaBatchOutput {
729 pub fn row_for_params(&self, p: &KamaParams) -> Option<usize> {
730 self.combos
731 .iter()
732 .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
733 }
734 pub fn values_for(&self, p: &KamaParams) -> Option<&[f64]> {
735 self.row_for_params(p).map(|row| {
736 let start = row * self.cols;
737 &self.values[start..start + self.cols]
738 })
739 }
740}
741
742#[inline(always)]
743fn expand_grid(r: &KamaBatchRange) -> Result<Vec<KamaParams>, KamaError> {
744 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, KamaError> {
745 if step == 0 || start == end {
746 return Ok(vec![start]);
747 }
748 if start < end {
749 return Ok((start..=end).step_by(step).collect());
750 }
751
752 let mut v = Vec::new();
753 let mut cur = start;
754 loop {
755 v.push(cur);
756 match cur.checked_sub(step) {
757 Some(next) if next >= end => {
758 cur = next;
759 }
760 _ => break,
761 }
762 }
763 if v.is_empty() {
764 Err(KamaError::InvalidRange { start, end, step })
765 } else {
766 Ok(v)
767 }
768 }
769 let periods = axis_usize(r.period)?;
770 let combos: Vec<KamaParams> = periods
771 .into_iter()
772 .map(|p| KamaParams { period: Some(p) })
773 .collect();
774 if combos.is_empty() {
775 return Err(KamaError::InvalidRange {
776 start: r.period.0,
777 end: r.period.1,
778 step: r.period.2,
779 });
780 }
781 Ok(combos)
782}
783
784#[inline(always)]
785pub fn kama_batch_slice(
786 data: &[f64],
787 sweep: &KamaBatchRange,
788 kern: Kernel,
789) -> Result<KamaBatchOutput, KamaError> {
790 kama_batch_inner(data, sweep, kern, false)
791}
792
793#[inline(always)]
794pub fn kama_batch_par_slice(
795 data: &[f64],
796 sweep: &KamaBatchRange,
797 kern: Kernel,
798) -> Result<KamaBatchOutput, KamaError> {
799 kama_batch_inner(data, sweep, kern, true)
800}
801
802#[inline(always)]
803fn kama_batch_inner(
804 data: &[f64],
805 sweep: &KamaBatchRange,
806 kern: Kernel,
807 parallel: bool,
808) -> Result<KamaBatchOutput, KamaError> {
809 let combos = expand_grid(sweep)?;
810 let cols = data.len();
811 let rows = combos.len();
812
813 if cols == 0 {
814 return Err(KamaError::EmptyInputData);
815 }
816
817 let total_cells = rows
818 .checked_mul(cols)
819 .ok_or(KamaError::InvalidInput("rows*cols overflow"))?;
820
821 let mut buf_mu = make_uninit_matrix(rows, cols);
822
823 let first = data
824 .iter()
825 .position(|x| !x.is_nan())
826 .ok_or(KamaError::AllValuesNaN)?;
827
828 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
829 if data.len() - first <= max_p {
830 return Err(KamaError::NotEnoughValidData {
831 needed: max_p + 1,
832 valid: data.len() - first,
833 });
834 }
835
836 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
837
838 init_matrix_prefixes(&mut buf_mu, cols, &warm);
839
840 let mut buf_guard = ManuallyDrop::new(buf_mu);
841 let out: &mut [f64] = unsafe {
842 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
843 };
844
845 kama_batch_inner_into(data, sweep, kern, parallel, out)?;
846
847 let values = unsafe {
848 Vec::from_raw_parts(
849 buf_guard.as_mut_ptr() as *mut f64,
850 total_cells,
851 buf_guard.capacity(),
852 )
853 };
854
855 Ok(KamaBatchOutput {
856 values,
857 combos,
858 rows,
859 cols,
860 })
861}
862
863#[inline(always)]
864fn kama_batch_inner_into(
865 data: &[f64],
866 sweep: &KamaBatchRange,
867 kern: Kernel,
868 parallel: bool,
869 out: &mut [f64],
870) -> Result<Vec<KamaParams>, KamaError> {
871 let combos = expand_grid(sweep)?;
872 let first = data
873 .iter()
874 .position(|x| !x.is_nan())
875 .ok_or(KamaError::AllValuesNaN)?;
876 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
877 if data.len() - first <= max_p {
878 return Err(KamaError::NotEnoughValidData {
879 needed: max_p + 1,
880 valid: data.len() - first,
881 });
882 }
883 let rows = combos.len();
884 let cols = data.len();
885
886 let expected = rows
887 .checked_mul(cols)
888 .ok_or(KamaError::InvalidInput("rows*cols overflow"))?;
889 if out.len() != expected {
890 return Err(KamaError::OutputLengthMismatch {
891 expected,
892 got: out.len(),
893 });
894 }
895
896 let mut abs_delta = vec![0.0f64; cols];
897 if cols > 1 {
898 for i in (first + 1)..cols {
899 let a = data[i];
900 let b = data[i - 1];
901 abs_delta[i] = (a - b).abs();
902 }
903 }
904 let mut prefix = vec![0.0f64; cols];
905 let mut run = 0.0f64;
906 for i in 0..cols {
907 run += abs_delta[i];
908 prefix[i] = run;
909 }
910
911 let out_uninit = unsafe {
912 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
913 };
914
915 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
916 let period = combos[row].period.unwrap();
917
918 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
919
920 match kern {
921 Kernel::Scalar | Kernel::ScalarBatch => {
922 kama_row_scalar_prefixed(data, &prefix, first, period, dst)
923 }
924 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
925 Kernel::Avx2 | Kernel::Avx2Batch => kama_row_avx2(data, first, period, dst),
926 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
927 Kernel::Avx512 | Kernel::Avx512Batch => kama_row_avx512(data, first, period, dst),
928 _ => kama_row_scalar(data, first, period, dst),
929 }
930 };
931
932 if parallel {
933 #[cfg(not(target_arch = "wasm32"))]
934 {
935 out_uninit
936 .par_chunks_mut(cols)
937 .enumerate()
938 .for_each(|(row, slice)| do_row(row, slice));
939 }
940
941 #[cfg(target_arch = "wasm32")]
942 {
943 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
944 do_row(row, slice);
945 }
946 }
947 } else {
948 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
949 do_row(row, slice);
950 }
951 }
952
953 Ok(combos)
954}
955
956#[inline(always)]
957unsafe fn kama_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
958 kama_scalar(data, period, first, out)
959}
960
961#[inline(always)]
962unsafe fn kama_row_scalar_prefixed(
963 data: &[f64],
964 prefix: &[f64],
965 first: usize,
966 period: usize,
967 out: &mut [f64],
968) {
969 let len = data.len();
970 let lookback = period - 1;
971
972 let sum0 = prefix[first + period] - prefix[first];
973
974 let init_idx = first + lookback + 1;
975 let mut kama = data[init_idx];
976 out[init_idx] = kama;
977
978 let mut sum_roc1 = sum0;
979 let const_max = 2.0 / 31.0;
980 let const_diff = (2.0 / 3.0) - const_max;
981 let mut trailing_idx = first;
982 let mut trailing_value = data[trailing_idx];
983
984 for i in (init_idx + 1)..len {
985 let price_prev = data[i - 1];
986 let price = data[i];
987
988 let next_tail = data[trailing_idx + 1];
989 let old_diff = (next_tail - trailing_value).abs();
990 let new_diff = (price - price_prev).abs();
991 sum_roc1 += new_diff - old_diff;
992
993 trailing_value = next_tail;
994 trailing_idx += 1;
995
996 let direction = (price - trailing_value).abs();
997 let er = if sum_roc1 == 0.0 {
998 0.0
999 } else {
1000 direction / sum_roc1
1001 };
1002 let t = er.mul_add(const_diff, const_max);
1003 let sc = t * t;
1004
1005 kama = (price - kama).mul_add(sc, kama);
1006 out[i] = kama;
1007 }
1008}
1009
1010#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1011#[inline(always)]
1012pub unsafe fn kama_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1013 kama_avx2(data, period, first, out)
1014}
1015
1016#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1017#[inline(always)]
1018pub unsafe fn kama_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1019 kama_avx512(data, period, first, out)
1020}
1021
1022#[cfg(feature = "python")]
1023use crate::utilities::kernel_validation::validate_kernel;
1024#[cfg(all(feature = "python", feature = "cuda"))]
1025use numpy::PyReadonlyArray2;
1026#[cfg(feature = "python")]
1027use numpy::PyUntypedArrayMethods;
1028#[cfg(feature = "python")]
1029use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1030#[cfg(feature = "python")]
1031use pyo3::exceptions::PyValueError;
1032#[cfg(feature = "python")]
1033use pyo3::prelude::*;
1034#[cfg(all(feature = "python", feature = "cuda"))]
1035use pyo3::types::PyDict;
1036
1037#[cfg(feature = "python")]
1038#[pyfunction(name = "kama")]
1039#[pyo3(signature = (data, period, kernel=None))]
1040pub fn kama_py<'py>(
1041 py: Python<'py>,
1042 data: PyReadonlyArray1<'py, f64>,
1043 period: usize,
1044 kernel: Option<&str>,
1045) -> PyResult<Bound<'py, PyArray1<f64>>> {
1046 use numpy::{IntoPyArray, PyArrayMethods};
1047
1048 let slice_in = data.as_slice()?;
1049 let kern = validate_kernel(kernel, false)?;
1050
1051 let params = KamaParams {
1052 period: Some(period),
1053 };
1054 let kama_in = KamaInput::from_slice(slice_in, params);
1055
1056 let result_vec: Vec<f64> = py
1057 .allow_threads(|| kama_with_kernel(&kama_in, kern).map(|o| o.values))
1058 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1059
1060 Ok(result_vec.into_pyarray(py))
1061}
1062
1063#[cfg(feature = "python")]
1064#[pyfunction(name = "kama_batch")]
1065#[pyo3(signature = (data, period_range, kernel=None))]
1066pub fn kama_batch_py<'py>(
1067 py: Python<'py>,
1068 data: PyReadonlyArray1<'py, f64>,
1069 period_range: (usize, usize, usize),
1070 kernel: Option<&str>,
1071) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1072 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1073 use pyo3::types::PyDict;
1074
1075 let slice_in = data.as_slice()?;
1076 let kern = validate_kernel(kernel, true)?;
1077
1078 let first = slice_in
1079 .iter()
1080 .position(|x| !x.is_nan())
1081 .ok_or_else(|| PyValueError::new_err("kama: All values are NaN."))?;
1082
1083 let sweep = KamaBatchRange {
1084 period: period_range,
1085 };
1086
1087 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1088 let rows = combos.len();
1089 let cols = slice_in.len();
1090 let total = rows
1091 .checked_mul(cols)
1092 .ok_or_else(|| PyValueError::new_err("kama_batch: rows*cols overflow"))?;
1093
1094 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1095 let slice_out = unsafe { out_arr.as_slice_mut()? };
1096
1097 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1098
1099 for (row, &warmup) in warm.iter().enumerate() {
1100 let start = row * cols;
1101 let end = start + warmup.min(cols);
1102 for v in &mut slice_out[start..end] {
1103 *v = f64::NAN;
1104 }
1105 }
1106
1107 let combos = py
1108 .allow_threads(|| {
1109 let kernel = match kern {
1110 Kernel::Auto => match detect_best_batch_kernel() {
1111 Kernel::Avx512Batch => Kernel::Avx2Batch,
1112 other => other,
1113 },
1114 k => k,
1115 };
1116 let simd = match kernel {
1117 Kernel::Avx512Batch => Kernel::Avx512,
1118 Kernel::Avx2Batch => Kernel::Avx2,
1119 Kernel::ScalarBatch => Kernel::Scalar,
1120
1121 Kernel::Scalar => Kernel::Scalar,
1122 Kernel::Avx2 => Kernel::Avx2,
1123 Kernel::Avx512 => Kernel::Avx512,
1124 _ => Kernel::Scalar,
1125 };
1126
1127 kama_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1128 })
1129 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1130
1131 let dict = PyDict::new(py);
1132 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1133 dict.set_item(
1134 "periods",
1135 combos
1136 .iter()
1137 .map(|p| p.period.unwrap() as u64)
1138 .collect::<Vec<_>>()
1139 .into_pyarray(py),
1140 )?;
1141
1142 Ok(dict)
1143}
1144
1145#[cfg(all(feature = "python", feature = "cuda"))]
1146#[pyfunction(name = "kama_cuda_batch_dev")]
1147#[pyo3(signature = (data_f32, period_range, device_id=0))]
1148pub fn kama_cuda_batch_dev_py(
1149 py: Python<'_>,
1150 data_f32: PyReadonlyArray1<'_, f32>,
1151 period_range: (usize, usize, usize),
1152 device_id: usize,
1153) -> PyResult<KamaDeviceArrayF32Py> {
1154 if !cuda_available() {
1155 return Err(PyValueError::new_err("CUDA not available"));
1156 }
1157
1158 let slice_in = data_f32.as_slice()?;
1159 let sweep = KamaBatchRange {
1160 period: period_range,
1161 };
1162
1163 let inner = py.allow_threads(|| {
1164 let cuda = CudaKama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1165 cuda.kama_batch_dev(slice_in, &sweep)
1166 .map_err(|e| PyValueError::new_err(e.to_string()))
1167 })?;
1168
1169 Ok(KamaDeviceArrayF32Py { inner })
1170}
1171
1172#[cfg(all(feature = "python", feature = "cuda"))]
1173#[pyfunction(name = "kama_cuda_many_series_one_param_dev")]
1174#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1175pub fn kama_cuda_many_series_one_param_dev_py(
1176 py: Python<'_>,
1177 data_tm_f32: PyReadonlyArray2<'_, f32>,
1178 period: usize,
1179 device_id: usize,
1180) -> PyResult<KamaDeviceArrayF32Py> {
1181 if !cuda_available() {
1182 return Err(PyValueError::new_err("CUDA not available"));
1183 }
1184
1185 let flat_in = data_tm_f32.as_slice()?;
1186 let rows = data_tm_f32.shape()[0];
1187 let cols = data_tm_f32.shape()[1];
1188 let params = KamaParams {
1189 period: Some(period),
1190 };
1191
1192 let inner = py.allow_threads(|| {
1193 let cuda = CudaKama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1194 cuda.kama_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1195 .map_err(|e| PyValueError::new_err(e.to_string()))
1196 })?;
1197
1198 Ok(KamaDeviceArrayF32Py { inner })
1199}
1200
1201#[cfg(all(feature = "python", feature = "cuda"))]
1202#[pyclass(module = "ta_indicators.cuda", unsendable)]
1203pub struct KamaDeviceArrayF32Py {
1204 pub(crate) inner: DeviceArrayF32Kama,
1205}
1206
1207#[cfg(all(feature = "python", feature = "cuda"))]
1208#[pymethods]
1209impl KamaDeviceArrayF32Py {
1210 #[getter]
1211 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1212 let d = PyDict::new(py);
1213 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1214 d.set_item("typestr", "<f4")?;
1215 d.set_item(
1216 "strides",
1217 (
1218 self.inner.cols * std::mem::size_of::<f32>(),
1219 std::mem::size_of::<f32>(),
1220 ),
1221 )?;
1222 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1223
1224 d.set_item("version", 3)?;
1225 Ok(d)
1226 }
1227
1228 fn __dlpack_device__(&self) -> (i32, i32) {
1229 (2, self.inner.device_id as i32)
1230 }
1231
1232 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1233 fn __dlpack__<'py>(
1234 &mut self,
1235 py: Python<'py>,
1236 stream: Option<pyo3::PyObject>,
1237 max_version: Option<pyo3::PyObject>,
1238 dl_device: Option<pyo3::PyObject>,
1239 copy: Option<pyo3::PyObject>,
1240 ) -> PyResult<PyObject> {
1241 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1242 use cust::memory::DeviceBuffer;
1243
1244 if let Some(obj) = &stream {
1245 if let Ok(i) = obj.extract::<i64>(py) {
1246 if i == 0 {
1247 return Err(pyo3::exceptions::PyValueError::new_err(
1248 "__dlpack__: stream 0 is disallowed for CUDA",
1249 ));
1250 }
1251 }
1252 }
1253
1254 let (kdl, alloc_dev) = self.__dlpack_device__();
1255 if let Some(dev_obj) = dl_device.as_ref() {
1256 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1257 if dev_ty != kdl || dev_id != alloc_dev {
1258 let wants_copy = copy
1259 .as_ref()
1260 .and_then(|c| c.extract::<bool>(py).ok())
1261 .unwrap_or(false);
1262 if wants_copy {
1263 return Err(pyo3::exceptions::PyNotImplementedError::new_err(
1264 "copy across devices not implemented",
1265 ));
1266 } else {
1267 return Err(PyValueError::new_err(
1268 "dl_device does not match allocation device_id",
1269 ));
1270 }
1271 }
1272 }
1273 }
1274
1275 let _ = stream;
1276
1277 let dummy =
1278 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1279
1280 let ctx = self.inner.ctx.clone();
1281 let device_id = self.inner.device_id;
1282 let inner = std::mem::replace(
1283 &mut self.inner,
1284 DeviceArrayF32Kama {
1285 buf: dummy,
1286 rows: 0,
1287 cols: 0,
1288 ctx,
1289 device_id,
1290 },
1291 );
1292
1293 let rows = inner.rows;
1294 let cols = inner.cols;
1295 let buf = inner.buf;
1296
1297 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1298
1299 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1300 }
1301}
1302
1303#[cfg(feature = "python")]
1304#[pyclass(name = "KamaStream")]
1305pub struct KamaStreamPy {
1306 inner: KamaStream,
1307}
1308
1309#[cfg(feature = "python")]
1310#[pymethods]
1311impl KamaStreamPy {
1312 #[new]
1313 pub fn new(period: usize) -> PyResult<Self> {
1314 let params = KamaParams {
1315 period: Some(period),
1316 };
1317 match KamaStream::try_new(params) {
1318 Ok(stream) => Ok(Self { inner: stream }),
1319 Err(e) => Err(PyValueError::new_err(format!("KamaStream error: {}", e))),
1320 }
1321 }
1322
1323 pub fn update(&mut self, value: f64) -> Option<f64> {
1324 self.inner.update(value)
1325 }
1326}
1327
1328#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1329use serde::{Deserialize, Serialize};
1330#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1331use wasm_bindgen::prelude::*;
1332
1333#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1334#[wasm_bindgen]
1335pub fn kama_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1336 let params = KamaParams {
1337 period: Some(period),
1338 };
1339 let input = KamaInput::from_slice(data, params);
1340
1341 let mut output = vec![0.0; data.len()];
1342
1343 kama_into_slice(&mut output, &input, Kernel::Auto)
1344 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1345
1346 Ok(output)
1347}
1348
1349#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1350#[wasm_bindgen]
1351pub fn kama_alloc(len: usize) -> *mut f64 {
1352 let mut vec = Vec::<f64>::with_capacity(len);
1353 let ptr = vec.as_mut_ptr();
1354 std::mem::forget(vec);
1355 ptr
1356}
1357
1358#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1359#[wasm_bindgen]
1360pub fn kama_free(ptr: *mut f64, len: usize) {
1361 if !ptr.is_null() {
1362 unsafe {
1363 let _ = Vec::from_raw_parts(ptr, len, len);
1364 }
1365 }
1366}
1367
1368#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1369#[wasm_bindgen]
1370pub fn kama_into(
1371 in_ptr: *const f64,
1372 out_ptr: *mut f64,
1373 len: usize,
1374 period: usize,
1375) -> Result<(), JsValue> {
1376 if in_ptr.is_null() || out_ptr.is_null() {
1377 return Err(JsValue::from_str("null pointer passed to kama_into"));
1378 }
1379
1380 unsafe {
1381 let data = std::slice::from_raw_parts(in_ptr, len);
1382
1383 if period == 0 || period > len {
1384 return Err(JsValue::from_str("Invalid period"));
1385 }
1386
1387 let params = KamaParams {
1388 period: Some(period),
1389 };
1390 let input = KamaInput::from_slice(data, params);
1391
1392 if in_ptr == out_ptr {
1393 let mut temp = vec![0.0; len];
1394 kama_into_slice(&mut temp, &input, Kernel::Auto)
1395 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1396
1397 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1398 out.copy_from_slice(&temp);
1399 } else {
1400 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1401 kama_into_slice(out, &input, Kernel::Auto)
1402 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1403 }
1404
1405 Ok(())
1406 }
1407}
1408
1409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1410#[derive(Serialize, Deserialize)]
1411pub struct KamaBatchConfig {
1412 pub period_range: (usize, usize, usize),
1413}
1414
1415#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1416#[derive(Serialize, Deserialize)]
1417pub struct KamaBatchJsOutput {
1418 pub values: Vec<f64>,
1419 pub combos: Vec<KamaParams>,
1420 pub rows: usize,
1421 pub cols: usize,
1422}
1423
1424#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1425#[wasm_bindgen(js_name = kama_batch)]
1426pub fn kama_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1427 let config: KamaBatchConfig = serde_wasm_bindgen::from_value(config)
1428 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1429
1430 let sweep = KamaBatchRange {
1431 period: config.period_range,
1432 };
1433
1434 let output = kama_batch_inner(data, &sweep, Kernel::Auto, false)
1435 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1436
1437 let js_output = KamaBatchJsOutput {
1438 values: output.values,
1439 combos: output.combos,
1440 rows: output.rows,
1441 cols: output.cols,
1442 };
1443
1444 serde_wasm_bindgen::to_value(&js_output)
1445 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1446}
1447
1448#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1449#[wasm_bindgen]
1450pub fn kama_batch_into(
1451 in_ptr: *const f64,
1452 out_ptr: *mut f64,
1453 len: usize,
1454 period_start: usize,
1455 period_end: usize,
1456 period_step: usize,
1457) -> Result<usize, JsValue> {
1458 if in_ptr.is_null() || out_ptr.is_null() {
1459 return Err(JsValue::from_str("null pointer passed to kama_batch_into"));
1460 }
1461
1462 unsafe {
1463 let data = std::slice::from_raw_parts(in_ptr, len);
1464
1465 let sweep = KamaBatchRange {
1466 period: (period_start, period_end, period_step),
1467 };
1468
1469 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1470 let rows = combos.len();
1471 let cols = len;
1472
1473 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1474
1475 kama_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1476 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1477
1478 Ok(rows)
1479 }
1480}
1481
1482#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1483#[wasm_bindgen]
1484pub fn kama_batch_js(
1485 data: &[f64],
1486 period_start: usize,
1487 period_end: usize,
1488 period_step: usize,
1489) -> Result<Vec<f64>, JsValue> {
1490 let sweep = KamaBatchRange {
1491 period: (period_start, period_end, period_step),
1492 };
1493 match kama_batch_slice(data, &sweep, Kernel::Auto) {
1494 Ok(output) => Ok(output.values),
1495 Err(e) => Err(JsValue::from_str(&format!("KAMA batch error: {}", e))),
1496 }
1497}
1498
1499#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1500#[wasm_bindgen]
1501pub fn kama_batch_metadata_js(
1502 period_start: usize,
1503 period_end: usize,
1504 period_step: usize,
1505) -> Vec<f64> {
1506 let periods: Vec<usize> = if period_step == 0 || period_start == period_end {
1507 vec![period_start]
1508 } else {
1509 (period_start..=period_end).step_by(period_step).collect()
1510 };
1511
1512 let mut result = Vec::new();
1513 for &period in &periods {
1514 result.push(period as f64);
1515 }
1516 result
1517}
1518
1519#[cfg(test)]
1520mod tests {
1521 use super::*;
1522 use crate::skip_if_unsupported;
1523 use crate::utilities::data_loader::read_candles_from_csv;
1524
1525 fn check_kama_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1526 skip_if_unsupported!(kernel, test_name);
1527 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1528 let candles = read_candles_from_csv(file_path)?;
1529 let default_params = KamaParams { period: None };
1530 let input = KamaInput::from_candles(&candles, "close", default_params);
1531 let output = kama_with_kernel(&input, kernel)?;
1532 assert_eq!(output.values.len(), candles.close.len());
1533 Ok(())
1534 }
1535
1536 fn check_kama_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1537 skip_if_unsupported!(kernel, test_name);
1538 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1539 let candles = read_candles_from_csv(file_path)?;
1540 let input = KamaInput::with_default_candles(&candles);
1541 let result = kama_with_kernel(&input, kernel)?;
1542 let expected_last_five = [
1543 60234.925553804125,
1544 60176.838757545665,
1545 60115.177367962766,
1546 60071.37070833558,
1547 59992.79386218023,
1548 ];
1549 assert!(
1550 result.values.len() >= 5,
1551 "Expected at least 5 values to compare"
1552 );
1553 assert_eq!(
1554 result.values.len(),
1555 candles.close.len(),
1556 "KAMA output length does not match input length"
1557 );
1558 let start_index = result.values.len().saturating_sub(5);
1559 let last_five = &result.values[start_index..];
1560 for (i, &val) in last_five.iter().enumerate() {
1561 let exp = expected_last_five[i];
1562 assert!(
1563 (val - exp).abs() < 1e-6,
1564 "KAMA mismatch at last-five index {}: expected {}, got {}",
1565 i,
1566 exp,
1567 val
1568 );
1569 }
1570 Ok(())
1571 }
1572
1573 fn check_kama_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574 skip_if_unsupported!(kernel, test_name);
1575 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576 let candles = read_candles_from_csv(file_path)?;
1577 let input = KamaInput::with_default_candles(&candles);
1578 match input.data {
1579 KamaData::Candles { source, .. } => assert_eq!(source, "close"),
1580 _ => panic!("Expected KamaData::Candles"),
1581 }
1582 let output = kama_with_kernel(&input, kernel)?;
1583 assert_eq!(output.values.len(), candles.close.len());
1584 Ok(())
1585 }
1586
1587 fn check_kama_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1588 skip_if_unsupported!(kernel, test_name);
1589 let input_data = [10.0, 20.0, 30.0];
1590 let params = KamaParams { period: Some(0) };
1591 let input = KamaInput::from_slice(&input_data, params);
1592 let res = kama_with_kernel(&input, kernel);
1593 assert!(
1594 res.is_err(),
1595 "[{}] KAMA should fail with zero period",
1596 test_name
1597 );
1598 Ok(())
1599 }
1600
1601 fn check_kama_period_exceeds_length(
1602 test_name: &str,
1603 kernel: Kernel,
1604 ) -> Result<(), Box<dyn Error>> {
1605 skip_if_unsupported!(kernel, test_name);
1606 let data_small = [10.0, 20.0, 30.0];
1607 let params = KamaParams { period: Some(10) };
1608 let input = KamaInput::from_slice(&data_small, params);
1609 let res = kama_with_kernel(&input, kernel);
1610 assert!(
1611 res.is_err(),
1612 "[{}] KAMA should fail with period exceeding length",
1613 test_name
1614 );
1615 Ok(())
1616 }
1617
1618 fn check_kama_very_small_dataset(
1619 test_name: &str,
1620 kernel: Kernel,
1621 ) -> Result<(), Box<dyn Error>> {
1622 skip_if_unsupported!(kernel, test_name);
1623 let single_point = [42.0];
1624 let params = KamaParams { period: Some(30) };
1625 let input = KamaInput::from_slice(&single_point, params);
1626 let res = kama_with_kernel(&input, kernel);
1627 assert!(
1628 res.is_err(),
1629 "[{}] KAMA should fail with insufficient data",
1630 test_name
1631 );
1632 Ok(())
1633 }
1634
1635 fn check_kama_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1636 skip_if_unsupported!(kernel, test_name);
1637 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1638 let candles = read_candles_from_csv(file_path)?;
1639 let first_params = KamaParams { period: Some(30) };
1640 let first_input = KamaInput::from_candles(&candles, "close", first_params);
1641 let first_result = kama_with_kernel(&first_input, kernel)?;
1642 let second_params = KamaParams { period: Some(10) };
1643 let second_input = KamaInput::from_slice(&first_result.values, second_params);
1644 let second_result = kama_with_kernel(&second_input, kernel)?;
1645 assert_eq!(second_result.values.len(), first_result.values.len());
1646 Ok(())
1647 }
1648
1649 fn check_kama_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1650 skip_if_unsupported!(kernel, test_name);
1651 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1652 let candles = read_candles_from_csv(file_path)?;
1653 let input = KamaInput::from_candles(&candles, "close", KamaParams { period: Some(30) });
1654 let res = kama_with_kernel(&input, kernel)?;
1655 assert_eq!(res.values.len(), candles.close.len());
1656 for val in res.values.iter().skip(30) {
1657 assert!(val.is_finite());
1658 }
1659 Ok(())
1660 }
1661
1662 fn check_kama_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1663 skip_if_unsupported!(kernel, test_name);
1664 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1665 let candles = read_candles_from_csv(file_path)?;
1666 let period = 30;
1667 let input = KamaInput::from_candles(
1668 &candles,
1669 "close",
1670 KamaParams {
1671 period: Some(period),
1672 },
1673 );
1674 let batch_output = kama_with_kernel(&input, kernel)?.values;
1675 let mut stream = KamaStream::try_new(KamaParams {
1676 period: Some(period),
1677 })?;
1678 let mut stream_values = Vec::with_capacity(candles.close.len());
1679 for &price in &candles.close {
1680 match stream.update(price) {
1681 Some(val) => stream_values.push(val),
1682 None => stream_values.push(f64::NAN),
1683 }
1684 }
1685 assert_eq!(batch_output.len(), stream_values.len());
1686 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1687 if b.is_nan() && s.is_nan() {
1688 continue;
1689 }
1690 let diff = (b - s).abs();
1691 assert!(
1692 diff < 1e-9,
1693 "[{}] KAMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1694 test_name,
1695 i,
1696 b,
1697 s,
1698 diff
1699 );
1700 }
1701 Ok(())
1702 }
1703
1704 macro_rules! generate_all_kama_tests {
1705 ($($test_fn:ident),*) => {
1706 paste::paste! {
1707 $(
1708 #[test]
1709 fn [<$test_fn _scalar_f64>]() {
1710 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1711 }
1712 )*
1713 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1714 $(
1715 #[test]
1716 fn [<$test_fn _avx2_f64>]() {
1717 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1718 }
1719 #[test]
1720 fn [<$test_fn _avx512_f64>]() {
1721 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1722 }
1723 )*
1724 }
1725 }
1726 }
1727
1728 #[cfg(debug_assertions)]
1729 fn check_kama_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1730 skip_if_unsupported!(kernel, test_name);
1731
1732 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1733 let candles = read_candles_from_csv(file_path)?;
1734
1735 let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1736 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1737
1738 for period in &test_periods {
1739 for source in &test_sources {
1740 let input = KamaInput::from_candles(
1741 &candles,
1742 source,
1743 KamaParams {
1744 period: Some(*period),
1745 },
1746 );
1747 let output = kama_with_kernel(&input, kernel)?;
1748
1749 for (i, &val) in output.values.iter().enumerate() {
1750 if val.is_nan() {
1751 continue;
1752 }
1753
1754 let bits = val.to_bits();
1755
1756 if bits == 0x11111111_11111111 {
1757 panic!(
1758 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1759 test_name, val, bits, i, period, source
1760 );
1761 }
1762
1763 if bits == 0x22222222_22222222 {
1764 panic!(
1765 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1766 test_name, val, bits, i, period, source
1767 );
1768 }
1769
1770 if bits == 0x33333333_33333333 {
1771 panic!(
1772 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1773 test_name, val, bits, i, period, source
1774 );
1775 }
1776 }
1777 }
1778 }
1779
1780 Ok(())
1781 }
1782
1783 #[cfg(not(debug_assertions))]
1784 fn check_kama_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1785 Ok(())
1786 }
1787
1788 fn check_kama_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1789 use proptest::prelude::*;
1790 skip_if_unsupported!(kernel, test_name);
1791
1792 let strat = (2usize..=100).prop_flat_map(|period| {
1793 (
1794 prop::collection::vec(
1795 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1796 (period + 1)..400,
1797 ),
1798 Just(period),
1799 )
1800 });
1801
1802 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1803 let params = KamaParams {
1804 period: Some(period),
1805 };
1806 let input = KamaInput::from_slice(&data, params);
1807
1808 let result = kama_with_kernel(&input, kernel);
1809 prop_assert!(
1810 result.is_ok(),
1811 "KAMA computation failed: {:?}",
1812 result.err()
1813 );
1814 let out = result.unwrap().values;
1815
1816 let ref_result = kama_with_kernel(&input, Kernel::Scalar);
1817 prop_assert!(ref_result.is_ok(), "Reference KAMA failed");
1818 let ref_out = ref_result.unwrap().values;
1819
1820 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1821
1822 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1823 let warmup_end = first_valid + period;
1824
1825 for i in 0..warmup_end.min(out.len()) {
1826 prop_assert!(
1827 out[i].is_nan(),
1828 "Expected NaN at index {} (warmup period), got {}",
1829 i,
1830 out[i]
1831 );
1832 }
1833
1834 for i in warmup_end..out.len() {
1835 prop_assert!(
1836 out[i].is_finite(),
1837 "Expected finite value at index {} (after warmup), got {}",
1838 i,
1839 out[i]
1840 );
1841 }
1842
1843 for i in warmup_end..out.len() {
1844 let window_start = i.saturating_sub(period);
1845 let window = &data[window_start..=i];
1846 let min_val = window.iter().cloned().fold(f64::INFINITY, f64::min);
1847 let max_val = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1848
1849 prop_assert!(
1850 out[i] >= min_val - 1e-6 && out[i] <= max_val + 1e-6,
1851 "KAMA at index {} = {} is outside window bounds [{}, {}]",
1852 i,
1853 out[i],
1854 min_val,
1855 max_val
1856 );
1857 }
1858
1859 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && !data.is_empty() {
1860 let const_val = data[first_valid];
1861
1862 if out.len() > warmup_end + period * 2 {
1863 let last_val = out[out.len() - 1];
1864 prop_assert!(
1865 (last_val - const_val).abs() < 1e-6,
1866 "For constant data {}, KAMA should converge but got {}",
1867 const_val,
1868 last_val
1869 );
1870 }
1871 }
1872
1873 for i in warmup_end..out.len() {
1874 let diff = (out[i] - ref_out[i]).abs();
1875 let ulps = {
1876 if out[i] == ref_out[i] {
1877 0
1878 } else {
1879 let a_bits = out[i].to_bits() as i64;
1880 let b_bits = ref_out[i].to_bits() as i64;
1881 (a_bits.wrapping_sub(b_bits)).unsigned_abs()
1882 }
1883 };
1884
1885 prop_assert!(
1886 ulps <= 10 || diff < 1e-10,
1887 "Kernel mismatch at index {}: {} vs {} (diff={}, ulps={})",
1888 i,
1889 out[i],
1890 ref_out[i],
1891 diff,
1892 ulps
1893 );
1894 }
1895
1896 for i in (warmup_end + 1)..out.len() {
1897 let change = (out[i] - out[i - 1]).abs();
1898 let price = data[i];
1899 let prev_kama = out[i - 1];
1900
1901 let max_possible_change = (price - prev_kama).abs() * 0.445;
1902 prop_assert!(
1903 change <= max_possible_change + 1e-6,
1904 "KAMA change {} exceeds maximum possible {} at index {}",
1905 change,
1906 max_possible_change,
1907 i
1908 );
1909 }
1910
1911 let post_warmup_data = &data[warmup_end..];
1912 if post_warmup_data.len() > period {
1913 let is_increasing = post_warmup_data.windows(2).all(|w| w[1] >= w[0] - 1e-10);
1914 let is_decreasing = post_warmup_data.windows(2).all(|w| w[1] <= w[0] + 1e-10);
1915
1916 if is_increasing {
1917 for i in (warmup_end + period)..out.len() {
1918 prop_assert!(
1919 out[i] >= out[i - 1] - 1e-6,
1920 "KAMA should be non-decreasing for increasing data at index {}",
1921 i
1922 );
1923 }
1924 }
1925 if is_decreasing {
1926 for i in (warmup_end + period)..out.len() {
1927 prop_assert!(
1928 out[i] <= out[i - 1] + 1e-6,
1929 "KAMA should be non-increasing for decreasing data at index {}",
1930 i
1931 );
1932 }
1933 }
1934 }
1935
1936 for i in (warmup_end + period)..out.len() {
1937 let window_start = i - period + 1;
1938 let window = &data[window_start..=i];
1939 let all_same = window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1940
1941 if all_same && i > warmup_end + period {
1942 let change = (out[i] - out[i - 1]).abs();
1943 let min_sc = (2.0 / 31.0_f64).powi(2);
1944 let max_change = (data[i] - out[i - 1]).abs() * min_sc;
1945 prop_assert!(
1946 change <= max_change + 1e-9,
1947 "With zero volatility at index {}, KAMA change {} exceeds minimum expected {}",
1948 i,
1949 change,
1950 max_change
1951 );
1952 }
1953 }
1954
1955 Ok(())
1956 })?;
1957
1958 Ok(())
1959 }
1960
1961 generate_all_kama_tests!(
1962 check_kama_partial_params,
1963 check_kama_accuracy,
1964 check_kama_default_candles,
1965 check_kama_zero_period,
1966 check_kama_period_exceeds_length,
1967 check_kama_very_small_dataset,
1968 check_kama_reinput,
1969 check_kama_nan_handling,
1970 check_kama_streaming,
1971 check_kama_no_poison,
1972 check_kama_property
1973 );
1974
1975 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1976 skip_if_unsupported!(kernel, test);
1977 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1978 let c = read_candles_from_csv(file)?;
1979 let output = KamaBatchBuilder::new()
1980 .kernel(kernel)
1981 .apply_candles(&c, "close")?;
1982 let def = KamaParams::default();
1983 let row = output.values_for(&def).expect("default row missing");
1984 assert_eq!(row.len(), c.close.len());
1985 let expected = [
1986 60234.925553804125,
1987 60176.838757545665,
1988 60115.177367962766,
1989 60071.37070833558,
1990 59992.79386218023,
1991 ];
1992 let start = row.len() - 5;
1993 for (i, &v) in row[start..].iter().enumerate() {
1994 assert!(
1995 (v - expected[i]).abs() < 1e-6,
1996 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1997 );
1998 }
1999 Ok(())
2000 }
2001
2002 macro_rules! gen_batch_tests {
2003 ($fn_name:ident) => {
2004 paste::paste! {
2005 #[test] fn [<$fn_name _scalar>]() {
2006 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2007 }
2008 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2009 #[test] fn [<$fn_name _avx2>]() {
2010 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2011 }
2012 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2013 #[test] fn [<$fn_name _avx512>]() {
2014 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2015 }
2016 #[test] fn [<$fn_name _auto_detect>]() {
2017 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2018 }
2019 }
2020 };
2021 }
2022
2023 #[cfg(debug_assertions)]
2024 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2025 skip_if_unsupported!(kernel, test);
2026
2027 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2028 let c = read_candles_from_csv(file)?;
2029
2030 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
2031
2032 for source in &test_sources {
2033 let output = KamaBatchBuilder::new()
2034 .kernel(kernel)
2035 .period_range(2, 200, 3)
2036 .apply_candles(&c, source)?;
2037
2038 for (idx, &val) in output.values.iter().enumerate() {
2039 if val.is_nan() {
2040 continue;
2041 }
2042
2043 let bits = val.to_bits();
2044 let row = idx / output.cols;
2045 let col = idx % output.cols;
2046
2047 if bits == 0x11111111_11111111 {
2048 panic!(
2049 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
2050 test, val, bits, row, col, idx, source
2051 );
2052 }
2053
2054 if bits == 0x22222222_22222222 {
2055 panic!(
2056 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
2057 test, val, bits, row, col, idx, source
2058 );
2059 }
2060
2061 if bits == 0x33333333_33333333 {
2062 panic!(
2063 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
2064 test, val, bits, row, col, idx, source
2065 );
2066 }
2067 }
2068 }
2069
2070 let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
2071 for (start, end, step) in edge_case_ranges {
2072 let output = KamaBatchBuilder::new()
2073 .kernel(kernel)
2074 .period_range(start, end, step)
2075 .apply_candles(&c, "close")?;
2076
2077 for (idx, &val) in output.values.iter().enumerate() {
2078 if val.is_nan() {
2079 continue;
2080 }
2081
2082 let bits = val.to_bits();
2083 let row = idx / output.cols;
2084 let col = idx % output.cols;
2085
2086 if bits == 0x11111111_11111111
2087 || bits == 0x22222222_22222222
2088 || bits == 0x33333333_33333333
2089 {
2090 panic!(
2091 "[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
2092 test, val, bits, row, col, start, end, step
2093 );
2094 }
2095 }
2096 }
2097
2098 Ok(())
2099 }
2100
2101 #[cfg(not(debug_assertions))]
2102 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2103 Ok(())
2104 }
2105
2106 gen_batch_tests!(check_batch_default_row);
2107 gen_batch_tests!(check_batch_no_poison);
2108
2109 #[test]
2110 fn test_kama_into_matches_api() -> Result<(), Box<dyn Error>> {
2111 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2112 let candles = read_candles_from_csv(file_path)?;
2113 let params = KamaParams::default();
2114 let input = KamaInput::from_candles(&candles, "close", params);
2115
2116 let base = kama(&input)?;
2117
2118 let mut out = vec![0.0f64; candles.close.len()];
2119 kama_into(&input, &mut out)?;
2120
2121 assert_eq!(base.values.len(), out.len());
2122 for (a, b) in base.values.iter().zip(out.iter()) {
2123 let equal = (a.is_nan() && b.is_nan()) || (*a - *b).abs() <= 1e-12;
2124 assert!(equal, "Mismatch: base={} vs into={}", a, b);
2125 }
2126 Ok(())
2127 }
2128}