Skip to main content

vector_ta/indicators/moving_averages/
uma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaUma;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use cust::memory::DeviceBuffer;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyDict, PyList};
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use js_sys;
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use serde::{Deserialize, Serialize};
26#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
27use wasm_bindgen::prelude::*;
28
29use crate::indicators::deviation::{deviation, DeviationInput, DeviationParams};
30use crate::indicators::mfi::{mfi, MfiInput, MfiParams};
31use crate::indicators::moving_averages::sma::{sma, SmaInput, SmaParams};
32use crate::indicators::moving_averages::wma::{wma, WmaInput, WmaParams};
33use crate::indicators::rsi::{rsi, RsiInput, RsiParams};
34use crate::utilities::data_loader::{source_type, Candles};
35use crate::utilities::enums::Kernel;
36use crate::utilities::helpers::{
37    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
38    make_uninit_matrix,
39};
40#[cfg(feature = "python")]
41use crate::utilities::kernel_validation::validate_kernel;
42#[cfg(not(target_arch = "wasm32"))]
43use rayon::prelude::*;
44use std::convert::AsRef;
45use std::error::Error;
46use std::mem::MaybeUninit;
47use thiserror::Error;
48
49impl<'a> AsRef<[f64]> for UmaInput<'a> {
50    #[inline(always)]
51    fn as_ref(&self) -> &[f64] {
52        match &self.data {
53            UmaData::Slice(slice) => slice,
54            UmaData::Candles { candles, source } => source_type(candles, source),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum UmaData<'a> {
61    Candles {
62        candles: &'a Candles,
63        source: &'a str,
64    },
65    Slice(&'a [f64]),
66}
67
68#[derive(Debug, Clone)]
69pub struct UmaOutput {
70    pub values: Vec<f64>,
71}
72
73#[derive(Debug, Clone)]
74#[cfg_attr(
75    all(target_arch = "wasm32", feature = "wasm"),
76    derive(Serialize, Deserialize)
77)]
78pub struct UmaParams {
79    pub accelerator: Option<f64>,
80    pub min_length: Option<usize>,
81    pub max_length: Option<usize>,
82    pub smooth_length: Option<usize>,
83}
84
85impl Default for UmaParams {
86    fn default() -> Self {
87        Self {
88            accelerator: Some(1.0),
89            min_length: Some(5),
90            max_length: Some(50),
91            smooth_length: Some(4),
92        }
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct UmaInput<'a> {
98    pub data: UmaData<'a>,
99    pub params: UmaParams,
100    pub volume: Option<&'a [f64]>,
101}
102
103impl<'a> UmaInput<'a> {
104    #[inline]
105    pub fn from_candles(c: &'a Candles, s: &'a str, p: UmaParams) -> Self {
106        Self {
107            data: UmaData::Candles {
108                candles: c,
109                source: s,
110            },
111            params: p,
112            volume: Some(&c.volume),
113        }
114    }
115
116    #[inline]
117    pub fn from_slice(sl: &'a [f64], vol: Option<&'a [f64]>, p: UmaParams) -> Self {
118        Self {
119            data: UmaData::Slice(sl),
120            params: p,
121            volume: vol,
122        }
123    }
124
125    #[inline]
126    pub fn with_default_candles(c: &'a Candles) -> Self {
127        Self::from_candles(c, "close", UmaParams::default())
128    }
129
130    #[inline]
131    pub fn get_accelerator(&self) -> f64 {
132        self.params.accelerator.unwrap_or(1.0)
133    }
134
135    #[inline]
136    pub fn get_min_length(&self) -> usize {
137        self.params.min_length.unwrap_or(5)
138    }
139
140    #[inline]
141    pub fn get_max_length(&self) -> usize {
142        self.params.max_length.unwrap_or(50)
143    }
144
145    #[inline]
146    pub fn get_smooth_length(&self) -> usize {
147        self.params.smooth_length.unwrap_or(4)
148    }
149}
150
151#[derive(Copy, Clone, Debug)]
152pub struct UmaBuilder {
153    accelerator: Option<f64>,
154    min_length: Option<usize>,
155    max_length: Option<usize>,
156    smooth_length: Option<usize>,
157    kernel: Kernel,
158}
159
160impl Default for UmaBuilder {
161    fn default() -> Self {
162        Self {
163            accelerator: None,
164            min_length: None,
165            max_length: None,
166            smooth_length: None,
167            kernel: Kernel::Auto,
168        }
169    }
170}
171
172impl UmaBuilder {
173    #[inline(always)]
174    pub fn new() -> Self {
175        Self::default()
176    }
177
178    #[inline(always)]
179    pub fn accelerator(mut self, a: f64) -> Self {
180        self.accelerator = Some(a);
181        self
182    }
183
184    #[inline(always)]
185    pub fn min_length(mut self, n: usize) -> Self {
186        self.min_length = Some(n);
187        self
188    }
189
190    #[inline(always)]
191    pub fn max_length(mut self, n: usize) -> Self {
192        self.max_length = Some(n);
193        self
194    }
195
196    #[inline(always)]
197    pub fn smooth_length(mut self, n: usize) -> Self {
198        self.smooth_length = Some(n);
199        self
200    }
201
202    #[inline(always)]
203    pub fn kernel(mut self, k: Kernel) -> Self {
204        self.kernel = k;
205        self
206    }
207
208    #[inline(always)]
209    pub fn apply(self, c: &Candles) -> Result<UmaOutput, UmaError> {
210        let p = UmaParams {
211            accelerator: self.accelerator,
212            min_length: self.min_length,
213            max_length: self.max_length,
214            smooth_length: self.smooth_length,
215        };
216        let i = UmaInput::from_candles(c, "close", p);
217        uma_with_kernel(&i, self.kernel)
218    }
219
220    #[inline(always)]
221    pub fn apply_slice(self, d: &[f64], v: Option<&[f64]>) -> Result<UmaOutput, UmaError> {
222        let p = UmaParams {
223            accelerator: self.accelerator,
224            min_length: self.min_length,
225            max_length: self.max_length,
226            smooth_length: self.smooth_length,
227        };
228        let i = UmaInput::from_slice(d, v, p);
229        uma_with_kernel(&i, self.kernel)
230    }
231
232    #[inline(always)]
233    pub fn into_stream(self) -> Result<UmaStream, UmaError> {
234        let p = UmaParams {
235            accelerator: self.accelerator,
236            min_length: self.min_length,
237            max_length: self.max_length,
238            smooth_length: self.smooth_length,
239        };
240        UmaStream::try_new(p)
241    }
242}
243
244#[derive(Debug, Clone)]
245pub struct UmaStream {
246    accelerator: f64,
247    min_length: usize,
248    max_length: usize,
249    smooth_length: usize,
250
251    price: Vec<f64>,
252    volume: Vec<f64>,
253    cap: usize,
254    head: usize,
255    count: usize,
256
257    sum: f64,
258    sumsq: f64,
259
260    has_prev: bool,
261    prev_price: f64,
262
263    upvol_cum: Vec<f64>,
264    dnvol_cum: Vec<f64>,
265    diff_step: usize,
266
267    rsi_avg_up: f64,
268    rsi_avg_dn: f64,
269    rsi_inited: bool,
270
271    dynamic_length: f64,
272
273    ln_lut: Vec<f64>,
274
275    uma_raw: Vec<f64>,
276    uma_raw_head: usize,
277    uma_raw_count: usize,
278    s1: f64,
279    s2: f64,
280    wma_denom: f64,
281
282    params: UmaParams,
283    kernel: Kernel,
284}
285
286impl UmaStream {
287    pub fn try_new(params: UmaParams) -> Result<Self, UmaError> {
288        let accelerator = params.accelerator.unwrap_or(1.0);
289        let min_length = params.min_length.unwrap_or(5);
290        let max_length = params.max_length.unwrap_or(50);
291        let smooth_len = params.smooth_length.unwrap_or(4);
292
293        if min_length == 0 {
294            return Err(UmaError::InvalidMinLength { min_length });
295        }
296        if max_length == 0 {
297            return Err(UmaError::InvalidMaxLength {
298                max_length,
299                data_len: 0,
300            });
301        }
302        if smooth_len == 0 {
303            return Err(UmaError::InvalidSmoothLength {
304                smooth_length: smooth_len,
305            });
306        }
307        if min_length > max_length {
308            return Err(UmaError::MinLengthGreaterThanMaxLength {
309                min_length,
310                max_length,
311            });
312        }
313        if accelerator < 1.0 {
314            return Err(UmaError::InvalidAccelerator { accelerator });
315        }
316
317        let mut ln_lut = Vec::with_capacity(max_length + 1);
318        ln_lut.push(0.0);
319        for k in 1..=max_length {
320            ln_lut.push((k as f64).ln());
321        }
322
323        Ok(Self {
324            accelerator,
325            min_length,
326            max_length,
327            smooth_length: smooth_len,
328
329            price: vec![0.0; max_length],
330            volume: vec![0.0; max_length],
331            cap: max_length,
332            head: 0,
333            count: 0,
334
335            sum: 0.0,
336            sumsq: 0.0,
337
338            has_prev: false,
339            prev_price: 0.0,
340
341            upvol_cum: vec![0.0; max_length + 1],
342            dnvol_cum: vec![0.0; max_length + 1],
343            diff_step: 0,
344
345            rsi_avg_up: 0.0,
346            rsi_avg_dn: 0.0,
347            rsi_inited: false,
348
349            dynamic_length: max_length as f64,
350            ln_lut,
351
352            uma_raw: vec![0.0; smooth_len],
353            uma_raw_head: 0,
354            uma_raw_count: 0,
355            s1: 0.0,
356            s2: 0.0,
357            wma_denom: (smooth_len as f64) * ((smooth_len as f64) + 1.0) * 0.5,
358
359            params,
360            kernel: Kernel::Auto,
361        })
362    }
363
364    #[inline]
365    pub fn update(&mut self, value: f64) -> Option<f64> {
366        self.update_with_volume(value, None)
367    }
368
369    pub fn update_with_volume(&mut self, value: f64, volume: Option<f64>) -> Option<f64> {
370        let v = volume.unwrap_or(0.0);
371
372        if self.count == self.cap {
373            let old = self.price[self.head];
374            self.sum -= old;
375            self.sumsq -= old * old;
376        } else {
377            self.count += 1;
378        }
379        self.price[self.head] = value;
380        self.volume[self.head] = v;
381        self.sum += value;
382        self.sumsq += value * value;
383
384        if self.has_prev {
385            let diff = value - self.prev_price;
386            let up = if diff > 0.0 { diff } else { 0.0 };
387            let dn = if diff < 0.0 { -diff } else { 0.0 };
388
389            let cur = (self.diff_step + 1) % (self.cap + 1);
390            let prev = self.diff_step % (self.cap + 1);
391            let up_contrib = if up > 0.0 { v } else { 0.0 };
392            let dn_contrib = if dn > 0.0 { v } else { 0.0 };
393            self.upvol_cum[cur] = self.upvol_cum[prev] + up_contrib;
394            self.dnvol_cum[cur] = self.dnvol_cum[prev] + dn_contrib;
395            self.diff_step += 1;
396
397            if !self.rsi_inited {
398                self.rsi_avg_up = up;
399                self.rsi_avg_dn = dn;
400                self.rsi_inited = true;
401            } else {
402            }
403        }
404        self.prev_price = value;
405        self.has_prev = true;
406
407        self.head = (self.head + 1) % self.cap;
408
409        if self.count < self.cap {
410            return None;
411        }
412
413        let n = self.cap as f64;
414        let mu = self.sum / n;
415        let var = (self.sumsq / n) - mu * mu;
416        let sd = if var > 0.0 { var.sqrt() } else { 0.0 };
417
418        let a = (-1.75f64).mul_add(sd, mu);
419        let b = (-0.25f64).mul_add(sd, mu);
420        let c = (0.25f64).mul_add(sd, mu);
421        let d = (1.75f64).mul_add(sd, mu);
422
423        let src = value;
424        if src >= b && src <= c {
425            self.dynamic_length += 1.0;
426        } else if src < a || src > d {
427            self.dynamic_length -= 1.0;
428        }
429        self.dynamic_length = self
430            .dynamic_length
431            .max(self.min_length as f64)
432            .min(self.max_length as f64);
433        let len_r = self.dynamic_length.round().max(1.0) as usize;
434
435        let mf = if v > 0.0 && self.diff_step > 0 {
436            let cur = self.diff_step % (self.cap + 1);
437            let prev = (self.diff_step + self.cap + 1 - len_r) % (self.cap + 1);
438            let up_sum = self.upvol_cum[cur] - self.upvol_cum[prev];
439            let dn_sum = self.dnvol_cum[cur] - self.dnvol_cum[prev];
440            let tot = up_sum + dn_sum;
441            if tot > 0.0 {
442                100.0 * up_sum / tot
443            } else {
444                50.0
445            }
446        } else {
447            if self.diff_step > 0 {
448                let newest_idx = (self.head + self.cap - 1) % self.cap;
449                let prev_idx = (self.head + self.cap - 2) % self.cap;
450                let prevv = self.price[prev_idx];
451                let last_diff = self.price[newest_idx] - prevv;
452                let up = if last_diff > 0.0 { last_diff } else { 0.0 };
453                let dn = if last_diff < 0.0 { -last_diff } else { 0.0 };
454
455                let alpha = 1.0 / (len_r as f64);
456                self.rsi_avg_up = (1.0 - alpha) * self.rsi_avg_up + alpha * up;
457                self.rsi_avg_dn = (1.0 - alpha) * self.rsi_avg_dn + alpha * dn;
458            }
459            if self.rsi_avg_dn == 0.0 {
460                100.0
461            } else {
462                let s = self.rsi_avg_up + self.rsi_avg_dn;
463                if s == 0.0 {
464                    50.0
465                } else {
466                    100.0 * self.rsi_avg_up / s
467                }
468            }
469        };
470
471        let mf_scaled = mf.mul_add(2.0, -100.0);
472        let p = self.accelerator + (mf_scaled.abs() * 0.04);
473
474        let start = (self.head + self.cap - len_r) % self.cap;
475        let mut xws = 0.0f64;
476        let mut wsum = 0.0f64;
477
478        let mut j = 0usize;
479        while j + 1 < len_r {
480            let k1 = len_r - j;
481            let k2 = k1 - 1;
482
483            let w1 = exp_kernel(p * self.ln_lut[k1]);
484            let idx1 = (start + j) % self.cap;
485            let x1 = self.price[idx1];
486            xws = x1.mul_add(w1, xws);
487            wsum += w1;
488
489            let w2 = exp_kernel(p * self.ln_lut[k2]);
490            let idx2 = (start + j + 1) % self.cap;
491            let x2 = self.price[idx2];
492            xws = x2.mul_add(w2, xws);
493            wsum += w2;
494
495            j += 2;
496        }
497        if j < len_r {
498            let k = len_r - j;
499            let w = exp_kernel(p * self.ln_lut[k]);
500            let idx = (start + j) % self.cap;
501            let x = self.price[idx];
502            xws = x.mul_add(w, xws);
503            wsum += w;
504        }
505
506        let uma_raw = if wsum > 0.0 { xws / wsum } else { src };
507
508        let m = self.smooth_length;
509        if self.uma_raw_count < m {
510            self.s1 += uma_raw;
511            self.s2 += uma_raw * ((self.uma_raw_count as f64) + 1.0);
512            self.uma_raw[self.uma_raw_head] = uma_raw;
513            self.uma_raw_head = (self.uma_raw_head + 1) % m;
514            self.uma_raw_count += 1;
515
516            if self.uma_raw_count < m {
517                return None;
518            }
519
520            let out = self.s2 / self.wma_denom;
521            Some(out)
522        } else {
523            let oldest = self.uma_raw[self.uma_raw_head];
524            let s1_prev = self.s1;
525            self.s1 = self.s1 - oldest + uma_raw;
526            self.s2 = self.s2 - s1_prev + (m as f64) * uma_raw;
527
528            self.uma_raw[self.uma_raw_head] = uma_raw;
529            self.uma_raw_head = (self.uma_raw_head + 1) % m;
530
531            Some(self.s2 / self.wma_denom)
532        }
533    }
534
535    pub fn reset(&mut self) {
536        self.price.fill(0.0);
537        self.volume.fill(0.0);
538        self.head = 0;
539        self.count = 0;
540        self.sum = 0.0;
541        self.sumsq = 0.0;
542
543        self.has_prev = false;
544        self.prev_price = 0.0;
545
546        self.upvol_cum.fill(0.0);
547        self.dnvol_cum.fill(0.0);
548        self.diff_step = 0;
549
550        self.rsi_avg_up = 0.0;
551        self.rsi_avg_dn = 0.0;
552        self.rsi_inited = false;
553
554        self.dynamic_length = self.max_length as f64;
555
556        self.uma_raw.fill(0.0);
557        self.uma_raw_head = 0;
558        self.uma_raw_count = 0;
559        self.s1 = 0.0;
560        self.s2 = 0.0;
561    }
562}
563
564#[inline(always)]
565fn exp_kernel(x: f64) -> f64 {
566    x.exp()
567}
568
569#[derive(Debug, Error)]
570pub enum UmaError {
571    #[error("uma: Input data slice is empty.")]
572    EmptyInputData,
573    #[error("uma: All values are NaN.")]
574    AllValuesNaN,
575    #[error("uma: Invalid max_length: max_length = {max_length}, data length = {data_len}")]
576    InvalidMaxLength { max_length: usize, data_len: usize },
577    #[error("uma: Not enough valid data: needed = {needed}, valid = {valid}")]
578    NotEnoughValidData { needed: usize, valid: usize },
579    #[error("uma: Invalid accelerator: {accelerator}")]
580    InvalidAccelerator { accelerator: f64 },
581    #[error("uma: Invalid min_length: {min_length}")]
582    InvalidMinLength { min_length: usize },
583    #[error("uma: Invalid smooth_length: {smooth_length}")]
584    InvalidSmoothLength { smooth_length: usize },
585    #[error("uma: min_length ({min_length}) must be <= max_length ({max_length})")]
586    MinLengthGreaterThanMaxLength {
587        min_length: usize,
588        max_length: usize,
589    },
590    #[error("uma: Output length mismatch: expected = {expected}, got = {got}")]
591    OutputLengthMismatch { expected: usize, got: usize },
592    #[error("uma: Invalid range: start = {start}, end = {end}, step = {step}")]
593    InvalidRange {
594        start: usize,
595        end: usize,
596        step: usize,
597    },
598    #[error("uma: Invalid kernel for batch API: {0:?}")]
599    InvalidKernelForBatch(Kernel),
600    #[error("uma: arithmetic overflow while computing {context}")]
601    ArithmeticOverflow { context: &'static str },
602    #[error("uma: Error from dependency: {0}")]
603    DependencyError(String),
604}
605
606#[inline(always)]
607fn uma_prepare<'a>(input: &'a UmaInput) -> Result<(&'a [f64], usize, usize, usize, f64), UmaError> {
608    let data: &[f64] = input.as_ref();
609    let len = data.len();
610    if len == 0 {
611        return Err(UmaError::EmptyInputData);
612    }
613
614    let first = data
615        .iter()
616        .position(|x| !x.is_nan())
617        .ok_or(UmaError::AllValuesNaN)?;
618    let accelerator = input.get_accelerator();
619    let min_length = input.get_min_length();
620    let max_length = input.get_max_length();
621    let smooth_length = input.get_smooth_length();
622
623    if max_length == 0 || max_length > len {
624        return Err(UmaError::InvalidMaxLength {
625            max_length,
626            data_len: len,
627        });
628    }
629    if min_length == 0 {
630        return Err(UmaError::InvalidMinLength { min_length });
631    }
632    if min_length > max_length {
633        return Err(UmaError::MinLengthGreaterThanMaxLength {
634            min_length,
635            max_length,
636        });
637    }
638    if smooth_length == 0 {
639        return Err(UmaError::InvalidSmoothLength { smooth_length });
640    }
641    if accelerator < 1.0 {
642        return Err(UmaError::InvalidAccelerator { accelerator });
643    }
644    if len - first < max_length {
645        return Err(UmaError::NotEnoughValidData {
646            needed: max_length,
647            valid: len - first,
648        });
649    }
650    Ok((data, first, min_length, max_length, accelerator))
651}
652
653#[inline(always)]
654fn uma_core_into(
655    input: &UmaInput,
656    first: usize,
657    min_length: usize,
658    max_length: usize,
659    accelerator: f64,
660    _kernel: Kernel,
661    out: &mut [f64],
662) -> Result<(), UmaError> {
663    let data: &[f64] = input.as_ref();
664    let len = data.len();
665    debug_assert!(len == out.len());
666
667    let mean = sma(&SmaInput::from_slice(
668        data,
669        SmaParams {
670            period: Some(max_length),
671        },
672    ))
673    .map_err(|e| UmaError::DependencyError(e.to_string()))?
674    .values;
675
676    let std_dev = deviation(&DeviationInput::from_slice(
677        data,
678        DeviationParams {
679            period: Some(max_length),
680            devtype: Some(0),
681        },
682    ))
683    .map_err(|e| UmaError::DependencyError(e.to_string()))?
684    .values;
685
686    let mut ln_lut: Vec<f64> = Vec::with_capacity(max_length + 1);
687
688    ln_lut.push(0.0);
689    for k in 1..=max_length {
690        ln_lut.push((k as f64).ln());
691    }
692
693    let (candles_opt, vol_opt) = match &input.data {
694        UmaData::Candles { candles, .. } => (Some(*candles), input.volume),
695        UmaData::Slice(_) => (None, input.volume),
696    };
697
698    let warmup_end = first
699        .checked_add(max_length)
700        .and_then(|x| x.checked_sub(1))
701        .ok_or(UmaError::ArithmeticOverflow {
702            context: "first + max_length - 1 (warmup_end)",
703        })?;
704    if warmup_end >= len {
705        return Ok(());
706    }
707
708    let mut dyn_len = max_length as f64;
709
710    #[inline(always)]
711    fn rsi_wilder_last(data: &[f64], start: usize, end: usize, period: usize) -> f64 {
712        if period == 0 || end <= start || end - start + 1 < period + 1 {
713            return 50.0;
714        }
715
716        let mut sum_up = 0.0f64;
717        let mut sum_dn = 0.0f64;
718        let mut has_nan = false;
719        let mut prev = data[start];
720        let init_last = start + period;
721        for j in (start + 1)..=init_last {
722            let cur = data[j];
723            let diff = cur - prev;
724            if !diff.is_finite() {
725                has_nan = true;
726                break;
727            }
728            if diff > 0.0 {
729                sum_up += diff;
730            } else {
731                sum_dn -= diff;
732            }
733            prev = cur;
734        }
735        if has_nan {
736            return 50.0;
737        }
738        let mut avg_up = sum_up / (period as f64);
739        let mut avg_dn = sum_dn / (period as f64);
740
741        if init_last < end {
742            let n_1 = (period - 1) as f64;
743            let n = period as f64;
744            for j in (init_last + 1)..=end {
745                let cur = data[j];
746                let diff = cur - prev;
747                let up = if diff > 0.0 { diff } else { 0.0 };
748                let dn = if diff < 0.0 { -diff } else { 0.0 };
749                avg_up = (avg_up * n_1 + up) / n;
750                avg_dn = (avg_dn * n_1 + dn) / n;
751                prev = cur;
752            }
753        }
754
755        if avg_dn == 0.0 {
756            100.0
757        } else if avg_up + avg_dn == 0.0 {
758            50.0
759        } else {
760            100.0 * avg_up / (avg_up + avg_dn)
761        }
762    }
763
764    #[inline(always)]
765    fn mfi_window_last_candles(c: &Candles, start: usize, end: usize) -> f64 {
766        if end <= start {
767            return 50.0;
768        }
769
770        let mut tp_prev = (c.high[start] + c.low[start] + c.close[start]) / 3.0;
771        let mut pos = 0.0f64;
772        let mut neg = 0.0f64;
773
774        for j in (start + 1)..=end {
775            let tp = (c.high[j] + c.low[j] + c.close[j]) / 3.0;
776
777            let mf = tp * c.volume[j];
778            if tp > tp_prev {
779                pos += mf;
780            } else if tp < tp_prev {
781                neg += mf;
782            }
783            tp_prev = tp;
784        }
785
786        let denom = pos + neg;
787        if denom > 0.0 {
788            100.0 * pos / denom
789        } else {
790            50.0
791        }
792    }
793
794    for i in warmup_end..len {
795        let mu = mean[i];
796        let sd = std_dev[i];
797        if mu.is_nan() || sd.is_nan() {
798            continue;
799        }
800        let src = data[i];
801
802        let a = (-1.75f64).mul_add(sd, mu);
803        let b = (-0.25f64).mul_add(sd, mu);
804        let c = (0.25f64).mul_add(sd, mu);
805        let d = (1.75f64).mul_add(sd, mu);
806
807        if src >= b && src <= c {
808            dyn_len += 1.0;
809        } else if src < a || src > d {
810            dyn_len -= 1.0;
811        }
812
813        dyn_len = dyn_len.max(min_length as f64).min(max_length as f64);
814        let len_r = dyn_len.round() as usize;
815        if i + 1 < len_r {
816            continue;
817        }
818
819        let mf: f64 = match (vol_opt, candles_opt) {
820            (Some(vol), Some(candles)) => {
821                let v_i = vol[i];
822                if v_i == 0.0 || v_i.is_nan() {
823                    let end = i;
824                    let start = if end + 1 >= 2 * len_r {
825                        end + 1 - 2 * len_r
826                    } else {
827                        0
828                    };
829                    rsi_wilder_last(data, start, end, len_r)
830                } else {
831                    let start = i + 1 - len_r;
832                    mfi_window_last_candles(candles, start, i)
833                }
834            }
835            (Some(vol), None) => {
836                let start = i + 1 - len_r;
837                let mut up_vol = 0.0f64;
838                let mut dn_vol = 0.0f64;
839                let mut prev = data[start];
840                for j in (start + 1)..=i {
841                    let cur = data[j];
842                    let v = vol[j];
843                    let diff = cur - prev;
844                    if diff > 0.0 {
845                        up_vol += v;
846                    } else if diff < 0.0 {
847                        dn_vol += v;
848                    }
849                    prev = cur;
850                }
851                let tot = up_vol + dn_vol;
852                if tot > 0.0 {
853                    100.0 * up_vol / tot
854                } else {
855                    50.0
856                }
857            }
858            _ => {
859                let end = i;
860                let start = if end + 1 >= 2 * len_r {
861                    end + 1 - 2 * len_r
862                } else {
863                    0
864                };
865                rsi_wilder_last(data, start, end, len_r)
866            }
867        };
868
869        let mf_scaled = mf.mul_add(2.0, -100.0);
870        let p = accelerator + (mf_scaled.abs() * 0.04);
871
872        let start = i + 1 - len_r;
873
874        let (mut xws, mut wsum) = (0.0f64, 0.0f64);
875
876        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
877        match _kernel {
878            #[cfg(target_feature = "avx512f")]
879            Kernel::Avx512 => unsafe {
880                let (sx, sw) = uma_weighted_accumulate_avx512(
881                    data.as_ptr().add(start),
882                    ln_lut.as_ptr(),
883                    len_r,
884                    p,
885                );
886                xws = sx;
887                wsum = sw;
888            },
889            #[cfg(target_feature = "avx2")]
890            Kernel::Avx2 => unsafe {
891                let (sx, sw) = uma_weighted_accumulate_avx2(
892                    data.as_ptr().add(start),
893                    ln_lut.as_ptr(),
894                    len_r,
895                    p,
896                );
897                xws = sx;
898                wsum = sw;
899            },
900            _ => {}
901        }
902
903        if wsum == 0.0 {
904            let mut j = 0usize;
905            while j + 1 < len_r {
906                let k1 = len_r - j;
907                let k2 = k1 - 1;
908
909                let w1 = exp_kernel(p * ln_lut[k1]);
910                let x1 = data[start + j];
911                if !x1.is_nan() {
912                    xws = x1.mul_add(w1, xws);
913                    wsum += w1;
914                }
915
916                let w2 = exp_kernel(p * ln_lut[k2]);
917                let x2 = data[start + j + 1];
918                if !x2.is_nan() {
919                    xws = x2.mul_add(w2, xws);
920                    wsum += w2;
921                }
922
923                j += 2;
924            }
925            if j < len_r {
926                let k = len_r - j;
927                let w = exp_kernel(p * ln_lut[k]);
928                let x = data[start + j];
929                if !x.is_nan() {
930                    xws = x.mul_add(w, xws);
931                    wsum += w;
932                }
933            }
934        }
935
936        out[i] = if wsum > 0.0 { xws / wsum } else { src };
937    }
938
939    Ok(())
940}
941
942#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
943#[inline(always)]
944unsafe fn uma_weighted_accumulate_avx2(
945    data: *const f64,
946    ln_lut: *const f64,
947    len_r: usize,
948    p: f64,
949) -> (f64, f64) {
950    use core::arch::x86_64::*;
951    let mut sum_v = _mm256_setzero_pd();
952    let mut wsum_v = _mm256_setzero_pd();
953
954    let mut j = 0usize;
955    while j + 3 < len_r {
956        let k0 = len_r - j;
957        let k1 = k0 - 1;
958        let k2 = k1 - 1;
959        let k3 = k2 - 1;
960
961        let w0 = exp_kernel(p * *ln_lut.add(k0));
962        let w1 = exp_kernel(p * *ln_lut.add(k1));
963        let w2 = exp_kernel(p * *ln_lut.add(k2));
964        let w3 = exp_kernel(p * *ln_lut.add(k3));
965        let wv = _mm256_setr_pd(w0, w1, w2, w3);
966
967        let xv = _mm256_loadu_pd(data.add(j));
968
969        let nan_mask = _mm256_cmp_pd(xv, xv, _CMP_ORD_Q);
970        let xv_nz = _mm256_and_pd(xv, nan_mask);
971
972        sum_v = _mm256_fmadd_pd(xv_nz, wv, sum_v);
973        let wv_masked = _mm256_and_pd(wv, nan_mask);
974        wsum_v = _mm256_add_pd(wsum_v, wv_masked);
975
976        j += 4;
977    }
978
979    let mut xws = 0.0f64;
980    let mut wsum = 0.0f64;
981    {
982        let tmp: [f64; 4] = core::mem::transmute(sum_v);
983        xws += tmp[0] + tmp[1] + tmp[2] + tmp[3];
984        let t2: [f64; 4] = core::mem::transmute(wsum_v);
985        wsum += t2[0] + t2[1] + t2[2] + t2[3];
986    }
987
988    while j < len_r {
989        let k = len_r - j;
990        let w = exp_kernel(p * *ln_lut.add(k));
991        let x = *data.add(j);
992        if !x.is_nan() {
993            xws = x.mul_add(w, xws);
994            wsum += w;
995        }
996        j += 1;
997    }
998    (xws, wsum)
999}
1000
1001#[cfg(all(
1002    feature = "nightly-avx",
1003    target_arch = "x86_64",
1004    not(target_feature = "avx512f")
1005))]
1006#[inline(always)]
1007unsafe fn uma_weighted_accumulate_avx512(
1008    data: *const f64,
1009    ln_lut: *const f64,
1010    len_r: usize,
1011    p: f64,
1012) -> (f64, f64) {
1013    let mut xws = 0.0f64;
1014    let mut wsum = 0.0f64;
1015    let mut j = 0usize;
1016    while j < len_r {
1017        let k = len_r - j;
1018        let w = exp_kernel(p * *ln_lut.add(k));
1019        let x = *data.add(j);
1020        if !x.is_nan() {
1021            xws = x.mul_add(w, xws);
1022            wsum += w;
1023        }
1024        j += 1;
1025    }
1026    (xws, wsum)
1027}
1028
1029#[cfg(all(
1030    feature = "nightly-avx",
1031    target_arch = "x86_64",
1032    target_feature = "avx512f"
1033))]
1034#[inline(always)]
1035unsafe fn uma_weighted_accumulate_avx512(
1036    data: *const f64,
1037    ln_lut: *const f64,
1038    len_r: usize,
1039    p: f64,
1040) -> (f64, f64) {
1041    use core::arch::x86_64::*;
1042    let mut sum_v = _mm512_setzero_pd();
1043    let mut wsum_v = _mm512_setzero_pd();
1044
1045    let mut j = 0usize;
1046    while j + 7 < len_r {
1047        let k0 = len_r - j;
1048        let ks = [k0, k0 - 1, k0 - 2, k0 - 3, k0 - 4, k0 - 5, k0 - 6, k0 - 7];
1049        let ws = [
1050            exp_kernel(p * *ln_lut.add(ks[0])),
1051            exp_kernel(p * *ln_lut.add(ks[1])),
1052            exp_kernel(p * *ln_lut.add(ks[2])),
1053            exp_kernel(p * *ln_lut.add(ks[3])),
1054            exp_kernel(p * *ln_lut.add(ks[4])),
1055            exp_kernel(p * *ln_lut.add(ks[5])),
1056            exp_kernel(p * *ln_lut.add(ks[6])),
1057            exp_kernel(p * *ln_lut.add(ks[7])),
1058        ];
1059        let wv = _mm512_loadu_pd(ws.as_ptr());
1060        let xv = _mm512_loadu_pd(data.add(j));
1061
1062        let nan_mask = _mm512_cmp_pd_mask(xv, xv, _CMP_ORD_Q);
1063        let xv_nz = _mm512_maskz_mov_pd(nan_mask, xv);
1064
1065        sum_v = _mm512_fmadd_pd(xv_nz, wv, sum_v);
1066        let wv_masked = _mm512_maskz_mov_pd(nan_mask, wv);
1067        wsum_v = _mm512_add_pd(wsum_v, wv_masked);
1068
1069        j += 8;
1070    }
1071
1072    let xws = {
1073        let mut tmp: [f64; 8] = core::mem::zeroed();
1074        _mm512_storeu_pd(tmp.as_mut_ptr(), sum_v);
1075        tmp.iter().copied().sum::<f64>()
1076    };
1077    let wsum0 = {
1078        let mut tmp: [f64; 8] = core::mem::zeroed();
1079        _mm512_storeu_pd(tmp.as_mut_ptr(), wsum_v);
1080        tmp.iter().copied().sum::<f64>()
1081    };
1082
1083    let mut xws_acc = xws;
1084    let mut wsum_acc = wsum0;
1085    while j < len_r {
1086        let k = len_r - j;
1087        let w = exp_kernel(p * *ln_lut.add(k));
1088        let x = *data.add(j);
1089        if !x.is_nan() {
1090            xws_acc = x.mul_add(w, xws_acc);
1091            wsum_acc += w;
1092        }
1093        j += 1;
1094    }
1095    (xws_acc, wsum_acc)
1096}
1097
1098#[cfg(all(
1099    feature = "nightly-avx",
1100    target_arch = "x86_64",
1101    not(target_feature = "avx512f")
1102))]
1103#[inline(always)]
1104unsafe fn uma_weighted_accumulate_avx512(
1105    _data: *const f64,
1106    _ln_lut: *const f64,
1107    _len_r: usize,
1108    _p: f64,
1109) -> (f64, f64) {
1110    unreachable!("uma_weighted_accumulate_avx512 should not be called without AVX512F")
1111}
1112
1113#[inline]
1114pub fn uma(input: &UmaInput) -> Result<UmaOutput, UmaError> {
1115    uma_with_kernel(input, Kernel::Auto)
1116}
1117
1118pub fn uma_with_kernel(input: &UmaInput, kernel: Kernel) -> Result<UmaOutput, UmaError> {
1119    let (data, first, min_len, max_len, accel) = uma_prepare(input)?;
1120    let warm = first
1121        .checked_add(max_len)
1122        .and_then(|x| x.checked_sub(1))
1123        .ok_or(UmaError::ArithmeticOverflow {
1124            context: "first + max_len - 1 (warm)",
1125        })?;
1126
1127    let mut out = alloc_with_nan_prefix(data.len(), warm);
1128
1129    let chosen = match kernel {
1130        Kernel::Auto => Kernel::Scalar,
1131        k => k,
1132    };
1133    uma_core_into(input, first, min_len, max_len, accel, chosen, &mut out)?;
1134
1135    let smooth = input.get_smooth_length();
1136    if smooth > 1 {
1137        let w = wma(&WmaInput::from_slice(
1138            &out,
1139            WmaParams {
1140                period: Some(smooth),
1141            },
1142        ))
1143        .map_err(|e| UmaError::DependencyError(e.to_string()))?
1144        .values;
1145        return Ok(UmaOutput { values: w });
1146    }
1147
1148    Ok(UmaOutput { values: out })
1149}
1150
1151#[inline]
1152pub fn uma_into_slice(dst: &mut [f64], input: &UmaInput, kern: Kernel) -> Result<(), UmaError> {
1153    let (data, first, min_len, max_len, accel) = uma_prepare(input)?;
1154    if dst.len() != data.len() {
1155        return Err(UmaError::OutputLengthMismatch {
1156            expected: data.len(),
1157            got: dst.len(),
1158        });
1159    }
1160
1161    let warm = first
1162        .checked_add(max_len)
1163        .and_then(|x| x.checked_sub(1))
1164        .ok_or(UmaError::ArithmeticOverflow {
1165            context: "first + max_len - 1 (warm)",
1166        })?;
1167    let warm_end = warm.min(dst.len());
1168    for v in &mut dst[..warm_end] {
1169        *v = f64::NAN;
1170    }
1171
1172    let chosen = match kern {
1173        Kernel::Auto => Kernel::Scalar,
1174        k => k,
1175    };
1176    uma_core_into(input, first, min_len, max_len, accel, chosen, dst)?;
1177
1178    let smooth = input.get_smooth_length();
1179    if smooth > 1 {
1180        let tmp = wma(&WmaInput::from_slice(
1181            dst,
1182            WmaParams {
1183                period: Some(smooth),
1184            },
1185        ))
1186        .map_err(|e| UmaError::DependencyError(e.to_string()))?
1187        .values;
1188        dst.copy_from_slice(&tmp);
1189    }
1190    Ok(())
1191}
1192
1193#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1194#[inline]
1195pub fn uma_into(input: &UmaInput, out: &mut [f64]) -> Result<(), UmaError> {
1196    uma_into_slice(out, input, Kernel::Auto)
1197}
1198
1199#[derive(Clone, Debug)]
1200pub struct UmaBatchRange {
1201    pub accelerator: (f64, f64, f64),
1202    pub min_length: (usize, usize, usize),
1203    pub max_length: (usize, usize, usize),
1204    pub smooth_length: (usize, usize, usize),
1205}
1206
1207impl Default for UmaBatchRange {
1208    fn default() -> Self {
1209        Self {
1210            accelerator: (1.0, 1.0, 0.0),
1211            min_length: (5, 5, 0),
1212            max_length: (50, 299, 1),
1213            smooth_length: (4, 4, 0),
1214        }
1215    }
1216}
1217
1218#[derive(Copy, Clone)]
1219pub struct UmaBatchBuilder {
1220    accelerator_range: (f64, f64, f64),
1221    min_length_range: (usize, usize, usize),
1222    max_length_range: (usize, usize, usize),
1223    smooth_length_range: (usize, usize, usize),
1224    kernel: Kernel,
1225}
1226
1227impl Default for UmaBatchBuilder {
1228    fn default() -> Self {
1229        Self {
1230            accelerator_range: (1.0, 1.0, 0.0),
1231            min_length_range: (5, 5, 0),
1232            max_length_range: (50, 299, 1),
1233            smooth_length_range: (4, 4, 0),
1234            kernel: Kernel::Auto,
1235        }
1236    }
1237}
1238
1239impl UmaBatchBuilder {
1240    #[inline(always)]
1241    pub fn new() -> Self {
1242        Self::default()
1243    }
1244
1245    #[inline(always)]
1246    pub fn accelerator_range(mut self, start: f64, end: f64, step: f64) -> Self {
1247        self.accelerator_range = (start, end, step);
1248        self
1249    }
1250
1251    #[inline(always)]
1252    pub fn min_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1253        self.min_length_range = (start, end, step);
1254        self
1255    }
1256
1257    #[inline(always)]
1258    pub fn max_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1259        self.max_length_range = (start, end, step);
1260        self
1261    }
1262
1263    #[inline(always)]
1264    pub fn smooth_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1265        self.smooth_length_range = (start, end, step);
1266        self
1267    }
1268
1269    #[inline(always)]
1270    pub fn kernel(mut self, k: Kernel) -> Self {
1271        self.kernel = k;
1272        self
1273    }
1274
1275    #[inline]
1276    pub fn apply_candles(
1277        self,
1278        candles: &Candles,
1279        source: &str,
1280    ) -> Result<UmaBatchOutput, UmaError> {
1281        let sweep = UmaBatchRange {
1282            accelerator: self.accelerator_range,
1283            min_length: self.min_length_range,
1284            max_length: self.max_length_range,
1285            smooth_length: self.smooth_length_range,
1286        };
1287
1288        let data = source_type(candles, source);
1289        uma_batch_inner(data, Some(&candles.volume), &sweep, self.kernel, false)
1290    }
1291
1292    #[inline]
1293    pub fn apply_slice(
1294        self,
1295        data: &[f64],
1296        volume: Option<&[f64]>,
1297    ) -> Result<UmaBatchOutput, UmaError> {
1298        let sweep = UmaBatchRange {
1299            accelerator: self.accelerator_range,
1300            min_length: self.min_length_range,
1301            max_length: self.max_length_range,
1302            smooth_length: self.smooth_length_range,
1303        };
1304
1305        uma_batch_inner(data, volume, &sweep, self.kernel, false)
1306    }
1307
1308    pub fn with_default_slice(
1309        data: &[f64],
1310        volume: Option<&[f64]>,
1311        k: Kernel,
1312    ) -> Result<UmaBatchOutput, UmaError> {
1313        UmaBatchBuilder::new().kernel(k).apply_slice(data, volume)
1314    }
1315
1316    pub fn with_default_candles(c: &Candles) -> Result<UmaBatchOutput, UmaError> {
1317        UmaBatchBuilder::new()
1318            .kernel(Kernel::Auto)
1319            .apply_candles(c, "close")
1320    }
1321}
1322
1323#[derive(Clone, Debug)]
1324pub struct UmaBatchOutput {
1325    pub values: Vec<f64>,
1326    pub combos: Vec<UmaParams>,
1327    pub rows: usize,
1328    pub cols: usize,
1329}
1330
1331impl UmaBatchOutput {
1332    pub fn row_for_params(&self, p: &UmaParams) -> Option<usize> {
1333        self.combos.iter().position(|c| {
1334            c.accelerator.unwrap_or(1.0) == p.accelerator.unwrap_or(1.0)
1335                && c.min_length.unwrap_or(5) == p.min_length.unwrap_or(5)
1336                && c.max_length.unwrap_or(50) == p.max_length.unwrap_or(50)
1337                && c.smooth_length.unwrap_or(4) == p.smooth_length.unwrap_or(4)
1338        })
1339    }
1340
1341    pub fn values_for(&self, p: &UmaParams) -> Option<&[f64]> {
1342        self.row_for_params(p).map(|row| {
1343            let start = row * self.cols;
1344            &self.values[start..start + self.cols]
1345        })
1346    }
1347}
1348
1349#[inline(always)]
1350pub fn expand_grid_uma(r: &UmaBatchRange) -> Vec<UmaParams> {
1351    fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, UmaError> {
1352        let (start, end, step) = range;
1353        if step == 0 || start == end {
1354            return Ok(vec![start]);
1355        }
1356        if start < end {
1357            let v: Vec<usize> = (start..=end).step_by(step).collect();
1358            return if v.is_empty() {
1359                Err(UmaError::InvalidRange { start, end, step })
1360            } else {
1361                Ok(v)
1362            };
1363        }
1364
1365        let mut v = Vec::new();
1366        let mut cur = start;
1367        loop {
1368            v.push(cur);
1369            if cur == end {
1370                break;
1371            }
1372            cur = cur
1373                .checked_sub(step)
1374                .ok_or(UmaError::InvalidRange { start, end, step })?;
1375            if cur < end {
1376                break;
1377            }
1378        }
1379        if v.is_empty() {
1380            return Err(UmaError::InvalidRange { start, end, step });
1381        }
1382        Ok(v)
1383    }
1384    fn axis_f64(range: (f64, f64, f64)) -> Result<Vec<f64>, UmaError> {
1385        let (start, end, step) = range;
1386        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1387            return Ok(vec![start]);
1388        }
1389        let mut v = Vec::new();
1390        if start <= end {
1391            let s = step.abs();
1392            if s < 1e-12 {
1393                return Ok(vec![start]);
1394            }
1395            let mut x = start;
1396            while x <= end + 1e-12 {
1397                v.push(x);
1398                x += s;
1399            }
1400        } else {
1401            let s = -step.abs();
1402            if s.abs() < 1e-12 {
1403                return Ok(vec![start]);
1404            }
1405            let mut x = start;
1406            while x >= end - 1e-12 {
1407                v.push(x);
1408                x += s;
1409            }
1410        }
1411        if v.is_empty() {
1412            return Err(UmaError::InvalidRange {
1413                start: start as usize,
1414                end: end as usize,
1415                step: step.abs() as usize,
1416            });
1417        }
1418        Ok(v)
1419    }
1420
1421    let accs = axis_f64(r.accelerator).unwrap_or_else(|_| vec![r.accelerator.0]);
1422    let mins = match axis_usize(r.min_length) {
1423        Ok(v) => v,
1424        Err(_) => vec![r.min_length.0],
1425    };
1426    let maxs = match axis_usize(r.max_length) {
1427        Ok(v) => v,
1428        Err(_) => vec![r.max_length.0],
1429    };
1430    let smooths = match axis_usize(r.smooth_length) {
1431        Ok(v) => v,
1432        Err(_) => vec![r.smooth_length.0],
1433    };
1434
1435    if !mins.is_empty() && !maxs.is_empty() {
1436        if let (Some(min_min), Some(max_max)) = (mins.iter().min(), maxs.iter().max()) {
1437            if *min_min > *max_max {
1438                return vec![];
1439            }
1440        }
1441    }
1442
1443    let cap = accs
1444        .len()
1445        .checked_mul(mins.len())
1446        .and_then(|x| x.checked_mul(maxs.len()))
1447        .and_then(|x| x.checked_mul(smooths.len()))
1448        .unwrap_or(0);
1449    let mut out = Vec::with_capacity(cap);
1450    for &a in &accs {
1451        for &min in &mins {
1452            for &max in &maxs {
1453                for &s in &smooths {
1454                    if min <= max {
1455                        out.push(UmaParams {
1456                            accelerator: Some(a),
1457                            min_length: Some(min),
1458                            max_length: Some(max),
1459                            smooth_length: Some(s),
1460                        });
1461                    }
1462                }
1463            }
1464        }
1465    }
1466    out
1467}
1468
1469pub fn uma_batch_with_kernel(
1470    data: &[f64],
1471    volume: Option<&[f64]>,
1472    sweep: &UmaBatchRange,
1473    k: Kernel,
1474) -> Result<UmaBatchOutput, UmaError> {
1475    let kernel = match k {
1476        Kernel::Auto => detect_best_batch_kernel(),
1477        other if other.is_batch() => other,
1478        other => return Err(UmaError::InvalidKernelForBatch(other)),
1479    };
1480
1481    uma_batch_inner(data, volume, sweep, kernel, true)
1482}
1483
1484#[inline(always)]
1485pub fn uma_batch_slice(
1486    data: &[f64],
1487    volume: Option<&[f64]>,
1488    sweep: &UmaBatchRange,
1489    kern: Kernel,
1490) -> Result<UmaBatchOutput, UmaError> {
1491    uma_batch_inner(data, volume, sweep, kern, false)
1492}
1493
1494#[inline(always)]
1495pub fn uma_batch_par_slice(
1496    data: &[f64],
1497    volume: Option<&[f64]>,
1498    sweep: &UmaBatchRange,
1499    kern: Kernel,
1500) -> Result<UmaBatchOutput, UmaError> {
1501    uma_batch_inner(data, volume, sweep, kern, true)
1502}
1503
1504#[inline(always)]
1505fn debatch(k: Kernel) -> Kernel {
1506    match k {
1507        Kernel::Avx512Batch | Kernel::Avx512 => Kernel::Avx512,
1508        Kernel::Avx2Batch | Kernel::Avx2 => Kernel::Avx2,
1509        Kernel::ScalarBatch | Kernel::Scalar => Kernel::Scalar,
1510        Kernel::Auto => Kernel::Scalar,
1511        _ => Kernel::Scalar,
1512    }
1513}
1514
1515#[inline(always)]
1516fn uma_batch_inner_into(
1517    data: &[f64],
1518    volume: Option<&[f64]>,
1519    sweep: &UmaBatchRange,
1520    kern: Kernel,
1521    parallel: bool,
1522    out: &mut [f64],
1523) -> Result<Vec<UmaParams>, UmaError> {
1524    match kern {
1525        Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1526            return Err(UmaError::InvalidKernelForBatch(kern));
1527        }
1528        _ => {}
1529    }
1530    let combos = expand_grid_uma(sweep);
1531    let combos = combos;
1532    if combos.is_empty() {
1533        return Err(UmaError::InvalidRange {
1534            start: sweep.min_length.0,
1535            end: sweep.max_length.1,
1536            step: sweep.max_length.2,
1537        });
1538    }
1539
1540    let cols = data.len();
1541    let rows = combos.len();
1542    let expected = rows.checked_mul(cols).ok_or(UmaError::ArithmeticOverflow {
1543        context: "rows * cols (batch out)",
1544    })?;
1545    if out.len() != expected {
1546        return Err(UmaError::OutputLengthMismatch {
1547            expected,
1548            got: out.len(),
1549        });
1550    }
1551
1552    let per_row_kernel = debatch(kern);
1553
1554    let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1555    let warm: Vec<usize> = combos
1556        .iter()
1557        .map(|c| {
1558            let a = first
1559                .checked_add(c.max_length.unwrap_or(50))
1560                .and_then(|x| x.checked_sub(1))
1561                .ok_or(UmaError::ArithmeticOverflow {
1562                    context: "first + max_length - 1 (batch warm)",
1563                })?;
1564            a.checked_add(c.smooth_length.unwrap_or(4))
1565                .and_then(|x| x.checked_sub(1))
1566                .ok_or(UmaError::ArithmeticOverflow {
1567                    context: "+ smooth_length - 1 (batch warm)",
1568                })
1569        })
1570        .collect::<Result<Vec<_>, _>>()?;
1571
1572    let out_mu = unsafe {
1573        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1574    };
1575    init_matrix_prefixes(out_mu, cols, &warm);
1576
1577    let do_row = |row: usize, row_slice: &mut [f64]| -> Result<(), UmaError> {
1578        let input = UmaInput::from_slice(data, volume, combos[row].clone());
1579        uma_into_slice(row_slice, &input, per_row_kernel)
1580    };
1581
1582    #[cfg(not(target_arch = "wasm32"))]
1583    if parallel {
1584        use rayon::prelude::*;
1585        out.par_chunks_mut(cols)
1586            .enumerate()
1587            .try_for_each(|(r, s)| do_row(r, s))?;
1588    } else {
1589        for (r, s) in out.chunks_mut(cols).enumerate() {
1590            do_row(r, s)?;
1591        }
1592    }
1593
1594    #[cfg(target_arch = "wasm32")]
1595    {
1596        for (r, s) in out.chunks_mut(cols).enumerate() {
1597            do_row(r, s)?;
1598        }
1599    }
1600
1601    Ok(combos)
1602}
1603
1604#[inline(always)]
1605fn uma_batch_inner(
1606    data: &[f64],
1607    volume: Option<&[f64]>,
1608    sweep: &UmaBatchRange,
1609    kern: Kernel,
1610    parallel: bool,
1611) -> Result<UmaBatchOutput, UmaError> {
1612    match kern {
1613        Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1614            return Err(UmaError::InvalidKernelForBatch(kern));
1615        }
1616        _ => {}
1617    }
1618    let combos = expand_grid_uma(sweep);
1619    let cols = data.len();
1620    let rows = combos.len();
1621    if cols == 0 {
1622        return Err(UmaError::EmptyInputData);
1623    }
1624    if rows == 0 {
1625        return Err(UmaError::InvalidRange {
1626            start: sweep.min_length.0,
1627            end: sweep.max_length.1,
1628            step: sweep.max_length.2,
1629        });
1630    }
1631
1632    let _cap = rows.checked_mul(cols).ok_or(UmaError::ArithmeticOverflow {
1633        context: "rows * cols (matrix alloc)",
1634    })?;
1635
1636    let mut buf_mu = make_uninit_matrix(rows, cols);
1637    let warm: Vec<usize> = combos
1638        .iter()
1639        .map(|c| {
1640            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1641            let a = first
1642                .checked_add(c.max_length.unwrap_or(50))
1643                .and_then(|x| x.checked_sub(1))
1644                .ok_or(UmaError::ArithmeticOverflow {
1645                    context: "first + max_length - 1 (batch warm alloc)",
1646                })?;
1647            a.checked_add(c.smooth_length.unwrap_or(4))
1648                .and_then(|x| x.checked_sub(1))
1649                .ok_or(UmaError::ArithmeticOverflow {
1650                    context: "+ smooth_length - 1 (batch warm alloc)",
1651                })
1652        })
1653        .collect::<Result<Vec<_>, _>>()?;
1654    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1655
1656    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1657    let out_slice: &mut [f64] =
1658        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1659
1660    let combos = uma_batch_inner_into(data, volume, sweep, kern, parallel, out_slice)?;
1661
1662    let values = unsafe {
1663        Vec::from_raw_parts(
1664            guard.as_mut_ptr() as *mut f64,
1665            guard.len(),
1666            guard.capacity(),
1667        )
1668    };
1669    Ok(UmaBatchOutput {
1670        values,
1671        combos,
1672        rows,
1673        cols,
1674    })
1675}
1676
1677#[cfg(feature = "python")]
1678#[pyfunction(name = "uma")]
1679#[pyo3(signature = (data, accelerator, min_length, max_length, smooth_length, volume=None, kernel=None))]
1680pub fn uma_py<'py>(
1681    py: Python<'py>,
1682    data: PyReadonlyArray1<'py, f64>,
1683    accelerator: f64,
1684    min_length: usize,
1685    max_length: usize,
1686    smooth_length: usize,
1687    volume: Option<PyReadonlyArray1<'py, f64>>,
1688    kernel: Option<&str>,
1689) -> PyResult<Bound<'py, PyArray1<f64>>> {
1690    use numpy::PyArrayMethods;
1691    let kern = validate_kernel(kernel, false)?;
1692    let slice_in = data.as_slice()?;
1693    let vol_slice = volume.as_ref().map(|v| v.as_slice()).transpose()?;
1694
1695    let params = UmaParams {
1696        accelerator: Some(accelerator),
1697        min_length: Some(min_length),
1698        max_length: Some(max_length),
1699        smooth_length: Some(smooth_length),
1700    };
1701    let input = UmaInput::from_slice(slice_in, vol_slice, params);
1702
1703    let out: Vec<f64> = py
1704        .allow_threads(|| uma_with_kernel(&input, kern).map(|o| o.values))
1705        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1706    Ok(out.into_pyarray(py))
1707}
1708
1709#[cfg(feature = "python")]
1710#[pyclass(name = "UmaStream")]
1711pub struct UmaStreamPy {
1712    stream: UmaStream,
1713}
1714
1715#[cfg(feature = "python")]
1716#[pymethods]
1717impl UmaStreamPy {
1718    #[new]
1719    pub fn new(
1720        accelerator: f64,
1721        min_length: usize,
1722        max_length: usize,
1723        smooth_length: usize,
1724    ) -> PyResult<Self> {
1725        let params = UmaParams {
1726            accelerator: Some(accelerator),
1727            min_length: Some(min_length),
1728            max_length: Some(max_length),
1729            smooth_length: Some(smooth_length),
1730        };
1731        let stream =
1732            UmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1733        Ok(Self { stream })
1734    }
1735
1736    pub fn update(&mut self, value: f64) -> Option<f64> {
1737        self.stream.update(value)
1738    }
1739
1740    #[pyo3(name = "update_with_volume")]
1741    pub fn update_with_volume_py(&mut self, value: f64, volume: Option<f64>) -> Option<f64> {
1742        self.stream.update_with_volume(value, volume)
1743    }
1744
1745    pub fn reset(&mut self) {
1746        self.stream.reset();
1747    }
1748}
1749
1750#[cfg(feature = "python")]
1751#[pyfunction(name = "uma_batch")]
1752#[pyo3(signature = (data, accelerator_range, min_length_range, max_length_range, smooth_length_range, volume=None, kernel=None))]
1753pub fn uma_batch_py<'py>(
1754    py: Python<'py>,
1755    data: numpy::PyReadonlyArray1<'py, f64>,
1756    accelerator_range: (f64, f64, f64),
1757    min_length_range: (usize, usize, usize),
1758    max_length_range: (usize, usize, usize),
1759    smooth_length_range: (usize, usize, usize),
1760    volume: Option<numpy::PyReadonlyArray1<'py, f64>>,
1761    kernel: Option<&str>,
1762) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1763    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1764    let slice_in = data.as_slice()?;
1765    let vol_slice = volume.as_ref().map(|v| v.as_slice()).transpose()?;
1766    let sweep = UmaBatchRange {
1767        accelerator: accelerator_range,
1768        min_length: min_length_range,
1769        max_length: max_length_range,
1770        smooth_length: smooth_length_range,
1771    };
1772    let combos = expand_grid_uma(&sweep);
1773    let rows = combos.len();
1774    let cols = slice_in.len();
1775
1776    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1777    let out_slice = unsafe { out_arr.as_slice_mut()? };
1778
1779    let kern = validate_kernel(kernel, true)?;
1780    py.allow_threads(|| uma_batch_inner_into(slice_in, vol_slice, &sweep, kern, false, out_slice))
1781        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1782
1783    let dict = PyDict::new(py);
1784    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1785    dict.set_item("rows", rows)?;
1786    dict.set_item("cols", cols)?;
1787    dict.set_item(
1788        "accelerators",
1789        combos
1790            .iter()
1791            .map(|c| c.accelerator.unwrap_or(1.0))
1792            .collect::<Vec<_>>()
1793            .into_pyarray(py),
1794    )?;
1795    dict.set_item(
1796        "min_lengths",
1797        combos
1798            .iter()
1799            .map(|c| c.min_length.unwrap_or(5) as u64)
1800            .collect::<Vec<_>>()
1801            .into_pyarray(py),
1802    )?;
1803    dict.set_item(
1804        "max_lengths",
1805        combos
1806            .iter()
1807            .map(|c| c.max_length.unwrap_or(50) as u64)
1808            .collect::<Vec<_>>()
1809            .into_pyarray(py),
1810    )?;
1811    dict.set_item(
1812        "smooth_lengths",
1813        combos
1814            .iter()
1815            .map(|c| c.smooth_length.unwrap_or(4) as u64)
1816            .collect::<Vec<_>>()
1817            .into_pyarray(py),
1818    )?;
1819
1820    let combo_list: Vec<Bound<'py, PyDict>> = combos
1821        .iter()
1822        .map(|c| {
1823            let d = PyDict::new(py);
1824            d.set_item("accelerator", c.accelerator.unwrap_or(1.0))
1825                .unwrap();
1826            d.set_item("min_length", c.min_length.unwrap_or(5)).unwrap();
1827            d.set_item("max_length", c.max_length.unwrap_or(50))
1828                .unwrap();
1829            d.set_item("smooth_length", c.smooth_length.unwrap_or(4))
1830                .unwrap();
1831            d.into()
1832        })
1833        .collect();
1834    dict.set_item("combos", combo_list)?;
1835
1836    Ok(dict.into())
1837}
1838
1839#[cfg(all(feature = "python", feature = "cuda"))]
1840#[pyclass(module = "ta_indicators.cuda", unsendable)]
1841pub struct UmaDeviceArrayF32Py {
1842    pub(crate) buf: Option<DeviceBuffer<f32>>,
1843    pub(crate) rows: usize,
1844    pub(crate) cols: usize,
1845    pub(crate) _ctx: Arc<Context>,
1846    pub(crate) device_id: u32,
1847}
1848
1849#[cfg(all(feature = "python", feature = "cuda"))]
1850#[pymethods]
1851impl UmaDeviceArrayF32Py {
1852    #[getter]
1853    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1854        let d = PyDict::new(py);
1855        d.set_item("shape", (self.rows, self.cols))?;
1856        d.set_item("typestr", "<f4")?;
1857        d.set_item(
1858            "strides",
1859            (
1860                self.cols * std::mem::size_of::<f32>(),
1861                std::mem::size_of::<f32>(),
1862            ),
1863        )?;
1864        let ptr = self
1865            .buf
1866            .as_ref()
1867            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1868            .as_device_ptr()
1869            .as_raw() as usize;
1870        d.set_item("data", (ptr, false))?;
1871
1872        d.set_item("version", 3)?;
1873        Ok(d)
1874    }
1875
1876    fn __dlpack_device__(&self) -> (i32, i32) {
1877        (2, self.device_id as i32)
1878    }
1879
1880    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1881    fn __dlpack__<'py>(
1882        &mut self,
1883        py: Python<'py>,
1884        stream: Option<PyObject>,
1885        max_version: Option<PyObject>,
1886        dl_device: Option<PyObject>,
1887        copy: Option<PyObject>,
1888    ) -> PyResult<PyObject> {
1889        let (kdl, alloc_dev) = self.__dlpack_device__();
1890        if let Some(dev_obj) = dl_device.as_ref() {
1891            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1892                if dev_ty != kdl || dev_id != alloc_dev {
1893                    let wants_copy = copy
1894                        .as_ref()
1895                        .and_then(|c| c.extract::<bool>(py).ok())
1896                        .unwrap_or(false);
1897                    if wants_copy {
1898                        return Err(PyValueError::new_err(
1899                            "__dlpack__(copy=True) not implemented for UMA device handle",
1900                        ));
1901                    } else {
1902                        return Err(PyValueError::new_err("dl_device mismatch for UMA tensor"));
1903                    }
1904                }
1905            }
1906        }
1907
1908        let _ = stream;
1909
1910        let buf = self
1911            .buf
1912            .take()
1913            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1914
1915        let rows = self.rows;
1916        let cols = self.cols;
1917
1918        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1919
1920        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1921    }
1922}
1923
1924#[cfg(all(feature = "python", feature = "cuda"))]
1925#[pyfunction(name = "uma_cuda_batch_dev")]
1926#[pyo3(signature = (data_f32, accelerator_range, min_length_range, max_length_range, smooth_length_range, volume_f32=None, device_id=0))]
1927pub fn uma_cuda_batch_dev_py(
1928    py: Python<'_>,
1929    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1930    accelerator_range: (f64, f64, f64),
1931    min_length_range: (usize, usize, usize),
1932    max_length_range: (usize, usize, usize),
1933    smooth_length_range: (usize, usize, usize),
1934    volume_f32: Option<numpy::PyReadonlyArray1<'_, f32>>,
1935    device_id: usize,
1936) -> PyResult<UmaDeviceArrayF32Py> {
1937    if !cuda_available() {
1938        return Err(PyValueError::new_err("CUDA not available"));
1939    }
1940
1941    let slice_in = data_f32.as_slice()?;
1942    let volume_slice = volume_f32.as_ref().map(|v| v.as_slice()).transpose()?;
1943    let sweep = UmaBatchRange {
1944        accelerator: accelerator_range,
1945        min_length: min_length_range,
1946        max_length: max_length_range,
1947        smooth_length: smooth_length_range,
1948    };
1949
1950    let (inner, ctx, dev_id) = py
1951        .allow_threads(
1952            || -> Result<_, crate::cuda::moving_averages::uma_wrapper::CudaUmaError> {
1953                let cuda = CudaUma::new(device_id)?;
1954                let out = cuda.uma_batch_dev(slice_in, volume_slice, &sweep)?;
1955                Ok((out, cuda.context_arc(), cuda.device_id()))
1956            },
1957        )
1958        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1959
1960    let crate::cuda::DeviceArrayF32 { buf, rows, cols } = inner;
1961    Ok(UmaDeviceArrayF32Py {
1962        buf: Some(buf),
1963        rows,
1964        cols,
1965        _ctx: ctx,
1966        device_id: dev_id,
1967    })
1968}
1969
1970#[cfg(all(feature = "python", feature = "cuda"))]
1971#[pyfunction(name = "uma_cuda_many_series_one_param_dev")]
1972#[pyo3(signature = (prices_tm_f32, accelerator, min_length, max_length, smooth_length, volume_tm_f32=None, device_id=0))]
1973pub fn uma_cuda_many_series_one_param_dev_py(
1974    py: Python<'_>,
1975    prices_tm_f32: PyReadonlyArray2<'_, f32>,
1976    accelerator: f64,
1977    min_length: usize,
1978    max_length: usize,
1979    smooth_length: usize,
1980    volume_tm_f32: Option<PyReadonlyArray2<'_, f32>>,
1981    device_id: usize,
1982) -> PyResult<UmaDeviceArrayF32Py> {
1983    if !cuda_available() {
1984        return Err(PyValueError::new_err("CUDA not available"));
1985    }
1986
1987    use numpy::PyUntypedArrayMethods;
1988
1989    let rows = prices_tm_f32.shape()[0];
1990    let cols = prices_tm_f32.shape()[1];
1991    if let Some(vol) = &volume_tm_f32 {
1992        let vshape = vol.shape();
1993        if vshape != prices_tm_f32.shape() {
1994            return Err(PyValueError::new_err(
1995                "price and volume matrices must share shape",
1996            ));
1997        }
1998    }
1999
2000    let prices_flat = prices_tm_f32.as_slice()?;
2001    let volume_flat = volume_tm_f32
2002        .as_ref()
2003        .map(|arr| arr.as_slice())
2004        .transpose()?;
2005
2006    let params = UmaParams {
2007        accelerator: Some(accelerator),
2008        min_length: Some(min_length),
2009        max_length: Some(max_length),
2010        smooth_length: Some(smooth_length),
2011    };
2012
2013    let (inner, ctx, dev_id) = py
2014        .allow_threads(
2015            || -> Result<_, crate::cuda::moving_averages::uma_wrapper::CudaUmaError> {
2016                let cuda = CudaUma::new(device_id)?;
2017                let out = cuda.uma_many_series_one_param_time_major_dev(
2018                    prices_flat,
2019                    volume_flat,
2020                    cols,
2021                    rows,
2022                    &params,
2023                )?;
2024                Ok((out, cuda.context_arc(), cuda.device_id()))
2025            },
2026        )
2027        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2028
2029    let crate::cuda::DeviceArrayF32 { buf, rows, cols } = inner;
2030    Ok(UmaDeviceArrayF32Py {
2031        buf: Some(buf),
2032        rows,
2033        cols,
2034        _ctx: ctx,
2035        device_id: dev_id,
2036    })
2037}
2038
2039#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2040#[wasm_bindgen]
2041pub fn uma_js(
2042    data: &[f64],
2043    accelerator: f64,
2044    min_length: usize,
2045    max_length: usize,
2046    smooth_length: usize,
2047    volume: Option<Vec<f64>>,
2048) -> Result<JsValue, JsValue> {
2049    let params = UmaParams {
2050        accelerator: Some(accelerator),
2051        min_length: Some(min_length),
2052        max_length: Some(max_length),
2053        smooth_length: Some(smooth_length),
2054    };
2055    let vol_slice = volume.as_deref();
2056    let input = UmaInput::from_slice(data, vol_slice, params);
2057
2058    match uma_with_kernel(&input, Kernel::Auto) {
2059        Ok(output) => {
2060            let obj = js_sys::Object::new();
2061            let values_array = js_sys::Array::new();
2062            for val in output.values {
2063                values_array.push(&JsValue::from_f64(val));
2064            }
2065            js_sys::Reflect::set(&obj, &JsValue::from_str("values"), &values_array)?;
2066            Ok(obj.into())
2067        }
2068        Err(e) => Err(JsValue::from_str(&e.to_string())),
2069    }
2070}
2071
2072#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2073#[derive(Serialize, Deserialize)]
2074pub struct UmaBatchConfig {
2075    pub accelerator_range: (f64, f64, f64),
2076    pub min_length_range: (usize, usize, usize),
2077    pub max_length_range: (usize, usize, usize),
2078    pub smooth_length_range: (usize, usize, usize),
2079}
2080
2081#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2082#[derive(Serialize, Deserialize)]
2083pub struct UmaBatchJsOutput {
2084    pub values: Vec<f64>,
2085    pub combos: Vec<UmaParams>,
2086    pub rows: usize,
2087    pub cols: usize,
2088}
2089
2090#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2091#[wasm_bindgen(js_name = "uma_batch")]
2092pub fn uma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2093    let cfg: UmaBatchConfig = serde_wasm_bindgen::from_value(config)
2094        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2095    let sweep = UmaBatchRange {
2096        accelerator: cfg.accelerator_range,
2097        min_length: cfg.min_length_range,
2098        max_length: cfg.max_length_range,
2099        smooth_length: cfg.smooth_length_range,
2100    };
2101    let out = uma_batch_inner(data, None, &sweep, detect_best_kernel(), false)
2102        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2103    let js = UmaBatchJsOutput {
2104        values: out.values,
2105        combos: out.combos,
2106        rows: out.rows,
2107        cols: out.cols,
2108    };
2109    serde_wasm_bindgen::to_value(&js)
2110        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2111}
2112
2113#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2114#[wasm_bindgen]
2115pub fn uma_alloc(len: usize) -> *mut f64 {
2116    let mut v = Vec::<f64>::with_capacity(len);
2117    let p = v.as_mut_ptr();
2118    std::mem::forget(v);
2119    p
2120}
2121
2122#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2123#[wasm_bindgen]
2124pub fn uma_free(ptr: *mut f64, len: usize) {
2125    unsafe {
2126        let _ = Vec::from_raw_parts(ptr, len, len);
2127    }
2128}
2129
2130#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2131#[wasm_bindgen]
2132pub fn uma_into(
2133    in_ptr: *const f64,
2134    out_ptr: *mut f64,
2135    len: usize,
2136    accelerator: f64,
2137    min_length: usize,
2138    max_length: usize,
2139    smooth_length: usize,
2140) -> Result<(), JsValue> {
2141    if in_ptr.is_null() || out_ptr.is_null() {
2142        return Err(JsValue::from_str("null pointer"));
2143    }
2144    unsafe {
2145        let data = std::slice::from_raw_parts(in_ptr, len);
2146        let params = UmaParams {
2147            accelerator: Some(accelerator),
2148            min_length: Some(min_length),
2149            max_length: Some(max_length),
2150            smooth_length: Some(smooth_length),
2151        };
2152        let input = UmaInput::from_slice(data, None, params);
2153
2154        if core::ptr::eq(in_ptr as *const u8, out_ptr as *const u8) {
2155            let mut tmp = vec![0.0; len];
2156            uma_into_slice(&mut tmp, &input, Kernel::Auto)
2157                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2158            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2159            out.copy_from_slice(&tmp);
2160        } else {
2161            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2162            uma_into_slice(out, &input, Kernel::Auto)
2163                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2164        }
2165        Ok(())
2166    }
2167}
2168
2169#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2170#[wasm_bindgen]
2171pub fn uma_stream_new(
2172    accelerator: f64,
2173    min_length: usize,
2174    max_length: usize,
2175    smooth_length: usize,
2176) -> *mut UmaStream {
2177    let params = UmaParams {
2178        accelerator: Some(accelerator),
2179        min_length: Some(min_length),
2180        max_length: Some(max_length),
2181        smooth_length: Some(smooth_length),
2182    };
2183
2184    match UmaStream::try_new(params) {
2185        Ok(stream) => Box::into_raw(Box::new(stream)),
2186        Err(_) => std::ptr::null_mut(),
2187    }
2188}
2189
2190#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2191#[wasm_bindgen]
2192pub fn uma_stream_update(stream: *mut UmaStream, value: f64) -> Option<f64> {
2193    if stream.is_null() {
2194        return None;
2195    }
2196    unsafe { (*stream).update_with_volume(value, None) }
2197}
2198
2199#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2200#[wasm_bindgen]
2201pub fn uma_stream_update_with_volume(
2202    stream: *mut UmaStream,
2203    value: f64,
2204    volume: f64,
2205) -> Option<f64> {
2206    if stream.is_null() {
2207        return None;
2208    }
2209    unsafe { (*stream).update_with_volume(value, Some(volume)) }
2210}
2211
2212#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2213#[wasm_bindgen]
2214pub fn uma_stream_reset(stream: *mut UmaStream) {
2215    if !stream.is_null() {
2216        unsafe {
2217            (*stream).reset();
2218        }
2219    }
2220}
2221
2222#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2223#[wasm_bindgen]
2224pub fn uma_stream_free(stream: *mut UmaStream) {
2225    if !stream.is_null() {
2226        unsafe {
2227            let _ = Box::from_raw(stream);
2228        }
2229    }
2230}
2231
2232#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2233#[wasm_bindgen]
2234pub fn uma_get_view(ptr: *mut f64, len: usize) -> js_sys::Float64Array {
2235    unsafe { js_sys::Float64Array::view(std::slice::from_raw_parts(ptr, len)) }
2236}
2237
2238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2239#[wasm_bindgen]
2240pub fn uma_update(
2241    ptr: *mut f64,
2242    len: usize,
2243    accelerator: f64,
2244    min_length: usize,
2245    max_length: usize,
2246    smooth_length: usize,
2247) -> Result<(), JsValue> {
2248    if ptr.is_null() {
2249        return Err(JsValue::from_str("null pointer"));
2250    }
2251
2252    unsafe {
2253        let data = std::slice::from_raw_parts(ptr, len);
2254        let params = UmaParams {
2255            accelerator: Some(accelerator),
2256            min_length: Some(min_length),
2257            max_length: Some(max_length),
2258            smooth_length: Some(smooth_length),
2259        };
2260        let input = UmaInput::from_slice(data, None, params);
2261
2262        let mut tmp = vec![0.0; len];
2263        uma_into_slice(&mut tmp, &input, Kernel::Auto)
2264            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2265
2266        let out = std::slice::from_raw_parts_mut(ptr, len);
2267        out.copy_from_slice(&tmp);
2268        Ok(())
2269    }
2270}
2271
2272#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2273#[wasm_bindgen]
2274pub fn uma_batch_js(
2275    data: &[f64],
2276    accelerator_range: Vec<f64>,
2277    min_length_range: Vec<usize>,
2278    max_length_range: Vec<usize>,
2279    smooth_length_range: Vec<usize>,
2280    volume: Option<Vec<f64>>,
2281) -> Result<JsValue, JsValue> {
2282    if accelerator_range.len() != 3
2283        || min_length_range.len() != 3
2284        || max_length_range.len() != 3
2285        || smooth_length_range.len() != 3
2286    {
2287        return Err(JsValue::from_str(
2288            "All range arrays must have exactly 3 elements: [start, end, step]",
2289        ));
2290    }
2291
2292    let sweep = UmaBatchRange {
2293        accelerator: (
2294            accelerator_range[0],
2295            accelerator_range[1],
2296            accelerator_range[2],
2297        ),
2298        min_length: (
2299            min_length_range[0],
2300            min_length_range[1],
2301            min_length_range[2],
2302        ),
2303        max_length: (
2304            max_length_range[0],
2305            max_length_range[1],
2306            max_length_range[2],
2307        ),
2308        smooth_length: (
2309            smooth_length_range[0],
2310            smooth_length_range[1],
2311            smooth_length_range[2],
2312        ),
2313    };
2314
2315    let vol_slice = volume.as_deref();
2316    let kernel = detect_best_batch_kernel();
2317
2318    let result = uma_batch_inner(data, vol_slice, &sweep, kernel, true)
2319        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2320
2321    let obj = js_sys::Object::new();
2322
2323    let values_array = js_sys::Array::new();
2324    for val in result.values {
2325        values_array.push(&JsValue::from_f64(val));
2326    }
2327    js_sys::Reflect::set(&obj, &JsValue::from_str("values"), &values_array)?;
2328
2329    let accelerators = js_sys::Array::new();
2330    let min_lengths = js_sys::Array::new();
2331    let max_lengths = js_sys::Array::new();
2332    let smooth_lengths = js_sys::Array::new();
2333
2334    for combo in &result.combos {
2335        accelerators.push(&JsValue::from_f64(combo.accelerator.unwrap_or(1.0)));
2336        min_lengths.push(&JsValue::from_f64(combo.min_length.unwrap_or(5) as f64));
2337        max_lengths.push(&JsValue::from_f64(combo.max_length.unwrap_or(50) as f64));
2338        smooth_lengths.push(&JsValue::from_f64(combo.smooth_length.unwrap_or(4) as f64));
2339    }
2340
2341    js_sys::Reflect::set(&obj, &JsValue::from_str("accelerators"), &accelerators)?;
2342    js_sys::Reflect::set(&obj, &JsValue::from_str("min_lengths"), &min_lengths)?;
2343    js_sys::Reflect::set(&obj, &JsValue::from_str("max_lengths"), &max_lengths)?;
2344    js_sys::Reflect::set(&obj, &JsValue::from_str("smooth_lengths"), &smooth_lengths)?;
2345    js_sys::Reflect::set(
2346        &obj,
2347        &JsValue::from_str("rows"),
2348        &JsValue::from_f64(result.rows as f64),
2349    )?;
2350    js_sys::Reflect::set(
2351        &obj,
2352        &JsValue::from_str("cols"),
2353        &JsValue::from_f64(result.cols as f64),
2354    )?;
2355
2356    Ok(obj.into())
2357}
2358
2359#[cfg(test)]
2360mod tests {
2361    use super::*;
2362    use crate::skip_if_unsupported;
2363    use crate::utilities::data_loader::read_candles_from_csv;
2364    #[cfg(feature = "proptest")]
2365    use proptest::prelude::*;
2366
2367    fn check_uma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2368        skip_if_unsupported!(kernel, test_name);
2369        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2370        let candles = read_candles_from_csv(file_path)?;
2371
2372        let default_params = UmaParams {
2373            accelerator: None,
2374            min_length: None,
2375            max_length: None,
2376            smooth_length: None,
2377        };
2378        let input = UmaInput::from_candles(&candles, "close", default_params);
2379        let output = uma_with_kernel(&input, kernel)?;
2380        assert_eq!(output.values.len(), candles.close.len());
2381
2382        Ok(())
2383    }
2384
2385    fn check_uma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2386        skip_if_unsupported!(kernel, test_name);
2387        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2388        let candles = read_candles_from_csv(file_path)?;
2389
2390        let input = UmaInput::from_candles(&candles, "close", UmaParams::default());
2391        let result = uma_with_kernel(&input, kernel)?;
2392
2393        let values = &result.values;
2394        let valid_values: Vec<f64> = values.iter().filter(|&&v| !v.is_nan()).copied().collect();
2395
2396        let expected_last_five = [
2397            59665.81830666,
2398            59477.69234542,
2399            59314.50778603,
2400            59218.23033661,
2401            59143.61473211,
2402        ];
2403
2404        if valid_values.len() >= 5 {
2405            let start = valid_values.len().saturating_sub(5);
2406            for (i, &val) in valid_values[start..].iter().enumerate() {
2407                let diff = (val - expected_last_five[i]).abs();
2408                let tolerance = expected_last_five[i] * 0.01;
2409                assert!(
2410                    diff < tolerance || diff < 100.0,
2411                    "[{}] UMA {:?} mismatch at idx {}: got {}, expected {}",
2412                    test_name,
2413                    kernel,
2414                    i,
2415                    val,
2416                    expected_last_five[i]
2417                );
2418            }
2419        }
2420        Ok(())
2421    }
2422
2423    fn check_uma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2424        skip_if_unsupported!(kernel, test_name);
2425        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2426        let candles = read_candles_from_csv(file_path)?;
2427
2428        let input = UmaInput::with_default_candles(&candles);
2429        match input.data {
2430            UmaData::Candles { source, .. } => assert_eq!(source, "close"),
2431            _ => panic!("Expected UmaData::Candles"),
2432        }
2433        let output = uma_with_kernel(&input, kernel)?;
2434        assert_eq!(output.values.len(), candles.close.len());
2435
2436        Ok(())
2437    }
2438
2439    fn check_uma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2440        skip_if_unsupported!(kernel, test_name);
2441        let input_data = [10.0, 20.0, 30.0];
2442        let params = UmaParams {
2443            accelerator: Some(1.0),
2444            min_length: Some(5),
2445            max_length: Some(0),
2446            smooth_length: Some(4),
2447        };
2448        let input = UmaInput::from_slice(&input_data, None, params);
2449        let res = uma_with_kernel(&input, kernel);
2450        assert!(
2451            res.is_err(),
2452            "[{}] UMA should fail with zero max_length",
2453            test_name
2454        );
2455        Ok(())
2456    }
2457
2458    fn check_uma_period_exceeds_length(
2459        test_name: &str,
2460        kernel: Kernel,
2461    ) -> Result<(), Box<dyn Error>> {
2462        skip_if_unsupported!(kernel, test_name);
2463        let data_small = [10.0, 20.0, 30.0];
2464        let params = UmaParams {
2465            accelerator: Some(1.0),
2466            min_length: Some(5),
2467            max_length: Some(10),
2468            smooth_length: Some(4),
2469        };
2470        let input = UmaInput::from_slice(&data_small, None, params);
2471        let res = uma_with_kernel(&input, kernel);
2472        assert!(
2473            res.is_err(),
2474            "[{}] UMA should fail with max_length exceeding data length",
2475            test_name
2476        );
2477        Ok(())
2478    }
2479
2480    fn check_uma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2481        skip_if_unsupported!(kernel, test_name);
2482        let single_point = [42.0];
2483        let params = UmaParams::default();
2484        let input = UmaInput::from_slice(&single_point, None, params);
2485        let res = uma_with_kernel(&input, kernel);
2486        assert!(
2487            res.is_err(),
2488            "[{}] UMA should fail with insufficient data",
2489            test_name
2490        );
2491        Ok(())
2492    }
2493
2494    fn check_uma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2495        skip_if_unsupported!(kernel, test_name);
2496        let empty: [f64; 0] = [];
2497        let input = UmaInput::from_slice(&empty, None, UmaParams::default());
2498        let res = uma_with_kernel(&input, kernel);
2499        assert!(
2500            matches!(res, Err(UmaError::EmptyInputData)),
2501            "[{}] UMA should fail with empty input",
2502            test_name
2503        );
2504        Ok(())
2505    }
2506
2507    fn check_uma_invalid_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2508        skip_if_unsupported!(kernel, test_name);
2509        let data: Vec<f64> = (0..100).map(|i| 100.0 + i as f64).collect();
2510
2511        let params = UmaParams {
2512            accelerator: Some(0.5),
2513            max_length: Some(10),
2514            ..Default::default()
2515        };
2516        let input = UmaInput::from_slice(&data, None, params);
2517        let res = uma_with_kernel(&input, kernel);
2518        assert!(
2519            matches!(res, Err(UmaError::InvalidAccelerator { .. })),
2520            "[{}] UMA should fail with invalid accelerator, got: {:?}",
2521            test_name,
2522            res
2523        );
2524        Ok(())
2525    }
2526
2527    fn check_uma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2528        skip_if_unsupported!(kernel, test_name);
2529        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2530        let candles = read_candles_from_csv(file_path)?;
2531
2532        let first_params = UmaParams::default();
2533        let first_input = UmaInput::from_candles(&candles, "close", first_params);
2534        let first_result = uma_with_kernel(&first_input, kernel)?;
2535
2536        let second_params = UmaParams::default();
2537        let second_input = UmaInput::from_slice(&first_result.values, None, second_params);
2538        let second_result = uma_with_kernel(&second_input, kernel)?;
2539
2540        assert_eq!(second_result.values.len(), first_result.values.len());
2541
2542        let valid_count = second_result
2543            .values
2544            .iter()
2545            .filter(|&&v| !v.is_nan())
2546            .count();
2547        assert!(
2548            valid_count > 0,
2549            "[{}] UMA reinput should produce valid values",
2550            test_name
2551        );
2552
2553        Ok(())
2554    }
2555
2556    fn check_uma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2557        skip_if_unsupported!(kernel, test_name);
2558
2559        let mut data = vec![f64::NAN; 10];
2560        data.extend((0..100).map(|i| 100.0 + i as f64));
2561
2562        let input = UmaInput::from_slice(&data, None, UmaParams::default());
2563        let res = uma_with_kernel(&input, kernel)?;
2564        assert_eq!(res.values.len(), data.len());
2565
2566        let valid_count = res.values[60..].iter().filter(|&&v| !v.is_nan()).count();
2567        assert!(
2568            valid_count > 0,
2569            "[{}] UMA should handle NaN prefix",
2570            test_name
2571        );
2572
2573        Ok(())
2574    }
2575
2576    fn check_uma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2577        skip_if_unsupported!(kernel, test_name);
2578
2579        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2580        let candles = read_candles_from_csv(file_path)?;
2581
2582        let params = UmaParams::default();
2583        let input = UmaInput::from_candles(&candles, "close", params.clone());
2584        let batch_output = uma_with_kernel(&input, kernel)?.values;
2585
2586        let mut stream = UmaStream::try_new(params)?;
2587        let mut stream_values = Vec::with_capacity(candles.close.len());
2588
2589        for (i, &price) in candles.close.iter().enumerate() {
2590            let volume = if i < candles.volume.len() {
2591                Some(candles.volume[i])
2592            } else {
2593                None
2594            };
2595
2596            match stream.update_with_volume(price, volume) {
2597                Some(uma_val) => stream_values.push(uma_val),
2598                None => stream_values.push(f64::NAN),
2599            }
2600        }
2601
2602        assert_eq!(batch_output.len(), stream_values.len());
2603
2604        let batch_valid: Vec<f64> = batch_output
2605            .iter()
2606            .filter(|&&v| !v.is_nan())
2607            .copied()
2608            .collect();
2609        let stream_valid: Vec<f64> = stream_values
2610            .iter()
2611            .filter(|&&v| !v.is_nan())
2612            .copied()
2613            .collect();
2614
2615        if batch_valid.len() >= 5 && stream_valid.len() >= 5 {
2616            let batch_last = &batch_valid[batch_valid.len() - 5..];
2617            let stream_last = &stream_valid[stream_valid.len() - 5..];
2618
2619            for (i, (&b, &s)) in batch_last.iter().zip(stream_last.iter()).enumerate() {
2620                let diff = (b - s).abs();
2621                let relative_diff = diff / b.abs().max(1.0);
2622                assert!(
2623                    relative_diff < 0.1,
2624                    "[{}] UMA streaming mismatch at idx {}: batch={}, stream={}, diff={}, rel_diff={}",
2625                    test_name,
2626                    i,
2627                    b,
2628                    s,
2629                    diff,
2630                    relative_diff
2631                );
2632            }
2633        }
2634
2635        Ok(())
2636    }
2637
2638    #[cfg(debug_assertions)]
2639    fn check_uma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2640        skip_if_unsupported!(kernel, test_name);
2641
2642        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2643        let candles = read_candles_from_csv(file_path)?;
2644
2645        let test_params = vec![
2646            UmaParams::default(),
2647            UmaParams {
2648                accelerator: Some(1.5),
2649                min_length: Some(3),
2650                max_length: Some(30),
2651                smooth_length: Some(3),
2652            },
2653            UmaParams {
2654                accelerator: Some(2.0),
2655                min_length: Some(10),
2656                max_length: Some(100),
2657                smooth_length: Some(8),
2658            },
2659        ];
2660
2661        for params in test_params.iter() {
2662            let input = UmaInput::from_candles(&candles, "close", params.clone());
2663            let output = uma_with_kernel(&input, kernel)?;
2664
2665            for (i, &val) in output.values.iter().enumerate() {
2666                if val.is_nan() {
2667                    continue;
2668                }
2669
2670                let bits = val.to_bits();
2671
2672                if bits == 0x11111111_11111111 {
2673                    panic!(
2674                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
2675                        test_name, val, bits, i
2676                    );
2677                }
2678
2679                if bits == 0x22222222_22222222 {
2680                    panic!(
2681                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
2682                        test_name, val, bits, i
2683                    );
2684                }
2685
2686                if bits == 0x33333333_33333333 {
2687                    panic!(
2688                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
2689                        test_name, val, bits, i
2690                    );
2691                }
2692            }
2693        }
2694
2695        Ok(())
2696    }
2697
2698    #[cfg(not(debug_assertions))]
2699    fn check_uma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2700        Ok(())
2701    }
2702
2703    #[cfg(feature = "proptest")]
2704    #[allow(clippy::float_cmp)]
2705    fn check_uma_property(
2706        test_name: &str,
2707        kernel: Kernel,
2708    ) -> Result<(), Box<dyn std::error::Error>> {
2709        use proptest::prelude::*;
2710        skip_if_unsupported!(kernel, test_name);
2711
2712        let strat = (5usize..=20, 20usize..=50, 2usize..=8, 1.0f64..3.0).prop_flat_map(
2713            |(min_len, max_len, smooth_len, acc)| {
2714                (
2715                    prop::collection::vec(
2716                        (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2717                        max_len + 10..200,
2718                    ),
2719                    Just(min_len),
2720                    Just(max_len),
2721                    Just(smooth_len),
2722                    Just(acc),
2723                )
2724            },
2725        );
2726
2727        proptest::test_runner::TestRunner::default()
2728            .run(&strat, |(data, min_len, max_len, smooth_len, acc)| {
2729                let params = UmaParams {
2730                    accelerator: Some(acc),
2731                    min_length: Some(min_len),
2732                    max_length: Some(max_len),
2733                    smooth_length: Some(smooth_len),
2734                };
2735                let input = UmaInput::from_slice(&data, None, params);
2736
2737                let UmaOutput { values: out } = uma_with_kernel(&input, kernel).unwrap();
2738                let UmaOutput { values: ref_out } =
2739                    uma_with_kernel(&input, Kernel::Scalar).unwrap();
2740
2741                prop_assert_eq!(out.len(), data.len());
2742                prop_assert_eq!(ref_out.len(), data.len());
2743
2744                for i in 0..data.len() {
2745                    let y = out[i];
2746                    let r = ref_out[i];
2747
2748                    if !y.is_finite() || !r.is_finite() {
2749                        prop_assert!(
2750                            y.to_bits() == r.to_bits(),
2751                            "finite/NaN mismatch idx {i}: {y} vs {r}"
2752                        );
2753                        continue;
2754                    }
2755
2756                    let y_bits = y.to_bits();
2757                    let r_bits = r.to_bits();
2758                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2759
2760                    prop_assert!(
2761                        (y - r).abs() <= 1e-6 || ulp_diff <= 10,
2762                        "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2763                    );
2764                }
2765                Ok(())
2766            })
2767            .unwrap();
2768
2769        Ok(())
2770    }
2771
2772    macro_rules! generate_all_uma_tests {
2773        ($($test_fn:ident),*) => {
2774            paste::paste! {
2775                $(
2776                    #[test]
2777                    fn [<$test_fn _scalar_f64>]() {
2778                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2779                    }
2780                )*
2781                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2782                $(
2783                    #[test]
2784                    fn [<$test_fn _avx2_f64>]() {
2785                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2786                    }
2787                    #[test]
2788                    fn [<$test_fn _avx512_f64>]() {
2789                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2790                    }
2791                )*
2792                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2793                $(
2794                    #[test]
2795                    fn [<$test_fn _simd128_f64>]() {
2796                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2797                    }
2798                )*
2799            }
2800        }
2801    }
2802
2803    generate_all_uma_tests!(
2804        check_uma_partial_params,
2805        check_uma_accuracy,
2806        check_uma_default_candles,
2807        check_uma_zero_period,
2808        check_uma_period_exceeds_length,
2809        check_uma_very_small_dataset,
2810        check_uma_empty_input,
2811        check_uma_invalid_params,
2812        check_uma_reinput,
2813        check_uma_nan_handling,
2814        check_uma_streaming,
2815        check_uma_no_poison
2816    );
2817
2818    #[cfg(feature = "proptest")]
2819    generate_all_uma_tests!(check_uma_property);
2820
2821    #[test]
2822    fn uma_into_slice_matches_with_kernel() {
2823        use crate::utilities::data_loader::read_candles_from_csv;
2824        let c = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv").unwrap();
2825        let params = UmaParams::default();
2826        let input = UmaInput::from_candles(&c, "close", params.clone());
2827
2828        let via_api = uma_with_kernel(&input, Kernel::Scalar).unwrap().values;
2829        let mut via_into = vec![0.0; via_api.len()];
2830        uma_into_slice(&mut via_into, &input, Kernel::Scalar).unwrap();
2831
2832        assert_eq!(via_api.len(), via_into.len());
2833        for (a, b) in via_api.iter().zip(via_into.iter()) {
2834            assert_eq!(a.to_bits(), b.to_bits());
2835        }
2836    }
2837
2838    #[test]
2839    fn test_uma_into_matches_api() {
2840        let len = 256usize;
2841        let mut ts = Vec::with_capacity(len);
2842        let mut open = Vec::with_capacity(len);
2843        let mut high = Vec::with_capacity(len);
2844        let mut low = Vec::with_capacity(len);
2845        let mut close = Vec::with_capacity(len);
2846        let mut volume = Vec::with_capacity(len);
2847
2848        for i in 0..len {
2849            ts.push(i as i64);
2850            let base = 100.0 + (i as f64) * 0.1 + (i as f64 / 10.0).sin() * 2.0;
2851            let c = if i < 3 { f64::NAN } else { base };
2852            let h = if i < 3 { f64::NAN } else { c + 1.0 };
2853            let l = if i < 3 { f64::NAN } else { c - 1.0 };
2854            let o = if i < 3 { f64::NAN } else { c };
2855            let v = 1000.0 + (i % 10) as f64;
2856
2857            open.push(o);
2858            high.push(h);
2859            low.push(l);
2860            close.push(c);
2861            volume.push(v);
2862        }
2863
2864        let candles =
2865            crate::utilities::data_loader::Candles::new(ts, open, high, low, close, volume);
2866        let params = UmaParams::default();
2867        let input = UmaInput::from_candles(&candles, "close", params);
2868
2869        let baseline = uma(&input).unwrap().values;
2870        let mut via_into = vec![0.0; baseline.len()];
2871
2872        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2873        uma_into(&input, &mut via_into).unwrap();
2874        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2875        uma_into_slice(&mut via_into, &input, Kernel::Auto).unwrap();
2876
2877        assert_eq!(baseline.len(), via_into.len());
2878        for (a, b) in baseline.iter().zip(via_into.iter()) {
2879            let both_nan = a.is_nan() && b.is_nan();
2880            let close_enough = if both_nan {
2881                true
2882            } else {
2883                (*a - *b).abs() <= 1e-12
2884            };
2885            assert!(close_enough, "Mismatch: a={} b={}", a, b);
2886        }
2887    }
2888
2889    #[cfg(feature = "python")]
2890    #[test]
2891    fn uma_batch_py_no_copy_shape() {
2892        pyo3::Python::with_gil(|py| {
2893            use numpy::{PyArray1, PyArrayMethods};
2894            let data = PyArray1::from_vec(py, (0..256).map(|i| i as f64).collect());
2895            let d = crate::indicators::moving_averages::uma::uma_batch_py(
2896                py,
2897                data.readonly(),
2898                (1.0, 1.0, 0.0),
2899                (5, 5, 0),
2900                (50, 50, 0),
2901                (4, 4, 0),
2902                None,
2903                Some("scalar_batch"),
2904            )
2905            .unwrap();
2906            let v = d.get_item("values").unwrap().expect("values missing");
2907
2908            assert!(v.downcast::<numpy::PyArray2<f64>>().is_ok());
2909        });
2910    }
2911
2912    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2913        skip_if_unsupported!(kernel, test);
2914
2915        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2916        let c = read_candles_from_csv(file)?;
2917
2918        let output = UmaBatchBuilder::new()
2919            .kernel(kernel)
2920            .apply_candles(&c, "close")?;
2921
2922        let def = UmaParams::default();
2923        let row = output.values_for(&def).expect("default row missing");
2924
2925        assert_eq!(row.len(), c.close.len());
2926
2927        let valid_count = row.iter().filter(|&&v| !v.is_nan()).count();
2928        assert!(
2929            valid_count > 0,
2930            "[{}] Batch should produce valid values",
2931            test
2932        );
2933
2934        Ok(())
2935    }
2936
2937    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2938        skip_if_unsupported!(kernel, test);
2939
2940        let data: Vec<f64> = (0..100).map(|i| 100.0 + i as f64).collect();
2941
2942        let output = UmaBatchBuilder::new()
2943            .kernel(kernel)
2944            .accelerator_range(1.0, 2.0, 0.5)
2945            .min_length_range(5, 10, 5)
2946            .max_length_range(20, 30, 10)
2947            .smooth_length_range(3, 5, 2)
2948            .apply_slice(&data, None)?;
2949
2950        let expected_combos = 3 * 2 * 2 * 2;
2951        assert_eq!(output.combos.len(), expected_combos);
2952        assert_eq!(output.rows, expected_combos);
2953        assert_eq!(output.cols, data.len());
2954
2955        Ok(())
2956    }
2957
2958    #[cfg(debug_assertions)]
2959    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2960        skip_if_unsupported!(kernel, test);
2961
2962        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2963        let c = read_candles_from_csv(file)?;
2964
2965        let test_configs = vec![
2966            (1.0, 1.5, 0.5, 5, 10, 5, 20, 30, 10, 3, 5, 2),
2967            (2.0, 2.0, 0.0, 10, 10, 0, 50, 50, 0, 4, 4, 0),
2968        ];
2969
2970        for (
2971            cfg_idx,
2972            &(
2973                a_start,
2974                a_end,
2975                a_step,
2976                min_start,
2977                min_end,
2978                min_step,
2979                max_start,
2980                max_end,
2981                max_step,
2982                s_start,
2983                s_end,
2984                s_step,
2985            ),
2986        ) in test_configs.iter().enumerate()
2987        {
2988            let output = UmaBatchBuilder::new()
2989                .kernel(kernel)
2990                .accelerator_range(a_start, a_end, a_step)
2991                .min_length_range(min_start, min_end, min_step)
2992                .max_length_range(max_start, max_end, max_step)
2993                .smooth_length_range(s_start, s_end, s_step)
2994                .apply_candles(&c, "close")?;
2995
2996            for (idx, &val) in output.values.iter().enumerate() {
2997                if val.is_nan() {
2998                    continue;
2999                }
3000
3001                let bits = val.to_bits();
3002                let row = idx / output.cols;
3003                let col = idx % output.cols;
3004                let combo = &output.combos[row];
3005
3006                if bits == 0x11111111_11111111 {
3007                    panic!(
3008                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3009                        at row {} col {} (flat index {}) with params: acc={}, min={}, max={}, smooth={}",
3010                        test,
3011                        cfg_idx,
3012                        val,
3013                        bits,
3014                        row,
3015                        col,
3016                        idx,
3017                        combo.accelerator.unwrap_or(1.0),
3018                        combo.min_length.unwrap_or(5),
3019                        combo.max_length.unwrap_or(50),
3020                        combo.smooth_length.unwrap_or(4)
3021                    );
3022                }
3023
3024                if bits == 0x22222222_22222222 {
3025                    panic!(
3026                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3027                        at row {} col {} (flat index {})",
3028                        test, cfg_idx, val, bits, row, col, idx
3029                    );
3030                }
3031
3032                if bits == 0x33333333_33333333 {
3033                    panic!(
3034                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3035                        at row {} col {} (flat index {})",
3036                        test, cfg_idx, val, bits, row, col, idx
3037                    );
3038                }
3039            }
3040        }
3041
3042        Ok(())
3043    }
3044
3045    #[cfg(not(debug_assertions))]
3046    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3047        Ok(())
3048    }
3049
3050    macro_rules! gen_batch_tests {
3051        ($fn_name:ident) => {
3052            paste::paste! {
3053                #[test] fn [<$fn_name _scalar>]()      {
3054                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3055                }
3056                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3057                #[test] fn [<$fn_name _avx2>]()        {
3058                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3059                }
3060                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3061                #[test] fn [<$fn_name _avx512>]()      {
3062                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3063                }
3064                #[test] fn [<$fn_name _auto_detect>]() {
3065                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
3066                                    Kernel::Auto);
3067                }
3068            }
3069        };
3070    }
3071
3072    gen_batch_tests!(check_batch_default_row);
3073    gen_batch_tests!(check_batch_sweep);
3074    gen_batch_tests!(check_batch_no_poison);
3075}