1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaJma;
5
6#[cfg(all(feature = "python", feature = "cuda"))]
7use crate::cuda::moving_averages::jma_wrapper::DeviceArrayF32Jma;
8
9use crate::utilities::data_loader::{source_type, Candles};
10use crate::utilities::enums::Kernel;
11use crate::utilities::helpers::{
12 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
13 make_uninit_matrix,
14};
15#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
16use core::arch::x86_64::*;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19use std::convert::AsRef;
20use std::error::Error;
21use std::mem::MaybeUninit;
22use thiserror::Error;
23
24impl<'a> AsRef<[f64]> for JmaInput<'a> {
25 #[inline(always)]
26 fn as_ref(&self) -> &[f64] {
27 match &self.data {
28 JmaData::Slice(slice) => slice,
29 JmaData::Candles { candles, source } => source_type(candles, source),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35pub enum JmaData<'a> {
36 Candles {
37 candles: &'a Candles,
38 source: &'a str,
39 },
40 Slice(&'a [f64]),
41}
42
43#[derive(Debug, Clone)]
44pub struct JmaOutput {
45 pub values: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50 all(target_arch = "wasm32", feature = "wasm"),
51 derive(serde::Serialize, serde::Deserialize)
52)]
53pub struct JmaParams {
54 pub period: Option<usize>,
55 pub phase: Option<f64>,
56 pub power: Option<u32>,
57}
58
59impl Default for JmaParams {
60 fn default() -> Self {
61 Self {
62 period: Some(7),
63 phase: Some(50.0),
64 power: Some(2),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct JmaInput<'a> {
71 pub data: JmaData<'a>,
72 pub params: JmaParams,
73}
74
75impl<'a> JmaInput<'a> {
76 #[inline]
77 pub fn from_candles(c: &'a Candles, s: &'a str, p: JmaParams) -> Self {
78 Self {
79 data: JmaData::Candles {
80 candles: c,
81 source: s,
82 },
83 params: p,
84 }
85 }
86 #[inline]
87 pub fn from_slice(sl: &'a [f64], p: JmaParams) -> Self {
88 Self {
89 data: JmaData::Slice(sl),
90 params: p,
91 }
92 }
93 #[inline]
94 pub fn with_default_candles(c: &'a Candles) -> Self {
95 Self::from_candles(c, "close", JmaParams::default())
96 }
97 #[inline]
98 pub fn get_period(&self) -> usize {
99 self.params.period.unwrap_or(7)
100 }
101 #[inline]
102 pub fn get_phase(&self) -> f64 {
103 self.params.phase.unwrap_or(50.0)
104 }
105 #[inline]
106 pub fn get_power(&self) -> u32 {
107 self.params.power.unwrap_or(2)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct JmaBuilder {
113 period: Option<usize>,
114 phase: Option<f64>,
115 power: Option<u32>,
116 kernel: Kernel,
117}
118
119impl Default for JmaBuilder {
120 fn default() -> Self {
121 Self {
122 period: None,
123 phase: None,
124 power: None,
125 kernel: Kernel::Auto,
126 }
127 }
128}
129
130impl JmaBuilder {
131 #[inline(always)]
132 pub fn new() -> Self {
133 Self::default()
134 }
135 #[inline(always)]
136 pub fn period(mut self, n: usize) -> Self {
137 self.period = Some(n);
138 self
139 }
140 #[inline(always)]
141 pub fn phase(mut self, x: f64) -> Self {
142 self.phase = Some(x);
143 self
144 }
145 #[inline(always)]
146 pub fn power(mut self, p: u32) -> Self {
147 self.power = Some(p);
148 self
149 }
150 #[inline(always)]
151 pub fn kernel(mut self, k: Kernel) -> Self {
152 self.kernel = k;
153 self
154 }
155
156 #[inline(always)]
157 pub fn apply(self, c: &Candles) -> Result<JmaOutput, JmaError> {
158 let p = JmaParams {
159 period: self.period,
160 phase: self.phase,
161 power: self.power,
162 };
163 let i = JmaInput::from_candles(c, "close", p);
164 jma_with_kernel(&i, self.kernel)
165 }
166
167 #[inline(always)]
168 pub fn apply_slice(self, d: &[f64]) -> Result<JmaOutput, JmaError> {
169 let p = JmaParams {
170 period: self.period,
171 phase: self.phase,
172 power: self.power,
173 };
174 let i = JmaInput::from_slice(d, p);
175 jma_with_kernel(&i, self.kernel)
176 }
177
178 #[inline(always)]
179 pub fn into_stream(self) -> Result<JmaStream, JmaError> {
180 let p = JmaParams {
181 period: self.period,
182 phase: self.phase,
183 power: self.power,
184 };
185 JmaStream::try_new(p)
186 }
187}
188
189#[derive(Debug, Error)]
190pub enum JmaError {
191 #[error("jma: All values are NaN.")]
192 AllValuesNaN,
193 #[error("jma: Empty input data.")]
194 EmptyInputData,
195 #[error("jma: Invalid period: period = {period}, data length = {data_len}")]
196 InvalidPeriod { period: usize, data_len: usize },
197 #[error("jma: Not enough valid data: needed = {needed}, valid = {valid}")]
198 NotEnoughValidData { needed: usize, valid: usize },
199 #[error("jma: Invalid phase: {phase}")]
200 InvalidPhase { phase: f64 },
201 #[error("jma: Output length mismatch: expected = {expected}, got = {got}")]
202 OutputLengthMismatch { expected: usize, got: usize },
203 #[error("jma: Invalid range (usize): start={start}, end={end}, step={step}")]
204 InvalidRangeUsize {
205 start: usize,
206 end: usize,
207 step: usize,
208 },
209 #[error("jma: Invalid range (f64): start={start}, end={end}, step={step}")]
210 InvalidRangeF64 { start: f64, end: f64, step: f64 },
211 #[error("jma: Invalid kernel for batch: {0:?}")]
212 InvalidKernelForBatch(Kernel),
213 #[error("jma: Invalid input: {0}")]
214 InvalidInput(String),
215}
216
217#[inline]
218pub fn jma(input: &JmaInput) -> Result<JmaOutput, JmaError> {
219 jma_with_kernel(input, Kernel::Auto)
220}
221
222pub fn jma_with_kernel(input: &JmaInput, kernel: Kernel) -> Result<JmaOutput, JmaError> {
223 let data: &[f64] = match &input.data {
224 JmaData::Candles { candles, source } => source_type(candles, source),
225 JmaData::Slice(sl) => sl,
226 };
227 let len = data.len();
228 if len == 0 {
229 return Err(JmaError::EmptyInputData);
230 }
231 let first = data
232 .iter()
233 .position(|x| !x.is_nan())
234 .ok_or(JmaError::AllValuesNaN)?;
235 let period = input.get_period();
236 let phase = input.get_phase();
237 let power = input.get_power();
238
239 if period == 0 || period > len {
240 return Err(JmaError::InvalidPeriod {
241 period,
242 data_len: len,
243 });
244 }
245 if (len - first) < period {
246 return Err(JmaError::NotEnoughValidData {
247 needed: period,
248 valid: len - first,
249 });
250 }
251 if phase.is_nan() || phase.is_infinite() {
252 return Err(JmaError::InvalidPhase { phase });
253 }
254
255 let chosen = match kernel {
256 Kernel::Auto => Kernel::Scalar,
257 other => other,
258 };
259
260 let mut out = alloc_with_nan_prefix(len, first);
261 unsafe {
262 match chosen {
263 Kernel::Scalar | Kernel::ScalarBatch => {
264 jma_scalar(data, period, phase, power, first, &mut out)
265 }
266 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
267 Kernel::Avx2 | Kernel::Avx2Batch => {
268 jma_avx2(data, period, phase, power, first, &mut out)
269 }
270 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
271 Kernel::Avx512 | Kernel::Avx512Batch => {
272 jma_avx512(data, period, phase, power, first, &mut out)
273 }
274 _ => unreachable!(),
275 }
276 }
277 Ok(JmaOutput { values: out })
278}
279
280pub fn jma_with_kernel_into(
281 input: &JmaInput,
282 kernel: Kernel,
283 out: &mut [f64],
284) -> Result<(), JmaError> {
285 let data: &[f64] = match &input.data {
286 JmaData::Candles { candles, source } => source_type(candles, source),
287 JmaData::Slice(sl) => sl,
288 };
289 let len = data.len();
290
291 if out.len() != len {
292 return Err(JmaError::OutputLengthMismatch {
293 expected: len,
294 got: out.len(),
295 });
296 }
297 if len == 0 {
298 return Err(JmaError::EmptyInputData);
299 }
300
301 let first = data
302 .iter()
303 .position(|x| !x.is_nan())
304 .ok_or(JmaError::AllValuesNaN)?;
305 let period = input.get_period();
306 let phase = input.get_phase();
307 let power = input.get_power();
308
309 if period == 0 || period > len {
310 return Err(JmaError::InvalidPeriod {
311 period,
312 data_len: len,
313 });
314 }
315 if (len - first) < period {
316 return Err(JmaError::NotEnoughValidData {
317 needed: period,
318 valid: len - first,
319 });
320 }
321 if phase.is_nan() || phase.is_infinite() {
322 return Err(JmaError::InvalidPhase { phase });
323 }
324
325 let chosen = match kernel {
326 Kernel::Auto => Kernel::Scalar,
327 other => other,
328 };
329
330 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
331 out[..first].fill(qnan);
332
333 unsafe {
334 match chosen {
335 Kernel::Scalar | Kernel::ScalarBatch => {
336 jma_scalar(data, period, phase, power, first, out)
337 }
338 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339 Kernel::Avx2 | Kernel::Avx2Batch => jma_avx2(data, period, phase, power, first, out),
340 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
341 Kernel::Avx512 | Kernel::Avx512Batch => {
342 jma_avx512(data, period, phase, power, first, out)
343 }
344 _ => unreachable!(),
345 }
346 }
347 Ok(())
348}
349
350#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
351pub fn jma_into(input: &JmaInput, out: &mut [f64]) -> Result<(), JmaError> {
352 jma_with_kernel_into(input, Kernel::Auto, out)
353}
354
355#[inline]
356pub fn jma_scalar(
357 data: &[f64],
358 period: usize,
359 phase: f64,
360 power: u32,
361 first_valid: usize,
362 output: &mut [f64],
363) {
364 assert_eq!(data.len(), output.len());
365 assert!(first_valid < data.len());
366
367 let pr = if phase < -100.0 {
368 0.5
369 } else if phase > 100.0 {
370 2.5
371 } else {
372 phase / 100.0 + 1.5
373 };
374
375 let beta = {
376 let num = 0.45 * (period as f64 - 1.0);
377 num / (num + 2.0)
378 };
379 let one_minus_beta = 1.0 - beta;
380
381 let alpha = beta.powi(power as i32);
382 let one_minus_alpha = 1.0 - alpha;
383 let alpha_sq = alpha * alpha;
384 let oma_sq = one_minus_alpha * one_minus_alpha;
385
386 let mut e0 = data[first_valid];
387 let mut e1 = 0.0;
388 let mut e2 = 0.0;
389 let mut j_prev = data[first_valid];
390
391 output[first_valid] = j_prev;
392
393 let n = data.len();
394 unsafe {
395 let mut p = data.as_ptr().add(first_valid + 1);
396 let mut q = output.as_mut_ptr().add(first_valid + 1);
397 let end_ptr = data.as_ptr().add(n);
398
399 while p.add(3) < end_ptr {
400 let x0 = *p;
401 e0 = one_minus_alpha.mul_add(x0, alpha * e0);
402 e1 = (x0 - e0).mul_add(one_minus_beta, beta * e1);
403 let d0 = e0 + pr * e1 - j_prev;
404 e2 = d0.mul_add(oma_sq, alpha_sq * e2);
405 j_prev += e2;
406 *q = j_prev;
407 p = p.add(1);
408 q = q.add(1);
409
410 let x1 = *p;
411 e0 = one_minus_alpha.mul_add(x1, alpha * e0);
412 e1 = (x1 - e0).mul_add(one_minus_beta, beta * e1);
413 let d1 = e0 + pr * e1 - j_prev;
414 e2 = d1.mul_add(oma_sq, alpha_sq * e2);
415 j_prev += e2;
416 *q = j_prev;
417 p = p.add(1);
418 q = q.add(1);
419
420 let x2 = *p;
421 e0 = one_minus_alpha.mul_add(x2, alpha * e0);
422 e1 = (x2 - e0).mul_add(one_minus_beta, beta * e1);
423 let d2 = e0 + pr * e1 - j_prev;
424 e2 = d2.mul_add(oma_sq, alpha_sq * e2);
425 j_prev += e2;
426 *q = j_prev;
427 p = p.add(1);
428 q = q.add(1);
429
430 let x3 = *p;
431 e0 = one_minus_alpha.mul_add(x3, alpha * e0);
432 e1 = (x3 - e0).mul_add(one_minus_beta, beta * e1);
433 let d3 = e0 + pr * e1 - j_prev;
434 e2 = d3.mul_add(oma_sq, alpha_sq * e2);
435 j_prev += e2;
436 *q = j_prev;
437 p = p.add(1);
438 q = q.add(1);
439 }
440
441 while p < end_ptr {
442 let x = *p;
443 e0 = one_minus_alpha.mul_add(x, alpha * e0);
444 e1 = (x - e0).mul_add(one_minus_beta, beta * e1);
445 let d = e0 + pr * e1 - j_prev;
446 e2 = d.mul_add(oma_sq, alpha_sq * e2);
447 j_prev += e2;
448 *q = j_prev;
449
450 p = p.add(1);
451 q = q.add(1);
452 }
453 }
454}
455
456#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
457#[target_feature(enable = "avx2,fma")]
458#[inline]
459pub unsafe fn jma_avx2(
460 data: &[f64],
461 period: usize,
462 phase: f64,
463 power: u32,
464 first_valid: usize,
465 output: &mut [f64],
466) {
467 assert_eq!(data.len(), output.len());
468 assert!(first_valid < data.len());
469
470 let pr = if phase < -100.0 {
471 0.5
472 } else if phase > 100.0 {
473 2.5
474 } else {
475 phase / 100.0 + 1.5
476 };
477
478 let beta = {
479 let num = 0.45 * (period as f64 - 1.0);
480 num / (num + 2.0)
481 };
482 let one_minus_beta = 1.0 - beta;
483
484 let alpha = beta.powi(power as i32);
485 let one_minus_alpha = 1.0 - alpha;
486 let alpha_sq = alpha * alpha;
487 let oma_sq = one_minus_alpha * one_minus_alpha;
488
489 let mut e0 = data[first_valid];
490 let mut e1 = 0.0;
491 let mut e2 = 0.0;
492 let mut j_prev = e0;
493
494 output[first_valid] = j_prev;
495
496 let n = data.len();
497 unsafe {
498 let mut p = data.as_ptr().add(first_valid + 1);
499 let mut q = output.as_mut_ptr().add(first_valid + 1);
500 let end_ptr = data.as_ptr().add(n);
501
502 while p.add(3) < end_ptr {
503 let x0 = *p;
504 e0 = one_minus_alpha.mul_add(x0, alpha * e0);
505 e1 = (x0 - e0).mul_add(one_minus_beta, beta * e1);
506 let d0 = e0 + pr * e1 - j_prev;
507 e2 = d0.mul_add(oma_sq, alpha_sq * e2);
508 j_prev += e2;
509 *q = j_prev;
510 p = p.add(1);
511 q = q.add(1);
512
513 let x1 = *p;
514 e0 = one_minus_alpha.mul_add(x1, alpha * e0);
515 e1 = (x1 - e0).mul_add(one_minus_beta, beta * e1);
516 let d1 = e0 + pr * e1 - j_prev;
517 e2 = d1.mul_add(oma_sq, alpha_sq * e2);
518 j_prev += e2;
519 *q = j_prev;
520 p = p.add(1);
521 q = q.add(1);
522
523 let x2 = *p;
524 e0 = one_minus_alpha.mul_add(x2, alpha * e0);
525 e1 = (x2 - e0).mul_add(one_minus_beta, beta * e1);
526 let d2 = e0 + pr * e1 - j_prev;
527 e2 = d2.mul_add(oma_sq, alpha_sq * e2);
528 j_prev += e2;
529 *q = j_prev;
530 p = p.add(1);
531 q = q.add(1);
532
533 let x3 = *p;
534 e0 = one_minus_alpha.mul_add(x3, alpha * e0);
535 e1 = (x3 - e0).mul_add(one_minus_beta, beta * e1);
536 let d3 = e0 + pr * e1 - j_prev;
537 e2 = d3.mul_add(oma_sq, alpha_sq * e2);
538 j_prev += e2;
539 *q = j_prev;
540 p = p.add(1);
541 q = q.add(1);
542 }
543
544 while p < end_ptr {
545 let x = *p;
546 e0 = one_minus_alpha.mul_add(x, alpha * e0);
547 e1 = (x - e0).mul_add(one_minus_beta, beta * e1);
548 let d = e0 + pr * e1 - j_prev;
549 e2 = d.mul_add(oma_sq, alpha_sq * e2);
550 j_prev += e2;
551 *q = j_prev;
552
553 p = p.add(1);
554 q = q.add(1);
555 }
556 }
557}
558
559#[inline(always)]
560fn jma_consts(period: usize, phase: f64, power: u32) -> (f64, f64, f64, f64, f64, f64, f64) {
561 let pr = if phase < -100.0 {
562 0.5
563 } else if phase > 100.0 {
564 2.5
565 } else {
566 phase / 100.0 + 1.5
567 };
568
569 let beta = {
570 let num = 0.45 * (period as f64 - 1.0);
571 num / (num + 2.0)
572 };
573 let alpha = beta.powi(power as i32);
574 (
575 pr,
576 beta,
577 alpha,
578 alpha * alpha,
579 (1.0 - alpha) * (1.0 - alpha),
580 1.0 - alpha,
581 1.0 - beta,
582 )
583}
584
585#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
586#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
587#[inline]
588pub unsafe fn jma_avx512(
589 data: &[f64],
590 period: usize,
591 phase: f64,
592 power: u32,
593 first_valid: usize,
594 out: &mut [f64],
595) {
596 debug_assert!(data.len() == out.len() && first_valid < data.len());
597
598 let (pr, beta, alpha, alpha_sq, oma_sq, one_minus_alpha, one_minus_beta) =
599 jma_consts(period, phase, power);
600
601 let pr_v = _mm512_set1_pd(pr);
602 let oma_sq_v = _mm512_set1_pd(oma_sq);
603 let alpha_sq_v = _mm512_set1_pd(alpha_sq);
604 let one_minus_alpha_v = _mm512_set1_pd(one_minus_alpha);
605 let alpha_v = _mm512_set1_pd(alpha);
606 let one_minus_beta_v = _mm512_set1_pd(one_minus_beta);
607 let beta_v = _mm512_set1_pd(beta);
608
609 let mut e0 = data[first_valid];
610 let mut e1 = 0.0;
611 let mut e2 = 0.0;
612 let mut j_prev = e0;
613
614 out[first_valid] = j_prev;
615
616 let mut i = first_valid + 1;
617 let n = data.len();
618
619 while i + 7 < n {
620 macro_rules! step {
621 ($idx:expr) => {{
622 let price = *data.get_unchecked(i + $idx);
623 e0 = one_minus_alpha.mul_add(price, alpha * e0);
624 e1 = (price - e0).mul_add(one_minus_beta, beta * e1);
625 let diff = e0 + pr * e1 - j_prev;
626 e2 = diff.mul_add(oma_sq, alpha_sq * e2);
627 j_prev += e2;
628 *out.get_unchecked_mut(i + $idx) = j_prev;
629 }};
630 }
631
632 step!(0);
633 step!(1);
634 step!(2);
635 step!(3);
636 step!(4);
637 step!(5);
638 step!(6);
639 step!(7);
640
641 i += 8;
642 }
643
644 while i < n {
645 let price = *data.get_unchecked(i);
646 e0 = one_minus_alpha.mul_add(price, alpha * e0);
647 e1 = (price - e0).mul_add(one_minus_beta, beta * e1);
648 let diff = e0 + pr * e1 - j_prev;
649 e2 = diff.mul_add(oma_sq, alpha_sq * e2);
650 j_prev += e2;
651
652 *out.get_unchecked_mut(i) = j_prev;
653 i += 1;
654 }
655}
656
657#[derive(Debug, Clone)]
658pub struct JmaStream {
659 period: usize,
660 phase: f64,
661 power: u32,
662
663 alpha: f64,
664 beta: f64,
665 phase_ratio: f64,
666 one_minus_alpha: f64,
667 one_minus_beta: f64,
668 alpha_sq: f64,
669 oma_sq: f64,
670
671 initialized: bool,
672 e0: f64,
673 e1: f64,
674 e2: f64,
675 jma_prev: f64,
676}
677
678impl JmaStream {
679 #[inline(always)]
680 pub fn try_new(params: JmaParams) -> Result<Self, JmaError> {
681 let period = params.period.unwrap_or(7);
682 if period == 0 {
683 return Err(JmaError::InvalidPeriod {
684 period,
685 data_len: 0,
686 });
687 }
688 let phase = params.phase.unwrap_or(50.0);
689 if phase.is_nan() || phase.is_infinite() {
690 return Err(JmaError::InvalidPhase { phase });
691 }
692 let power = params.power.unwrap_or(2);
693
694 let clamped = phase.max(-100.0).min(100.0);
695 let phase_ratio = clamped / 100.0 + 1.5;
696
697 let numerator = 0.45 * (period as f64 - 1.0);
698 let denominator = numerator + 2.0;
699 let beta = if denominator.abs() < f64::EPSILON {
700 0.0
701 } else {
702 numerator / denominator
703 };
704
705 let alpha = pow_u32(beta, power);
706
707 let one_minus_alpha = 1.0 - alpha;
708 let one_minus_beta = 1.0 - beta;
709 let alpha_sq = alpha * alpha;
710 let oma_sq = one_minus_alpha * one_minus_alpha;
711
712 Ok(Self {
713 period,
714 phase,
715 power,
716 alpha,
717 beta,
718 phase_ratio,
719 one_minus_alpha,
720 one_minus_beta,
721 alpha_sq,
722 oma_sq,
723 initialized: false,
724 e0: f64::NAN,
725 e1: 0.0,
726 e2: 0.0,
727 jma_prev: f64::NAN,
728 })
729 }
730
731 #[inline(always)]
732 pub fn update(&mut self, value: f64) -> Option<f64> {
733 if !self.initialized {
734 if value.is_nan() {
735 return None;
736 }
737 self.initialized = true;
738 self.e0 = value;
739 self.e1 = 0.0;
740 self.e2 = 0.0;
741 self.jma_prev = value;
742 return Some(value);
743 }
744
745 self.e0 = self.one_minus_alpha * value + self.alpha * self.e0;
746 self.e1 = (value - self.e0) * self.one_minus_beta + self.beta * self.e1;
747 let diff = self.e0 + self.phase_ratio * self.e1 - self.jma_prev;
748 self.e2 = diff * self.oma_sq + self.alpha_sq * self.e2;
749 self.jma_prev += self.e2;
750 Some(self.jma_prev)
751 }
752
753 #[inline(always)]
754 pub fn update_fma(&mut self, value: f64) -> Option<f64> {
755 if !self.initialized {
756 if value.is_nan() {
757 return None;
758 }
759 self.initialized = true;
760 self.e0 = value;
761 self.e1 = 0.0;
762 self.e2 = 0.0;
763 self.jma_prev = value;
764 return Some(value);
765 }
766
767 self.e0 = self.one_minus_alpha.mul_add(value, self.alpha * self.e0);
768 self.e1 = (value - self.e0).mul_add(self.one_minus_beta, self.beta * self.e1);
769 let diff = self.e0 + self.phase_ratio * self.e1 - self.jma_prev;
770 self.e2 = diff.mul_add(self.oma_sq, self.alpha_sq * self.e2);
771 self.jma_prev += self.e2;
772 Some(self.jma_prev)
773 }
774}
775
776#[inline(always)]
777fn pow_u32(mut base: f64, mut exp: u32) -> f64 {
778 let mut acc = 1.0;
779 while exp != 0 {
780 if (exp & 1) != 0 {
781 acc *= base;
782 }
783 base *= base;
784 exp >>= 1;
785 }
786 acc
787}
788
789#[derive(Clone, Debug)]
790pub struct JmaBatchRange {
791 pub period: (usize, usize, usize),
792 pub phase: (f64, f64, f64),
793 pub power: (u32, u32, u32),
794}
795
796impl Default for JmaBatchRange {
797 fn default() -> Self {
798 Self {
799 period: (7, 256, 1),
800 phase: (50.0, 50.0, 0.0),
801 power: (2, 2, 0),
802 }
803 }
804}
805
806#[derive(Clone, Debug, Default)]
807pub struct JmaBatchBuilder {
808 range: JmaBatchRange,
809 kernel: Kernel,
810}
811
812impl JmaBatchBuilder {
813 pub fn new() -> Self {
814 Self::default()
815 }
816 pub fn kernel(mut self, k: Kernel) -> Self {
817 self.kernel = k;
818 self
819 }
820 #[inline]
821 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
822 self.range.period = (start, end, step);
823 self
824 }
825 #[inline]
826 pub fn period_static(mut self, p: usize) -> Self {
827 self.range.period = (p, p, 0);
828 self
829 }
830 #[inline]
831 pub fn phase_range(mut self, start: f64, end: f64, step: f64) -> Self {
832 self.range.phase = (start, end, step);
833 self
834 }
835 #[inline]
836 pub fn phase_static(mut self, x: f64) -> Self {
837 self.range.phase = (x, x, 0.0);
838 self
839 }
840 #[inline]
841 pub fn power_range(mut self, start: u32, end: u32, step: u32) -> Self {
842 self.range.power = (start, end, step);
843 self
844 }
845 #[inline]
846 pub fn power_static(mut self, p: u32) -> Self {
847 self.range.power = (p, p, 0);
848 self
849 }
850 pub fn apply_slice(self, data: &[f64]) -> Result<JmaBatchOutput, JmaError> {
851 jma_batch_with_kernel(data, &self.range, self.kernel)
852 }
853 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<JmaBatchOutput, JmaError> {
854 JmaBatchBuilder::new().kernel(k).apply_slice(data)
855 }
856 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<JmaBatchOutput, JmaError> {
857 let slice = source_type(c, src);
858 self.apply_slice(slice)
859 }
860 pub fn with_default_candles(c: &Candles) -> Result<JmaBatchOutput, JmaError> {
861 JmaBatchBuilder::new()
862 .kernel(Kernel::Auto)
863 .apply_candles(c, "close")
864 }
865}
866
867pub fn jma_batch_with_kernel(
868 data: &[f64],
869 sweep: &JmaBatchRange,
870 k: Kernel,
871) -> Result<JmaBatchOutput, JmaError> {
872 let kernel = match k {
873 Kernel::Auto => detect_best_batch_kernel(),
874 other if other.is_batch() => other,
875 other => return Err(JmaError::InvalidKernelForBatch(other)),
876 };
877 let simd = match kernel {
878 Kernel::Avx512Batch => Kernel::Avx512,
879 Kernel::Avx2Batch => Kernel::Avx2,
880 Kernel::ScalarBatch => Kernel::Scalar,
881 _ => unreachable!(),
882 };
883 jma_batch_par_slice(data, sweep, simd)
884}
885
886#[derive(Clone, Debug)]
887pub struct JmaBatchOutput {
888 pub values: Vec<f64>,
889 pub combos: Vec<JmaParams>,
890 pub rows: usize,
891 pub cols: usize,
892}
893
894impl JmaBatchOutput {
895 pub fn row_for_params(&self, p: &JmaParams) -> Option<usize> {
896 self.combos.iter().position(|c| {
897 c.period.unwrap_or(7) == p.period.unwrap_or(7)
898 && (c.phase.unwrap_or(50.0) - p.phase.unwrap_or(50.0)).abs() < 1e-12
899 && c.power.unwrap_or(2) == p.power.unwrap_or(2)
900 })
901 }
902 pub fn values_for(&self, p: &JmaParams) -> Option<&[f64]> {
903 self.row_for_params(p).map(|row| {
904 let start = row * self.cols;
905 &self.values[start..start + self.cols]
906 })
907 }
908}
909
910#[inline(always)]
911fn expand_grid(r: &JmaBatchRange) -> Vec<JmaParams> {
912 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
913 if step == 0 || start == end {
914 return vec![start];
915 }
916 let mut v = Vec::new();
917 if start < end {
918 let mut x = start;
919 while x <= end {
920 v.push(x);
921 match x.checked_add(step) {
922 Some(n) => x = n,
923 None => break,
924 }
925 }
926 } else {
927 let mut x = start;
928 while x >= end {
929 v.push(x);
930 if x < end + step {
931 break;
932 }
933 x -= step;
934 }
935 }
936 v
937 }
938 fn axis_f64((start, end, step): (f64, f64, f64)) -> Vec<f64> {
939 const EPS: f64 = 1e-12;
940 if step.abs() < EPS || (start - end).abs() < EPS {
941 return vec![start];
942 }
943 let s = step.abs();
944 let mut v = Vec::new();
945 if start < end {
946 let mut x = start;
947 while x <= end + EPS {
948 v.push(x);
949 x += s;
950 }
951 } else {
952 let mut x = start;
953 while x >= end - EPS {
954 v.push(x);
955 x -= s;
956 }
957 }
958 v
959 }
960 fn axis_u32((start, end, step): (u32, u32, u32)) -> Vec<u32> {
961 if step == 0 || start == end {
962 return vec![start];
963 }
964 let mut v = Vec::new();
965 if start < end {
966 let mut x = start;
967 while x <= end {
968 v.push(x);
969 x = x.saturating_add(step);
970 if x == *v.last().unwrap() {
971 break;
972 }
973 }
974 } else {
975 let mut x = start;
976 while x >= end {
977 v.push(x);
978 if x < end + step {
979 break;
980 }
981 x = x.saturating_sub(step);
982 if x == *v.last().unwrap() {
983 break;
984 }
985 }
986 }
987 v
988 }
989 let periods = axis_usize(r.period);
990 let phases = axis_f64(r.phase);
991 let powers = axis_u32(r.power);
992 let mut out = Vec::with_capacity(periods.len() * phases.len() * powers.len());
993 for &p in &periods {
994 for &ph in &phases {
995 for &po in &powers {
996 out.push(JmaParams {
997 period: Some(p),
998 phase: Some(ph),
999 power: Some(po),
1000 });
1001 }
1002 }
1003 }
1004 out
1005}
1006
1007#[inline(always)]
1008pub fn jma_batch_slice(
1009 data: &[f64],
1010 sweep: &JmaBatchRange,
1011 kern: Kernel,
1012) -> Result<JmaBatchOutput, JmaError> {
1013 jma_batch_inner(data, sweep, kern, false)
1014}
1015
1016#[inline(always)]
1017pub fn jma_batch_par_slice(
1018 data: &[f64],
1019 sweep: &JmaBatchRange,
1020 kern: Kernel,
1021) -> Result<JmaBatchOutput, JmaError> {
1022 jma_batch_inner(data, sweep, kern, true)
1023}
1024
1025#[inline(always)]
1026fn jma_batch_inner(
1027 data: &[f64],
1028 sweep: &JmaBatchRange,
1029 kern: Kernel,
1030 parallel: bool,
1031) -> Result<JmaBatchOutput, JmaError> {
1032 let combos = expand_grid(sweep);
1033 if combos.is_empty() {
1034 return Err(JmaError::InvalidInput("no parameter combinations".into()));
1035 }
1036 let cols = data.len();
1037 if cols == 0 {
1038 return Err(JmaError::EmptyInputData);
1039 }
1040 let first = data
1041 .iter()
1042 .position(|x| !x.is_nan())
1043 .ok_or(JmaError::AllValuesNaN)?;
1044 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1045 if data.len() - first < max_p {
1046 return Err(JmaError::NotEnoughValidData {
1047 needed: max_p,
1048 valid: data.len() - first,
1049 });
1050 }
1051 let rows = combos.len();
1052
1053 let warm: Vec<usize> = vec![first; rows];
1054
1055 let _cap = rows
1056 .checked_mul(cols)
1057 .ok_or_else(|| JmaError::InvalidInput("rows * cols overflow".into()))?;
1058 let mut raw = make_uninit_matrix(rows, cols);
1059 init_matrix_prefixes(&mut raw, cols, &warm);
1060
1061 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1062 let prm = &combos[row];
1063 let period = prm.period.unwrap();
1064 let phase = prm.phase.unwrap();
1065 let power = prm.power.unwrap();
1066
1067 let out_row =
1068 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1069
1070 match kern {
1071 Kernel::Scalar => jma_row_scalar(data, first, period, phase, power, out_row),
1072 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1073 Kernel::Avx2 => jma_row_avx2(data, first, period, phase, power, out_row),
1074 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1075 Kernel::Avx512 => jma_row_avx512(data, first, period, phase, power, out_row),
1076 _ => unreachable!(),
1077 }
1078 };
1079
1080 if parallel {
1081 #[cfg(not(target_arch = "wasm32"))]
1082 {
1083 raw.par_chunks_mut(cols)
1084 .enumerate()
1085 .for_each(|(row, slice)| do_row(row, slice));
1086 }
1087
1088 #[cfg(target_arch = "wasm32")]
1089 {
1090 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1091 do_row(row, slice);
1092 }
1093 }
1094 } else {
1095 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1096 do_row(row, slice);
1097 }
1098 }
1099
1100 let mut guard = core::mem::ManuallyDrop::new(raw);
1101 let values = unsafe {
1102 Vec::from_raw_parts(
1103 guard.as_mut_ptr() as *mut f64,
1104 guard.len(),
1105 guard.capacity(),
1106 )
1107 };
1108
1109 Ok(JmaBatchOutput {
1110 values,
1111 combos,
1112 rows,
1113 cols,
1114 })
1115}
1116
1117#[inline(always)]
1118fn jma_batch_inner_into(
1119 data: &[f64],
1120 sweep: &JmaBatchRange,
1121 kern: Kernel,
1122 parallel: bool,
1123 out: &mut [f64],
1124) -> Result<(Vec<JmaParams>, usize, usize), JmaError> {
1125 let combos = expand_grid(sweep);
1126 if combos.is_empty() {
1127 return Err(JmaError::InvalidInput("no parameter combinations".into()));
1128 }
1129 let cols = data.len();
1130 if cols == 0 {
1131 return Err(JmaError::EmptyInputData);
1132 }
1133 let first = data
1134 .iter()
1135 .position(|x| !x.is_nan())
1136 .ok_or(JmaError::AllValuesNaN)?;
1137 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1138 if cols - first < max_p {
1139 return Err(JmaError::NotEnoughValidData {
1140 needed: max_p,
1141 valid: cols - first,
1142 });
1143 }
1144 let rows = combos.len();
1145
1146 let expected = rows
1147 .checked_mul(cols)
1148 .ok_or_else(|| JmaError::InvalidInput("rows * cols overflow".into()))?;
1149 if out.len() != expected {
1150 return Err(JmaError::OutputLengthMismatch {
1151 expected,
1152 got: out.len(),
1153 });
1154 }
1155
1156 let warm: Vec<usize> = vec![first; rows];
1157 let out_uninit = unsafe {
1158 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1159 };
1160 init_matrix_prefixes(out_uninit, cols, &warm);
1161
1162 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1163 let prm = &combos[row];
1164 let period = prm.period.unwrap();
1165 let phase = prm.phase.unwrap();
1166 let power = prm.power.unwrap();
1167
1168 match kern {
1169 Kernel::Scalar => jma_row_scalar(data, first, period, phase, power, out_row),
1170 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1171 Kernel::Avx2 => jma_row_avx2(data, first, period, phase, power, out_row),
1172 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1173 Kernel::Avx512 => jma_row_avx512(data, first, period, phase, power, out_row),
1174 _ => unreachable!(),
1175 }
1176 };
1177
1178 if parallel {
1179 #[cfg(not(target_arch = "wasm32"))]
1180 {
1181 out.par_chunks_mut(cols)
1182 .enumerate()
1183 .for_each(|(row, slice)| do_row(row, slice));
1184 }
1185 #[cfg(target_arch = "wasm32")]
1186 {
1187 for (row, slice) in out.chunks_mut(cols).enumerate() {
1188 do_row(row, slice);
1189 }
1190 }
1191 } else {
1192 for (row, slice) in out.chunks_mut(cols).enumerate() {
1193 do_row(row, slice);
1194 }
1195 }
1196
1197 Ok((combos, rows, cols))
1198}
1199
1200#[inline(always)]
1201unsafe fn jma_row_scalar(
1202 data: &[f64],
1203 first: usize,
1204 period: usize,
1205 phase: f64,
1206 power: u32,
1207 out: &mut [f64],
1208) {
1209 jma_scalar(data, period, phase, power, first, out);
1210}
1211
1212#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1213#[inline(always)]
1214unsafe fn jma_row_avx2(
1215 data: &[f64],
1216 first: usize,
1217 period: usize,
1218 phase: f64,
1219 power: u32,
1220 out: &mut [f64],
1221) {
1222 jma_avx2(data, period, phase, power, first, out);
1223}
1224
1225#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1226#[inline(always)]
1227unsafe fn jma_row_avx512(
1228 data: &[f64],
1229 first: usize,
1230 period: usize,
1231 phase: f64,
1232 power: u32,
1233 out: &mut [f64],
1234) {
1235 jma_avx512(data, period, phase, power, first, out);
1236}
1237
1238#[inline(always)]
1239pub fn expand_grid_jma(r: &JmaBatchRange) -> Vec<JmaParams> {
1240 expand_grid(r)
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245 use super::*;
1246 use crate::skip_if_unsupported;
1247 use crate::utilities::data_loader::read_candles_from_csv;
1248 #[cfg(feature = "proptest")]
1249 use proptest::prelude::*;
1250
1251 fn check_jma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1252 skip_if_unsupported!(kernel, test_name);
1253 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1254 let candles = read_candles_from_csv(file_path)?;
1255
1256 let default_params = JmaParams {
1257 period: None,
1258 phase: None,
1259 power: None,
1260 };
1261 let input = JmaInput::from_candles(&candles, "close", default_params);
1262 let output = jma_with_kernel(&input, kernel)?;
1263 assert_eq!(output.values.len(), candles.close.len());
1264
1265 Ok(())
1266 }
1267
1268 fn check_jma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1269 skip_if_unsupported!(kernel, test_name);
1270 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1271 let candles = read_candles_from_csv(file_path)?;
1272 let input = JmaInput::from_candles(&candles, "close", JmaParams::default());
1273 let result = jma_with_kernel(&input, kernel)?;
1274 let expected_last_five = [
1275 59305.04794668568,
1276 59261.270455005455,
1277 59156.791263606865,
1278 59128.30656791065,
1279 58918.89223153998,
1280 ];
1281 let start = result.values.len().saturating_sub(5);
1282 for (i, &val) in result.values[start..].iter().enumerate() {
1283 let diff = (val - expected_last_five[i]).abs();
1284 assert!(
1285 diff < 1e-6,
1286 "[{}] JMA {:?} mismatch at idx {}: got {}, expected {}",
1287 test_name,
1288 kernel,
1289 i,
1290 val,
1291 expected_last_five[i]
1292 );
1293 }
1294 Ok(())
1295 }
1296
1297 fn check_jma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1298 skip_if_unsupported!(kernel, test_name);
1299 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1300 let candles = read_candles_from_csv(file_path)?;
1301 let input = JmaInput::with_default_candles(&candles);
1302 match input.data {
1303 JmaData::Candles { source, .. } => assert_eq!(source, "close"),
1304 _ => panic!("Expected JmaData::Candles"),
1305 }
1306 let output = jma_with_kernel(&input, kernel)?;
1307 assert_eq!(output.values.len(), candles.close.len());
1308 Ok(())
1309 }
1310
1311 fn check_jma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1312 skip_if_unsupported!(kernel, test_name);
1313 let input_data = [10.0, 20.0, 30.0];
1314 let params = JmaParams {
1315 period: Some(0),
1316 phase: None,
1317 power: None,
1318 };
1319 let input = JmaInput::from_slice(&input_data, params);
1320 let res = jma_with_kernel(&input, kernel);
1321 assert!(
1322 res.is_err(),
1323 "[{}] JMA should fail with zero period",
1324 test_name
1325 );
1326 Ok(())
1327 }
1328
1329 fn check_jma_period_exceeds_length(
1330 test_name: &str,
1331 kernel: Kernel,
1332 ) -> Result<(), Box<dyn Error>> {
1333 skip_if_unsupported!(kernel, test_name);
1334 let data_small = [10.0, 20.0, 30.0];
1335 let params = JmaParams {
1336 period: Some(10),
1337 phase: None,
1338 power: None,
1339 };
1340 let input = JmaInput::from_slice(&data_small, params);
1341 let res = jma_with_kernel(&input, kernel);
1342 assert!(
1343 res.is_err(),
1344 "[{}] JMA should fail with period exceeding length",
1345 test_name
1346 );
1347 Ok(())
1348 }
1349
1350 fn check_jma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1351 skip_if_unsupported!(kernel, test_name);
1352 let single_point = [42.0];
1353 let params = JmaParams {
1354 period: Some(7),
1355 phase: None,
1356 power: None,
1357 };
1358 let input = JmaInput::from_slice(&single_point, params);
1359 let res = jma_with_kernel(&input, kernel);
1360 assert!(
1361 res.is_err(),
1362 "[{}] JMA should fail with insufficient data",
1363 test_name
1364 );
1365 Ok(())
1366 }
1367
1368 fn check_jma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1369 skip_if_unsupported!(kernel, test_name);
1370 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1371 let candles = read_candles_from_csv(file_path)?;
1372 let first_params = JmaParams {
1373 period: Some(7),
1374 phase: None,
1375 power: None,
1376 };
1377 let first_input = JmaInput::from_candles(&candles, "close", first_params);
1378 let first_result = jma_with_kernel(&first_input, kernel)?;
1379 let second_params = JmaParams {
1380 period: Some(7),
1381 phase: None,
1382 power: None,
1383 };
1384 let second_input = JmaInput::from_slice(&first_result.values, second_params);
1385 let second_result = jma_with_kernel(&second_input, kernel)?;
1386 assert_eq!(second_result.values.len(), first_result.values.len());
1387 Ok(())
1388 }
1389
1390 fn check_jma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1391 skip_if_unsupported!(kernel, test_name);
1392 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1393 let candles = read_candles_from_csv(file_path)?;
1394 let input = JmaInput::from_candles(
1395 &candles,
1396 "close",
1397 JmaParams {
1398 period: Some(7),
1399 phase: None,
1400 power: None,
1401 },
1402 );
1403 let res = jma_with_kernel(&input, kernel)?;
1404 assert_eq!(res.values.len(), candles.close.len());
1405 if res.values.len() > 240 {
1406 for (i, &val) in res.values[240..].iter().enumerate() {
1407 assert!(
1408 !val.is_nan(),
1409 "[{}] Found unexpected NaN at out-index {}",
1410 test_name,
1411 240 + i
1412 );
1413 }
1414 }
1415 Ok(())
1416 }
1417
1418 fn check_jma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1419 skip_if_unsupported!(kernel, test_name);
1420 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1421 let candles = read_candles_from_csv(file_path)?;
1422 let period = 7;
1423 let phase = 50.0;
1424 let power = 2;
1425 let input = JmaInput::from_candles(
1426 &candles,
1427 "close",
1428 JmaParams {
1429 period: Some(period),
1430 phase: Some(phase),
1431 power: Some(power),
1432 },
1433 );
1434 let batch_output = jma_with_kernel(&input, kernel)?.values;
1435 let mut stream = JmaStream::try_new(JmaParams {
1436 period: Some(period),
1437 phase: Some(phase),
1438 power: Some(power),
1439 })?;
1440 let mut stream_values = Vec::with_capacity(candles.close.len());
1441 for &price in &candles.close {
1442 match stream.update(price) {
1443 Some(jma_val) => stream_values.push(jma_val),
1444 None => stream_values.push(f64::NAN),
1445 }
1446 }
1447 assert_eq!(batch_output.len(), stream_values.len());
1448 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1449 if b.is_nan() && s.is_nan() {
1450 continue;
1451 }
1452 let diff = (b - s).abs();
1453 assert!(
1454 diff < 1e-8,
1455 "[{}] JMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1456 test_name,
1457 i,
1458 b,
1459 s,
1460 diff
1461 );
1462 }
1463 Ok(())
1464 }
1465
1466 #[cfg(debug_assertions)]
1467 fn check_jma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1468 skip_if_unsupported!(kernel, test_name);
1469
1470 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1471 let candles = read_candles_from_csv(file_path)?;
1472
1473 let test_params = vec![
1474 JmaParams::default(),
1475 JmaParams {
1476 period: Some(3),
1477 phase: Some(0.0),
1478 power: Some(1),
1479 },
1480 JmaParams {
1481 period: Some(3),
1482 phase: Some(50.0),
1483 power: Some(2),
1484 },
1485 JmaParams {
1486 period: Some(3),
1487 phase: Some(100.0),
1488 power: Some(3),
1489 },
1490 JmaParams {
1491 period: Some(7),
1492 phase: Some(25.0),
1493 power: Some(1),
1494 },
1495 JmaParams {
1496 period: Some(7),
1497 phase: Some(50.0),
1498 power: Some(2),
1499 },
1500 JmaParams {
1501 period: Some(7),
1502 phase: Some(75.0),
1503 power: Some(3),
1504 },
1505 JmaParams {
1506 period: Some(10),
1507 phase: Some(0.0),
1508 power: Some(2),
1509 },
1510 JmaParams {
1511 period: Some(14),
1512 phase: Some(100.0),
1513 power: Some(2),
1514 },
1515 JmaParams {
1516 period: Some(20),
1517 phase: Some(50.0),
1518 power: Some(1),
1519 },
1520 JmaParams {
1521 period: Some(30),
1522 phase: Some(50.0),
1523 power: Some(2),
1524 },
1525 JmaParams {
1526 period: Some(50),
1527 phase: Some(50.0),
1528 power: Some(3),
1529 },
1530 JmaParams {
1531 period: Some(1),
1532 phase: Some(0.0),
1533 power: Some(1),
1534 },
1535 JmaParams {
1536 period: Some(100),
1537 phase: Some(100.0),
1538 power: Some(5),
1539 },
1540 JmaParams {
1541 period: Some(10),
1542 phase: Some(-100.0),
1543 power: Some(2),
1544 },
1545 JmaParams {
1546 period: Some(10),
1547 phase: Some(200.0),
1548 power: Some(2),
1549 },
1550 ];
1551
1552 for (param_idx, params) in test_params.iter().enumerate() {
1553 let input = JmaInput::from_candles(&candles, "close", params.clone());
1554 let output = jma_with_kernel(&input, kernel)?;
1555
1556 for (i, &val) in output.values.iter().enumerate() {
1557 if val.is_nan() {
1558 continue;
1559 }
1560
1561 let bits = val.to_bits();
1562
1563 if bits == 0x11111111_11111111 {
1564 panic!(
1565 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1566 with params: period={}, phase={}, power={}",
1567 test_name,
1568 val,
1569 bits,
1570 i,
1571 params.period.unwrap_or(7),
1572 params.phase.unwrap_or(50.0),
1573 params.power.unwrap_or(2)
1574 );
1575 }
1576
1577 if bits == 0x22222222_22222222 {
1578 panic!(
1579 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1580 with params: period={}, phase={}, power={}",
1581 test_name,
1582 val,
1583 bits,
1584 i,
1585 params.period.unwrap_or(7),
1586 params.phase.unwrap_or(50.0),
1587 params.power.unwrap_or(2)
1588 );
1589 }
1590
1591 if bits == 0x33333333_33333333 {
1592 panic!(
1593 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1594 with params: period={}, phase={}, power={}",
1595 test_name,
1596 val,
1597 bits,
1598 i,
1599 params.period.unwrap_or(7),
1600 params.phase.unwrap_or(50.0),
1601 params.power.unwrap_or(2)
1602 );
1603 }
1604 }
1605 }
1606
1607 Ok(())
1608 }
1609
1610 #[cfg(not(debug_assertions))]
1611 fn check_jma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1612 Ok(())
1613 }
1614
1615 #[cfg(feature = "proptest")]
1616 #[allow(clippy::float_cmp)]
1617 fn check_jma_property(
1618 test_name: &str,
1619 kernel: Kernel,
1620 ) -> Result<(), Box<dyn std::error::Error>> {
1621 use proptest::prelude::*;
1622 skip_if_unsupported!(kernel, test_name);
1623
1624 let strat = (2usize..=100).prop_flat_map(|period| {
1625 (
1626 prop::collection::vec(
1627 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1628 period..400,
1629 ),
1630 Just(period),
1631 -100f64..=100f64,
1632 1u32..=10,
1633 )
1634 });
1635
1636 proptest::test_runner::TestRunner::default()
1637 .run(&strat, |(data, period, phase, power)| {
1638 let params = JmaParams {
1639 period: Some(period),
1640 phase: Some(phase),
1641 power: Some(power),
1642 };
1643 let input = JmaInput::from_slice(&data, params);
1644
1645 let JmaOutput { values: out } = jma_with_kernel(&input, kernel).unwrap();
1646 let JmaOutput { values: ref_out } = jma_with_kernel(&input, Kernel::Scalar).unwrap();
1647
1648
1649 prop_assert_eq!(out.len(), data.len());
1650
1651
1652 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
1653
1654
1655
1656
1657 for i in 0..first_valid.min(out.len()) {
1658 prop_assert!(
1659 out[i].is_nan(),
1660 "idx {}: expected NaN before first valid input, got {}", i, out[i]
1661 );
1662 }
1663
1664
1665 if first_valid < out.len() {
1666 prop_assert!(
1667 out[first_valid].is_finite(),
1668 "JMA should output a finite value at first_valid index {}", first_valid
1669 );
1670 }
1671
1672
1673 for i in (first_valid + 1)..data.len() {
1674 if data[i].is_finite() && !data[first_valid..=i].iter().any(|x| x.is_nan()) {
1675 prop_assert!(
1676 out[i].is_finite(),
1677 "idx {}: expected finite value, got {}", i, out[i]
1678 );
1679 }
1680 }
1681
1682
1683 let warmup_estimate = first_valid + period;
1684 if warmup_estimate + 20 < data.len() {
1685 let window_start = warmup_estimate;
1686 let window_end = (warmup_estimate + 50).min(data.len());
1687
1688 let input_slice = &data[window_start..window_end];
1689 let output_slice = &out[window_start..window_end];
1690
1691 if input_slice.iter().all(|x| x.is_finite()) && output_slice.iter().all(|x| x.is_finite()) {
1692 let input_mean: f64 = input_slice.iter().sum::<f64>() / input_slice.len() as f64;
1693 let output_mean: f64 = output_slice.iter().sum::<f64>() / output_slice.len() as f64;
1694
1695 let input_var: f64 = input_slice.iter()
1696 .map(|x| (x - input_mean).powi(2))
1697 .sum::<f64>() / input_slice.len() as f64;
1698 let output_var: f64 = output_slice.iter()
1699 .map(|x| (x - output_mean).powi(2))
1700 .sum::<f64>() / output_slice.len() as f64;
1701
1702
1703 if input_var > 1e-10 {
1704 prop_assert!(
1705 output_var <= input_var * 1.1,
1706 "JMA should smooth data: input_var={}, output_var={}, period={}, phase={}, power={}",
1707 input_var, output_var, period, phase, power
1708 );
1709 }
1710 }
1711 }
1712
1713
1714 if period == 1 {
1715
1716 for i in first_valid..data.len().min(first_valid + 20) {
1717 if data[i].is_finite() && out[i].is_finite() {
1718 prop_assert!(
1719 (out[i] - data[i]).abs() <= data[i].abs() * 0.1 + 1e-6,
1720 "period=1 should closely track input: idx={}, data={}, out={}, diff={}",
1721 i, data[i], out[i], (out[i] - data[i]).abs()
1722 );
1723 }
1724 }
1725 }
1726
1727
1728 let warmup_estimate = first_valid + period;
1729 if data[first_valid..].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && warmup_estimate + 5 < data.len() {
1730 let constant_val = data[first_valid];
1731 for i in (warmup_estimate + 5)..data.len() {
1732 if out[i].is_finite() {
1733 prop_assert!(
1734 (out[i] - constant_val).abs() <= 1e-6,
1735 "constant input should produce constant output: idx={}, expected={}, got={}",
1736 i, constant_val, out[i]
1737 );
1738 }
1739 }
1740 }
1741
1742
1743 if kernel != Kernel::Scalar {
1744 for i in first_valid..data.len() {
1745 if out[i].is_finite() && ref_out[i].is_finite() {
1746 let y_bits = out[i].to_bits();
1747 let r_bits = ref_out[i].to_bits();
1748 let diff_bits = if y_bits > r_bits {
1749 y_bits - r_bits
1750 } else {
1751 r_bits - y_bits
1752 };
1753
1754
1755
1756 let abs_diff = (out[i] - ref_out[i]).abs();
1757 let rel_diff = if ref_out[i].abs() > 1e-10 {
1758 abs_diff / ref_out[i].abs()
1759 } else {
1760 abs_diff
1761 };
1762
1763 prop_assert!(
1764 diff_bits <= 1000 || abs_diff < 1e-9 || rel_diff < 1e-12,
1765 "kernel consistency failed at idx {}: {:?}={}, Scalar={}, diff_bits={}, abs_diff={}, rel_diff={}",
1766 i, kernel, out[i], ref_out[i], diff_bits, abs_diff, rel_diff
1767 );
1768 }
1769 }
1770 }
1771
1772
1773
1774 let warmup_estimate = first_valid + period;
1775 if phase.abs() > 10.0 && warmup_estimate + 30 < data.len() {
1776 let params_neutral = JmaParams {
1777 period: Some(period),
1778 phase: Some(0.0),
1779 power: Some(power),
1780 };
1781 let input_neutral = JmaInput::from_slice(&data, params_neutral);
1782 if let Ok(JmaOutput { values: out_neutral }) = jma_with_kernel(&input_neutral, kernel) {
1783
1784 let check_start = warmup_estimate + 10;
1785 let check_end = (warmup_estimate + 30).min(data.len() - 1);
1786
1787
1788 let mut lead_count = 0;
1789 let mut lag_count = 0;
1790
1791 for i in check_start..check_end {
1792 if data[i].is_finite() && data[i-1].is_finite() &&
1793 out[i].is_finite() && out_neutral[i].is_finite() {
1794 let data_change = data[i] - data[i-1];
1795 if data_change.abs() > 1e-10 {
1796 let phase_diff = out[i] - out_neutral[i];
1797 if data_change > 0.0 {
1798
1799 if phase > 0.0 && phase_diff > 0.0 {
1800 lead_count += 1;
1801 } else if phase < 0.0 && phase_diff < 0.0 {
1802 lag_count += 1;
1803 }
1804 } else {
1805
1806 if phase > 0.0 && phase_diff < 0.0 {
1807 lead_count += 1;
1808 } else if phase < 0.0 && phase_diff > 0.0 {
1809 lag_count += 1;
1810 }
1811 }
1812 }
1813 }
1814 }
1815
1816
1817 if phase > 10.0 {
1818 prop_assert!(
1819 lead_count > 0,
1820 "Positive phase should show leading behavior: phase={}, lead_count={}",
1821 phase, lead_count
1822 );
1823 } else if phase < -10.0 {
1824 prop_assert!(
1825 lag_count > 0,
1826 "Negative phase should show lagging behavior: phase={}, lag_count={}",
1827 phase, lag_count
1828 );
1829 }
1830 }
1831 }
1832
1833
1834 let warmup_estimate2 = first_valid + period;
1835 if power > 1 && warmup_estimate2 + 20 < data.len() {
1836 let params_low_power = JmaParams {
1837 period: Some(period),
1838 phase: Some(phase),
1839 power: Some(1),
1840 };
1841 let input_low_power = JmaInput::from_slice(&data, params_low_power);
1842 if let Ok(JmaOutput { values: out_low_power }) = jma_with_kernel(&input_low_power, kernel) {
1843
1844 let check_start = warmup_estimate2;
1845 let check_end = (warmup_estimate2 + 30).min(data.len());
1846
1847 let mut high_power_responsiveness = 0.0;
1848 let mut low_power_responsiveness = 0.0;
1849 let mut count = 0;
1850
1851 for i in check_start..check_end {
1852 if data[i].is_finite() && out[i].is_finite() && out_low_power[i].is_finite() {
1853 high_power_responsiveness += (out[i] - data[i]).abs();
1854 low_power_responsiveness += (out_low_power[i] - data[i]).abs();
1855 count += 1;
1856 }
1857 }
1858
1859 if count > 0 {
1860 high_power_responsiveness /= count as f64;
1861 low_power_responsiveness /= count as f64;
1862
1863
1864
1865
1866 if low_power_responsiveness > high_power_responsiveness * 1.5 {
1867 prop_assert!(
1868 high_power_responsiveness <= low_power_responsiveness,
1869 "Higher power should be more responsive: power={} resp={}, power=1 resp={}",
1870 power, high_power_responsiveness, low_power_responsiveness
1871 );
1872 }
1873 }
1874 }
1875 }
1876
1877 Ok(())
1878 })
1879 .unwrap();
1880
1881 Ok(())
1882 }
1883
1884 #[cfg(not(feature = "proptest"))]
1885 fn check_jma_property(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1886 Ok(())
1887 }
1888
1889 macro_rules! generate_all_jma_tests {
1890 ($($test_fn:ident),*) => {
1891 paste::paste! {
1892 $(
1893 #[test]
1894 fn [<$test_fn _scalar_f64>]() {
1895 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1896 }
1897 )*
1898 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1899 $(
1900 #[test]
1901 fn [<$test_fn _avx2_f64>]() {
1902 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1903 }
1904 #[test]
1905 fn [<$test_fn _avx512_f64>]() {
1906 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1907 }
1908 )*
1909 }
1910 }
1911 }
1912
1913 generate_all_jma_tests!(
1914 check_jma_partial_params,
1915 check_jma_accuracy,
1916 check_jma_default_candles,
1917 check_jma_zero_period,
1918 check_jma_period_exceeds_length,
1919 check_jma_very_small_dataset,
1920 check_jma_reinput,
1921 check_jma_nan_handling,
1922 check_jma_streaming,
1923 check_jma_property,
1924 check_jma_no_poison
1925 );
1926
1927 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1928 skip_if_unsupported!(kernel, test);
1929
1930 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1931 let c = read_candles_from_csv(file)?;
1932
1933 let output = JmaBatchBuilder::new()
1934 .kernel(kernel)
1935 .apply_candles(&c, "close")?;
1936
1937 let def = JmaParams::default();
1938 let row = output.values_for(&def).expect("default row missing");
1939
1940 assert_eq!(row.len(), c.close.len());
1941
1942 let expected = [
1943 59305.04794668568,
1944 59261.270455005455,
1945 59156.791263606865,
1946 59128.30656791065,
1947 58918.89223153998,
1948 ];
1949 let start = row.len() - 5;
1950 for (i, &v) in row[start..].iter().enumerate() {
1951 assert!(
1952 (v - expected[i]).abs() < 1e-6,
1953 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1954 );
1955 }
1956 Ok(())
1957 }
1958
1959 #[cfg(debug_assertions)]
1960 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1961 skip_if_unsupported!(kernel, test);
1962
1963 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1964 let c = read_candles_from_csv(file)?;
1965
1966 let batch_configs = vec![
1967 (5, 20, 5, 0.0, 100.0, 50.0, 1, 3, 1),
1968 (1, 10, 3, -100.0, 200.0, 100.0, 1, 5, 2),
1969 (50, 100, 25, -50.0, 50.0, 25.0, 2, 4, 1),
1970 (14, 14, 1, -100.0, 200.0, 50.0, 1, 5, 1),
1971 (1, 1, 1, 0.0, 0.0, 1.0, 1, 1, 1),
1972 (10, 30, 10, 50.0, 50.0, 1.0, 1, 5, 2),
1973 (5, 50, 5, -50.0, 150.0, 25.0, 1, 3, 1),
1974 (7, 21, 7, -100.0, 100.0, 40.0, 2, 2, 1),
1975 (80, 100, 20, 100.0, 200.0, 50.0, 4, 5, 1),
1976 ];
1977
1978 for (idx, config) in batch_configs.iter().enumerate() {
1979 let output = JmaBatchBuilder::new()
1980 .kernel(kernel)
1981 .period_range(config.0, config.1, config.2)
1982 .phase_range(config.3, config.4, config.5)
1983 .power_range(config.6, config.7, config.8)
1984 .apply_candles(&c, "close")?;
1985
1986 for (val_idx, &val) in output.values.iter().enumerate() {
1987 if val.is_nan() {
1988 continue;
1989 }
1990
1991 let bits = val.to_bits();
1992 let row = val_idx / output.cols;
1993 let col = val_idx % output.cols;
1994 let combo = &output.combos[row];
1995
1996 if bits == 0x11111111_11111111 {
1997 panic!(
1998 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
1999 in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2000 combo params: period={}, phase={}, power={}",
2001 test, val, bits, row, col, val_idx, idx,
2002 config.0, config.1, config.2, config.3, config.4, config.5, config.6, config.7, config.8,
2003 combo.period.unwrap_or(7), combo.phase.unwrap_or(50.0), combo.power.unwrap_or(2)
2004 );
2005 }
2006
2007 if bits == 0x22222222_22222222 {
2008 panic!(
2009 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
2010 in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2011 combo params: period={}, phase={}, power={}",
2012 test,
2013 val,
2014 bits,
2015 row,
2016 col,
2017 val_idx,
2018 idx,
2019 config.0,
2020 config.1,
2021 config.2,
2022 config.3,
2023 config.4,
2024 config.5,
2025 config.6,
2026 config.7,
2027 config.8,
2028 combo.period.unwrap_or(7),
2029 combo.phase.unwrap_or(50.0),
2030 combo.power.unwrap_or(2)
2031 );
2032 }
2033
2034 if bits == 0x33333333_33333333 {
2035 panic!(
2036 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
2037 in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2038 combo params: period={}, phase={}, power={}",
2039 test,
2040 val,
2041 bits,
2042 row,
2043 col,
2044 val_idx,
2045 idx,
2046 config.0,
2047 config.1,
2048 config.2,
2049 config.3,
2050 config.4,
2051 config.5,
2052 config.6,
2053 config.7,
2054 config.8,
2055 combo.period.unwrap_or(7),
2056 combo.phase.unwrap_or(50.0),
2057 combo.power.unwrap_or(2)
2058 );
2059 }
2060 }
2061 }
2062
2063 Ok(())
2064 }
2065
2066 #[cfg(not(debug_assertions))]
2067 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2068 Ok(())
2069 }
2070
2071 macro_rules! gen_batch_tests {
2072 ($fn_name:ident) => {
2073 paste::paste! {
2074 #[test] fn [<$fn_name _scalar>]() {
2075 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2076 }
2077 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2078 #[test] fn [<$fn_name _avx2>]() {
2079 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2080 }
2081 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2082 #[test] fn [<$fn_name _avx512>]() {
2083 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2084 }
2085 #[test] fn [<$fn_name _auto_detect>]() {
2086 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2087 }
2088 }
2089 };
2090 }
2091 gen_batch_tests!(check_batch_default_row);
2092 gen_batch_tests!(check_batch_no_poison);
2093
2094 #[test]
2095 fn test_jma_into_matches_api() -> Result<(), Box<dyn Error>> {
2096 let len = 256usize;
2097 let mut data = Vec::with_capacity(len);
2098 for _ in 0..5 {
2099 data.push(f64::from_bits(0x7ff8_0000_0000_0000));
2100 }
2101 for i in 0..(len - 5) {
2102 let x = i as f64;
2103
2104 data.push((x * 0.1).sin() * 10.0 + x * 0.01);
2105 }
2106
2107 let input = JmaInput::from_slice(&data, JmaParams::default());
2108
2109 let baseline = jma(&input)?.values;
2110
2111 let mut out = vec![0.0; data.len()];
2112
2113 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2114 {
2115 jma_into(&input, &mut out)?;
2116
2117 assert_eq!(out.len(), baseline.len());
2118 for (a, b) in out.iter().zip(baseline.iter()) {
2119 if a.is_nan() || b.is_nan() {
2120 assert!(a.is_nan() && b.is_nan());
2121 } else {
2122 assert_eq!(a, b);
2123 }
2124 }
2125 }
2126
2127 Ok(())
2128 }
2129}
2130
2131#[cfg(feature = "python")]
2132use crate::utilities::kernel_validation::validate_kernel;
2133#[cfg(all(feature = "python", feature = "cuda"))]
2134use numpy::PyUntypedArrayMethods;
2135#[cfg(feature = "python")]
2136use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
2137#[cfg(feature = "python")]
2138use pyo3::exceptions::PyValueError;
2139#[cfg(feature = "python")]
2140use pyo3::prelude::*;
2141#[cfg(feature = "python")]
2142use pyo3::types::PyDict;
2143
2144#[cfg(feature = "python")]
2145#[pyfunction(name = "jma")]
2146#[pyo3(signature = (data, period, phase=50.0, power=2, kernel=None))]
2147pub fn jma_py<'py>(
2148 py: Python<'py>,
2149 data: PyReadonlyArray1<'py, f64>,
2150 period: usize,
2151 phase: f64,
2152 power: u32,
2153 kernel: Option<&str>,
2154) -> PyResult<Bound<'py, PyArray1<f64>>> {
2155 let slice_in = data.as_slice()?;
2156 let kern = validate_kernel(kernel, false)?;
2157
2158 let params = JmaParams {
2159 period: Some(period),
2160 phase: Some(phase),
2161 power: Some(power),
2162 };
2163 let jma_in = JmaInput::from_slice(slice_in, params);
2164
2165 let result_vec: Vec<f64> = py
2166 .allow_threads(|| jma_with_kernel(&jma_in, kern).map(|o| o.values))
2167 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2168
2169 Ok(result_vec.into_pyarray(py))
2170}
2171
2172#[cfg(feature = "python")]
2173#[pyfunction(name = "jma_batch")]
2174#[pyo3(signature = (data, period_range, phase_range=(50.0, 50.0, 0.0), power_range=(2, 2, 0), kernel=None))]
2175pub fn jma_batch_py<'py>(
2176 py: Python<'py>,
2177 data: PyReadonlyArray1<'py, f64>,
2178 period_range: (usize, usize, usize),
2179 phase_range: (f64, f64, f64),
2180 power_range: (u32, u32, u32),
2181 kernel: Option<&str>,
2182) -> PyResult<Bound<'py, PyDict>> {
2183 let slice_in = data.as_slice()?;
2184 let kern = validate_kernel(kernel, true)?;
2185
2186 let sweep = JmaBatchRange {
2187 period: period_range,
2188 phase: phase_range,
2189 power: power_range,
2190 };
2191
2192 let combos = expand_grid(&sweep);
2193 if combos.is_empty() {
2194 return Err(PyValueError::new_err("Invalid parameter ranges"));
2195 }
2196
2197 let rows = combos.len();
2198 let cols = slice_in.len();
2199
2200 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2201 let slice_out = unsafe { out_arr.as_slice_mut()? };
2202
2203 let (combos_result, _, _) = py
2204 .allow_threads(|| {
2205 let kernel = match kern {
2206 Kernel::Auto => detect_best_batch_kernel(),
2207 k => k,
2208 };
2209
2210 let simd = match kernel {
2211 Kernel::Avx512Batch => Kernel::Avx512,
2212 Kernel::Avx2Batch => Kernel::Avx2,
2213 Kernel::ScalarBatch => Kernel::Scalar,
2214 _ => kernel,
2215 };
2216
2217 jma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2218 })
2219 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2220
2221 let dict = PyDict::new(py);
2222 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2223
2224 dict.set_item(
2225 "periods",
2226 combos_result
2227 .iter()
2228 .map(|p| p.period.unwrap() as u64)
2229 .collect::<Vec<_>>()
2230 .into_pyarray(py),
2231 )?;
2232 dict.set_item(
2233 "phases",
2234 combos_result
2235 .iter()
2236 .map(|p| p.phase.unwrap())
2237 .collect::<Vec<_>>()
2238 .into_pyarray(py),
2239 )?;
2240 dict.set_item(
2241 "powers",
2242 combos_result
2243 .iter()
2244 .map(|p| p.power.unwrap() as u64)
2245 .collect::<Vec<_>>()
2246 .into_pyarray(py),
2247 )?;
2248
2249 Ok(dict)
2250}
2251
2252#[cfg(all(feature = "python", feature = "cuda"))]
2253#[pyfunction(name = "jma_cuda_batch_dev")]
2254#[pyo3(signature = (data_f32, period_range, phase_range=(50.0, 50.0, 0.0), power_range=(2, 2, 0), device_id=0))]
2255pub fn jma_cuda_batch_dev_py(
2256 py: Python<'_>,
2257 data_f32: PyReadonlyArray1<'_, f32>,
2258 period_range: (usize, usize, usize),
2259 phase_range: (f64, f64, f64),
2260 power_range: (u32, u32, u32),
2261 device_id: usize,
2262) -> PyResult<JmaDeviceArrayF32Py> {
2263 if !cuda_available() {
2264 return Err(PyValueError::new_err("CUDA not available"));
2265 }
2266
2267 let slice_in = data_f32.as_slice()?;
2268 let sweep = JmaBatchRange {
2269 period: period_range,
2270 phase: phase_range,
2271 power: power_range,
2272 };
2273
2274 let inner = py.allow_threads(|| {
2275 let cuda = CudaJma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2276 cuda.jma_batch_dev(slice_in, &sweep)
2277 .map_err(|e| PyValueError::new_err(e.to_string()))
2278 })?;
2279
2280 Ok(JmaDeviceArrayF32Py { inner: Some(inner) })
2281}
2282
2283#[cfg(all(feature = "python", feature = "cuda"))]
2284#[pyfunction(name = "jma_cuda_many_series_one_param_dev")]
2285#[pyo3(signature = (prices_tm_f32, period, phase=50.0, power=2, device_id=0))]
2286pub fn jma_cuda_many_series_one_param_dev_py(
2287 py: Python<'_>,
2288 prices_tm_f32: PyReadonlyArray2<'_, f32>,
2289 period: usize,
2290 phase: f64,
2291 power: u32,
2292 device_id: usize,
2293) -> PyResult<JmaDeviceArrayF32Py> {
2294 if !cuda_available() {
2295 return Err(PyValueError::new_err("CUDA not available"));
2296 }
2297
2298 let shape = prices_tm_f32.shape();
2299 if shape.len() != 2 {
2300 return Err(PyValueError::new_err("prices matrix must be 2-dimensional"));
2301 }
2302 let rows = shape[0];
2303 let cols = shape[1];
2304 let prices_flat = prices_tm_f32.as_slice()?;
2305 let params = JmaParams {
2306 period: Some(period),
2307 phase: Some(phase),
2308 power: Some(power),
2309 };
2310
2311 let inner = py.allow_threads(|| {
2312 let cuda = CudaJma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2313 cuda.jma_many_series_one_param_time_major_dev(prices_flat, cols, rows, ¶ms)
2314 .map_err(|e| PyValueError::new_err(e.to_string()))
2315 })?;
2316
2317 Ok(JmaDeviceArrayF32Py { inner: Some(inner) })
2318}
2319
2320#[cfg(all(feature = "python", feature = "cuda"))]
2321#[pyclass(module = "ta_indicators.cuda", unsendable)]
2322pub struct JmaDeviceArrayF32Py {
2323 pub(crate) inner: Option<DeviceArrayF32Jma>,
2324}
2325
2326#[cfg(all(feature = "python", feature = "cuda"))]
2327#[pymethods]
2328impl JmaDeviceArrayF32Py {
2329 #[getter]
2330 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2331 let inner = self
2332 .inner
2333 .as_ref()
2334 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2335 let d = PyDict::new(py);
2336 d.set_item("shape", (inner.rows, inner.cols))?;
2337 d.set_item("typestr", "<f4")?;
2338 d.set_item(
2339 "strides",
2340 (
2341 inner.cols * std::mem::size_of::<f32>(),
2342 std::mem::size_of::<f32>(),
2343 ),
2344 )?;
2345 d.set_item("data", (inner.device_ptr() as usize, false))?;
2346
2347 d.set_item("version", 3)?;
2348 Ok(d)
2349 }
2350
2351 fn __dlpack_device__(&self) -> (i32, i32) {
2352 let dev = self.inner.as_ref().map(|h| h.device_id as i32).unwrap_or(0);
2353 (2, dev)
2354 }
2355
2356 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2357 fn __dlpack__<'py>(
2358 &mut self,
2359 py: Python<'py>,
2360 stream: Option<pyo3::PyObject>,
2361 max_version: Option<pyo3::PyObject>,
2362 dl_device: Option<pyo3::PyObject>,
2363 copy: Option<pyo3::PyObject>,
2364 ) -> PyResult<PyObject> {
2365 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2366
2367 let (kdl, alloc_dev) = self.__dlpack_device__();
2368
2369 if let Some(dev_obj) = dl_device.as_ref() {
2370 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2371 if dev_ty != kdl || dev_id != alloc_dev {
2372 let wants_copy = copy
2373 .as_ref()
2374 .and_then(|c| c.extract::<bool>(py).ok())
2375 .unwrap_or(false);
2376 if wants_copy {
2377 return Err(PyValueError::new_err(
2378 "device copy not implemented for __dlpack__",
2379 ));
2380 } else {
2381 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2382 }
2383 }
2384 }
2385 }
2386
2387 let _ = stream;
2388
2389 let inner = self
2390 .inner
2391 .take()
2392 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2393
2394 let rows = inner.rows;
2395 let cols = inner.cols;
2396 let buf = inner.buf;
2397
2398 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2399
2400 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2401 }
2402}
2403
2404#[cfg(feature = "python")]
2405#[pyclass(name = "JmaStream")]
2406pub struct JmaStreamPy {
2407 inner: JmaStream,
2408}
2409
2410#[cfg(feature = "python")]
2411#[pymethods]
2412impl JmaStreamPy {
2413 #[new]
2414 #[pyo3(signature = (period, phase=50.0, power=2))]
2415 fn new(period: usize, phase: f64, power: u32) -> PyResult<Self> {
2416 let params = JmaParams {
2417 period: Some(period),
2418 phase: Some(phase),
2419 power: Some(power),
2420 };
2421
2422 let stream = JmaStream::try_new(params)
2423 .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
2424
2425 Ok(Self { inner: stream })
2426 }
2427
2428 fn update(&mut self, value: f64) -> Option<f64> {
2429 self.inner.update(value)
2430 }
2431}
2432
2433#[inline]
2434pub fn jma_into_slice(dst: &mut [f64], input: &JmaInput, kern: Kernel) -> Result<(), JmaError> {
2435 let data: &[f64] = match &input.data {
2436 JmaData::Candles { candles, source } => source_type(candles, source),
2437 JmaData::Slice(sl) => sl,
2438 };
2439
2440 if dst.len() != data.len() {
2441 return Err(JmaError::OutputLengthMismatch {
2442 expected: data.len(),
2443 got: dst.len(),
2444 });
2445 }
2446
2447 jma_with_kernel_into(input, kern, dst)
2448}
2449
2450#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2451use serde::{Deserialize, Serialize};
2452#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2453use wasm_bindgen::prelude::*;
2454
2455#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2456#[wasm_bindgen]
2457pub fn jma_js(data: &[f64], period: usize, phase: f64, power: u32) -> Result<Vec<f64>, JsValue> {
2458 let params = JmaParams {
2459 period: Some(period),
2460 phase: Some(phase),
2461 power: Some(power),
2462 };
2463 let input = JmaInput::from_slice(data, params);
2464
2465 let mut output = vec![0.0; data.len()];
2466
2467 jma_into_slice(&mut output, &input, Kernel::Auto)
2468 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2469
2470 Ok(output)
2471}
2472
2473#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2474#[wasm_bindgen]
2475pub fn jma_alloc(len: usize) -> *mut f64 {
2476 let mut vec = Vec::<f64>::with_capacity(len);
2477 let ptr = vec.as_mut_ptr();
2478 std::mem::forget(vec);
2479 ptr
2480}
2481
2482#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2483#[wasm_bindgen]
2484pub fn jma_free(ptr: *mut f64, len: usize) {
2485 if !ptr.is_null() {
2486 unsafe {
2487 let _ = Vec::from_raw_parts(ptr, len, len);
2488 }
2489 }
2490}
2491
2492#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2493#[wasm_bindgen]
2494pub fn jma_into(
2495 in_ptr: *const f64,
2496 out_ptr: *mut f64,
2497 len: usize,
2498 period: usize,
2499 phase: f64,
2500 power: u32,
2501) -> Result<(), JsValue> {
2502 if in_ptr.is_null() || out_ptr.is_null() {
2503 return Err(JsValue::from_str("null pointer passed to jma_into"));
2504 }
2505
2506 unsafe {
2507 let data = std::slice::from_raw_parts(in_ptr, len);
2508
2509 if period == 0 || period > len {
2510 return Err(JsValue::from_str("Invalid period"));
2511 }
2512
2513 let params = JmaParams {
2514 period: Some(period),
2515 phase: Some(phase),
2516 power: Some(power),
2517 };
2518 let input = JmaInput::from_slice(data, params);
2519
2520 if in_ptr == out_ptr {
2521 let mut temp = vec![0.0; len];
2522 jma_into_slice(&mut temp, &input, Kernel::Auto)
2523 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2524
2525 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2526 out.copy_from_slice(&temp);
2527 } else {
2528 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2529 jma_into_slice(out, &input, Kernel::Auto)
2530 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2531 }
2532
2533 Ok(())
2534 }
2535}
2536
2537#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2538#[derive(Serialize, Deserialize)]
2539pub struct JmaBatchConfig {
2540 pub period_range: (usize, usize, usize),
2541 pub phase_range: (f64, f64, f64),
2542 pub power_range: (u32, u32, u32),
2543}
2544
2545#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2546#[derive(Serialize, Deserialize)]
2547pub struct JmaBatchJsOutput {
2548 pub values: Vec<f64>,
2549 pub combos: Vec<JmaParams>,
2550 pub rows: usize,
2551 pub cols: usize,
2552}
2553
2554#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2555#[wasm_bindgen(js_name = jma_batch)]
2556pub fn jma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2557 let config: JmaBatchConfig = serde_wasm_bindgen::from_value(config)
2558 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2559
2560 let sweep = JmaBatchRange {
2561 period: config.period_range,
2562 phase: config.phase_range,
2563 power: config.power_range,
2564 };
2565
2566 let simd = match detect_best_batch_kernel() {
2567 Kernel::Avx512Batch => Kernel::Avx512,
2568 Kernel::Avx2Batch => Kernel::Avx2,
2569 Kernel::ScalarBatch => Kernel::Scalar,
2570 _ => Kernel::Scalar,
2571 };
2572
2573 let output = jma_batch_inner(data, &sweep, simd, false)
2574 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2575
2576 let js_output = JmaBatchJsOutput {
2577 values: output.values,
2578 combos: output.combos,
2579 rows: output.rows,
2580 cols: output.cols,
2581 };
2582
2583 serde_wasm_bindgen::to_value(&js_output)
2584 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2585}
2586
2587#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2588#[wasm_bindgen]
2589#[deprecated(since = "1.0.0", note = "Use jma_batch instead")]
2590pub fn jma_batch_js(
2591 data: &[f64],
2592 period_start: usize,
2593 period_end: usize,
2594 period_step: usize,
2595 phase_start: f64,
2596 phase_end: f64,
2597 phase_step: f64,
2598 power_start: u32,
2599 power_end: u32,
2600 power_step: u32,
2601) -> Result<Vec<f64>, JsValue> {
2602 let sweep = JmaBatchRange {
2603 period: (period_start, period_end, period_step),
2604 phase: (phase_start, phase_end, phase_step),
2605 power: (power_start, power_end, power_step),
2606 };
2607
2608 jma_batch_inner(data, &sweep, Kernel::Scalar, false)
2609 .map(|output| output.values)
2610 .map_err(|e| JsValue::from_str(&e.to_string()))
2611}
2612
2613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2614#[wasm_bindgen]
2615pub fn jma_batch_metadata_js(
2616 period_start: usize,
2617 period_end: usize,
2618 period_step: usize,
2619 phase_start: f64,
2620 phase_end: f64,
2621 phase_step: f64,
2622 power_start: u32,
2623 power_end: u32,
2624 power_step: u32,
2625) -> Vec<f64> {
2626 let mut metadata = Vec::new();
2627
2628 let mut current_period = period_start;
2629 while current_period <= period_end {
2630 let mut current_phase = phase_start;
2631 while current_phase <= phase_end || (phase_step == 0.0 && current_phase == phase_start) {
2632 let mut current_power = power_start;
2633 while current_power <= power_end || (power_step == 0 && current_power == power_start) {
2634 metadata.push(current_period as f64);
2635 metadata.push(current_phase);
2636 metadata.push(current_power as f64);
2637
2638 if power_step == 0 {
2639 break;
2640 }
2641 current_power += power_step;
2642 }
2643 if phase_step == 0.0 {
2644 break;
2645 }
2646 current_phase += phase_step;
2647 }
2648 if period_step == 0 {
2649 break;
2650 }
2651 current_period += period_step;
2652 }
2653
2654 metadata
2655}
2656
2657#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2658#[wasm_bindgen]
2659pub fn jma_batch_into(
2660 in_ptr: *const f64,
2661 out_ptr: *mut f64,
2662 len: usize,
2663 period_start: usize,
2664 period_end: usize,
2665 period_step: usize,
2666 phase_start: f64,
2667 phase_end: f64,
2668 phase_step: f64,
2669 power_start: u32,
2670 power_end: u32,
2671 power_step: u32,
2672) -> Result<usize, JsValue> {
2673 if in_ptr.is_null() || out_ptr.is_null() {
2674 return Err(JsValue::from_str("null pointer passed to jma_batch_into"));
2675 }
2676
2677 unsafe {
2678 let data = std::slice::from_raw_parts(in_ptr, len);
2679
2680 let sweep = JmaBatchRange {
2681 period: (period_start, period_end, period_step),
2682 phase: (phase_start, phase_end, phase_step),
2683 power: (power_start, power_end, power_step),
2684 };
2685
2686 let combos = expand_grid_jma(&sweep);
2687 let rows = combos.len();
2688 let cols = len;
2689
2690 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2691
2692 jma_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
2693 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2694
2695 Ok(rows)
2696 }
2697}