1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaUma;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use cust::memory::DeviceBuffer;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyDict, PyList};
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use js_sys;
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use serde::{Deserialize, Serialize};
26#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
27use wasm_bindgen::prelude::*;
28
29use crate::indicators::deviation::{deviation, DeviationInput, DeviationParams};
30use crate::indicators::mfi::{mfi, MfiInput, MfiParams};
31use crate::indicators::moving_averages::sma::{sma, SmaInput, SmaParams};
32use crate::indicators::moving_averages::wma::{wma, WmaInput, WmaParams};
33use crate::indicators::rsi::{rsi, RsiInput, RsiParams};
34use crate::utilities::data_loader::{source_type, Candles};
35use crate::utilities::enums::Kernel;
36use crate::utilities::helpers::{
37 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
38 make_uninit_matrix,
39};
40#[cfg(feature = "python")]
41use crate::utilities::kernel_validation::validate_kernel;
42#[cfg(not(target_arch = "wasm32"))]
43use rayon::prelude::*;
44use std::convert::AsRef;
45use std::error::Error;
46use std::mem::MaybeUninit;
47use thiserror::Error;
48
49impl<'a> AsRef<[f64]> for UmaInput<'a> {
50 #[inline(always)]
51 fn as_ref(&self) -> &[f64] {
52 match &self.data {
53 UmaData::Slice(slice) => slice,
54 UmaData::Candles { candles, source } => source_type(candles, source),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum UmaData<'a> {
61 Candles {
62 candles: &'a Candles,
63 source: &'a str,
64 },
65 Slice(&'a [f64]),
66}
67
68#[derive(Debug, Clone)]
69pub struct UmaOutput {
70 pub values: Vec<f64>,
71}
72
73#[derive(Debug, Clone)]
74#[cfg_attr(
75 all(target_arch = "wasm32", feature = "wasm"),
76 derive(Serialize, Deserialize)
77)]
78pub struct UmaParams {
79 pub accelerator: Option<f64>,
80 pub min_length: Option<usize>,
81 pub max_length: Option<usize>,
82 pub smooth_length: Option<usize>,
83}
84
85impl Default for UmaParams {
86 fn default() -> Self {
87 Self {
88 accelerator: Some(1.0),
89 min_length: Some(5),
90 max_length: Some(50),
91 smooth_length: Some(4),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct UmaInput<'a> {
98 pub data: UmaData<'a>,
99 pub params: UmaParams,
100 pub volume: Option<&'a [f64]>,
101}
102
103impl<'a> UmaInput<'a> {
104 #[inline]
105 pub fn from_candles(c: &'a Candles, s: &'a str, p: UmaParams) -> Self {
106 Self {
107 data: UmaData::Candles {
108 candles: c,
109 source: s,
110 },
111 params: p,
112 volume: Some(&c.volume),
113 }
114 }
115
116 #[inline]
117 pub fn from_slice(sl: &'a [f64], vol: Option<&'a [f64]>, p: UmaParams) -> Self {
118 Self {
119 data: UmaData::Slice(sl),
120 params: p,
121 volume: vol,
122 }
123 }
124
125 #[inline]
126 pub fn with_default_candles(c: &'a Candles) -> Self {
127 Self::from_candles(c, "close", UmaParams::default())
128 }
129
130 #[inline]
131 pub fn get_accelerator(&self) -> f64 {
132 self.params.accelerator.unwrap_or(1.0)
133 }
134
135 #[inline]
136 pub fn get_min_length(&self) -> usize {
137 self.params.min_length.unwrap_or(5)
138 }
139
140 #[inline]
141 pub fn get_max_length(&self) -> usize {
142 self.params.max_length.unwrap_or(50)
143 }
144
145 #[inline]
146 pub fn get_smooth_length(&self) -> usize {
147 self.params.smooth_length.unwrap_or(4)
148 }
149}
150
151#[derive(Copy, Clone, Debug)]
152pub struct UmaBuilder {
153 accelerator: Option<f64>,
154 min_length: Option<usize>,
155 max_length: Option<usize>,
156 smooth_length: Option<usize>,
157 kernel: Kernel,
158}
159
160impl Default for UmaBuilder {
161 fn default() -> Self {
162 Self {
163 accelerator: None,
164 min_length: None,
165 max_length: None,
166 smooth_length: None,
167 kernel: Kernel::Auto,
168 }
169 }
170}
171
172impl UmaBuilder {
173 #[inline(always)]
174 pub fn new() -> Self {
175 Self::default()
176 }
177
178 #[inline(always)]
179 pub fn accelerator(mut self, a: f64) -> Self {
180 self.accelerator = Some(a);
181 self
182 }
183
184 #[inline(always)]
185 pub fn min_length(mut self, n: usize) -> Self {
186 self.min_length = Some(n);
187 self
188 }
189
190 #[inline(always)]
191 pub fn max_length(mut self, n: usize) -> Self {
192 self.max_length = Some(n);
193 self
194 }
195
196 #[inline(always)]
197 pub fn smooth_length(mut self, n: usize) -> Self {
198 self.smooth_length = Some(n);
199 self
200 }
201
202 #[inline(always)]
203 pub fn kernel(mut self, k: Kernel) -> Self {
204 self.kernel = k;
205 self
206 }
207
208 #[inline(always)]
209 pub fn apply(self, c: &Candles) -> Result<UmaOutput, UmaError> {
210 let p = UmaParams {
211 accelerator: self.accelerator,
212 min_length: self.min_length,
213 max_length: self.max_length,
214 smooth_length: self.smooth_length,
215 };
216 let i = UmaInput::from_candles(c, "close", p);
217 uma_with_kernel(&i, self.kernel)
218 }
219
220 #[inline(always)]
221 pub fn apply_slice(self, d: &[f64], v: Option<&[f64]>) -> Result<UmaOutput, UmaError> {
222 let p = UmaParams {
223 accelerator: self.accelerator,
224 min_length: self.min_length,
225 max_length: self.max_length,
226 smooth_length: self.smooth_length,
227 };
228 let i = UmaInput::from_slice(d, v, p);
229 uma_with_kernel(&i, self.kernel)
230 }
231
232 #[inline(always)]
233 pub fn into_stream(self) -> Result<UmaStream, UmaError> {
234 let p = UmaParams {
235 accelerator: self.accelerator,
236 min_length: self.min_length,
237 max_length: self.max_length,
238 smooth_length: self.smooth_length,
239 };
240 UmaStream::try_new(p)
241 }
242}
243
244#[derive(Debug, Clone)]
245pub struct UmaStream {
246 accelerator: f64,
247 min_length: usize,
248 max_length: usize,
249 smooth_length: usize,
250
251 price: Vec<f64>,
252 volume: Vec<f64>,
253 cap: usize,
254 head: usize,
255 count: usize,
256
257 sum: f64,
258 sumsq: f64,
259
260 has_prev: bool,
261 prev_price: f64,
262
263 upvol_cum: Vec<f64>,
264 dnvol_cum: Vec<f64>,
265 diff_step: usize,
266
267 rsi_avg_up: f64,
268 rsi_avg_dn: f64,
269 rsi_inited: bool,
270
271 dynamic_length: f64,
272
273 ln_lut: Vec<f64>,
274
275 uma_raw: Vec<f64>,
276 uma_raw_head: usize,
277 uma_raw_count: usize,
278 s1: f64,
279 s2: f64,
280 wma_denom: f64,
281
282 params: UmaParams,
283 kernel: Kernel,
284}
285
286impl UmaStream {
287 pub fn try_new(params: UmaParams) -> Result<Self, UmaError> {
288 let accelerator = params.accelerator.unwrap_or(1.0);
289 let min_length = params.min_length.unwrap_or(5);
290 let max_length = params.max_length.unwrap_or(50);
291 let smooth_len = params.smooth_length.unwrap_or(4);
292
293 if min_length == 0 {
294 return Err(UmaError::InvalidMinLength { min_length });
295 }
296 if max_length == 0 {
297 return Err(UmaError::InvalidMaxLength {
298 max_length,
299 data_len: 0,
300 });
301 }
302 if smooth_len == 0 {
303 return Err(UmaError::InvalidSmoothLength {
304 smooth_length: smooth_len,
305 });
306 }
307 if min_length > max_length {
308 return Err(UmaError::MinLengthGreaterThanMaxLength {
309 min_length,
310 max_length,
311 });
312 }
313 if accelerator < 1.0 {
314 return Err(UmaError::InvalidAccelerator { accelerator });
315 }
316
317 let mut ln_lut = Vec::with_capacity(max_length + 1);
318 ln_lut.push(0.0);
319 for k in 1..=max_length {
320 ln_lut.push((k as f64).ln());
321 }
322
323 Ok(Self {
324 accelerator,
325 min_length,
326 max_length,
327 smooth_length: smooth_len,
328
329 price: vec![0.0; max_length],
330 volume: vec![0.0; max_length],
331 cap: max_length,
332 head: 0,
333 count: 0,
334
335 sum: 0.0,
336 sumsq: 0.0,
337
338 has_prev: false,
339 prev_price: 0.0,
340
341 upvol_cum: vec![0.0; max_length + 1],
342 dnvol_cum: vec![0.0; max_length + 1],
343 diff_step: 0,
344
345 rsi_avg_up: 0.0,
346 rsi_avg_dn: 0.0,
347 rsi_inited: false,
348
349 dynamic_length: max_length as f64,
350 ln_lut,
351
352 uma_raw: vec![0.0; smooth_len],
353 uma_raw_head: 0,
354 uma_raw_count: 0,
355 s1: 0.0,
356 s2: 0.0,
357 wma_denom: (smooth_len as f64) * ((smooth_len as f64) + 1.0) * 0.5,
358
359 params,
360 kernel: Kernel::Auto,
361 })
362 }
363
364 #[inline]
365 pub fn update(&mut self, value: f64) -> Option<f64> {
366 self.update_with_volume(value, None)
367 }
368
369 pub fn update_with_volume(&mut self, value: f64, volume: Option<f64>) -> Option<f64> {
370 let v = volume.unwrap_or(0.0);
371
372 if self.count == self.cap {
373 let old = self.price[self.head];
374 self.sum -= old;
375 self.sumsq -= old * old;
376 } else {
377 self.count += 1;
378 }
379 self.price[self.head] = value;
380 self.volume[self.head] = v;
381 self.sum += value;
382 self.sumsq += value * value;
383
384 if self.has_prev {
385 let diff = value - self.prev_price;
386 let up = if diff > 0.0 { diff } else { 0.0 };
387 let dn = if diff < 0.0 { -diff } else { 0.0 };
388
389 let cur = (self.diff_step + 1) % (self.cap + 1);
390 let prev = self.diff_step % (self.cap + 1);
391 let up_contrib = if up > 0.0 { v } else { 0.0 };
392 let dn_contrib = if dn > 0.0 { v } else { 0.0 };
393 self.upvol_cum[cur] = self.upvol_cum[prev] + up_contrib;
394 self.dnvol_cum[cur] = self.dnvol_cum[prev] + dn_contrib;
395 self.diff_step += 1;
396
397 if !self.rsi_inited {
398 self.rsi_avg_up = up;
399 self.rsi_avg_dn = dn;
400 self.rsi_inited = true;
401 } else {
402 }
403 }
404 self.prev_price = value;
405 self.has_prev = true;
406
407 self.head = (self.head + 1) % self.cap;
408
409 if self.count < self.cap {
410 return None;
411 }
412
413 let n = self.cap as f64;
414 let mu = self.sum / n;
415 let var = (self.sumsq / n) - mu * mu;
416 let sd = if var > 0.0 { var.sqrt() } else { 0.0 };
417
418 let a = (-1.75f64).mul_add(sd, mu);
419 let b = (-0.25f64).mul_add(sd, mu);
420 let c = (0.25f64).mul_add(sd, mu);
421 let d = (1.75f64).mul_add(sd, mu);
422
423 let src = value;
424 if src >= b && src <= c {
425 self.dynamic_length += 1.0;
426 } else if src < a || src > d {
427 self.dynamic_length -= 1.0;
428 }
429 self.dynamic_length = self
430 .dynamic_length
431 .max(self.min_length as f64)
432 .min(self.max_length as f64);
433 let len_r = self.dynamic_length.round().max(1.0) as usize;
434
435 let mf = if v > 0.0 && self.diff_step > 0 {
436 let cur = self.diff_step % (self.cap + 1);
437 let prev = (self.diff_step + self.cap + 1 - len_r) % (self.cap + 1);
438 let up_sum = self.upvol_cum[cur] - self.upvol_cum[prev];
439 let dn_sum = self.dnvol_cum[cur] - self.dnvol_cum[prev];
440 let tot = up_sum + dn_sum;
441 if tot > 0.0 {
442 100.0 * up_sum / tot
443 } else {
444 50.0
445 }
446 } else {
447 if self.diff_step > 0 {
448 let newest_idx = (self.head + self.cap - 1) % self.cap;
449 let prev_idx = (self.head + self.cap - 2) % self.cap;
450 let prevv = self.price[prev_idx];
451 let last_diff = self.price[newest_idx] - prevv;
452 let up = if last_diff > 0.0 { last_diff } else { 0.0 };
453 let dn = if last_diff < 0.0 { -last_diff } else { 0.0 };
454
455 let alpha = 1.0 / (len_r as f64);
456 self.rsi_avg_up = (1.0 - alpha) * self.rsi_avg_up + alpha * up;
457 self.rsi_avg_dn = (1.0 - alpha) * self.rsi_avg_dn + alpha * dn;
458 }
459 if self.rsi_avg_dn == 0.0 {
460 100.0
461 } else {
462 let s = self.rsi_avg_up + self.rsi_avg_dn;
463 if s == 0.0 {
464 50.0
465 } else {
466 100.0 * self.rsi_avg_up / s
467 }
468 }
469 };
470
471 let mf_scaled = mf.mul_add(2.0, -100.0);
472 let p = self.accelerator + (mf_scaled.abs() * 0.04);
473
474 let start = (self.head + self.cap - len_r) % self.cap;
475 let mut xws = 0.0f64;
476 let mut wsum = 0.0f64;
477
478 let mut j = 0usize;
479 while j + 1 < len_r {
480 let k1 = len_r - j;
481 let k2 = k1 - 1;
482
483 let w1 = exp_kernel(p * self.ln_lut[k1]);
484 let idx1 = (start + j) % self.cap;
485 let x1 = self.price[idx1];
486 xws = x1.mul_add(w1, xws);
487 wsum += w1;
488
489 let w2 = exp_kernel(p * self.ln_lut[k2]);
490 let idx2 = (start + j + 1) % self.cap;
491 let x2 = self.price[idx2];
492 xws = x2.mul_add(w2, xws);
493 wsum += w2;
494
495 j += 2;
496 }
497 if j < len_r {
498 let k = len_r - j;
499 let w = exp_kernel(p * self.ln_lut[k]);
500 let idx = (start + j) % self.cap;
501 let x = self.price[idx];
502 xws = x.mul_add(w, xws);
503 wsum += w;
504 }
505
506 let uma_raw = if wsum > 0.0 { xws / wsum } else { src };
507
508 let m = self.smooth_length;
509 if self.uma_raw_count < m {
510 self.s1 += uma_raw;
511 self.s2 += uma_raw * ((self.uma_raw_count as f64) + 1.0);
512 self.uma_raw[self.uma_raw_head] = uma_raw;
513 self.uma_raw_head = (self.uma_raw_head + 1) % m;
514 self.uma_raw_count += 1;
515
516 if self.uma_raw_count < m {
517 return None;
518 }
519
520 let out = self.s2 / self.wma_denom;
521 Some(out)
522 } else {
523 let oldest = self.uma_raw[self.uma_raw_head];
524 let s1_prev = self.s1;
525 self.s1 = self.s1 - oldest + uma_raw;
526 self.s2 = self.s2 - s1_prev + (m as f64) * uma_raw;
527
528 self.uma_raw[self.uma_raw_head] = uma_raw;
529 self.uma_raw_head = (self.uma_raw_head + 1) % m;
530
531 Some(self.s2 / self.wma_denom)
532 }
533 }
534
535 pub fn reset(&mut self) {
536 self.price.fill(0.0);
537 self.volume.fill(0.0);
538 self.head = 0;
539 self.count = 0;
540 self.sum = 0.0;
541 self.sumsq = 0.0;
542
543 self.has_prev = false;
544 self.prev_price = 0.0;
545
546 self.upvol_cum.fill(0.0);
547 self.dnvol_cum.fill(0.0);
548 self.diff_step = 0;
549
550 self.rsi_avg_up = 0.0;
551 self.rsi_avg_dn = 0.0;
552 self.rsi_inited = false;
553
554 self.dynamic_length = self.max_length as f64;
555
556 self.uma_raw.fill(0.0);
557 self.uma_raw_head = 0;
558 self.uma_raw_count = 0;
559 self.s1 = 0.0;
560 self.s2 = 0.0;
561 }
562}
563
564#[inline(always)]
565fn exp_kernel(x: f64) -> f64 {
566 x.exp()
567}
568
569#[derive(Debug, Error)]
570pub enum UmaError {
571 #[error("uma: Input data slice is empty.")]
572 EmptyInputData,
573 #[error("uma: All values are NaN.")]
574 AllValuesNaN,
575 #[error("uma: Invalid max_length: max_length = {max_length}, data length = {data_len}")]
576 InvalidMaxLength { max_length: usize, data_len: usize },
577 #[error("uma: Not enough valid data: needed = {needed}, valid = {valid}")]
578 NotEnoughValidData { needed: usize, valid: usize },
579 #[error("uma: Invalid accelerator: {accelerator}")]
580 InvalidAccelerator { accelerator: f64 },
581 #[error("uma: Invalid min_length: {min_length}")]
582 InvalidMinLength { min_length: usize },
583 #[error("uma: Invalid smooth_length: {smooth_length}")]
584 InvalidSmoothLength { smooth_length: usize },
585 #[error("uma: min_length ({min_length}) must be <= max_length ({max_length})")]
586 MinLengthGreaterThanMaxLength {
587 min_length: usize,
588 max_length: usize,
589 },
590 #[error("uma: Output length mismatch: expected = {expected}, got = {got}")]
591 OutputLengthMismatch { expected: usize, got: usize },
592 #[error("uma: Invalid range: start = {start}, end = {end}, step = {step}")]
593 InvalidRange {
594 start: usize,
595 end: usize,
596 step: usize,
597 },
598 #[error("uma: Invalid kernel for batch API: {0:?}")]
599 InvalidKernelForBatch(Kernel),
600 #[error("uma: arithmetic overflow while computing {context}")]
601 ArithmeticOverflow { context: &'static str },
602 #[error("uma: Error from dependency: {0}")]
603 DependencyError(String),
604}
605
606#[inline(always)]
607fn uma_prepare<'a>(input: &'a UmaInput) -> Result<(&'a [f64], usize, usize, usize, f64), UmaError> {
608 let data: &[f64] = input.as_ref();
609 let len = data.len();
610 if len == 0 {
611 return Err(UmaError::EmptyInputData);
612 }
613
614 let first = data
615 .iter()
616 .position(|x| !x.is_nan())
617 .ok_or(UmaError::AllValuesNaN)?;
618 let accelerator = input.get_accelerator();
619 let min_length = input.get_min_length();
620 let max_length = input.get_max_length();
621 let smooth_length = input.get_smooth_length();
622
623 if max_length == 0 || max_length > len {
624 return Err(UmaError::InvalidMaxLength {
625 max_length,
626 data_len: len,
627 });
628 }
629 if min_length == 0 {
630 return Err(UmaError::InvalidMinLength { min_length });
631 }
632 if min_length > max_length {
633 return Err(UmaError::MinLengthGreaterThanMaxLength {
634 min_length,
635 max_length,
636 });
637 }
638 if smooth_length == 0 {
639 return Err(UmaError::InvalidSmoothLength { smooth_length });
640 }
641 if accelerator < 1.0 {
642 return Err(UmaError::InvalidAccelerator { accelerator });
643 }
644 if len - first < max_length {
645 return Err(UmaError::NotEnoughValidData {
646 needed: max_length,
647 valid: len - first,
648 });
649 }
650 Ok((data, first, min_length, max_length, accelerator))
651}
652
653#[inline(always)]
654fn uma_core_into(
655 input: &UmaInput,
656 first: usize,
657 min_length: usize,
658 max_length: usize,
659 accelerator: f64,
660 _kernel: Kernel,
661 out: &mut [f64],
662) -> Result<(), UmaError> {
663 let data: &[f64] = input.as_ref();
664 let len = data.len();
665 debug_assert!(len == out.len());
666
667 let mean = sma(&SmaInput::from_slice(
668 data,
669 SmaParams {
670 period: Some(max_length),
671 },
672 ))
673 .map_err(|e| UmaError::DependencyError(e.to_string()))?
674 .values;
675
676 let std_dev = deviation(&DeviationInput::from_slice(
677 data,
678 DeviationParams {
679 period: Some(max_length),
680 devtype: Some(0),
681 },
682 ))
683 .map_err(|e| UmaError::DependencyError(e.to_string()))?
684 .values;
685
686 let mut ln_lut: Vec<f64> = Vec::with_capacity(max_length + 1);
687
688 ln_lut.push(0.0);
689 for k in 1..=max_length {
690 ln_lut.push((k as f64).ln());
691 }
692
693 let (candles_opt, vol_opt) = match &input.data {
694 UmaData::Candles { candles, .. } => (Some(*candles), input.volume),
695 UmaData::Slice(_) => (None, input.volume),
696 };
697
698 let warmup_end = first
699 .checked_add(max_length)
700 .and_then(|x| x.checked_sub(1))
701 .ok_or(UmaError::ArithmeticOverflow {
702 context: "first + max_length - 1 (warmup_end)",
703 })?;
704 if warmup_end >= len {
705 return Ok(());
706 }
707
708 let mut dyn_len = max_length as f64;
709
710 #[inline(always)]
711 fn rsi_wilder_last(data: &[f64], start: usize, end: usize, period: usize) -> f64 {
712 if period == 0 || end <= start || end - start + 1 < period + 1 {
713 return 50.0;
714 }
715
716 let mut sum_up = 0.0f64;
717 let mut sum_dn = 0.0f64;
718 let mut has_nan = false;
719 let mut prev = data[start];
720 let init_last = start + period;
721 for j in (start + 1)..=init_last {
722 let cur = data[j];
723 let diff = cur - prev;
724 if !diff.is_finite() {
725 has_nan = true;
726 break;
727 }
728 if diff > 0.0 {
729 sum_up += diff;
730 } else {
731 sum_dn -= diff;
732 }
733 prev = cur;
734 }
735 if has_nan {
736 return 50.0;
737 }
738 let mut avg_up = sum_up / (period as f64);
739 let mut avg_dn = sum_dn / (period as f64);
740
741 if init_last < end {
742 let n_1 = (period - 1) as f64;
743 let n = period as f64;
744 for j in (init_last + 1)..=end {
745 let cur = data[j];
746 let diff = cur - prev;
747 let up = if diff > 0.0 { diff } else { 0.0 };
748 let dn = if diff < 0.0 { -diff } else { 0.0 };
749 avg_up = (avg_up * n_1 + up) / n;
750 avg_dn = (avg_dn * n_1 + dn) / n;
751 prev = cur;
752 }
753 }
754
755 if avg_dn == 0.0 {
756 100.0
757 } else if avg_up + avg_dn == 0.0 {
758 50.0
759 } else {
760 100.0 * avg_up / (avg_up + avg_dn)
761 }
762 }
763
764 #[inline(always)]
765 fn mfi_window_last_candles(c: &Candles, start: usize, end: usize) -> f64 {
766 if end <= start {
767 return 50.0;
768 }
769
770 let mut tp_prev = (c.high[start] + c.low[start] + c.close[start]) / 3.0;
771 let mut pos = 0.0f64;
772 let mut neg = 0.0f64;
773
774 for j in (start + 1)..=end {
775 let tp = (c.high[j] + c.low[j] + c.close[j]) / 3.0;
776
777 let mf = tp * c.volume[j];
778 if tp > tp_prev {
779 pos += mf;
780 } else if tp < tp_prev {
781 neg += mf;
782 }
783 tp_prev = tp;
784 }
785
786 let denom = pos + neg;
787 if denom > 0.0 {
788 100.0 * pos / denom
789 } else {
790 50.0
791 }
792 }
793
794 for i in warmup_end..len {
795 let mu = mean[i];
796 let sd = std_dev[i];
797 if mu.is_nan() || sd.is_nan() {
798 continue;
799 }
800 let src = data[i];
801
802 let a = (-1.75f64).mul_add(sd, mu);
803 let b = (-0.25f64).mul_add(sd, mu);
804 let c = (0.25f64).mul_add(sd, mu);
805 let d = (1.75f64).mul_add(sd, mu);
806
807 if src >= b && src <= c {
808 dyn_len += 1.0;
809 } else if src < a || src > d {
810 dyn_len -= 1.0;
811 }
812
813 dyn_len = dyn_len.max(min_length as f64).min(max_length as f64);
814 let len_r = dyn_len.round() as usize;
815 if i + 1 < len_r {
816 continue;
817 }
818
819 let mf: f64 = match (vol_opt, candles_opt) {
820 (Some(vol), Some(candles)) => {
821 let v_i = vol[i];
822 if v_i == 0.0 || v_i.is_nan() {
823 let end = i;
824 let start = if end + 1 >= 2 * len_r {
825 end + 1 - 2 * len_r
826 } else {
827 0
828 };
829 rsi_wilder_last(data, start, end, len_r)
830 } else {
831 let start = i + 1 - len_r;
832 mfi_window_last_candles(candles, start, i)
833 }
834 }
835 (Some(vol), None) => {
836 let start = i + 1 - len_r;
837 let mut up_vol = 0.0f64;
838 let mut dn_vol = 0.0f64;
839 let mut prev = data[start];
840 for j in (start + 1)..=i {
841 let cur = data[j];
842 let v = vol[j];
843 let diff = cur - prev;
844 if diff > 0.0 {
845 up_vol += v;
846 } else if diff < 0.0 {
847 dn_vol += v;
848 }
849 prev = cur;
850 }
851 let tot = up_vol + dn_vol;
852 if tot > 0.0 {
853 100.0 * up_vol / tot
854 } else {
855 50.0
856 }
857 }
858 _ => {
859 let end = i;
860 let start = if end + 1 >= 2 * len_r {
861 end + 1 - 2 * len_r
862 } else {
863 0
864 };
865 rsi_wilder_last(data, start, end, len_r)
866 }
867 };
868
869 let mf_scaled = mf.mul_add(2.0, -100.0);
870 let p = accelerator + (mf_scaled.abs() * 0.04);
871
872 let start = i + 1 - len_r;
873
874 let (mut xws, mut wsum) = (0.0f64, 0.0f64);
875
876 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
877 match _kernel {
878 #[cfg(target_feature = "avx512f")]
879 Kernel::Avx512 => unsafe {
880 let (sx, sw) = uma_weighted_accumulate_avx512(
881 data.as_ptr().add(start),
882 ln_lut.as_ptr(),
883 len_r,
884 p,
885 );
886 xws = sx;
887 wsum = sw;
888 },
889 #[cfg(target_feature = "avx2")]
890 Kernel::Avx2 => unsafe {
891 let (sx, sw) = uma_weighted_accumulate_avx2(
892 data.as_ptr().add(start),
893 ln_lut.as_ptr(),
894 len_r,
895 p,
896 );
897 xws = sx;
898 wsum = sw;
899 },
900 _ => {}
901 }
902
903 if wsum == 0.0 {
904 let mut j = 0usize;
905 while j + 1 < len_r {
906 let k1 = len_r - j;
907 let k2 = k1 - 1;
908
909 let w1 = exp_kernel(p * ln_lut[k1]);
910 let x1 = data[start + j];
911 if !x1.is_nan() {
912 xws = x1.mul_add(w1, xws);
913 wsum += w1;
914 }
915
916 let w2 = exp_kernel(p * ln_lut[k2]);
917 let x2 = data[start + j + 1];
918 if !x2.is_nan() {
919 xws = x2.mul_add(w2, xws);
920 wsum += w2;
921 }
922
923 j += 2;
924 }
925 if j < len_r {
926 let k = len_r - j;
927 let w = exp_kernel(p * ln_lut[k]);
928 let x = data[start + j];
929 if !x.is_nan() {
930 xws = x.mul_add(w, xws);
931 wsum += w;
932 }
933 }
934 }
935
936 out[i] = if wsum > 0.0 { xws / wsum } else { src };
937 }
938
939 Ok(())
940}
941
942#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
943#[inline(always)]
944unsafe fn uma_weighted_accumulate_avx2(
945 data: *const f64,
946 ln_lut: *const f64,
947 len_r: usize,
948 p: f64,
949) -> (f64, f64) {
950 use core::arch::x86_64::*;
951 let mut sum_v = _mm256_setzero_pd();
952 let mut wsum_v = _mm256_setzero_pd();
953
954 let mut j = 0usize;
955 while j + 3 < len_r {
956 let k0 = len_r - j;
957 let k1 = k0 - 1;
958 let k2 = k1 - 1;
959 let k3 = k2 - 1;
960
961 let w0 = exp_kernel(p * *ln_lut.add(k0));
962 let w1 = exp_kernel(p * *ln_lut.add(k1));
963 let w2 = exp_kernel(p * *ln_lut.add(k2));
964 let w3 = exp_kernel(p * *ln_lut.add(k3));
965 let wv = _mm256_setr_pd(w0, w1, w2, w3);
966
967 let xv = _mm256_loadu_pd(data.add(j));
968
969 let nan_mask = _mm256_cmp_pd(xv, xv, _CMP_ORD_Q);
970 let xv_nz = _mm256_and_pd(xv, nan_mask);
971
972 sum_v = _mm256_fmadd_pd(xv_nz, wv, sum_v);
973 let wv_masked = _mm256_and_pd(wv, nan_mask);
974 wsum_v = _mm256_add_pd(wsum_v, wv_masked);
975
976 j += 4;
977 }
978
979 let mut xws = 0.0f64;
980 let mut wsum = 0.0f64;
981 {
982 let tmp: [f64; 4] = core::mem::transmute(sum_v);
983 xws += tmp[0] + tmp[1] + tmp[2] + tmp[3];
984 let t2: [f64; 4] = core::mem::transmute(wsum_v);
985 wsum += t2[0] + t2[1] + t2[2] + t2[3];
986 }
987
988 while j < len_r {
989 let k = len_r - j;
990 let w = exp_kernel(p * *ln_lut.add(k));
991 let x = *data.add(j);
992 if !x.is_nan() {
993 xws = x.mul_add(w, xws);
994 wsum += w;
995 }
996 j += 1;
997 }
998 (xws, wsum)
999}
1000
1001#[cfg(all(
1002 feature = "nightly-avx",
1003 target_arch = "x86_64",
1004 not(target_feature = "avx512f")
1005))]
1006#[inline(always)]
1007unsafe fn uma_weighted_accumulate_avx512(
1008 data: *const f64,
1009 ln_lut: *const f64,
1010 len_r: usize,
1011 p: f64,
1012) -> (f64, f64) {
1013 let mut xws = 0.0f64;
1014 let mut wsum = 0.0f64;
1015 let mut j = 0usize;
1016 while j < len_r {
1017 let k = len_r - j;
1018 let w = exp_kernel(p * *ln_lut.add(k));
1019 let x = *data.add(j);
1020 if !x.is_nan() {
1021 xws = x.mul_add(w, xws);
1022 wsum += w;
1023 }
1024 j += 1;
1025 }
1026 (xws, wsum)
1027}
1028
1029#[cfg(all(
1030 feature = "nightly-avx",
1031 target_arch = "x86_64",
1032 target_feature = "avx512f"
1033))]
1034#[inline(always)]
1035unsafe fn uma_weighted_accumulate_avx512(
1036 data: *const f64,
1037 ln_lut: *const f64,
1038 len_r: usize,
1039 p: f64,
1040) -> (f64, f64) {
1041 use core::arch::x86_64::*;
1042 let mut sum_v = _mm512_setzero_pd();
1043 let mut wsum_v = _mm512_setzero_pd();
1044
1045 let mut j = 0usize;
1046 while j + 7 < len_r {
1047 let k0 = len_r - j;
1048 let ks = [k0, k0 - 1, k0 - 2, k0 - 3, k0 - 4, k0 - 5, k0 - 6, k0 - 7];
1049 let ws = [
1050 exp_kernel(p * *ln_lut.add(ks[0])),
1051 exp_kernel(p * *ln_lut.add(ks[1])),
1052 exp_kernel(p * *ln_lut.add(ks[2])),
1053 exp_kernel(p * *ln_lut.add(ks[3])),
1054 exp_kernel(p * *ln_lut.add(ks[4])),
1055 exp_kernel(p * *ln_lut.add(ks[5])),
1056 exp_kernel(p * *ln_lut.add(ks[6])),
1057 exp_kernel(p * *ln_lut.add(ks[7])),
1058 ];
1059 let wv = _mm512_loadu_pd(ws.as_ptr());
1060 let xv = _mm512_loadu_pd(data.add(j));
1061
1062 let nan_mask = _mm512_cmp_pd_mask(xv, xv, _CMP_ORD_Q);
1063 let xv_nz = _mm512_maskz_mov_pd(nan_mask, xv);
1064
1065 sum_v = _mm512_fmadd_pd(xv_nz, wv, sum_v);
1066 let wv_masked = _mm512_maskz_mov_pd(nan_mask, wv);
1067 wsum_v = _mm512_add_pd(wsum_v, wv_masked);
1068
1069 j += 8;
1070 }
1071
1072 let xws = {
1073 let mut tmp: [f64; 8] = core::mem::zeroed();
1074 _mm512_storeu_pd(tmp.as_mut_ptr(), sum_v);
1075 tmp.iter().copied().sum::<f64>()
1076 };
1077 let wsum0 = {
1078 let mut tmp: [f64; 8] = core::mem::zeroed();
1079 _mm512_storeu_pd(tmp.as_mut_ptr(), wsum_v);
1080 tmp.iter().copied().sum::<f64>()
1081 };
1082
1083 let mut xws_acc = xws;
1084 let mut wsum_acc = wsum0;
1085 while j < len_r {
1086 let k = len_r - j;
1087 let w = exp_kernel(p * *ln_lut.add(k));
1088 let x = *data.add(j);
1089 if !x.is_nan() {
1090 xws_acc = x.mul_add(w, xws_acc);
1091 wsum_acc += w;
1092 }
1093 j += 1;
1094 }
1095 (xws_acc, wsum_acc)
1096}
1097
1098#[inline]
1099pub fn uma(input: &UmaInput) -> Result<UmaOutput, UmaError> {
1100 uma_with_kernel(input, Kernel::Auto)
1101}
1102
1103pub fn uma_with_kernel(input: &UmaInput, kernel: Kernel) -> Result<UmaOutput, UmaError> {
1104 let (data, first, min_len, max_len, accel) = uma_prepare(input)?;
1105 let warm = first
1106 .checked_add(max_len)
1107 .and_then(|x| x.checked_sub(1))
1108 .ok_or(UmaError::ArithmeticOverflow {
1109 context: "first + max_len - 1 (warm)",
1110 })?;
1111
1112 let mut out = alloc_with_nan_prefix(data.len(), warm);
1113
1114 let chosen = match kernel {
1115 Kernel::Auto => Kernel::Scalar,
1116 k => k,
1117 };
1118 uma_core_into(input, first, min_len, max_len, accel, chosen, &mut out)?;
1119
1120 let smooth = input.get_smooth_length();
1121 if smooth > 1 {
1122 let w = wma(&WmaInput::from_slice(
1123 &out,
1124 WmaParams {
1125 period: Some(smooth),
1126 },
1127 ))
1128 .map_err(|e| UmaError::DependencyError(e.to_string()))?
1129 .values;
1130 return Ok(UmaOutput { values: w });
1131 }
1132
1133 Ok(UmaOutput { values: out })
1134}
1135
1136#[inline]
1137pub fn uma_into_slice(dst: &mut [f64], input: &UmaInput, kern: Kernel) -> Result<(), UmaError> {
1138 let (data, first, min_len, max_len, accel) = uma_prepare(input)?;
1139 if dst.len() != data.len() {
1140 return Err(UmaError::OutputLengthMismatch {
1141 expected: data.len(),
1142 got: dst.len(),
1143 });
1144 }
1145
1146 let warm = first
1147 .checked_add(max_len)
1148 .and_then(|x| x.checked_sub(1))
1149 .ok_or(UmaError::ArithmeticOverflow {
1150 context: "first + max_len - 1 (warm)",
1151 })?;
1152 let warm_end = warm.min(dst.len());
1153 for v in &mut dst[..warm_end] {
1154 *v = f64::NAN;
1155 }
1156
1157 let chosen = match kern {
1158 Kernel::Auto => Kernel::Scalar,
1159 k => k,
1160 };
1161 uma_core_into(input, first, min_len, max_len, accel, chosen, dst)?;
1162
1163 let smooth = input.get_smooth_length();
1164 if smooth > 1 {
1165 let tmp = wma(&WmaInput::from_slice(
1166 dst,
1167 WmaParams {
1168 period: Some(smooth),
1169 },
1170 ))
1171 .map_err(|e| UmaError::DependencyError(e.to_string()))?
1172 .values;
1173 dst.copy_from_slice(&tmp);
1174 }
1175 Ok(())
1176}
1177
1178#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1179#[inline]
1180pub fn uma_into(input: &UmaInput, out: &mut [f64]) -> Result<(), UmaError> {
1181 uma_into_slice(out, input, Kernel::Auto)
1182}
1183
1184#[derive(Clone, Debug)]
1185pub struct UmaBatchRange {
1186 pub accelerator: (f64, f64, f64),
1187 pub min_length: (usize, usize, usize),
1188 pub max_length: (usize, usize, usize),
1189 pub smooth_length: (usize, usize, usize),
1190}
1191
1192impl Default for UmaBatchRange {
1193 fn default() -> Self {
1194 Self {
1195 accelerator: (1.0, 1.0, 0.0),
1196 min_length: (5, 5, 0),
1197 max_length: (50, 299, 1),
1198 smooth_length: (4, 4, 0),
1199 }
1200 }
1201}
1202
1203#[derive(Copy, Clone)]
1204pub struct UmaBatchBuilder {
1205 accelerator_range: (f64, f64, f64),
1206 min_length_range: (usize, usize, usize),
1207 max_length_range: (usize, usize, usize),
1208 smooth_length_range: (usize, usize, usize),
1209 kernel: Kernel,
1210}
1211
1212impl Default for UmaBatchBuilder {
1213 fn default() -> Self {
1214 Self {
1215 accelerator_range: (1.0, 1.0, 0.0),
1216 min_length_range: (5, 5, 0),
1217 max_length_range: (50, 299, 1),
1218 smooth_length_range: (4, 4, 0),
1219 kernel: Kernel::Auto,
1220 }
1221 }
1222}
1223
1224impl UmaBatchBuilder {
1225 #[inline(always)]
1226 pub fn new() -> Self {
1227 Self::default()
1228 }
1229
1230 #[inline(always)]
1231 pub fn accelerator_range(mut self, start: f64, end: f64, step: f64) -> Self {
1232 self.accelerator_range = (start, end, step);
1233 self
1234 }
1235
1236 #[inline(always)]
1237 pub fn min_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1238 self.min_length_range = (start, end, step);
1239 self
1240 }
1241
1242 #[inline(always)]
1243 pub fn max_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1244 self.max_length_range = (start, end, step);
1245 self
1246 }
1247
1248 #[inline(always)]
1249 pub fn smooth_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1250 self.smooth_length_range = (start, end, step);
1251 self
1252 }
1253
1254 #[inline(always)]
1255 pub fn kernel(mut self, k: Kernel) -> Self {
1256 self.kernel = k;
1257 self
1258 }
1259
1260 #[inline]
1261 pub fn apply_candles(
1262 self,
1263 candles: &Candles,
1264 source: &str,
1265 ) -> Result<UmaBatchOutput, UmaError> {
1266 let sweep = UmaBatchRange {
1267 accelerator: self.accelerator_range,
1268 min_length: self.min_length_range,
1269 max_length: self.max_length_range,
1270 smooth_length: self.smooth_length_range,
1271 };
1272
1273 let data = source_type(candles, source);
1274 uma_batch_inner(data, Some(&candles.volume), &sweep, self.kernel, false)
1275 }
1276
1277 #[inline]
1278 pub fn apply_slice(
1279 self,
1280 data: &[f64],
1281 volume: Option<&[f64]>,
1282 ) -> Result<UmaBatchOutput, UmaError> {
1283 let sweep = UmaBatchRange {
1284 accelerator: self.accelerator_range,
1285 min_length: self.min_length_range,
1286 max_length: self.max_length_range,
1287 smooth_length: self.smooth_length_range,
1288 };
1289
1290 uma_batch_inner(data, volume, &sweep, self.kernel, false)
1291 }
1292
1293 pub fn with_default_slice(
1294 data: &[f64],
1295 volume: Option<&[f64]>,
1296 k: Kernel,
1297 ) -> Result<UmaBatchOutput, UmaError> {
1298 UmaBatchBuilder::new().kernel(k).apply_slice(data, volume)
1299 }
1300
1301 pub fn with_default_candles(c: &Candles) -> Result<UmaBatchOutput, UmaError> {
1302 UmaBatchBuilder::new()
1303 .kernel(Kernel::Auto)
1304 .apply_candles(c, "close")
1305 }
1306}
1307
1308#[derive(Clone, Debug)]
1309pub struct UmaBatchOutput {
1310 pub values: Vec<f64>,
1311 pub combos: Vec<UmaParams>,
1312 pub rows: usize,
1313 pub cols: usize,
1314}
1315
1316impl UmaBatchOutput {
1317 pub fn row_for_params(&self, p: &UmaParams) -> Option<usize> {
1318 self.combos.iter().position(|c| {
1319 c.accelerator.unwrap_or(1.0) == p.accelerator.unwrap_or(1.0)
1320 && c.min_length.unwrap_or(5) == p.min_length.unwrap_or(5)
1321 && c.max_length.unwrap_or(50) == p.max_length.unwrap_or(50)
1322 && c.smooth_length.unwrap_or(4) == p.smooth_length.unwrap_or(4)
1323 })
1324 }
1325
1326 pub fn values_for(&self, p: &UmaParams) -> Option<&[f64]> {
1327 self.row_for_params(p).map(|row| {
1328 let start = row * self.cols;
1329 &self.values[start..start + self.cols]
1330 })
1331 }
1332}
1333
1334#[inline(always)]
1335pub fn expand_grid_uma(r: &UmaBatchRange) -> Vec<UmaParams> {
1336 fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, UmaError> {
1337 let (start, end, step) = range;
1338 if step == 0 || start == end {
1339 return Ok(vec![start]);
1340 }
1341 if start < end {
1342 let v: Vec<usize> = (start..=end).step_by(step).collect();
1343 return if v.is_empty() {
1344 Err(UmaError::InvalidRange { start, end, step })
1345 } else {
1346 Ok(v)
1347 };
1348 }
1349
1350 let mut v = Vec::new();
1351 let mut cur = start;
1352 loop {
1353 v.push(cur);
1354 if cur == end {
1355 break;
1356 }
1357 cur = cur
1358 .checked_sub(step)
1359 .ok_or(UmaError::InvalidRange { start, end, step })?;
1360 if cur < end {
1361 break;
1362 }
1363 }
1364 if v.is_empty() {
1365 return Err(UmaError::InvalidRange { start, end, step });
1366 }
1367 Ok(v)
1368 }
1369 fn axis_f64(range: (f64, f64, f64)) -> Result<Vec<f64>, UmaError> {
1370 let (start, end, step) = range;
1371 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1372 return Ok(vec![start]);
1373 }
1374 let mut v = Vec::new();
1375 if start <= end {
1376 let s = step.abs();
1377 if s < 1e-12 {
1378 return Ok(vec![start]);
1379 }
1380 let mut x = start;
1381 while x <= end + 1e-12 {
1382 v.push(x);
1383 x += s;
1384 }
1385 } else {
1386 let s = -step.abs();
1387 if s.abs() < 1e-12 {
1388 return Ok(vec![start]);
1389 }
1390 let mut x = start;
1391 while x >= end - 1e-12 {
1392 v.push(x);
1393 x += s;
1394 }
1395 }
1396 if v.is_empty() {
1397 return Err(UmaError::InvalidRange {
1398 start: start as usize,
1399 end: end as usize,
1400 step: step.abs() as usize,
1401 });
1402 }
1403 Ok(v)
1404 }
1405
1406 let accs = axis_f64(r.accelerator).unwrap_or_else(|_| vec![r.accelerator.0]);
1407 let mins = match axis_usize(r.min_length) {
1408 Ok(v) => v,
1409 Err(_) => vec![r.min_length.0],
1410 };
1411 let maxs = match axis_usize(r.max_length) {
1412 Ok(v) => v,
1413 Err(_) => vec![r.max_length.0],
1414 };
1415 let smooths = match axis_usize(r.smooth_length) {
1416 Ok(v) => v,
1417 Err(_) => vec![r.smooth_length.0],
1418 };
1419
1420 if !mins.is_empty() && !maxs.is_empty() {
1421 if let (Some(min_min), Some(max_max)) = (mins.iter().min(), maxs.iter().max()) {
1422 if *min_min > *max_max {
1423 return vec![];
1424 }
1425 }
1426 }
1427
1428 let cap = accs
1429 .len()
1430 .checked_mul(mins.len())
1431 .and_then(|x| x.checked_mul(maxs.len()))
1432 .and_then(|x| x.checked_mul(smooths.len()))
1433 .unwrap_or(0);
1434 let mut out = Vec::with_capacity(cap);
1435 for &a in &accs {
1436 for &min in &mins {
1437 for &max in &maxs {
1438 for &s in &smooths {
1439 if min <= max {
1440 out.push(UmaParams {
1441 accelerator: Some(a),
1442 min_length: Some(min),
1443 max_length: Some(max),
1444 smooth_length: Some(s),
1445 });
1446 }
1447 }
1448 }
1449 }
1450 }
1451 out
1452}
1453
1454pub fn uma_batch_with_kernel(
1455 data: &[f64],
1456 volume: Option<&[f64]>,
1457 sweep: &UmaBatchRange,
1458 k: Kernel,
1459) -> Result<UmaBatchOutput, UmaError> {
1460 let kernel = match k {
1461 Kernel::Auto => detect_best_batch_kernel(),
1462 other if other.is_batch() => other,
1463 other => return Err(UmaError::InvalidKernelForBatch(other)),
1464 };
1465
1466 uma_batch_inner(data, volume, sweep, kernel, true)
1467}
1468
1469#[inline(always)]
1470pub fn uma_batch_slice(
1471 data: &[f64],
1472 volume: Option<&[f64]>,
1473 sweep: &UmaBatchRange,
1474 kern: Kernel,
1475) -> Result<UmaBatchOutput, UmaError> {
1476 uma_batch_inner(data, volume, sweep, kern, false)
1477}
1478
1479#[inline(always)]
1480pub fn uma_batch_par_slice(
1481 data: &[f64],
1482 volume: Option<&[f64]>,
1483 sweep: &UmaBatchRange,
1484 kern: Kernel,
1485) -> Result<UmaBatchOutput, UmaError> {
1486 uma_batch_inner(data, volume, sweep, kern, true)
1487}
1488
1489#[inline(always)]
1490fn debatch(k: Kernel) -> Kernel {
1491 match k {
1492 Kernel::Avx512Batch | Kernel::Avx512 => Kernel::Avx512,
1493 Kernel::Avx2Batch | Kernel::Avx2 => Kernel::Avx2,
1494 Kernel::ScalarBatch | Kernel::Scalar => Kernel::Scalar,
1495 Kernel::Auto => Kernel::Scalar,
1496 _ => Kernel::Scalar,
1497 }
1498}
1499
1500#[inline(always)]
1501fn uma_batch_inner_into(
1502 data: &[f64],
1503 volume: Option<&[f64]>,
1504 sweep: &UmaBatchRange,
1505 kern: Kernel,
1506 parallel: bool,
1507 out: &mut [f64],
1508) -> Result<Vec<UmaParams>, UmaError> {
1509 match kern {
1510 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1511 return Err(UmaError::InvalidKernelForBatch(kern));
1512 }
1513 _ => {}
1514 }
1515 let combos = expand_grid_uma(sweep);
1516 let combos = combos;
1517 if combos.is_empty() {
1518 return Err(UmaError::InvalidRange {
1519 start: sweep.min_length.0,
1520 end: sweep.max_length.1,
1521 step: sweep.max_length.2,
1522 });
1523 }
1524
1525 let cols = data.len();
1526 let rows = combos.len();
1527 let expected = rows.checked_mul(cols).ok_or(UmaError::ArithmeticOverflow {
1528 context: "rows * cols (batch out)",
1529 })?;
1530 if out.len() != expected {
1531 return Err(UmaError::OutputLengthMismatch {
1532 expected,
1533 got: out.len(),
1534 });
1535 }
1536
1537 let per_row_kernel = debatch(kern);
1538
1539 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1540 let warm: Vec<usize> = combos
1541 .iter()
1542 .map(|c| {
1543 let a = first
1544 .checked_add(c.max_length.unwrap_or(50))
1545 .and_then(|x| x.checked_sub(1))
1546 .ok_or(UmaError::ArithmeticOverflow {
1547 context: "first + max_length - 1 (batch warm)",
1548 })?;
1549 a.checked_add(c.smooth_length.unwrap_or(4))
1550 .and_then(|x| x.checked_sub(1))
1551 .ok_or(UmaError::ArithmeticOverflow {
1552 context: "+ smooth_length - 1 (batch warm)",
1553 })
1554 })
1555 .collect::<Result<Vec<_>, _>>()?;
1556
1557 let out_mu = unsafe {
1558 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1559 };
1560 init_matrix_prefixes(out_mu, cols, &warm);
1561
1562 let do_row = |row: usize, row_slice: &mut [f64]| -> Result<(), UmaError> {
1563 let input = UmaInput::from_slice(data, volume, combos[row].clone());
1564 uma_into_slice(row_slice, &input, per_row_kernel)
1565 };
1566
1567 #[cfg(not(target_arch = "wasm32"))]
1568 if parallel {
1569 use rayon::prelude::*;
1570 out.par_chunks_mut(cols)
1571 .enumerate()
1572 .try_for_each(|(r, s)| do_row(r, s))?;
1573 } else {
1574 for (r, s) in out.chunks_mut(cols).enumerate() {
1575 do_row(r, s)?;
1576 }
1577 }
1578
1579 #[cfg(target_arch = "wasm32")]
1580 {
1581 for (r, s) in out.chunks_mut(cols).enumerate() {
1582 do_row(r, s)?;
1583 }
1584 }
1585
1586 Ok(combos)
1587}
1588
1589#[inline(always)]
1590fn uma_batch_inner(
1591 data: &[f64],
1592 volume: Option<&[f64]>,
1593 sweep: &UmaBatchRange,
1594 kern: Kernel,
1595 parallel: bool,
1596) -> Result<UmaBatchOutput, UmaError> {
1597 match kern {
1598 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1599 return Err(UmaError::InvalidKernelForBatch(kern));
1600 }
1601 _ => {}
1602 }
1603 let combos = expand_grid_uma(sweep);
1604 let cols = data.len();
1605 let rows = combos.len();
1606 if cols == 0 {
1607 return Err(UmaError::EmptyInputData);
1608 }
1609 if rows == 0 {
1610 return Err(UmaError::InvalidRange {
1611 start: sweep.min_length.0,
1612 end: sweep.max_length.1,
1613 step: sweep.max_length.2,
1614 });
1615 }
1616
1617 let _cap = rows.checked_mul(cols).ok_or(UmaError::ArithmeticOverflow {
1618 context: "rows * cols (matrix alloc)",
1619 })?;
1620
1621 let mut buf_mu = make_uninit_matrix(rows, cols);
1622 let warm: Vec<usize> = combos
1623 .iter()
1624 .map(|c| {
1625 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1626 let a = first
1627 .checked_add(c.max_length.unwrap_or(50))
1628 .and_then(|x| x.checked_sub(1))
1629 .ok_or(UmaError::ArithmeticOverflow {
1630 context: "first + max_length - 1 (batch warm alloc)",
1631 })?;
1632 a.checked_add(c.smooth_length.unwrap_or(4))
1633 .and_then(|x| x.checked_sub(1))
1634 .ok_or(UmaError::ArithmeticOverflow {
1635 context: "+ smooth_length - 1 (batch warm alloc)",
1636 })
1637 })
1638 .collect::<Result<Vec<_>, _>>()?;
1639 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1640
1641 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1642 let out_slice: &mut [f64] =
1643 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1644
1645 let combos = uma_batch_inner_into(data, volume, sweep, kern, parallel, out_slice)?;
1646
1647 let values = unsafe {
1648 Vec::from_raw_parts(
1649 guard.as_mut_ptr() as *mut f64,
1650 guard.len(),
1651 guard.capacity(),
1652 )
1653 };
1654 Ok(UmaBatchOutput {
1655 values,
1656 combos,
1657 rows,
1658 cols,
1659 })
1660}
1661
1662#[cfg(feature = "python")]
1663#[pyfunction(name = "uma")]
1664#[pyo3(signature = (data, accelerator, min_length, max_length, smooth_length, volume=None, kernel=None))]
1665pub fn uma_py<'py>(
1666 py: Python<'py>,
1667 data: PyReadonlyArray1<'py, f64>,
1668 accelerator: f64,
1669 min_length: usize,
1670 max_length: usize,
1671 smooth_length: usize,
1672 volume: Option<PyReadonlyArray1<'py, f64>>,
1673 kernel: Option<&str>,
1674) -> PyResult<Bound<'py, PyArray1<f64>>> {
1675 use numpy::PyArrayMethods;
1676 let kern = validate_kernel(kernel, false)?;
1677 let slice_in = data.as_slice()?;
1678 let vol_slice = volume.as_ref().map(|v| v.as_slice()).transpose()?;
1679
1680 let params = UmaParams {
1681 accelerator: Some(accelerator),
1682 min_length: Some(min_length),
1683 max_length: Some(max_length),
1684 smooth_length: Some(smooth_length),
1685 };
1686 let input = UmaInput::from_slice(slice_in, vol_slice, params);
1687
1688 let out: Vec<f64> = py
1689 .allow_threads(|| uma_with_kernel(&input, kern).map(|o| o.values))
1690 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1691 Ok(out.into_pyarray(py))
1692}
1693
1694#[cfg(feature = "python")]
1695#[pyclass(name = "UmaStream")]
1696pub struct UmaStreamPy {
1697 stream: UmaStream,
1698}
1699
1700#[cfg(feature = "python")]
1701#[pymethods]
1702impl UmaStreamPy {
1703 #[new]
1704 pub fn new(
1705 accelerator: f64,
1706 min_length: usize,
1707 max_length: usize,
1708 smooth_length: usize,
1709 ) -> PyResult<Self> {
1710 let params = UmaParams {
1711 accelerator: Some(accelerator),
1712 min_length: Some(min_length),
1713 max_length: Some(max_length),
1714 smooth_length: Some(smooth_length),
1715 };
1716 let stream =
1717 UmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1718 Ok(Self { stream })
1719 }
1720
1721 pub fn update(&mut self, value: f64) -> Option<f64> {
1722 self.stream.update(value)
1723 }
1724
1725 #[pyo3(name = "update_with_volume")]
1726 pub fn update_with_volume_py(&mut self, value: f64, volume: Option<f64>) -> Option<f64> {
1727 self.stream.update_with_volume(value, volume)
1728 }
1729
1730 pub fn reset(&mut self) {
1731 self.stream.reset();
1732 }
1733}
1734
1735#[cfg(feature = "python")]
1736#[pyfunction(name = "uma_batch")]
1737#[pyo3(signature = (data, accelerator_range, min_length_range, max_length_range, smooth_length_range, volume=None, kernel=None))]
1738pub fn uma_batch_py<'py>(
1739 py: Python<'py>,
1740 data: numpy::PyReadonlyArray1<'py, f64>,
1741 accelerator_range: (f64, f64, f64),
1742 min_length_range: (usize, usize, usize),
1743 max_length_range: (usize, usize, usize),
1744 smooth_length_range: (usize, usize, usize),
1745 volume: Option<numpy::PyReadonlyArray1<'py, f64>>,
1746 kernel: Option<&str>,
1747) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1748 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1749 let slice_in = data.as_slice()?;
1750 let vol_slice = volume.as_ref().map(|v| v.as_slice()).transpose()?;
1751 let sweep = UmaBatchRange {
1752 accelerator: accelerator_range,
1753 min_length: min_length_range,
1754 max_length: max_length_range,
1755 smooth_length: smooth_length_range,
1756 };
1757 let combos = expand_grid_uma(&sweep);
1758 let rows = combos.len();
1759 let cols = slice_in.len();
1760
1761 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1762 let out_slice = unsafe { out_arr.as_slice_mut()? };
1763
1764 let kern = validate_kernel(kernel, true)?;
1765 py.allow_threads(|| uma_batch_inner_into(slice_in, vol_slice, &sweep, kern, false, out_slice))
1766 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1767
1768 let dict = PyDict::new(py);
1769 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1770 dict.set_item("rows", rows)?;
1771 dict.set_item("cols", cols)?;
1772 dict.set_item(
1773 "accelerators",
1774 combos
1775 .iter()
1776 .map(|c| c.accelerator.unwrap_or(1.0))
1777 .collect::<Vec<_>>()
1778 .into_pyarray(py),
1779 )?;
1780 dict.set_item(
1781 "min_lengths",
1782 combos
1783 .iter()
1784 .map(|c| c.min_length.unwrap_or(5) as u64)
1785 .collect::<Vec<_>>()
1786 .into_pyarray(py),
1787 )?;
1788 dict.set_item(
1789 "max_lengths",
1790 combos
1791 .iter()
1792 .map(|c| c.max_length.unwrap_or(50) as u64)
1793 .collect::<Vec<_>>()
1794 .into_pyarray(py),
1795 )?;
1796 dict.set_item(
1797 "smooth_lengths",
1798 combos
1799 .iter()
1800 .map(|c| c.smooth_length.unwrap_or(4) as u64)
1801 .collect::<Vec<_>>()
1802 .into_pyarray(py),
1803 )?;
1804
1805 let combo_list: Vec<Bound<'py, PyDict>> = combos
1806 .iter()
1807 .map(|c| {
1808 let d = PyDict::new(py);
1809 d.set_item("accelerator", c.accelerator.unwrap_or(1.0))
1810 .unwrap();
1811 d.set_item("min_length", c.min_length.unwrap_or(5)).unwrap();
1812 d.set_item("max_length", c.max_length.unwrap_or(50))
1813 .unwrap();
1814 d.set_item("smooth_length", c.smooth_length.unwrap_or(4))
1815 .unwrap();
1816 d.into()
1817 })
1818 .collect();
1819 dict.set_item("combos", combo_list)?;
1820
1821 Ok(dict.into())
1822}
1823
1824#[cfg(all(feature = "python", feature = "cuda"))]
1825#[pyclass(module = "ta_indicators.cuda", unsendable)]
1826pub struct UmaDeviceArrayF32Py {
1827 pub(crate) buf: Option<DeviceBuffer<f32>>,
1828 pub(crate) rows: usize,
1829 pub(crate) cols: usize,
1830 pub(crate) _ctx: Arc<Context>,
1831 pub(crate) device_id: u32,
1832}
1833
1834#[cfg(all(feature = "python", feature = "cuda"))]
1835#[pymethods]
1836impl UmaDeviceArrayF32Py {
1837 #[getter]
1838 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1839 let d = PyDict::new(py);
1840 d.set_item("shape", (self.rows, self.cols))?;
1841 d.set_item("typestr", "<f4")?;
1842 d.set_item(
1843 "strides",
1844 (
1845 self.cols * std::mem::size_of::<f32>(),
1846 std::mem::size_of::<f32>(),
1847 ),
1848 )?;
1849 let ptr = self
1850 .buf
1851 .as_ref()
1852 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1853 .as_device_ptr()
1854 .as_raw() as usize;
1855 d.set_item("data", (ptr, false))?;
1856
1857 d.set_item("version", 3)?;
1858 Ok(d)
1859 }
1860
1861 fn __dlpack_device__(&self) -> (i32, i32) {
1862 (2, self.device_id as i32)
1863 }
1864
1865 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1866 fn __dlpack__<'py>(
1867 &mut self,
1868 py: Python<'py>,
1869 stream: Option<PyObject>,
1870 max_version: Option<PyObject>,
1871 dl_device: Option<PyObject>,
1872 copy: Option<PyObject>,
1873 ) -> PyResult<PyObject> {
1874 let (kdl, alloc_dev) = self.__dlpack_device__();
1875 if let Some(dev_obj) = dl_device.as_ref() {
1876 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1877 if dev_ty != kdl || dev_id != alloc_dev {
1878 let wants_copy = copy
1879 .as_ref()
1880 .and_then(|c| c.extract::<bool>(py).ok())
1881 .unwrap_or(false);
1882 if wants_copy {
1883 return Err(PyValueError::new_err(
1884 "__dlpack__(copy=True) not implemented for UMA device handle",
1885 ));
1886 } else {
1887 return Err(PyValueError::new_err("dl_device mismatch for UMA tensor"));
1888 }
1889 }
1890 }
1891 }
1892
1893 let _ = stream;
1894
1895 let buf = self
1896 .buf
1897 .take()
1898 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1899
1900 let rows = self.rows;
1901 let cols = self.cols;
1902
1903 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1904
1905 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1906 }
1907}
1908
1909#[cfg(all(feature = "python", feature = "cuda"))]
1910#[pyfunction(name = "uma_cuda_batch_dev")]
1911#[pyo3(signature = (data_f32, accelerator_range, min_length_range, max_length_range, smooth_length_range, volume_f32=None, device_id=0))]
1912pub fn uma_cuda_batch_dev_py(
1913 py: Python<'_>,
1914 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1915 accelerator_range: (f64, f64, f64),
1916 min_length_range: (usize, usize, usize),
1917 max_length_range: (usize, usize, usize),
1918 smooth_length_range: (usize, usize, usize),
1919 volume_f32: Option<numpy::PyReadonlyArray1<'_, f32>>,
1920 device_id: usize,
1921) -> PyResult<UmaDeviceArrayF32Py> {
1922 if !cuda_available() {
1923 return Err(PyValueError::new_err("CUDA not available"));
1924 }
1925
1926 let slice_in = data_f32.as_slice()?;
1927 let volume_slice = volume_f32.as_ref().map(|v| v.as_slice()).transpose()?;
1928 let sweep = UmaBatchRange {
1929 accelerator: accelerator_range,
1930 min_length: min_length_range,
1931 max_length: max_length_range,
1932 smooth_length: smooth_length_range,
1933 };
1934
1935 let (inner, ctx, dev_id) = py
1936 .allow_threads(
1937 || -> Result<_, crate::cuda::moving_averages::uma_wrapper::CudaUmaError> {
1938 let cuda = CudaUma::new(device_id)?;
1939 let out = cuda.uma_batch_dev(slice_in, volume_slice, &sweep)?;
1940 Ok((out, cuda.context_arc(), cuda.device_id()))
1941 },
1942 )
1943 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1944
1945 let crate::cuda::DeviceArrayF32 { buf, rows, cols } = inner;
1946 Ok(UmaDeviceArrayF32Py {
1947 buf: Some(buf),
1948 rows,
1949 cols,
1950 _ctx: ctx,
1951 device_id: dev_id,
1952 })
1953}
1954
1955#[cfg(all(feature = "python", feature = "cuda"))]
1956#[pyfunction(name = "uma_cuda_many_series_one_param_dev")]
1957#[pyo3(signature = (prices_tm_f32, accelerator, min_length, max_length, smooth_length, volume_tm_f32=None, device_id=0))]
1958pub fn uma_cuda_many_series_one_param_dev_py(
1959 py: Python<'_>,
1960 prices_tm_f32: PyReadonlyArray2<'_, f32>,
1961 accelerator: f64,
1962 min_length: usize,
1963 max_length: usize,
1964 smooth_length: usize,
1965 volume_tm_f32: Option<PyReadonlyArray2<'_, f32>>,
1966 device_id: usize,
1967) -> PyResult<UmaDeviceArrayF32Py> {
1968 if !cuda_available() {
1969 return Err(PyValueError::new_err("CUDA not available"));
1970 }
1971
1972 use numpy::PyUntypedArrayMethods;
1973
1974 let rows = prices_tm_f32.shape()[0];
1975 let cols = prices_tm_f32.shape()[1];
1976 if let Some(vol) = &volume_tm_f32 {
1977 let vshape = vol.shape();
1978 if vshape != prices_tm_f32.shape() {
1979 return Err(PyValueError::new_err(
1980 "price and volume matrices must share shape",
1981 ));
1982 }
1983 }
1984
1985 let prices_flat = prices_tm_f32.as_slice()?;
1986 let volume_flat = volume_tm_f32
1987 .as_ref()
1988 .map(|arr| arr.as_slice())
1989 .transpose()?;
1990
1991 let params = UmaParams {
1992 accelerator: Some(accelerator),
1993 min_length: Some(min_length),
1994 max_length: Some(max_length),
1995 smooth_length: Some(smooth_length),
1996 };
1997
1998 let (inner, ctx, dev_id) = py
1999 .allow_threads(
2000 || -> Result<_, crate::cuda::moving_averages::uma_wrapper::CudaUmaError> {
2001 let cuda = CudaUma::new(device_id)?;
2002 let out = cuda.uma_many_series_one_param_time_major_dev(
2003 prices_flat,
2004 volume_flat,
2005 cols,
2006 rows,
2007 ¶ms,
2008 )?;
2009 Ok((out, cuda.context_arc(), cuda.device_id()))
2010 },
2011 )
2012 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2013
2014 let crate::cuda::DeviceArrayF32 { buf, rows, cols } = inner;
2015 Ok(UmaDeviceArrayF32Py {
2016 buf: Some(buf),
2017 rows,
2018 cols,
2019 _ctx: ctx,
2020 device_id: dev_id,
2021 })
2022}
2023
2024#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2025#[wasm_bindgen]
2026pub fn uma_js(
2027 data: &[f64],
2028 accelerator: f64,
2029 min_length: usize,
2030 max_length: usize,
2031 smooth_length: usize,
2032 volume: Option<Vec<f64>>,
2033) -> Result<JsValue, JsValue> {
2034 let params = UmaParams {
2035 accelerator: Some(accelerator),
2036 min_length: Some(min_length),
2037 max_length: Some(max_length),
2038 smooth_length: Some(smooth_length),
2039 };
2040 let vol_slice = volume.as_deref();
2041 let input = UmaInput::from_slice(data, vol_slice, params);
2042
2043 match uma_with_kernel(&input, Kernel::Auto) {
2044 Ok(output) => {
2045 let obj = js_sys::Object::new();
2046 let values_array = js_sys::Array::new();
2047 for val in output.values {
2048 values_array.push(&JsValue::from_f64(val));
2049 }
2050 js_sys::Reflect::set(&obj, &JsValue::from_str("values"), &values_array)?;
2051 Ok(obj.into())
2052 }
2053 Err(e) => Err(JsValue::from_str(&e.to_string())),
2054 }
2055}
2056
2057#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2058#[derive(Serialize, Deserialize)]
2059pub struct UmaBatchConfig {
2060 pub accelerator_range: (f64, f64, f64),
2061 pub min_length_range: (usize, usize, usize),
2062 pub max_length_range: (usize, usize, usize),
2063 pub smooth_length_range: (usize, usize, usize),
2064}
2065
2066#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2067#[derive(Serialize, Deserialize)]
2068pub struct UmaBatchJsOutput {
2069 pub values: Vec<f64>,
2070 pub combos: Vec<UmaParams>,
2071 pub rows: usize,
2072 pub cols: usize,
2073}
2074
2075#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2076#[wasm_bindgen(js_name = "uma_batch")]
2077pub fn uma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2078 let cfg: UmaBatchConfig = serde_wasm_bindgen::from_value(config)
2079 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2080 let sweep = UmaBatchRange {
2081 accelerator: cfg.accelerator_range,
2082 min_length: cfg.min_length_range,
2083 max_length: cfg.max_length_range,
2084 smooth_length: cfg.smooth_length_range,
2085 };
2086 let out = uma_batch_inner(data, None, &sweep, detect_best_kernel(), false)
2087 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2088 let js = UmaBatchJsOutput {
2089 values: out.values,
2090 combos: out.combos,
2091 rows: out.rows,
2092 cols: out.cols,
2093 };
2094 serde_wasm_bindgen::to_value(&js)
2095 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2096}
2097
2098#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2099#[wasm_bindgen]
2100pub fn uma_alloc(len: usize) -> *mut f64 {
2101 let mut v = Vec::<f64>::with_capacity(len);
2102 let p = v.as_mut_ptr();
2103 std::mem::forget(v);
2104 p
2105}
2106
2107#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2108#[wasm_bindgen]
2109pub fn uma_free(ptr: *mut f64, len: usize) {
2110 unsafe {
2111 let _ = Vec::from_raw_parts(ptr, len, len);
2112 }
2113}
2114
2115#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2116#[wasm_bindgen]
2117pub fn uma_into(
2118 in_ptr: *const f64,
2119 out_ptr: *mut f64,
2120 len: usize,
2121 accelerator: f64,
2122 min_length: usize,
2123 max_length: usize,
2124 smooth_length: usize,
2125) -> Result<(), JsValue> {
2126 if in_ptr.is_null() || out_ptr.is_null() {
2127 return Err(JsValue::from_str("null pointer"));
2128 }
2129 unsafe {
2130 let data = std::slice::from_raw_parts(in_ptr, len);
2131 let params = UmaParams {
2132 accelerator: Some(accelerator),
2133 min_length: Some(min_length),
2134 max_length: Some(max_length),
2135 smooth_length: Some(smooth_length),
2136 };
2137 let input = UmaInput::from_slice(data, None, params);
2138
2139 if core::ptr::eq(in_ptr as *const u8, out_ptr as *const u8) {
2140 let mut tmp = vec![0.0; len];
2141 uma_into_slice(&mut tmp, &input, Kernel::Auto)
2142 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2143 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2144 out.copy_from_slice(&tmp);
2145 } else {
2146 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2147 uma_into_slice(out, &input, Kernel::Auto)
2148 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2149 }
2150 Ok(())
2151 }
2152}
2153
2154#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2155#[wasm_bindgen]
2156pub fn uma_stream_new(
2157 accelerator: f64,
2158 min_length: usize,
2159 max_length: usize,
2160 smooth_length: usize,
2161) -> *mut UmaStream {
2162 let params = UmaParams {
2163 accelerator: Some(accelerator),
2164 min_length: Some(min_length),
2165 max_length: Some(max_length),
2166 smooth_length: Some(smooth_length),
2167 };
2168
2169 match UmaStream::try_new(params) {
2170 Ok(stream) => Box::into_raw(Box::new(stream)),
2171 Err(_) => std::ptr::null_mut(),
2172 }
2173}
2174
2175#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2176#[wasm_bindgen]
2177pub fn uma_stream_update(stream: *mut UmaStream, value: f64) -> Option<f64> {
2178 if stream.is_null() {
2179 return None;
2180 }
2181 unsafe { (*stream).update_with_volume(value, None) }
2182}
2183
2184#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2185#[wasm_bindgen]
2186pub fn uma_stream_update_with_volume(
2187 stream: *mut UmaStream,
2188 value: f64,
2189 volume: f64,
2190) -> Option<f64> {
2191 if stream.is_null() {
2192 return None;
2193 }
2194 unsafe { (*stream).update_with_volume(value, Some(volume)) }
2195}
2196
2197#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2198#[wasm_bindgen]
2199pub fn uma_stream_reset(stream: *mut UmaStream) {
2200 if !stream.is_null() {
2201 unsafe {
2202 (*stream).reset();
2203 }
2204 }
2205}
2206
2207#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2208#[wasm_bindgen]
2209pub fn uma_stream_free(stream: *mut UmaStream) {
2210 if !stream.is_null() {
2211 unsafe {
2212 let _ = Box::from_raw(stream);
2213 }
2214 }
2215}
2216
2217#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2218#[wasm_bindgen]
2219pub fn uma_get_view(ptr: *mut f64, len: usize) -> js_sys::Float64Array {
2220 unsafe { js_sys::Float64Array::view(std::slice::from_raw_parts(ptr, len)) }
2221}
2222
2223#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2224#[wasm_bindgen]
2225pub fn uma_update(
2226 ptr: *mut f64,
2227 len: usize,
2228 accelerator: f64,
2229 min_length: usize,
2230 max_length: usize,
2231 smooth_length: usize,
2232) -> Result<(), JsValue> {
2233 if ptr.is_null() {
2234 return Err(JsValue::from_str("null pointer"));
2235 }
2236
2237 unsafe {
2238 let data = std::slice::from_raw_parts(ptr, len);
2239 let params = UmaParams {
2240 accelerator: Some(accelerator),
2241 min_length: Some(min_length),
2242 max_length: Some(max_length),
2243 smooth_length: Some(smooth_length),
2244 };
2245 let input = UmaInput::from_slice(data, None, params);
2246
2247 let mut tmp = vec![0.0; len];
2248 uma_into_slice(&mut tmp, &input, Kernel::Auto)
2249 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2250
2251 let out = std::slice::from_raw_parts_mut(ptr, len);
2252 out.copy_from_slice(&tmp);
2253 Ok(())
2254 }
2255}
2256
2257#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2258#[wasm_bindgen]
2259pub fn uma_batch_js(
2260 data: &[f64],
2261 accelerator_range: Vec<f64>,
2262 min_length_range: Vec<usize>,
2263 max_length_range: Vec<usize>,
2264 smooth_length_range: Vec<usize>,
2265 volume: Option<Vec<f64>>,
2266) -> Result<JsValue, JsValue> {
2267 if accelerator_range.len() != 3
2268 || min_length_range.len() != 3
2269 || max_length_range.len() != 3
2270 || smooth_length_range.len() != 3
2271 {
2272 return Err(JsValue::from_str(
2273 "All range arrays must have exactly 3 elements: [start, end, step]",
2274 ));
2275 }
2276
2277 let sweep = UmaBatchRange {
2278 accelerator: (
2279 accelerator_range[0],
2280 accelerator_range[1],
2281 accelerator_range[2],
2282 ),
2283 min_length: (
2284 min_length_range[0],
2285 min_length_range[1],
2286 min_length_range[2],
2287 ),
2288 max_length: (
2289 max_length_range[0],
2290 max_length_range[1],
2291 max_length_range[2],
2292 ),
2293 smooth_length: (
2294 smooth_length_range[0],
2295 smooth_length_range[1],
2296 smooth_length_range[2],
2297 ),
2298 };
2299
2300 let vol_slice = volume.as_deref();
2301 let kernel = detect_best_batch_kernel();
2302
2303 let result = uma_batch_inner(data, vol_slice, &sweep, kernel, true)
2304 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2305
2306 let obj = js_sys::Object::new();
2307
2308 let values_array = js_sys::Array::new();
2309 for val in result.values {
2310 values_array.push(&JsValue::from_f64(val));
2311 }
2312 js_sys::Reflect::set(&obj, &JsValue::from_str("values"), &values_array)?;
2313
2314 let accelerators = js_sys::Array::new();
2315 let min_lengths = js_sys::Array::new();
2316 let max_lengths = js_sys::Array::new();
2317 let smooth_lengths = js_sys::Array::new();
2318
2319 for combo in &result.combos {
2320 accelerators.push(&JsValue::from_f64(combo.accelerator.unwrap_or(1.0)));
2321 min_lengths.push(&JsValue::from_f64(combo.min_length.unwrap_or(5) as f64));
2322 max_lengths.push(&JsValue::from_f64(combo.max_length.unwrap_or(50) as f64));
2323 smooth_lengths.push(&JsValue::from_f64(combo.smooth_length.unwrap_or(4) as f64));
2324 }
2325
2326 js_sys::Reflect::set(&obj, &JsValue::from_str("accelerators"), &accelerators)?;
2327 js_sys::Reflect::set(&obj, &JsValue::from_str("min_lengths"), &min_lengths)?;
2328 js_sys::Reflect::set(&obj, &JsValue::from_str("max_lengths"), &max_lengths)?;
2329 js_sys::Reflect::set(&obj, &JsValue::from_str("smooth_lengths"), &smooth_lengths)?;
2330 js_sys::Reflect::set(
2331 &obj,
2332 &JsValue::from_str("rows"),
2333 &JsValue::from_f64(result.rows as f64),
2334 )?;
2335 js_sys::Reflect::set(
2336 &obj,
2337 &JsValue::from_str("cols"),
2338 &JsValue::from_f64(result.cols as f64),
2339 )?;
2340
2341 Ok(obj.into())
2342}
2343
2344#[cfg(test)]
2345mod tests {
2346 use super::*;
2347 use crate::skip_if_unsupported;
2348 use crate::utilities::data_loader::read_candles_from_csv;
2349 #[cfg(feature = "proptest")]
2350 use proptest::prelude::*;
2351
2352 fn check_uma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2353 skip_if_unsupported!(kernel, test_name);
2354 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2355 let candles = read_candles_from_csv(file_path)?;
2356
2357 let default_params = UmaParams {
2358 accelerator: None,
2359 min_length: None,
2360 max_length: None,
2361 smooth_length: None,
2362 };
2363 let input = UmaInput::from_candles(&candles, "close", default_params);
2364 let output = uma_with_kernel(&input, kernel)?;
2365 assert_eq!(output.values.len(), candles.close.len());
2366
2367 Ok(())
2368 }
2369
2370 fn check_uma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2371 skip_if_unsupported!(kernel, test_name);
2372 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2373 let candles = read_candles_from_csv(file_path)?;
2374
2375 let input = UmaInput::from_candles(&candles, "close", UmaParams::default());
2376 let result = uma_with_kernel(&input, kernel)?;
2377
2378 let values = &result.values;
2379 let valid_values: Vec<f64> = values.iter().filter(|&&v| !v.is_nan()).copied().collect();
2380
2381 let expected_last_five = [
2382 59665.81830666,
2383 59477.69234542,
2384 59314.50778603,
2385 59218.23033661,
2386 59143.61473211,
2387 ];
2388
2389 if valid_values.len() >= 5 {
2390 let start = valid_values.len().saturating_sub(5);
2391 for (i, &val) in valid_values[start..].iter().enumerate() {
2392 let diff = (val - expected_last_five[i]).abs();
2393 let tolerance = expected_last_five[i] * 0.01;
2394 assert!(
2395 diff < tolerance || diff < 100.0,
2396 "[{}] UMA {:?} mismatch at idx {}: got {}, expected {}",
2397 test_name,
2398 kernel,
2399 i,
2400 val,
2401 expected_last_five[i]
2402 );
2403 }
2404 }
2405 Ok(())
2406 }
2407
2408 fn check_uma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2409 skip_if_unsupported!(kernel, test_name);
2410 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2411 let candles = read_candles_from_csv(file_path)?;
2412
2413 let input = UmaInput::with_default_candles(&candles);
2414 match input.data {
2415 UmaData::Candles { source, .. } => assert_eq!(source, "close"),
2416 _ => panic!("Expected UmaData::Candles"),
2417 }
2418 let output = uma_with_kernel(&input, kernel)?;
2419 assert_eq!(output.values.len(), candles.close.len());
2420
2421 Ok(())
2422 }
2423
2424 fn check_uma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2425 skip_if_unsupported!(kernel, test_name);
2426 let input_data = [10.0, 20.0, 30.0];
2427 let params = UmaParams {
2428 accelerator: Some(1.0),
2429 min_length: Some(5),
2430 max_length: Some(0),
2431 smooth_length: Some(4),
2432 };
2433 let input = UmaInput::from_slice(&input_data, None, params);
2434 let res = uma_with_kernel(&input, kernel);
2435 assert!(
2436 res.is_err(),
2437 "[{}] UMA should fail with zero max_length",
2438 test_name
2439 );
2440 Ok(())
2441 }
2442
2443 fn check_uma_period_exceeds_length(
2444 test_name: &str,
2445 kernel: Kernel,
2446 ) -> Result<(), Box<dyn Error>> {
2447 skip_if_unsupported!(kernel, test_name);
2448 let data_small = [10.0, 20.0, 30.0];
2449 let params = UmaParams {
2450 accelerator: Some(1.0),
2451 min_length: Some(5),
2452 max_length: Some(10),
2453 smooth_length: Some(4),
2454 };
2455 let input = UmaInput::from_slice(&data_small, None, params);
2456 let res = uma_with_kernel(&input, kernel);
2457 assert!(
2458 res.is_err(),
2459 "[{}] UMA should fail with max_length exceeding data length",
2460 test_name
2461 );
2462 Ok(())
2463 }
2464
2465 fn check_uma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2466 skip_if_unsupported!(kernel, test_name);
2467 let single_point = [42.0];
2468 let params = UmaParams::default();
2469 let input = UmaInput::from_slice(&single_point, None, params);
2470 let res = uma_with_kernel(&input, kernel);
2471 assert!(
2472 res.is_err(),
2473 "[{}] UMA should fail with insufficient data",
2474 test_name
2475 );
2476 Ok(())
2477 }
2478
2479 fn check_uma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2480 skip_if_unsupported!(kernel, test_name);
2481 let empty: [f64; 0] = [];
2482 let input = UmaInput::from_slice(&empty, None, UmaParams::default());
2483 let res = uma_with_kernel(&input, kernel);
2484 assert!(
2485 matches!(res, Err(UmaError::EmptyInputData)),
2486 "[{}] UMA should fail with empty input",
2487 test_name
2488 );
2489 Ok(())
2490 }
2491
2492 fn check_uma_invalid_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2493 skip_if_unsupported!(kernel, test_name);
2494 let data: Vec<f64> = (0..100).map(|i| 100.0 + i as f64).collect();
2495
2496 let params = UmaParams {
2497 accelerator: Some(0.5),
2498 max_length: Some(10),
2499 ..Default::default()
2500 };
2501 let input = UmaInput::from_slice(&data, None, params);
2502 let res = uma_with_kernel(&input, kernel);
2503 assert!(
2504 matches!(res, Err(UmaError::InvalidAccelerator { .. })),
2505 "[{}] UMA should fail with invalid accelerator, got: {:?}",
2506 test_name,
2507 res
2508 );
2509 Ok(())
2510 }
2511
2512 fn check_uma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2513 skip_if_unsupported!(kernel, test_name);
2514 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2515 let candles = read_candles_from_csv(file_path)?;
2516
2517 let first_params = UmaParams::default();
2518 let first_input = UmaInput::from_candles(&candles, "close", first_params);
2519 let first_result = uma_with_kernel(&first_input, kernel)?;
2520
2521 let second_params = UmaParams::default();
2522 let second_input = UmaInput::from_slice(&first_result.values, None, second_params);
2523 let second_result = uma_with_kernel(&second_input, kernel)?;
2524
2525 assert_eq!(second_result.values.len(), first_result.values.len());
2526
2527 let valid_count = second_result
2528 .values
2529 .iter()
2530 .filter(|&&v| !v.is_nan())
2531 .count();
2532 assert!(
2533 valid_count > 0,
2534 "[{}] UMA reinput should produce valid values",
2535 test_name
2536 );
2537
2538 Ok(())
2539 }
2540
2541 fn check_uma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2542 skip_if_unsupported!(kernel, test_name);
2543
2544 let mut data = vec![f64::NAN; 10];
2545 data.extend((0..100).map(|i| 100.0 + i as f64));
2546
2547 let input = UmaInput::from_slice(&data, None, UmaParams::default());
2548 let res = uma_with_kernel(&input, kernel)?;
2549 assert_eq!(res.values.len(), data.len());
2550
2551 let valid_count = res.values[60..].iter().filter(|&&v| !v.is_nan()).count();
2552 assert!(
2553 valid_count > 0,
2554 "[{}] UMA should handle NaN prefix",
2555 test_name
2556 );
2557
2558 Ok(())
2559 }
2560
2561 fn check_uma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2562 skip_if_unsupported!(kernel, test_name);
2563
2564 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2565 let candles = read_candles_from_csv(file_path)?;
2566
2567 let params = UmaParams::default();
2568 let input = UmaInput::from_candles(&candles, "close", params.clone());
2569 let batch_output = uma_with_kernel(&input, kernel)?.values;
2570
2571 let mut stream = UmaStream::try_new(params)?;
2572 let mut stream_values = Vec::with_capacity(candles.close.len());
2573
2574 for (i, &price) in candles.close.iter().enumerate() {
2575 let volume = if i < candles.volume.len() {
2576 Some(candles.volume[i])
2577 } else {
2578 None
2579 };
2580
2581 match stream.update_with_volume(price, volume) {
2582 Some(uma_val) => stream_values.push(uma_val),
2583 None => stream_values.push(f64::NAN),
2584 }
2585 }
2586
2587 assert_eq!(batch_output.len(), stream_values.len());
2588
2589 let batch_valid: Vec<f64> = batch_output
2590 .iter()
2591 .filter(|&&v| !v.is_nan())
2592 .copied()
2593 .collect();
2594 let stream_valid: Vec<f64> = stream_values
2595 .iter()
2596 .filter(|&&v| !v.is_nan())
2597 .copied()
2598 .collect();
2599
2600 if batch_valid.len() >= 5 && stream_valid.len() >= 5 {
2601 let batch_last = &batch_valid[batch_valid.len() - 5..];
2602 let stream_last = &stream_valid[stream_valid.len() - 5..];
2603
2604 for (i, (&b, &s)) in batch_last.iter().zip(stream_last.iter()).enumerate() {
2605 let diff = (b - s).abs();
2606 let relative_diff = diff / b.abs().max(1.0);
2607 assert!(
2608 relative_diff < 0.1,
2609 "[{}] UMA streaming mismatch at idx {}: batch={}, stream={}, diff={}, rel_diff={}",
2610 test_name,
2611 i,
2612 b,
2613 s,
2614 diff,
2615 relative_diff
2616 );
2617 }
2618 }
2619
2620 Ok(())
2621 }
2622
2623 #[cfg(debug_assertions)]
2624 fn check_uma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2625 skip_if_unsupported!(kernel, test_name);
2626
2627 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2628 let candles = read_candles_from_csv(file_path)?;
2629
2630 let test_params = vec![
2631 UmaParams::default(),
2632 UmaParams {
2633 accelerator: Some(1.5),
2634 min_length: Some(3),
2635 max_length: Some(30),
2636 smooth_length: Some(3),
2637 },
2638 UmaParams {
2639 accelerator: Some(2.0),
2640 min_length: Some(10),
2641 max_length: Some(100),
2642 smooth_length: Some(8),
2643 },
2644 ];
2645
2646 for params in test_params.iter() {
2647 let input = UmaInput::from_candles(&candles, "close", params.clone());
2648 let output = uma_with_kernel(&input, kernel)?;
2649
2650 for (i, &val) in output.values.iter().enumerate() {
2651 if val.is_nan() {
2652 continue;
2653 }
2654
2655 let bits = val.to_bits();
2656
2657 if bits == 0x11111111_11111111 {
2658 panic!(
2659 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
2660 test_name, val, bits, i
2661 );
2662 }
2663
2664 if bits == 0x22222222_22222222 {
2665 panic!(
2666 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
2667 test_name, val, bits, i
2668 );
2669 }
2670
2671 if bits == 0x33333333_33333333 {
2672 panic!(
2673 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
2674 test_name, val, bits, i
2675 );
2676 }
2677 }
2678 }
2679
2680 Ok(())
2681 }
2682
2683 #[cfg(not(debug_assertions))]
2684 fn check_uma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2685 Ok(())
2686 }
2687
2688 #[cfg(feature = "proptest")]
2689 #[allow(clippy::float_cmp)]
2690 fn check_uma_property(
2691 test_name: &str,
2692 kernel: Kernel,
2693 ) -> Result<(), Box<dyn std::error::Error>> {
2694 use proptest::prelude::*;
2695 skip_if_unsupported!(kernel, test_name);
2696
2697 let strat = (5usize..=20, 20usize..=50, 2usize..=8, 1.0f64..3.0).prop_flat_map(
2698 |(min_len, max_len, smooth_len, acc)| {
2699 (
2700 prop::collection::vec(
2701 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2702 max_len + 10..200,
2703 ),
2704 Just(min_len),
2705 Just(max_len),
2706 Just(smooth_len),
2707 Just(acc),
2708 )
2709 },
2710 );
2711
2712 proptest::test_runner::TestRunner::default()
2713 .run(&strat, |(data, min_len, max_len, smooth_len, acc)| {
2714 let params = UmaParams {
2715 accelerator: Some(acc),
2716 min_length: Some(min_len),
2717 max_length: Some(max_len),
2718 smooth_length: Some(smooth_len),
2719 };
2720 let input = UmaInput::from_slice(&data, None, params);
2721
2722 let UmaOutput { values: out } = uma_with_kernel(&input, kernel).unwrap();
2723 let UmaOutput { values: ref_out } =
2724 uma_with_kernel(&input, Kernel::Scalar).unwrap();
2725
2726 prop_assert_eq!(out.len(), data.len());
2727 prop_assert_eq!(ref_out.len(), data.len());
2728
2729 for i in 0..data.len() {
2730 let y = out[i];
2731 let r = ref_out[i];
2732
2733 if !y.is_finite() || !r.is_finite() {
2734 prop_assert!(
2735 y.to_bits() == r.to_bits(),
2736 "finite/NaN mismatch idx {i}: {y} vs {r}"
2737 );
2738 continue;
2739 }
2740
2741 let y_bits = y.to_bits();
2742 let r_bits = r.to_bits();
2743 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2744
2745 prop_assert!(
2746 (y - r).abs() <= 1e-6 || ulp_diff <= 10,
2747 "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2748 );
2749 }
2750 Ok(())
2751 })
2752 .unwrap();
2753
2754 Ok(())
2755 }
2756
2757 macro_rules! generate_all_uma_tests {
2758 ($($test_fn:ident),*) => {
2759 paste::paste! {
2760 $(
2761 #[test]
2762 fn [<$test_fn _scalar_f64>]() {
2763 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2764 }
2765 )*
2766 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2767 $(
2768 #[test]
2769 fn [<$test_fn _avx2_f64>]() {
2770 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2771 }
2772 #[test]
2773 fn [<$test_fn _avx512_f64>]() {
2774 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2775 }
2776 )*
2777 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2778 $(
2779 #[test]
2780 fn [<$test_fn _simd128_f64>]() {
2781 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2782 }
2783 )*
2784 }
2785 }
2786 }
2787
2788 generate_all_uma_tests!(
2789 check_uma_partial_params,
2790 check_uma_accuracy,
2791 check_uma_default_candles,
2792 check_uma_zero_period,
2793 check_uma_period_exceeds_length,
2794 check_uma_very_small_dataset,
2795 check_uma_empty_input,
2796 check_uma_invalid_params,
2797 check_uma_reinput,
2798 check_uma_nan_handling,
2799 check_uma_streaming,
2800 check_uma_no_poison
2801 );
2802
2803 #[cfg(feature = "proptest")]
2804 generate_all_uma_tests!(check_uma_property);
2805
2806 #[test]
2807 fn uma_into_slice_matches_with_kernel() {
2808 use crate::utilities::data_loader::read_candles_from_csv;
2809 let c = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv").unwrap();
2810 let params = UmaParams::default();
2811 let input = UmaInput::from_candles(&c, "close", params.clone());
2812
2813 let via_api = uma_with_kernel(&input, Kernel::Scalar).unwrap().values;
2814 let mut via_into = vec![0.0; via_api.len()];
2815 uma_into_slice(&mut via_into, &input, Kernel::Scalar).unwrap();
2816
2817 assert_eq!(via_api.len(), via_into.len());
2818 for (a, b) in via_api.iter().zip(via_into.iter()) {
2819 assert_eq!(a.to_bits(), b.to_bits());
2820 }
2821 }
2822
2823 #[test]
2824 fn test_uma_into_matches_api() {
2825 let len = 256usize;
2826 let mut ts = Vec::with_capacity(len);
2827 let mut open = Vec::with_capacity(len);
2828 let mut high = Vec::with_capacity(len);
2829 let mut low = Vec::with_capacity(len);
2830 let mut close = Vec::with_capacity(len);
2831 let mut volume = Vec::with_capacity(len);
2832
2833 for i in 0..len {
2834 ts.push(i as i64);
2835 let base = 100.0 + (i as f64) * 0.1 + (i as f64 / 10.0).sin() * 2.0;
2836 let c = if i < 3 { f64::NAN } else { base };
2837 let h = if i < 3 { f64::NAN } else { c + 1.0 };
2838 let l = if i < 3 { f64::NAN } else { c - 1.0 };
2839 let o = if i < 3 { f64::NAN } else { c };
2840 let v = 1000.0 + (i % 10) as f64;
2841
2842 open.push(o);
2843 high.push(h);
2844 low.push(l);
2845 close.push(c);
2846 volume.push(v);
2847 }
2848
2849 let candles =
2850 crate::utilities::data_loader::Candles::new(ts, open, high, low, close, volume);
2851 let params = UmaParams::default();
2852 let input = UmaInput::from_candles(&candles, "close", params);
2853
2854 let baseline = uma(&input).unwrap().values;
2855 let mut via_into = vec![0.0; baseline.len()];
2856
2857 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2858 uma_into(&input, &mut via_into).unwrap();
2859 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2860 uma_into_slice(&mut via_into, &input, Kernel::Auto).unwrap();
2861
2862 assert_eq!(baseline.len(), via_into.len());
2863 for (a, b) in baseline.iter().zip(via_into.iter()) {
2864 let both_nan = a.is_nan() && b.is_nan();
2865 let close_enough = if both_nan {
2866 true
2867 } else {
2868 (*a - *b).abs() <= 1e-12
2869 };
2870 assert!(close_enough, "Mismatch: a={} b={}", a, b);
2871 }
2872 }
2873
2874 #[cfg(feature = "python")]
2875 #[test]
2876 fn uma_batch_py_no_copy_shape() {
2877 pyo3::Python::with_gil(|py| {
2878 use numpy::{PyArray1, PyArrayMethods};
2879 let data = PyArray1::from_vec(py, (0..256).map(|i| i as f64).collect());
2880 let d = crate::indicators::moving_averages::uma::uma_batch_py(
2881 py,
2882 data.readonly(),
2883 (1.0, 1.0, 0.0),
2884 (5, 5, 0),
2885 (50, 50, 0),
2886 (4, 4, 0),
2887 None,
2888 Some("scalar_batch"),
2889 )
2890 .unwrap();
2891 let v = d.get_item("values").unwrap().expect("values missing");
2892
2893 assert!(v.downcast::<numpy::PyArray2<f64>>().is_ok());
2894 });
2895 }
2896
2897 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2898 skip_if_unsupported!(kernel, test);
2899
2900 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2901 let c = read_candles_from_csv(file)?;
2902
2903 let output = UmaBatchBuilder::new()
2904 .kernel(kernel)
2905 .apply_candles(&c, "close")?;
2906
2907 let def = UmaParams::default();
2908 let row = output.values_for(&def).expect("default row missing");
2909
2910 assert_eq!(row.len(), c.close.len());
2911
2912 let valid_count = row.iter().filter(|&&v| !v.is_nan()).count();
2913 assert!(
2914 valid_count > 0,
2915 "[{}] Batch should produce valid values",
2916 test
2917 );
2918
2919 Ok(())
2920 }
2921
2922 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2923 skip_if_unsupported!(kernel, test);
2924
2925 let data: Vec<f64> = (0..100).map(|i| 100.0 + i as f64).collect();
2926
2927 let output = UmaBatchBuilder::new()
2928 .kernel(kernel)
2929 .accelerator_range(1.0, 2.0, 0.5)
2930 .min_length_range(5, 10, 5)
2931 .max_length_range(20, 30, 10)
2932 .smooth_length_range(3, 5, 2)
2933 .apply_slice(&data, None)?;
2934
2935 let expected_combos = 3 * 2 * 2 * 2;
2936 assert_eq!(output.combos.len(), expected_combos);
2937 assert_eq!(output.rows, expected_combos);
2938 assert_eq!(output.cols, data.len());
2939
2940 Ok(())
2941 }
2942
2943 #[cfg(debug_assertions)]
2944 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2945 skip_if_unsupported!(kernel, test);
2946
2947 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2948 let c = read_candles_from_csv(file)?;
2949
2950 let test_configs = vec![
2951 (1.0, 1.5, 0.5, 5, 10, 5, 20, 30, 10, 3, 5, 2),
2952 (2.0, 2.0, 0.0, 10, 10, 0, 50, 50, 0, 4, 4, 0),
2953 ];
2954
2955 for (
2956 cfg_idx,
2957 &(
2958 a_start,
2959 a_end,
2960 a_step,
2961 min_start,
2962 min_end,
2963 min_step,
2964 max_start,
2965 max_end,
2966 max_step,
2967 s_start,
2968 s_end,
2969 s_step,
2970 ),
2971 ) in test_configs.iter().enumerate()
2972 {
2973 let output = UmaBatchBuilder::new()
2974 .kernel(kernel)
2975 .accelerator_range(a_start, a_end, a_step)
2976 .min_length_range(min_start, min_end, min_step)
2977 .max_length_range(max_start, max_end, max_step)
2978 .smooth_length_range(s_start, s_end, s_step)
2979 .apply_candles(&c, "close")?;
2980
2981 for (idx, &val) in output.values.iter().enumerate() {
2982 if val.is_nan() {
2983 continue;
2984 }
2985
2986 let bits = val.to_bits();
2987 let row = idx / output.cols;
2988 let col = idx % output.cols;
2989 let combo = &output.combos[row];
2990
2991 if bits == 0x11111111_11111111 {
2992 panic!(
2993 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2994 at row {} col {} (flat index {}) with params: acc={}, min={}, max={}, smooth={}",
2995 test,
2996 cfg_idx,
2997 val,
2998 bits,
2999 row,
3000 col,
3001 idx,
3002 combo.accelerator.unwrap_or(1.0),
3003 combo.min_length.unwrap_or(5),
3004 combo.max_length.unwrap_or(50),
3005 combo.smooth_length.unwrap_or(4)
3006 );
3007 }
3008
3009 if bits == 0x22222222_22222222 {
3010 panic!(
3011 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3012 at row {} col {} (flat index {})",
3013 test, cfg_idx, val, bits, row, col, idx
3014 );
3015 }
3016
3017 if bits == 0x33333333_33333333 {
3018 panic!(
3019 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3020 at row {} col {} (flat index {})",
3021 test, cfg_idx, val, bits, row, col, idx
3022 );
3023 }
3024 }
3025 }
3026
3027 Ok(())
3028 }
3029
3030 #[cfg(not(debug_assertions))]
3031 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3032 Ok(())
3033 }
3034
3035 macro_rules! gen_batch_tests {
3036 ($fn_name:ident) => {
3037 paste::paste! {
3038 #[test] fn [<$fn_name _scalar>]() {
3039 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3040 }
3041 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3042 #[test] fn [<$fn_name _avx2>]() {
3043 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3044 }
3045 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3046 #[test] fn [<$fn_name _avx512>]() {
3047 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3048 }
3049 #[test] fn [<$fn_name _auto_detect>]() {
3050 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
3051 Kernel::Auto);
3052 }
3053 }
3054 };
3055 }
3056
3057 gen_batch_tests!(check_batch_default_row);
3058 gen_batch_tests!(check_batch_sweep);
3059 gen_batch_tests!(check_batch_no_poison);
3060}