1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::convert::AsRef;
14use std::error::Error;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18impl<'a> AsRef<[f64]> for JsaInput<'a> {
19 #[inline(always)]
20 fn as_ref(&self) -> &[f64] {
21 match &self.data {
22 JsaData::Slice(slice) => slice,
23 JsaData::Candles { candles, source } => source_type(candles, source),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub enum JsaData<'a> {
30 Candles {
31 candles: &'a Candles,
32 source: &'a str,
33 },
34 Slice(&'a [f64]),
35}
36
37#[derive(Debug, Clone)]
38pub struct JsaOutput {
39 pub values: Vec<f64>,
40}
41
42#[derive(Debug, Clone)]
43pub struct JsaParams {
44 pub period: Option<usize>,
45}
46
47impl Default for JsaParams {
48 fn default() -> Self {
49 Self { period: Some(30) }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct JsaInput<'a> {
55 pub data: JsaData<'a>,
56 pub params: JsaParams,
57}
58
59impl<'a> JsaInput<'a> {
60 #[inline(always)]
61 pub fn from_candles(c: &'a Candles, s: &'a str, p: JsaParams) -> Self {
62 Self {
63 data: JsaData::Candles {
64 candles: c,
65 source: s,
66 },
67 params: p,
68 }
69 }
70 #[inline(always)]
71 pub fn from_slice(sl: &'a [f64], p: JsaParams) -> Self {
72 Self {
73 data: JsaData::Slice(sl),
74 params: p,
75 }
76 }
77 #[inline(always)]
78 pub fn with_default_candles(c: &'a Candles) -> Self {
79 Self::from_candles(c, "close", JsaParams::default())
80 }
81 #[inline(always)]
82 pub fn get_period(&self) -> usize {
83 self.params.period.unwrap_or(30)
84 }
85}
86
87#[derive(Copy, Clone, Debug)]
88pub struct JsaBuilder {
89 period: Option<usize>,
90 kernel: Kernel,
91}
92
93impl Default for JsaBuilder {
94 fn default() -> Self {
95 Self {
96 period: None,
97 kernel: Kernel::Auto,
98 }
99 }
100}
101
102impl JsaBuilder {
103 #[inline(always)]
104 pub fn new() -> Self {
105 Self::default()
106 }
107 #[inline(always)]
108 pub fn period(mut self, n: usize) -> Self {
109 self.period = Some(n);
110 self
111 }
112 #[inline(always)]
113 pub fn kernel(mut self, k: Kernel) -> Self {
114 self.kernel = k;
115 self
116 }
117 #[inline(always)]
118 pub fn apply(self, c: &Candles) -> Result<JsaOutput, JsaError> {
119 let p = JsaParams {
120 period: self.period,
121 };
122 let i = JsaInput::from_candles(c, "close", p);
123 jsa_with_kernel(&i, self.kernel)
124 }
125 #[inline(always)]
126 pub fn apply_slice(self, d: &[f64]) -> Result<JsaOutput, JsaError> {
127 let p = JsaParams {
128 period: self.period,
129 };
130 let i = JsaInput::from_slice(d, p);
131 jsa_with_kernel(&i, self.kernel)
132 }
133 #[inline(always)]
134 pub fn into_stream(self) -> Result<JsaStream, JsaError> {
135 let p = JsaParams {
136 period: self.period,
137 };
138 JsaStream::try_new(p)
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum JsaError {
144 #[error("jsa: Input data slice is empty.")]
145 EmptyInputData,
146
147 #[error("jsa: All values are NaN.")]
148 AllValuesNaN,
149
150 #[error("jsa: Invalid period: period = {period}, data length = {data_len}")]
151 InvalidPeriod { period: usize, data_len: usize },
152
153 #[error("jsa: Not enough valid data: needed = {needed}, valid = {valid}")]
154 NotEnoughValidData { needed: usize, valid: usize },
155
156 #[error("jsa: output length mismatch: expected = {expected}, got = {got}")]
157 OutputLengthMismatch { expected: usize, got: usize },
158
159 #[error("jsa: invalid kernel for batch op: {kernel:?}")]
160 InvalidKernel { kernel: Kernel },
161
162 #[error("jsa: invalid range expansion: start={start}, end={end}, step={step}")]
163 InvalidRange {
164 start: usize,
165 end: usize,
166 step: usize,
167 },
168
169 #[error("jsa: arithmetic overflow while computing sizes")]
170 ArithmeticOverflow,
171
172 #[error("jsa: invalid kernel used for batch op: {kernel:?}")]
173 InvalidKernelForBatch { kernel: Kernel },
174}
175
176#[inline]
177pub fn jsa(input: &JsaInput) -> Result<JsaOutput, JsaError> {
178 jsa_with_kernel(input, Kernel::Auto)
179}
180
181#[inline(always)]
182fn jsa_compute_into(data: &[f64], period: usize, first: usize, k: Kernel, out: &mut [f64]) {
183 unsafe {
184 match k {
185 Kernel::Scalar | Kernel::ScalarBatch => jsa_scalar(data, period, first, out),
186 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
187 Kernel::Avx2 | Kernel::Avx2Batch => jsa_avx2(data, period, first, out),
188 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
189 Kernel::Avx512 | Kernel::Avx512Batch => jsa_avx512(data, period, first, out),
190 _ => unreachable!(),
191 }
192 }
193}
194
195pub fn jsa_with_kernel(input: &JsaInput, kernel: Kernel) -> Result<JsaOutput, JsaError> {
196 let data: &[f64] = match &input.data {
197 JsaData::Candles { candles, source } => source_type(candles, source),
198 JsaData::Slice(sl) => sl,
199 };
200
201 if data.is_empty() {
202 return Err(JsaError::EmptyInputData);
203 }
204
205 let first = data
206 .iter()
207 .position(|x| !x.is_nan())
208 .ok_or(JsaError::AllValuesNaN)?;
209 let len = data.len();
210 let period = input.get_period();
211
212 if period == 0 || period > len {
213 return Err(JsaError::InvalidPeriod {
214 period,
215 data_len: len,
216 });
217 }
218 if (len - first) < period {
219 return Err(JsaError::NotEnoughValidData {
220 needed: period,
221 valid: len - first,
222 });
223 }
224
225 let warm = first
226 .checked_add(period)
227 .ok_or(JsaError::ArithmeticOverflow)?;
228 let mut out = alloc_with_nan_prefix(len, warm);
229 let chosen = match kernel {
230 Kernel::Auto => Kernel::Scalar,
231 k => k,
232 };
233 jsa_compute_into(data, period, first, chosen, &mut out);
234 Ok(JsaOutput { values: out })
235}
236
237#[inline]
238#[inline]
239pub fn jsa_into(input: &JsaInput, out: &mut [f64]) -> Result<(), JsaError> {
240 jsa_with_kernel_into(input, Kernel::Auto, out)
241}
242
243#[inline]
244pub fn jsa_into_slice(dst: &mut [f64], input: &JsaInput, kern: Kernel) -> Result<(), JsaError> {
245 jsa_with_kernel_into(input, kern, dst)
246}
247
248pub fn jsa_with_kernel_into(
249 input: &JsaInput,
250 kernel: Kernel,
251 out: &mut [f64],
252) -> Result<(), JsaError> {
253 let data: &[f64] = match &input.data {
254 JsaData::Candles { candles, source } => source_type(candles, source),
255 JsaData::Slice(sl) => sl,
256 };
257
258 if data.is_empty() {
259 return Err(JsaError::EmptyInputData);
260 }
261
262 let len = data.len();
263
264 if out.len() != len {
265 return Err(JsaError::OutputLengthMismatch {
266 expected: len,
267 got: out.len(),
268 });
269 }
270
271 let first = data
272 .iter()
273 .position(|x| !x.is_nan())
274 .ok_or(JsaError::AllValuesNaN)?;
275 let period = input.get_period();
276
277 if period == 0 || period > len {
278 return Err(JsaError::InvalidPeriod {
279 period,
280 data_len: len,
281 });
282 }
283 if (len - first) < period {
284 return Err(JsaError::NotEnoughValidData {
285 needed: period,
286 valid: len - first,
287 });
288 }
289
290 let warm = first + period;
291
292 out[..warm].fill(f64::NAN);
293
294 let chosen = match kernel {
295 Kernel::Auto => Kernel::Scalar,
296 k => k,
297 };
298 jsa_compute_into(data, period, first, chosen, out);
299 Ok(())
300}
301
302#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
303#[inline]
304pub fn jsa_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
305 if period <= 32 {
306 unsafe { jsa_avx512_short(data, period, first_valid, out) }
307 } else {
308 unsafe { jsa_avx512_long(data, period, first_valid, out) }
309 }
310}
311
312#[inline]
313pub fn jsa_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
314 for i in (first_val + period)..data.len() {
315 out[i] = (data[i] + data[i - period]) * 0.5;
316 }
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[inline(always)]
321pub unsafe fn jsa_avx2(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
322 use core::arch::x86_64::*;
323 let len = data.len();
324 let start = first_val + period;
325 if start >= len {
326 return;
327 }
328
329 let dp = data.as_ptr();
330 let op = out.as_mut_ptr();
331
332 let mut p_cur = dp.add(start);
333 let mut p_past = dp.add(start - period);
334 let mut p_out = op.add(start);
335 let end = op.add(len);
336
337 let half = _mm256_set1_pd(0.5);
338
339 while p_out.add(4) <= end {
340 let x = _mm256_loadu_pd(p_cur);
341 let y = _mm256_loadu_pd(p_past);
342 let s = _mm256_add_pd(x, y);
343 let a = _mm256_mul_pd(s, half);
344 _mm256_storeu_pd(p_out, a);
345 p_cur = p_cur.add(4);
346 p_past = p_past.add(4);
347 p_out = p_out.add(4);
348 }
349
350 while p_out < end {
351 *p_out = (*p_cur + *p_past) * 0.5;
352 p_cur = p_cur.add(1);
353 p_past = p_past.add(1);
354 p_out = p_out.add(1);
355 }
356}
357
358#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359#[inline(always)]
360pub unsafe fn jsa_avx512_short(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
361 use core::arch::x86_64::*;
362 let len = data.len();
363 let start = first_val + period;
364 if start >= len {
365 return;
366 }
367
368 let dp = data.as_ptr();
369 let op = out.as_mut_ptr();
370
371 let mut p_cur = dp.add(start);
372 let mut p_past = dp.add(start - period);
373 let mut p_out = op.add(start);
374 let end = op.add(len);
375
376 let half = _mm512_set1_pd(0.5);
377
378 while p_out.add(8) <= end {
379 let x = _mm512_loadu_pd(p_cur);
380 let y = _mm512_loadu_pd(p_past);
381 let s = _mm512_add_pd(x, y);
382 let a = _mm512_mul_pd(s, half);
383 _mm512_storeu_pd(p_out, a);
384 p_cur = p_cur.add(8);
385 p_past = p_past.add(8);
386 p_out = p_out.add(8);
387 }
388
389 while p_out < end {
390 *p_out = (*p_cur + *p_past) * 0.5;
391 p_cur = p_cur.add(1);
392 p_past = p_past.add(1);
393 p_out = p_out.add(1);
394 }
395}
396
397#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
398#[inline(always)]
399pub unsafe fn jsa_avx512_long(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
400 use core::arch::x86_64::*;
401 let len = data.len();
402 let start = first_val + period;
403 if start >= len {
404 return;
405 }
406
407 let dp = data.as_ptr();
408 let op = out.as_mut_ptr();
409
410 let mut p_cur = dp.add(start);
411 let mut p_past = dp.add(start - period);
412 let mut p_out = op.add(start);
413 let end = op.add(len);
414
415 let half = _mm512_set1_pd(0.5);
416
417 while p_out.add(8) <= end {
418 let x = _mm512_loadu_pd(p_cur);
419 let y = _mm512_loadu_pd(p_past);
420 let s = _mm512_add_pd(x, y);
421 let a = _mm512_mul_pd(s, half);
422 _mm512_storeu_pd(p_out, a);
423 p_cur = p_cur.add(8);
424 p_past = p_past.add(8);
425 p_out = p_out.add(8);
426 }
427 while p_out < end {
428 *p_out = (*p_cur + *p_past) * 0.5;
429 p_cur = p_cur.add(1);
430 p_past = p_past.add(1);
431 p_out = p_out.add(1);
432 }
433}
434
435#[derive(Debug, Clone)]
436pub struct JsaStream {
437 period: usize,
438 buffer: Vec<f64>,
439 head: usize,
440 filled: bool,
441}
442
443impl JsaStream {
444 pub fn try_new(params: JsaParams) -> Result<Self, JsaError> {
445 let period = params.period.unwrap_or(30);
446 if period == 0 {
447 return Err(JsaError::InvalidPeriod {
448 period,
449 data_len: 0,
450 });
451 }
452 Ok(Self {
453 period,
454 buffer: vec![f64::NAN; period],
455 head: 0,
456 filled: false,
457 })
458 }
459
460 #[inline(always)]
461 pub fn update(&mut self, value: f64) -> Option<f64> {
462 let out = if self.filled {
463 let past = self.buffer[self.head];
464 Some((value + past) * 0.5)
465 } else {
466 None
467 };
468
469 self.buffer[self.head] = value;
470
471 let next = self.head + 1;
472 if next == self.period {
473 self.head = 0;
474 if !self.filled {
475 self.filled = true;
476 }
477 } else {
478 self.head = next;
479 }
480
481 out
482 }
483}
484
485#[derive(Clone, Debug)]
486pub struct JsaBatchRange {
487 pub period: (usize, usize, usize),
488}
489
490impl Default for JsaBatchRange {
491 fn default() -> Self {
492 Self {
493 period: (30, 279, 1),
494 }
495 }
496}
497
498#[derive(Clone, Debug, Default)]
499pub struct JsaBatchBuilder {
500 range: JsaBatchRange,
501 kernel: Kernel,
502}
503
504impl JsaBatchBuilder {
505 pub fn new() -> Self {
506 Self::default()
507 }
508 pub fn kernel(mut self, k: Kernel) -> Self {
509 self.kernel = k;
510 self
511 }
512 #[inline]
513 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
514 self.range.period = (start, end, step);
515 self
516 }
517 #[inline]
518 pub fn period_static(mut self, p: usize) -> Self {
519 self.range.period = (p, p, 0);
520 self
521 }
522 pub fn apply_slice(self, data: &[f64]) -> Result<JsaBatchOutput, JsaError> {
523 jsa_batch_with_kernel(data, &self.range, self.kernel)
524 }
525 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<JsaBatchOutput, JsaError> {
526 JsaBatchBuilder::new().kernel(k).apply_slice(data)
527 }
528 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<JsaBatchOutput, JsaError> {
529 let slice = source_type(c, src);
530 self.apply_slice(slice)
531 }
532 pub fn with_default_candles(c: &Candles) -> Result<JsaBatchOutput, JsaError> {
533 JsaBatchBuilder::new()
534 .kernel(Kernel::Auto)
535 .apply_candles(c, "close")
536 }
537}
538
539pub fn jsa_batch_with_kernel(
540 data: &[f64],
541 sweep: &JsaBatchRange,
542 k: Kernel,
543) -> Result<JsaBatchOutput, JsaError> {
544 let kernel = match k {
545 Kernel::Auto => detect_best_batch_kernel(),
546 other if other.is_batch() => other,
547 other => return Err(JsaError::InvalidKernelForBatch { kernel: other }),
548 };
549
550 let simd = match kernel {
551 Kernel::Avx512Batch => Kernel::Avx512,
552 Kernel::Avx2Batch => Kernel::Avx2,
553 Kernel::ScalarBatch => Kernel::Scalar,
554 _ => unreachable!(),
555 };
556 jsa_batch_par_slice(data, sweep, simd)
557}
558
559#[derive(Clone, Debug)]
560pub struct JsaBatchOutput {
561 pub values: Vec<f64>,
562 pub combos: Vec<JsaParams>,
563 pub rows: usize,
564 pub cols: usize,
565}
566impl JsaBatchOutput {
567 pub fn row_for_params(&self, p: &JsaParams) -> Option<usize> {
568 self.combos
569 .iter()
570 .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
571 }
572 pub fn values_for(&self, p: &JsaParams) -> Option<&[f64]> {
573 self.row_for_params(p).map(|row| {
574 let start = row * self.cols;
575 &self.values[start..start + self.cols]
576 })
577 }
578}
579
580#[inline(always)]
581fn expand_grid(r: &JsaBatchRange) -> Result<Vec<JsaParams>, JsaError> {
582 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, JsaError> {
583 if step == 0 || start == end {
584 return Ok(vec![start]);
585 }
586 let mut v = Vec::new();
587 if start < end {
588 let mut cur = start;
589 while cur <= end {
590 v.push(cur);
591
592 match cur.checked_add(step) {
593 Some(next) => cur = next,
594 None => break,
595 }
596 }
597 } else {
598 let mut cur = start;
599 while cur >= end {
600 v.push(cur);
601
602 if cur < end + step {
603 break;
604 }
605 cur -= step;
606 if cur == usize::MAX {
607 break;
608 }
609 }
610 }
611 if v.is_empty() {
612 return Err(JsaError::InvalidRange { start, end, step });
613 }
614 Ok(v)
615 }
616 let periods = axis_usize(r.period)?;
617 let mut out = Vec::with_capacity(periods.len());
618 for &p in &periods {
619 out.push(JsaParams { period: Some(p) });
620 }
621 Ok(out)
622}
623
624#[inline(always)]
625pub fn jsa_batch_slice(
626 data: &[f64],
627 sweep: &JsaBatchRange,
628 kern: Kernel,
629) -> Result<JsaBatchOutput, JsaError> {
630 jsa_batch_inner(data, sweep, kern, false)
631}
632
633#[inline(always)]
634pub fn jsa_batch_par_slice(
635 data: &[f64],
636 sweep: &JsaBatchRange,
637 kern: Kernel,
638) -> Result<JsaBatchOutput, JsaError> {
639 jsa_batch_inner(data, sweep, kern, true)
640}
641
642#[inline(always)]
643fn jsa_batch_inner(
644 data: &[f64],
645 sweep: &JsaBatchRange,
646 kern: Kernel,
647 parallel: bool,
648) -> Result<JsaBatchOutput, JsaError> {
649 if data.is_empty() {
650 return Err(JsaError::EmptyInputData);
651 }
652
653 let combos = expand_grid(sweep)?;
654 let first = data
655 .iter()
656 .position(|x| !x.is_nan())
657 .ok_or(JsaError::AllValuesNaN)?;
658 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
659 if data.len() - first < max_p {
660 return Err(JsaError::NotEnoughValidData {
661 needed: max_p,
662 valid: data.len() - first,
663 });
664 }
665 let rows = combos.len();
666 let cols = data.len();
667
668 let _total = rows.checked_mul(cols).ok_or(JsaError::ArithmeticOverflow)?;
669 let mut warm: Vec<usize> = Vec::with_capacity(rows);
670 for c in &combos {
671 let p = c.period.unwrap();
672 let w = first.checked_add(p).ok_or(JsaError::ArithmeticOverflow)?;
673 warm.push(w);
674 }
675
676 let mut raw = make_uninit_matrix(rows, cols);
677
678 init_matrix_prefixes(&mut raw, cols, &warm);
679
680 let actual_kern = match kern {
681 Kernel::Auto => detect_best_batch_kernel(),
682 k => k,
683 };
684
685 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
686 let period = combos[row].period.unwrap();
687
688 let out_row =
689 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
690
691 match actual_kern {
692 Kernel::ScalarBatch | Kernel::Scalar => jsa_row_scalar(data, first, period, out_row),
693 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
694 Kernel::Avx2Batch | Kernel::Avx2 => jsa_row_avx2(data, first, period, out_row),
695 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
696 Kernel::Avx512Batch | Kernel::Avx512 => jsa_row_avx512(data, first, period, out_row),
697 _ => unreachable!(),
698 }
699 };
700
701 if parallel {
702 #[cfg(not(target_arch = "wasm32"))]
703 {
704 raw.par_chunks_mut(cols)
705 .enumerate()
706 .for_each(|(row, slice)| do_row(row, slice));
707 }
708
709 #[cfg(target_arch = "wasm32")]
710 {
711 for (row, slice) in raw.chunks_mut(cols).enumerate() {
712 do_row(row, slice);
713 }
714 }
715 } else {
716 for (row, slice) in raw.chunks_mut(cols).enumerate() {
717 do_row(row, slice);
718 }
719 }
720
721 use core::mem::ManuallyDrop;
722
723 let mut buf_guard = ManuallyDrop::new(raw);
724 let values = unsafe {
725 Vec::from_raw_parts(
726 buf_guard.as_mut_ptr() as *mut f64,
727 buf_guard.len(),
728 buf_guard.capacity(),
729 )
730 };
731
732 Ok(JsaBatchOutput {
733 values,
734 combos,
735 rows,
736 cols,
737 })
738}
739
740#[inline(always)]
741fn jsa_batch_inner_into(
742 data: &[f64],
743 sweep: &JsaBatchRange,
744 kern: Kernel,
745 parallel: bool,
746 out: &mut [f64],
747) -> Result<(Vec<JsaParams>, usize, usize), JsaError> {
748 if data.is_empty() {
749 return Err(JsaError::EmptyInputData);
750 }
751
752 let combos = expand_grid(sweep)?;
753 let first = data
754 .iter()
755 .position(|x| !x.is_nan())
756 .ok_or(JsaError::AllValuesNaN)?;
757 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
758 if data.len() - first < max_p {
759 return Err(JsaError::NotEnoughValidData {
760 needed: max_p,
761 valid: data.len() - first,
762 });
763 }
764 let rows = combos.len();
765 let cols = data.len();
766
767 let expected = rows.checked_mul(cols).ok_or(JsaError::ArithmeticOverflow)?;
768 if out.len() != expected {
769 return Err(JsaError::OutputLengthMismatch {
770 expected,
771 got: out.len(),
772 });
773 }
774 let mut warm: Vec<usize> = Vec::with_capacity(rows);
775 for c in &combos {
776 let p = c.period.unwrap();
777 let w = first.checked_add(p).ok_or(JsaError::ArithmeticOverflow)?;
778 warm.push(w);
779 }
780
781 let out_uninit = unsafe {
782 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
783 };
784 init_matrix_prefixes(out_uninit, cols, &warm);
785
786 let actual_kern = match kern {
787 Kernel::Auto => detect_best_batch_kernel(),
788 k => k,
789 };
790
791 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
792 let period = combos[row].period.unwrap();
793
794 match actual_kern {
795 Kernel::ScalarBatch | Kernel::Scalar => jsa_row_scalar(data, first, period, out_row),
796 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
797 Kernel::Avx2Batch | Kernel::Avx2 => jsa_row_avx2(data, first, period, out_row),
798 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799 Kernel::Avx512Batch | Kernel::Avx512 => jsa_row_avx512(data, first, period, out_row),
800 _ => unreachable!(),
801 }
802 };
803
804 if parallel {
805 #[cfg(not(target_arch = "wasm32"))]
806 {
807 out.par_chunks_mut(cols)
808 .enumerate()
809 .for_each(|(row, slice)| do_row(row, slice));
810 }
811 #[cfg(target_arch = "wasm32")]
812 {
813 for (row, slice) in out.chunks_mut(cols).enumerate() {
814 do_row(row, slice);
815 }
816 }
817 } else {
818 for (row, slice) in out.chunks_mut(cols).enumerate() {
819 do_row(row, slice);
820 }
821 }
822
823 Ok((combos, rows, cols))
824}
825
826#[inline(always)]
827unsafe fn jsa_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
828 for i in (first + period)..data.len() {
829 out[i] = (data[i] + data[i - period]) * 0.5;
830 }
831}
832
833#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
834#[inline(always)]
835unsafe fn jsa_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
836 use core::arch::x86_64::*;
837 let len = data.len();
838 let start = first + period;
839 if start >= len {
840 return;
841 }
842
843 let dp = data.as_ptr();
844 let op = out.as_mut_ptr();
845
846 let mut p_cur = dp.add(start);
847 let mut p_past = dp.add(start - period);
848 let mut p_out = op.add(start);
849 let end = op.add(len);
850
851 let half = _mm256_set1_pd(0.5);
852
853 while p_out.add(4) <= end {
854 let x = _mm256_loadu_pd(p_cur);
855 let y = _mm256_loadu_pd(p_past);
856 let s = _mm256_add_pd(x, y);
857 let a = _mm256_mul_pd(s, half);
858 _mm256_storeu_pd(p_out, a);
859 p_cur = p_cur.add(4);
860 p_past = p_past.add(4);
861 p_out = p_out.add(4);
862 }
863 while p_out < end {
864 *p_out = (*p_cur + *p_past) * 0.5;
865 p_cur = p_cur.add(1);
866 p_past = p_past.add(1);
867 p_out = p_out.add(1);
868 }
869}
870
871#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
872#[inline(always)]
873unsafe fn jsa_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
874 if period <= 32 {
875 jsa_row_avx512_short(data, first, period, out);
876 } else {
877 jsa_row_avx512_long(data, first, period, out);
878 }
879}
880
881#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
882#[inline(always)]
883unsafe fn jsa_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
884 use core::arch::x86_64::*;
885 let len = data.len();
886 let start = first + period;
887 if start >= len {
888 return;
889 }
890
891 let dp = data.as_ptr();
892 let op = out.as_mut_ptr();
893
894 let mut p_cur = dp.add(start);
895 let mut p_past = dp.add(start - period);
896 let mut p_out = op.add(start);
897 let end = op.add(len);
898
899 let half = _mm512_set1_pd(0.5);
900
901 while p_out.add(8) <= end {
902 let x = _mm512_loadu_pd(p_cur);
903 let y = _mm512_loadu_pd(p_past);
904 let s = _mm512_add_pd(x, y);
905 let a = _mm512_mul_pd(s, half);
906 _mm512_storeu_pd(p_out, a);
907 p_cur = p_cur.add(8);
908 p_past = p_past.add(8);
909 p_out = p_out.add(8);
910 }
911 while p_out < end {
912 *p_out = (*p_cur + *p_past) * 0.5;
913 p_cur = p_cur.add(1);
914 p_past = p_past.add(1);
915 p_out = p_out.add(1);
916 }
917}
918
919#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
920#[inline(always)]
921unsafe fn jsa_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
922 use core::arch::x86_64::*;
923 let len = data.len();
924 let start = first + period;
925 if start >= len {
926 return;
927 }
928
929 let dp = data.as_ptr();
930 let op = out.as_mut_ptr();
931
932 let mut p_cur = dp.add(start);
933 let mut p_past = dp.add(start - period);
934 let mut p_out = op.add(start);
935 let end = op.add(len);
936
937 let half = _mm512_set1_pd(0.5);
938
939 while p_out.add(8) <= end {
940 let x = _mm512_loadu_pd(p_cur);
941 let y = _mm512_loadu_pd(p_past);
942 let s = _mm512_add_pd(x, y);
943 let a = _mm512_mul_pd(s, half);
944 _mm512_storeu_pd(p_out, a);
945 p_cur = p_cur.add(8);
946 p_past = p_past.add(8);
947 p_out = p_out.add(8);
948 }
949 while p_out < end {
950 *p_out = (*p_cur + *p_past) * 0.5;
951 p_cur = p_cur.add(1);
952 p_past = p_past.add(1);
953 p_out = p_out.add(1);
954 }
955}
956
957#[cfg(feature = "python")]
958use numpy::PyUntypedArrayMethods;
959#[cfg(feature = "python")]
960use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
961#[cfg(feature = "python")]
962use pyo3::exceptions::PyValueError;
963#[cfg(feature = "python")]
964use pyo3::prelude::*;
965#[cfg(feature = "python")]
966use pyo3::types::PyDict;
967
968#[cfg(all(feature = "python", feature = "cuda"))]
969use crate::cuda::moving_averages::jsa_wrapper::JsaDeviceHandle;
970#[cfg(all(feature = "python", feature = "cuda"))]
971use crate::cuda::{cuda_available, moving_averages::CudaJsa};
972#[cfg(all(feature = "python", feature = "cuda"))]
973use cust::context::Context;
974#[cfg(all(feature = "python", feature = "cuda"))]
975use cust::memory::DeviceBuffer;
976#[cfg(all(feature = "python", feature = "cuda"))]
977use std::sync::Arc;
978
979#[cfg(feature = "python")]
980#[pyfunction]
981#[pyo3(name = "jsa")]
982#[pyo3(signature = (data, period, kernel=None))]
983pub fn jsa_py<'py>(
984 py: Python<'py>,
985 data: PyReadonlyArray1<'py, f64>,
986 period: usize,
987 kernel: Option<&str>,
988) -> PyResult<Bound<'py, PyArray1<f64>>> {
989 use numpy::PyArrayMethods;
990
991 let kern = validate_kernel(kernel, false)?;
992
993 if data.is_c_contiguous() {
994 let slice_in = data.as_slice()?;
995
996 let out_arr = unsafe { PyArray1::<f64>::new(py, [slice_in.len()], false) };
997 let slice_out = unsafe { out_arr.as_slice_mut()? };
998 let params = JsaParams {
999 period: Some(period),
1000 };
1001 let input = JsaInput::from_slice(slice_in, params);
1002 py.allow_threads(|| jsa_with_kernel_into(&input, kern, slice_out))
1003 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1004 Ok(out_arr)
1005 } else {
1006 let owned = data.as_array().to_owned();
1007 let slice_in = owned.as_slice().expect("owned array should be contiguous");
1008 let params = JsaParams {
1009 period: Some(period),
1010 };
1011 let input = JsaInput::from_slice(slice_in, params);
1012 let mut buf = vec![f64::NAN; slice_in.len()];
1013 py.allow_threads(|| jsa_with_kernel_into(&input, kern, &mut buf))
1014 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1015 Ok(PyArray1::from_vec(py, buf))
1016 }
1017}
1018
1019#[cfg(feature = "python")]
1020#[pyfunction]
1021#[pyo3(name = "jsa_batch")]
1022#[pyo3(signature = (data, period_start, period_end, period_step, kernel=None))]
1023pub fn jsa_batch_py<'py>(
1024 py: Python<'py>,
1025 data: PyReadonlyArray1<'py, f64>,
1026 period_start: usize,
1027 period_end: usize,
1028 period_step: usize,
1029 kernel: Option<&str>,
1030) -> PyResult<Bound<'py, PyDict>> {
1031 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1032 use pyo3::types::PyDict;
1033
1034 let slice_in = data.as_slice()?;
1035 let kern = validate_kernel(kernel, true)?;
1036 let sweep = JsaBatchRange {
1037 period: (period_start, period_end, period_step),
1038 };
1039
1040 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1041 let rows = combos.len();
1042 let cols = slice_in.len();
1043 let total = rows
1044 .checked_mul(cols)
1045 .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1046
1047 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1048 let slice_out = unsafe { out_arr.as_slice_mut()? };
1049
1050 let (combos, _, _) = py
1051 .allow_threads(|| {
1052 let kernel = match kern {
1053 Kernel::Auto => detect_best_batch_kernel(),
1054 k => k,
1055 };
1056
1057 let simd = match kernel {
1058 Kernel::Avx512Batch => Kernel::Avx512,
1059 Kernel::Avx2Batch => Kernel::Avx2,
1060 Kernel::ScalarBatch => Kernel::Scalar,
1061 _ => kernel,
1062 };
1063
1064 jsa_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1065 })
1066 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1067
1068 let dict = PyDict::new(py);
1069 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1070
1071 dict.set_item(
1072 "periods",
1073 combos
1074 .iter()
1075 .map(|p| p.period.unwrap() as u64)
1076 .collect::<Vec<_>>()
1077 .into_pyarray(py),
1078 )?;
1079
1080 Ok(dict)
1081}
1082
1083#[cfg(all(feature = "python", feature = "cuda"))]
1084#[pyfunction(name = "jsa_cuda_batch_dev")]
1085#[pyo3(signature = (data_f32, period_range=(30, 30, 0), device_id=0))]
1086pub fn jsa_cuda_batch_dev_py(
1087 py: Python<'_>,
1088 data_f32: PyReadonlyArray1<'_, f32>,
1089 period_range: (usize, usize, usize),
1090 device_id: usize,
1091) -> PyResult<JsaDeviceArrayF32Py> {
1092 if !cuda_available() {
1093 return Err(PyValueError::new_err("CUDA not available"));
1094 }
1095
1096 let slice_in = data_f32.as_slice()?;
1097 let sweep = JsaBatchRange {
1098 period: period_range,
1099 };
1100
1101 use crate::cuda::moving_averages::CudaJsaError;
1102 let handle: JsaDeviceHandle = py
1103 .allow_threads(|| -> Result<_, CudaJsaError> {
1104 let cuda = CudaJsa::new(device_id)?;
1105 cuda.jsa_batch_dev_handle(slice_in, &sweep)
1106 })
1107 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1108
1109 Ok(JsaDeviceArrayF32Py::from_handle_rust(handle))
1110}
1111
1112#[cfg(all(feature = "python", feature = "cuda"))]
1113#[pyfunction(name = "jsa_cuda_many_series_one_param_dev")]
1114#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1115pub fn jsa_cuda_many_series_one_param_dev_py(
1116 py: Python<'_>,
1117 data_tm_f32: PyReadonlyArray2<'_, f32>,
1118 period: usize,
1119 device_id: usize,
1120) -> PyResult<JsaDeviceArrayF32Py> {
1121 if !cuda_available() {
1122 return Err(PyValueError::new_err("CUDA not available"));
1123 }
1124 if period == 0 {
1125 return Err(PyValueError::new_err("period must be positive"));
1126 }
1127
1128 let flat = data_tm_f32.as_slice()?;
1129 let shape = data_tm_f32.shape();
1130 if shape.len() != 2 {
1131 return Err(PyValueError::new_err("expected a 2D array"));
1132 }
1133 let series_len = shape[0];
1134 let num_series = shape[1];
1135 let params = JsaParams {
1136 period: Some(period),
1137 };
1138
1139 use crate::cuda::moving_averages::CudaJsaError;
1140 let handle: JsaDeviceHandle = py
1141 .allow_threads(|| -> Result<_, CudaJsaError> {
1142 let cuda = CudaJsa::new(device_id)?;
1143 cuda.jsa_many_series_one_param_time_major_dev_handle(
1144 flat, num_series, series_len, ¶ms,
1145 )
1146 })
1147 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1148
1149 Ok(JsaDeviceArrayF32Py::from_handle_rust(handle))
1150}
1151
1152#[cfg(feature = "python")]
1153#[pyclass(name = "JsaStream")]
1154pub struct JsaStreamPy {
1155 inner: JsaStream,
1156}
1157
1158#[cfg(feature = "python")]
1159#[pymethods]
1160impl JsaStreamPy {
1161 #[new]
1162 pub fn new(period: usize) -> PyResult<Self> {
1163 let params = JsaParams {
1164 period: Some(period),
1165 };
1166 let stream =
1167 JsaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1168 Ok(JsaStreamPy { inner: stream })
1169 }
1170
1171 pub fn update(&mut self, value: f64) -> Option<f64> {
1172 self.inner.update(value)
1173 }
1174}
1175
1176#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1177use serde::{Deserialize, Serialize};
1178#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1179use wasm_bindgen::prelude::*;
1180
1181#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1182#[wasm_bindgen]
1183pub fn jsa_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1184 let params = JsaParams {
1185 period: Some(period),
1186 };
1187 let input = JsaInput::from_slice(data, params);
1188
1189 let mut output = vec![0.0; data.len()];
1190
1191 jsa_into(&input, &mut output).map_err(|e| JsValue::from_str(&e.to_string()))?;
1192
1193 Ok(output)
1194}
1195
1196#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1197#[derive(Serialize, Deserialize)]
1198pub struct JsaBatchConfig {
1199 pub period_range: (usize, usize, usize),
1200}
1201
1202#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1203#[derive(Serialize, Deserialize)]
1204pub struct JsaBatchJsOutput {
1205 pub values: Vec<f64>,
1206 pub periods: Vec<usize>,
1207 pub rows: usize,
1208 pub cols: usize,
1209}
1210
1211#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1212#[wasm_bindgen(js_name = jsa_batch)]
1213pub fn jsa_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1214 let config: JsaBatchConfig = serde_wasm_bindgen::from_value(config)
1215 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1216
1217 let sweep = JsaBatchRange {
1218 period: config.period_range,
1219 };
1220
1221 let output = jsa_batch_inner(data, &sweep, Kernel::Auto, false)
1222 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1223
1224 let js_output = JsaBatchJsOutput {
1225 values: output.values,
1226 periods: output
1227 .combos
1228 .iter()
1229 .map(|p| p.period.unwrap_or(30))
1230 .collect(),
1231 rows: output.rows,
1232 cols: output.cols,
1233 };
1234
1235 serde_wasm_bindgen::to_value(&js_output)
1236 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1237}
1238
1239#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1240#[wasm_bindgen(js_name = jsa_batch_simple)]
1241pub fn jsa_batch_simple(
1242 data: &[f64],
1243 period_start: usize,
1244 period_end: usize,
1245 period_step: usize,
1246) -> Result<Vec<f64>, JsValue> {
1247 let sweep = JsaBatchRange {
1248 period: (period_start, period_end, period_step),
1249 };
1250
1251 jsa_batch_inner(data, &sweep, Kernel::Auto, false)
1252 .map(|output| output.values)
1253 .map_err(|e| JsValue::from_str(&e.to_string()))
1254}
1255
1256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1257#[wasm_bindgen]
1258pub fn jsa_alloc(len: usize) -> *mut f64 {
1259 let mut vec = Vec::<f64>::with_capacity(len);
1260 let ptr = vec.as_mut_ptr();
1261 std::mem::forget(vec);
1262 ptr
1263}
1264
1265#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1266#[wasm_bindgen]
1267pub fn jsa_free(ptr: *mut f64, len: usize) {
1268 if !ptr.is_null() {
1269 unsafe {
1270 let _ = Vec::from_raw_parts(ptr, len, len);
1271 }
1272 }
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen(js_name = jsa_into)]
1277pub fn jsa_into_wasm(
1278 in_ptr: *const f64,
1279 out_ptr: *mut f64,
1280 len: usize,
1281 period: usize,
1282) -> Result<(), JsValue> {
1283 if in_ptr.is_null() || out_ptr.is_null() {
1284 return Err(JsValue::from_str("null pointer passed to jsa_into"));
1285 }
1286 unsafe {
1287 let data = std::slice::from_raw_parts(in_ptr, len);
1288 if period == 0 || period > len {
1289 return Err(JsValue::from_str("Invalid period"));
1290 }
1291 let input = JsaInput::from_slice(
1292 data,
1293 JsaParams {
1294 period: Some(period),
1295 },
1296 );
1297 if in_ptr == out_ptr {
1298 let mut temp = vec![0.0; len];
1299 jsa_into_slice(&mut temp, &input, Kernel::Auto)
1300 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1301 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&temp);
1302 } else {
1303 jsa_into_slice(
1304 std::slice::from_raw_parts_mut(out_ptr, len),
1305 &input,
1306 Kernel::Auto,
1307 )
1308 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1309 }
1310 Ok(())
1311 }
1312}
1313
1314#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1315#[wasm_bindgen]
1316#[deprecated(note = "use jsa_into")]
1317pub fn jsa_fast(
1318 in_ptr: *const f64,
1319 out_ptr: *mut f64,
1320 len: usize,
1321 period: usize,
1322) -> Result<(), JsValue> {
1323 jsa_into_wasm(in_ptr, out_ptr, len, period)
1324}
1325
1326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1327#[wasm_bindgen]
1328pub fn jsa_batch_into(
1329 in_ptr: *const f64,
1330 out_ptr: *mut f64,
1331 len: usize,
1332 period_start: usize,
1333 period_end: usize,
1334 period_step: usize,
1335) -> Result<usize, JsValue> {
1336 if in_ptr.is_null() || out_ptr.is_null() {
1337 return Err(JsValue::from_str("null pointer passed to jsa_batch_into"));
1338 }
1339
1340 unsafe {
1341 let data = std::slice::from_raw_parts(in_ptr, len);
1342 let sweep = JsaBatchRange {
1343 period: (period_start, period_end, period_step),
1344 };
1345
1346 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1347 let rows = combos.len();
1348 let total_size = rows * len;
1349
1350 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1351
1352 jsa_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1353 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1354
1355 Ok(rows)
1356 }
1357}
1358
1359#[cfg(feature = "python")]
1360pub fn register_jsa_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1361 m.add_function(wrap_pyfunction!(jsa_py, m)?)?;
1362 m.add_function(wrap_pyfunction!(jsa_batch_py, m)?)?;
1363 m.add_class::<JsaStreamPy>()?;
1364 #[cfg(feature = "cuda")]
1365 {
1366 m.add_class::<JsaDeviceArrayF32Py>()?;
1367 m.add_function(wrap_pyfunction!(jsa_cuda_batch_dev_py, m)?)?;
1368 m.add_function(wrap_pyfunction!(jsa_cuda_many_series_one_param_dev_py, m)?)?;
1369 }
1370 Ok(())
1371}
1372
1373#[cfg(all(feature = "python", feature = "cuda"))]
1374#[pyclass(module = "ta_indicators.cuda", unsendable)]
1375pub struct JsaDeviceArrayF32Py {
1376 buf: Option<DeviceBuffer<f32>>,
1377 rows: usize,
1378 cols: usize,
1379 _ctx: Arc<Context>,
1380 device_id: u32,
1381}
1382
1383#[cfg(all(feature = "python", feature = "cuda"))]
1384#[pymethods]
1385impl JsaDeviceArrayF32Py {
1386 #[getter]
1387 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1388 let d = PyDict::new(py);
1389 d.set_item("shape", (self.rows, self.cols))?;
1390 d.set_item("typestr", "<f4")?;
1391 d.set_item(
1392 "strides",
1393 (
1394 self.cols * std::mem::size_of::<f32>(),
1395 std::mem::size_of::<f32>(),
1396 ),
1397 )?;
1398 let ptr = if self.rows == 0 || self.cols == 0 {
1399 0usize
1400 } else {
1401 self.buf
1402 .as_ref()
1403 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1404 .as_device_ptr()
1405 .as_raw() as usize
1406 };
1407 d.set_item("data", (ptr, false))?;
1408
1409 d.set_item("version", 3)?;
1410 Ok(d)
1411 }
1412
1413 fn __dlpack_device__(&self) -> (i32, i32) {
1414 (2, self.device_id as i32)
1415 }
1416
1417 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1418 fn __dlpack__<'py>(
1419 &mut self,
1420 py: Python<'py>,
1421 stream: Option<pyo3::PyObject>,
1422 max_version: Option<pyo3::PyObject>,
1423 dl_device: Option<pyo3::PyObject>,
1424 copy: Option<pyo3::PyObject>,
1425 ) -> PyResult<PyObject> {
1426 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1427
1428 let (kdl, alloc_dev) = self.__dlpack_device__();
1429 if let Some(dev_obj) = dl_device.as_ref() {
1430 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1431 if dev_ty != kdl || dev_id != alloc_dev {
1432 let wants_copy = copy
1433 .as_ref()
1434 .and_then(|c| c.extract::<bool>(py).ok())
1435 .unwrap_or(false);
1436 if wants_copy {
1437 return Err(PyValueError::new_err(
1438 "device copy not implemented for __dlpack__",
1439 ));
1440 } else {
1441 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1442 }
1443 }
1444 }
1445 }
1446
1447 let _ = stream;
1448
1449 let buf = self
1450 .buf
1451 .take()
1452 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1453
1454 let rows = self.rows;
1455 let cols = self.cols;
1456
1457 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1458
1459 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1460 }
1461}
1462
1463#[cfg(all(feature = "python", feature = "cuda"))]
1464impl JsaDeviceArrayF32Py {
1465 pub(crate) fn from_handle_rust(handle: JsaDeviceHandle) -> Self {
1466 JsaDeviceArrayF32Py {
1467 buf: Some(handle.buf),
1468 rows: handle.rows,
1469 cols: handle.cols,
1470 _ctx: handle._ctx,
1471 device_id: handle.device_id,
1472 }
1473 }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478 use super::*;
1479 use crate::skip_if_unsupported;
1480 use crate::utilities::data_loader::read_candles_from_csv;
1481
1482 fn check_jsa_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1483 skip_if_unsupported!(kernel, test_name);
1484 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1485 let candles = read_candles_from_csv(file_path)?;
1486
1487 let default_params = JsaParams { period: None };
1488 let input = JsaInput::from_candles(&candles, "close", default_params);
1489 let output = jsa_with_kernel(&input, kernel)?;
1490 assert_eq!(output.values.len(), candles.close.len());
1491 Ok(())
1492 }
1493
1494 fn check_jsa_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1495 skip_if_unsupported!(kernel, test_name);
1496 let expected_last_five = [61640.0, 61418.0, 61240.0, 61060.5, 60889.5];
1497 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1498 let candles = read_candles_from_csv(file_path)?;
1499 let default_params = JsaParams::default();
1500 let input = JsaInput::from_candles(&candles, "close", default_params);
1501 let result = jsa_with_kernel(&input, kernel)?;
1502 let start_idx = result.values.len() - 5;
1503 for (i, &val) in result.values[start_idx..].iter().enumerate() {
1504 let expected = expected_last_five[i];
1505 assert!(
1506 (val - expected).abs() < 1e-5,
1507 "[{}] mismatch idx {}: got {}, expected {}",
1508 test_name,
1509 i,
1510 val,
1511 expected
1512 );
1513 }
1514 Ok(())
1515 }
1516
1517 fn check_jsa_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1518 skip_if_unsupported!(kernel, test_name);
1519 let input_data = [10.0, 20.0, 30.0];
1520 let params = JsaParams { period: Some(0) };
1521 let input = JsaInput::from_slice(&input_data, params);
1522 let res = jsa_with_kernel(&input, kernel);
1523 assert!(
1524 res.is_err(),
1525 "[{}] JSA should fail with zero period",
1526 test_name
1527 );
1528 Ok(())
1529 }
1530
1531 fn check_jsa_period_exceeds_length(
1532 test_name: &str,
1533 kernel: Kernel,
1534 ) -> Result<(), Box<dyn Error>> {
1535 skip_if_unsupported!(kernel, test_name);
1536 let data_small = [10.0, 20.0, 30.0];
1537 let params = JsaParams { period: Some(10) };
1538 let input = JsaInput::from_slice(&data_small, params);
1539 let res = jsa_with_kernel(&input, kernel);
1540 assert!(
1541 res.is_err(),
1542 "[{}] JSA should fail with period exceeding length",
1543 test_name
1544 );
1545 Ok(())
1546 }
1547
1548 fn check_jsa_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1549 skip_if_unsupported!(kernel, test_name);
1550 let single_point = [42.0];
1551 let params = JsaParams { period: Some(5) };
1552 let input = JsaInput::from_slice(&single_point, params);
1553 let res = jsa_with_kernel(&input, kernel);
1554 assert!(
1555 res.is_err(),
1556 "[{}] JSA should fail with insufficient data",
1557 test_name
1558 );
1559 Ok(())
1560 }
1561
1562 fn check_jsa_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1563 skip_if_unsupported!(kernel, test_name);
1564 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1565 let candles = read_candles_from_csv(file_path)?;
1566
1567 let first_params = JsaParams { period: Some(10) };
1568 let first_input = JsaInput::from_candles(&candles, "close", first_params);
1569 let first_result = jsa_with_kernel(&first_input, kernel)?;
1570
1571 let second_params = JsaParams { period: Some(5) };
1572 let second_input = JsaInput::from_slice(&first_result.values, second_params);
1573 let second_result = jsa_with_kernel(&second_input, kernel)?;
1574
1575 assert_eq!(second_result.values.len(), first_result.values.len());
1576 for i in 30..second_result.values.len() {
1577 assert!(
1578 second_result.values[i].is_finite(),
1579 "[{}] NaN at idx {}",
1580 test_name,
1581 i
1582 );
1583 }
1584 Ok(())
1585 }
1586
1587 fn check_jsa_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1588 skip_if_unsupported!(kernel, test_name);
1589 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1590 let candles = read_candles_from_csv(file_path)?;
1591 let period = 30;
1592 let input = JsaInput::from_candles(
1593 &candles,
1594 "close",
1595 JsaParams {
1596 period: Some(period),
1597 },
1598 );
1599 let batch_output = jsa_with_kernel(&input, kernel)?.values;
1600
1601 let mut stream = JsaStream::try_new(JsaParams {
1602 period: Some(period),
1603 })?;
1604 let mut stream_values = Vec::with_capacity(candles.close.len());
1605 for &price in &candles.close {
1606 match stream.update(price) {
1607 Some(val) => stream_values.push(val),
1608 None => stream_values.push(f64::NAN),
1609 }
1610 }
1611 assert_eq!(batch_output.len(), stream_values.len());
1612 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1613 if b.is_nan() && s.is_nan() {
1614 continue;
1615 }
1616 let diff = (b - s).abs();
1617 assert!(
1618 diff < 1e-9,
1619 "[{}] streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1620 test_name,
1621 i,
1622 b,
1623 s,
1624 diff
1625 );
1626 }
1627 Ok(())
1628 }
1629
1630 macro_rules! generate_all_jsa_tests {
1631 ($($test_fn:ident),*) => {
1632 paste::paste! {
1633 $( #[test]
1634 fn [<$test_fn _scalar_f64>]() {
1635 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1636 }
1637 )*
1638 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1639 $( #[test]
1640 fn [<$test_fn _avx2_f64>]() {
1641 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1642 }
1643 #[test]
1644 fn [<$test_fn _avx512_f64>]() {
1645 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1646 }
1647 )*
1648 }
1649 }
1650 }
1651
1652 #[cfg(debug_assertions)]
1653 fn check_jsa_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1654 skip_if_unsupported!(kernel, test_name);
1655
1656 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1657 let candles = read_candles_from_csv(file_path)?;
1658
1659 let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1660 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1661
1662 for period in &test_periods {
1663 for source in &test_sources {
1664 let input = JsaInput::from_candles(
1665 &candles,
1666 source,
1667 JsaParams {
1668 period: Some(*period),
1669 },
1670 );
1671 let output = jsa_with_kernel(&input, kernel)?;
1672
1673 for (i, &val) in output.values.iter().enumerate() {
1674 if val.is_nan() {
1675 continue;
1676 }
1677
1678 let bits = val.to_bits();
1679
1680 if bits == 0x11111111_11111111 {
1681 panic!(
1682 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1683 test_name, val, bits, i, period, source
1684 );
1685 }
1686
1687 if bits == 0x22222222_22222222 {
1688 panic!(
1689 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1690 test_name, val, bits, i, period, source
1691 );
1692 }
1693
1694 if bits == 0x33333333_33333333 {
1695 panic!(
1696 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1697 test_name, val, bits, i, period, source
1698 );
1699 }
1700 }
1701 }
1702 }
1703
1704 Ok(())
1705 }
1706
1707 #[cfg(not(debug_assertions))]
1708 fn check_jsa_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709 Ok(())
1710 }
1711
1712 generate_all_jsa_tests!(
1713 check_jsa_partial_params,
1714 check_jsa_accuracy,
1715 check_jsa_zero_period,
1716 check_jsa_period_exceeds_length,
1717 check_jsa_very_small_dataset,
1718 check_jsa_reinput,
1719 check_jsa_streaming,
1720 check_jsa_no_poison
1721 );
1722
1723 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1724 skip_if_unsupported!(kernel, test);
1725
1726 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1727 let c = read_candles_from_csv(file)?;
1728
1729 let output = JsaBatchBuilder::new()
1730 .kernel(kernel)
1731 .apply_candles(&c, "close")?;
1732
1733 let def = JsaParams::default();
1734 let row = output.values_for(&def).expect("default row missing");
1735 assert_eq!(row.len(), c.close.len());
1736
1737 let expected = [61640.0, 61418.0, 61240.0, 61060.5, 60889.5];
1738 let start = row.len() - 5;
1739 for (i, &v) in row[start..].iter().enumerate() {
1740 assert!(
1741 (v - expected[i]).abs() < 1e-5,
1742 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1743 );
1744 }
1745 Ok(())
1746 }
1747
1748 macro_rules! gen_batch_tests {
1749 ($fn_name:ident) => {
1750 paste::paste! {
1751 #[test] fn [<$fn_name _scalar>]() {
1752 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1753 }
1754 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1755 #[test] fn [<$fn_name _avx2>]() {
1756 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1757 }
1758 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1759 #[test] fn [<$fn_name _avx512>]() {
1760 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1761 }
1762 #[test] fn [<$fn_name _auto_detect>]() {
1763 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1764 }
1765 }
1766 };
1767 }
1768
1769 #[cfg(debug_assertions)]
1770 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1771 skip_if_unsupported!(kernel, test);
1772
1773 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1774 let c = read_candles_from_csv(file)?;
1775
1776 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1777
1778 for source in &test_sources {
1779 let output = JsaBatchBuilder::new()
1780 .kernel(kernel)
1781 .period_range(2, 200, 3)
1782 .apply_candles(&c, source)?;
1783
1784 for (idx, &val) in output.values.iter().enumerate() {
1785 if val.is_nan() {
1786 continue;
1787 }
1788
1789 let bits = val.to_bits();
1790 let row = idx / output.cols;
1791 let col = idx % output.cols;
1792
1793 if bits == 0x11111111_11111111 {
1794 panic!(
1795 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1796 test, val, bits, row, col, idx, source
1797 );
1798 }
1799
1800 if bits == 0x22222222_22222222 {
1801 panic!(
1802 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1803 test, val, bits, row, col, idx, source
1804 );
1805 }
1806
1807 if bits == 0x33333333_33333333 {
1808 panic!(
1809 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1810 test, val, bits, row, col, idx, source
1811 );
1812 }
1813 }
1814 }
1815
1816 let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1817 for (start, end, step) in edge_case_ranges {
1818 let output = JsaBatchBuilder::new()
1819 .kernel(kernel)
1820 .period_range(start, end, step)
1821 .apply_candles(&c, "close")?;
1822
1823 for (idx, &val) in output.values.iter().enumerate() {
1824 if val.is_nan() {
1825 continue;
1826 }
1827
1828 let bits = val.to_bits();
1829 let row = idx / output.cols;
1830 let col = idx % output.cols;
1831
1832 if bits == 0x11111111_11111111
1833 || bits == 0x22222222_22222222
1834 || bits == 0x33333333_33333333
1835 {
1836 panic!(
1837 "[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1838 test, val, bits, row, col, start, end, step
1839 );
1840 }
1841 }
1842 }
1843
1844 Ok(())
1845 }
1846
1847 #[cfg(not(debug_assertions))]
1848 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1849 Ok(())
1850 }
1851
1852 gen_batch_tests!(check_batch_default_row);
1853 gen_batch_tests!(check_batch_no_poison);
1854
1855 #[cfg(feature = "proptest")]
1856 #[allow(clippy::float_cmp)]
1857 fn check_jsa_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1858 use proptest::prelude::*;
1859 skip_if_unsupported!(kernel, test_name);
1860
1861 let strat = (1usize..=100).prop_flat_map(|period| {
1862 (
1863 prop::collection::vec(
1864 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1865 period..400,
1866 ),
1867 Just(period),
1868 )
1869 });
1870
1871 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1872 let params = JsaParams {
1873 period: Some(period),
1874 };
1875 let input = JsaInput::from_slice(&data, params);
1876
1877 let JsaOutput { values: out } = jsa_with_kernel(&input, kernel).unwrap();
1878
1879 let JsaOutput { values: ref_out } = jsa_with_kernel(&input, Kernel::Scalar).unwrap();
1880
1881 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1882 let warmup_period = first_valid + period;
1883
1884 for i in 0..warmup_period.min(data.len()) {
1885 prop_assert!(
1886 out[i].is_nan(),
1887 "[{}] Expected NaN during warmup at index {}, got {}",
1888 test_name,
1889 i,
1890 out[i]
1891 );
1892 }
1893
1894 for i in warmup_period..data.len() {
1895 let expected = (data[i] + data[i - period]) * 0.5;
1896 let actual = out[i];
1897
1898 prop_assert!(
1899 (actual - expected).abs() < 1e-9,
1900 "[{}] Formula verification failed at index {}: expected {}, got {}, diff = {}",
1901 test_name,
1902 i,
1903 expected,
1904 actual,
1905 (actual - expected).abs()
1906 );
1907 }
1908
1909 for i in warmup_period..data.len() {
1910 let val1 = data[i];
1911 let val2 = data[i - period];
1912 let min_val = val1.min(val2);
1913 let max_val = val1.max(val2);
1914 let actual = out[i];
1915
1916 prop_assert!(
1917 actual >= min_val - 1e-9 && actual <= max_val + 1e-9,
1918 "[{}] Output bounds check failed at index {}: {} not in [{}, {}]",
1919 test_name,
1920 i,
1921 actual,
1922 min_val,
1923 max_val
1924 );
1925 }
1926
1927 if kernel != Kernel::Scalar {
1928 for i in 0..data.len() {
1929 let y = out[i];
1930 let r = ref_out[i];
1931
1932 if y.is_nan() && r.is_nan() {
1933 continue;
1934 }
1935
1936 let y_bits = y.to_bits();
1937 let r_bits = r.to_bits();
1938 prop_assert!(
1939 y_bits == r_bits,
1940 "[{}] Cross-kernel mismatch at index {}: {} ({:016X}) != {} ({:016X})",
1941 test_name,
1942 i,
1943 y,
1944 y_bits,
1945 r,
1946 r_bits
1947 );
1948 }
1949 }
1950
1951 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9) && !data.is_empty() {
1952 let constant = data[first_valid];
1953 for i in warmup_period..data.len() {
1954 prop_assert!(
1955 (out[i] - constant).abs() < 1e-9,
1956 "[{}] Constant data test failed at index {}: expected {}, got {}",
1957 test_name,
1958 i,
1959 constant,
1960 out[i]
1961 );
1962 }
1963 }
1964
1965 let is_monotonic_inc = data.windows(2).all(|w| w[1] >= w[0] - 1e-12);
1966 if is_monotonic_inc && warmup_period + 1 < data.len() {
1967 for i in (warmup_period + 1)..data.len() {
1968 prop_assert!(
1969 out[i] >= out[i - 1] - 1e-9,
1970 "[{}] Monotonic test failed at index {}: {} < {}",
1971 test_name,
1972 i,
1973 out[i],
1974 out[i - 1]
1975 );
1976 }
1977 }
1978
1979 if period == 1 && warmup_period < data.len() {
1980 for i in warmup_period..data.len() {
1981 let expected = (data[i] + data[i - 1]) * 0.5;
1982 let actual = out[i];
1983 prop_assert!(
1984 (actual - expected).abs() < 1e-9,
1985 "[{}] Period=1 test failed at index {}: expected {}, got {}",
1986 test_name,
1987 i,
1988 expected,
1989 actual
1990 );
1991 }
1992 }
1993
1994 #[cfg(debug_assertions)]
1995 {
1996 for (i, &val) in out.iter().enumerate() {
1997 if val.is_nan() {
1998 continue;
1999 }
2000
2001 let bits = val.to_bits();
2002
2003 prop_assert!(
2004 bits != 0x11111111_11111111,
2005 "[{}] Found alloc_with_nan_prefix poison at index {}",
2006 test_name,
2007 i
2008 );
2009 prop_assert!(
2010 bits != 0x22222222_22222222,
2011 "[{}] Found init_matrix_prefixes poison at index {}",
2012 test_name,
2013 i
2014 );
2015 prop_assert!(
2016 bits != 0x33333333_33333333,
2017 "[{}] Found make_uninit_matrix poison at index {}",
2018 test_name,
2019 i
2020 );
2021 }
2022 }
2023
2024 Ok(())
2025 })?;
2026
2027 Ok(())
2028 }
2029
2030 #[cfg(feature = "proptest")]
2031 generate_all_jsa_tests!(check_jsa_property);
2032
2033 #[test]
2034 fn test_jsa_into_matches_api() -> Result<(), Box<dyn Error>> {
2035 let mut data = Vec::with_capacity(256);
2036 for _ in 0..8 {
2037 data.push(f64::NAN);
2038 }
2039 for i in 0..248u64 {
2040 let x = (i as f64).sin() * 3.14159 + (i as f64) * 0.01;
2041 data.push(x);
2042 }
2043
2044 let input = JsaInput::from_slice(&data, JsaParams::default());
2045
2046 let base = jsa(&input)?.values;
2047
2048 let mut out = vec![0.0; data.len()];
2049 jsa_into(&input, &mut out)?;
2050
2051 assert_eq!(base.len(), out.len(), "lengths must match");
2052
2053 for (i, (&a, &b)) in base.iter().zip(out.iter()).enumerate() {
2054 let both_nan = a.is_nan() && b.is_nan();
2055 assert!(both_nan || a == b, "mismatch at idx {}: {} vs {}", i, a, b);
2056 }
2057 Ok(())
2058 }
2059}