Skip to main content

vector_ta/indicators/
emd_trend.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::ema::{ema_with_kernel, EmaData, EmaInput, EmaParams, EmaStream};
16use crate::indicators::moving_averages::frama::{FramaParams, FramaStream};
17use crate::indicators::moving_averages::ma::{ma_with_kernel, MaData};
18use crate::indicators::moving_averages::ma_stream::{ma_stream, MaStream};
19use crate::utilities::data_loader::Candles;
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26use std::error::Error;
27use thiserror::Error;
28
29const DEFAULT_SOURCE: &str = "close";
30const DEFAULT_AVG_TYPE: &str = "SMA";
31const DEFAULT_LENGTH: usize = 28;
32const DEFAULT_MULT: f64 = 1.0;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35enum SourceKind {
36    Open,
37    High,
38    Low,
39    Close,
40    Oc2,
41    Hl2,
42    Occ3,
43    Hlc3,
44    Ohlc4,
45    Hlcc4,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49enum AvgKind {
50    Sma,
51    Ema,
52    Hma,
53    Dema,
54    Tema,
55    Rma,
56    Frama,
57}
58
59#[derive(Debug, Clone)]
60pub enum EmdTrendData<'a> {
61    Candles {
62        candles: &'a Candles,
63    },
64    Slices {
65        open: &'a [f64],
66        high: &'a [f64],
67        low: &'a [f64],
68        close: &'a [f64],
69    },
70}
71
72#[derive(Debug, Clone)]
73pub struct EmdTrendOutput {
74    pub direction: Vec<f64>,
75    pub average: Vec<f64>,
76    pub upper: Vec<f64>,
77    pub lower: Vec<f64>,
78}
79
80#[derive(Debug, Clone)]
81#[cfg_attr(
82    all(target_arch = "wasm32", feature = "wasm"),
83    derive(Serialize, Deserialize)
84)]
85pub struct EmdTrendParams {
86    pub source: Option<String>,
87    pub avg_type: Option<String>,
88    pub length: Option<usize>,
89    pub mult: Option<f64>,
90}
91
92impl Default for EmdTrendParams {
93    fn default() -> Self {
94        Self {
95            source: Some(DEFAULT_SOURCE.to_string()),
96            avg_type: Some(DEFAULT_AVG_TYPE.to_string()),
97            length: Some(DEFAULT_LENGTH),
98            mult: Some(DEFAULT_MULT),
99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct EmdTrendInput<'a> {
105    pub data: EmdTrendData<'a>,
106    pub params: EmdTrendParams,
107}
108
109impl<'a> EmdTrendInput<'a> {
110    #[inline]
111    pub fn from_candles(candles: &'a Candles, params: EmdTrendParams) -> Self {
112        Self {
113            data: EmdTrendData::Candles { candles },
114            params,
115        }
116    }
117
118    #[inline]
119    pub fn from_slices(
120        open: &'a [f64],
121        high: &'a [f64],
122        low: &'a [f64],
123        close: &'a [f64],
124        params: EmdTrendParams,
125    ) -> Self {
126        Self {
127            data: EmdTrendData::Slices {
128                open,
129                high,
130                low,
131                close,
132            },
133            params,
134        }
135    }
136
137    #[inline]
138    pub fn with_default_candles(candles: &'a Candles) -> Self {
139        Self::from_candles(candles, EmdTrendParams::default())
140    }
141
142    #[inline]
143    pub fn get_source(&self) -> &str {
144        self.params.source.as_deref().unwrap_or(DEFAULT_SOURCE)
145    }
146
147    #[inline]
148    pub fn get_avg_type(&self) -> &str {
149        self.params.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE)
150    }
151
152    #[inline]
153    pub fn get_length(&self) -> usize {
154        self.params.length.unwrap_or(DEFAULT_LENGTH)
155    }
156
157    #[inline]
158    pub fn get_mult(&self) -> f64 {
159        self.params.mult.unwrap_or(DEFAULT_MULT)
160    }
161}
162
163#[derive(Copy, Clone, Debug)]
164pub struct EmdTrendBuilder {
165    source: Option<SourceKind>,
166    avg_type: Option<AvgKind>,
167    length: Option<usize>,
168    mult: Option<f64>,
169    kernel: Kernel,
170}
171
172impl Default for EmdTrendBuilder {
173    fn default() -> Self {
174        Self {
175            source: None,
176            avg_type: None,
177            length: None,
178            mult: None,
179            kernel: Kernel::Auto,
180        }
181    }
182}
183
184impl EmdTrendBuilder {
185    #[inline(always)]
186    pub fn new() -> Self {
187        Self::default()
188    }
189
190    #[inline(always)]
191    pub fn source(mut self, value: &str) -> Result<Self, EmdTrendError> {
192        self.source = Some(parse_source_kind(value)?);
193        Ok(self)
194    }
195
196    #[inline(always)]
197    pub fn avg_type(mut self, value: &str) -> Result<Self, EmdTrendError> {
198        self.avg_type = Some(parse_avg_kind(value)?);
199        Ok(self)
200    }
201
202    #[inline(always)]
203    pub fn length(mut self, value: usize) -> Self {
204        self.length = Some(value);
205        self
206    }
207
208    #[inline(always)]
209    pub fn mult(mut self, value: f64) -> Self {
210        self.mult = Some(value);
211        self
212    }
213
214    #[inline(always)]
215    pub fn kernel(mut self, value: Kernel) -> Self {
216        self.kernel = value;
217        self
218    }
219}
220
221#[derive(Debug, Error)]
222pub enum EmdTrendError {
223    #[error("emd_trend: Input data slice is empty.")]
224    EmptyInputData,
225    #[error(
226        "emd_trend: Input length mismatch: open = {open_len}, high = {high_len}, low = {low_len}, close = {close_len}"
227    )]
228    InputLengthMismatch {
229        open_len: usize,
230        high_len: usize,
231        low_len: usize,
232        close_len: usize,
233    },
234    #[error("emd_trend: All values are NaN.")]
235    AllValuesNaN,
236    #[error("emd_trend: Invalid source: {value}")]
237    InvalidSource { value: String },
238    #[error("emd_trend: Invalid avg_type: {value}")]
239    InvalidAvgType { value: String },
240    #[error("emd_trend: Invalid length: {length}")]
241    InvalidLength { length: usize },
242    #[error("emd_trend: Invalid mult: {mult}")]
243    InvalidMult { mult: f64 },
244    #[error("emd_trend: Not enough valid data: needed = {needed}, valid = {valid}")]
245    NotEnoughValidData { needed: usize, valid: usize },
246    #[error("emd_trend: Output length mismatch: expected = {expected}, got = {got}")]
247    OutputLengthMismatch { expected: usize, got: usize },
248    #[error("emd_trend: Output length mismatch: dst = {dst_len}, expected = {expected_len}")]
249    MismatchedOutputLen { dst_len: usize, expected_len: usize },
250    #[error("emd_trend: Invalid range: start={start}, end={end}, step={step}")]
251    InvalidRange {
252        start: String,
253        end: String,
254        step: String,
255    },
256    #[error("emd_trend: Invalid kernel for batch: {0:?}")]
257    InvalidKernelForBatch(Kernel),
258    #[error("emd_trend: Invalid input: {msg}")]
259    InvalidInput { msg: String },
260}
261
262#[inline(always)]
263fn source_kind_name(kind: SourceKind) -> &'static str {
264    match kind {
265        SourceKind::Open => "open",
266        SourceKind::High => "high",
267        SourceKind::Low => "low",
268        SourceKind::Close => "close",
269        SourceKind::Oc2 => "oc2",
270        SourceKind::Hl2 => "hl2",
271        SourceKind::Occ3 => "occ3",
272        SourceKind::Hlc3 => "hlc3",
273        SourceKind::Ohlc4 => "ohlc4",
274        SourceKind::Hlcc4 => "hlcc4",
275    }
276}
277
278#[inline(always)]
279fn avg_kind_name(kind: AvgKind) -> &'static str {
280    match kind {
281        AvgKind::Sma => "SMA",
282        AvgKind::Ema => "EMA",
283        AvgKind::Hma => "HMA",
284        AvgKind::Dema => "DEMA",
285        AvgKind::Tema => "TEMA",
286        AvgKind::Rma => "RMA",
287        AvgKind::Frama => "FRAMA",
288    }
289}
290
291#[inline(always)]
292fn parse_source_kind(value: &str) -> Result<SourceKind, EmdTrendError> {
293    if value.eq_ignore_ascii_case("open") {
294        return Ok(SourceKind::Open);
295    }
296    if value.eq_ignore_ascii_case("high") {
297        return Ok(SourceKind::High);
298    }
299    if value.eq_ignore_ascii_case("low") {
300        return Ok(SourceKind::Low);
301    }
302    if value.eq_ignore_ascii_case("close") {
303        return Ok(SourceKind::Close);
304    }
305    if value.eq_ignore_ascii_case("oc2") {
306        return Ok(SourceKind::Oc2);
307    }
308    if value.eq_ignore_ascii_case("hl2") {
309        return Ok(SourceKind::Hl2);
310    }
311    if value.eq_ignore_ascii_case("occ3") {
312        return Ok(SourceKind::Occ3);
313    }
314    if value.eq_ignore_ascii_case("hlc3") {
315        return Ok(SourceKind::Hlc3);
316    }
317    if value.eq_ignore_ascii_case("ohlc4") {
318        return Ok(SourceKind::Ohlc4);
319    }
320    if value.eq_ignore_ascii_case("hlcc4") {
321        return Ok(SourceKind::Hlcc4);
322    }
323    Err(EmdTrendError::InvalidSource {
324        value: value.to_string(),
325    })
326}
327
328#[inline(always)]
329fn parse_avg_kind(value: &str) -> Result<AvgKind, EmdTrendError> {
330    if value.eq_ignore_ascii_case("SMA") {
331        return Ok(AvgKind::Sma);
332    }
333    if value.eq_ignore_ascii_case("EMA") {
334        return Ok(AvgKind::Ema);
335    }
336    if value.eq_ignore_ascii_case("HMA") {
337        return Ok(AvgKind::Hma);
338    }
339    if value.eq_ignore_ascii_case("DEMA") {
340        return Ok(AvgKind::Dema);
341    }
342    if value.eq_ignore_ascii_case("TEMA") {
343        return Ok(AvgKind::Tema);
344    }
345    if value.eq_ignore_ascii_case("RMA") {
346        return Ok(AvgKind::Rma);
347    }
348    if value.eq_ignore_ascii_case("FRAMA") {
349        return Ok(AvgKind::Frama);
350    }
351    Err(EmdTrendError::InvalidAvgType {
352        value: value.to_string(),
353    })
354}
355
356#[inline(always)]
357fn internal_avg_id(kind: AvgKind) -> &'static str {
358    match kind {
359        AvgKind::Sma => "sma",
360        AvgKind::Ema => "ema",
361        AvgKind::Hma => "hma",
362        AvgKind::Dema => "dema",
363        AvgKind::Tema => "tema",
364        AvgKind::Rma => "wilders",
365        AvgKind::Frama => "frama",
366    }
367}
368
369#[inline(always)]
370fn source_value(kind: SourceKind, open: f64, high: f64, low: f64, close: f64) -> f64 {
371    match kind {
372        SourceKind::Open => open,
373        SourceKind::High => high,
374        SourceKind::Low => low,
375        SourceKind::Close => close,
376        SourceKind::Oc2 => {
377            if open.is_finite() && close.is_finite() {
378                0.5 * (open + close)
379            } else {
380                f64::NAN
381            }
382        }
383        SourceKind::Hl2 => {
384            if high.is_finite() && low.is_finite() {
385                0.5 * (high + low)
386            } else {
387                f64::NAN
388            }
389        }
390        SourceKind::Occ3 => {
391            if open.is_finite() && close.is_finite() {
392                (open + close + close) / 3.0
393            } else {
394                f64::NAN
395            }
396        }
397        SourceKind::Hlc3 => {
398            if high.is_finite() && low.is_finite() && close.is_finite() {
399                (high + low + close) / 3.0
400            } else {
401                f64::NAN
402            }
403        }
404        SourceKind::Ohlc4 => {
405            if open.is_finite() && high.is_finite() && low.is_finite() && close.is_finite() {
406                (open + high + low + close) * 0.25
407            } else {
408                f64::NAN
409            }
410        }
411        SourceKind::Hlcc4 => {
412            if high.is_finite() && low.is_finite() && close.is_finite() {
413                (high + low + close + close) * 0.25
414            } else {
415                f64::NAN
416            }
417        }
418    }
419}
420
421#[inline(always)]
422fn longest_valid_run(values: &[f64]) -> usize {
423    let mut best = 0usize;
424    let mut current = 0usize;
425    for &value in values {
426        if value.is_finite() {
427            current += 1;
428            best = best.max(current);
429        } else {
430            current = 0;
431        }
432    }
433    best
434}
435
436#[inline(always)]
437fn axis_usize(range: (usize, usize, usize)) -> Vec<usize> {
438    let (start, end, step) = range;
439    if step == 0 || start == end {
440        return vec![start];
441    }
442    let (lo, hi) = if start <= end {
443        (start, end)
444    } else {
445        (end, start)
446    };
447    let mut out = Vec::new();
448    let mut value = lo;
449    loop {
450        out.push(value);
451        match value.checked_add(step) {
452            Some(next) if next <= hi => value = next,
453            _ => break,
454        }
455    }
456    if start > end {
457        out.reverse();
458    }
459    out
460}
461
462#[inline(always)]
463fn axis_f64(range: (f64, f64, f64)) -> Vec<f64> {
464    let (start, end, step) = range;
465    if step == 0.0 || (start - end).abs() <= f64::EPSILON {
466        return vec![start];
467    }
468    let (lo, hi, reverse) = if start <= end {
469        (start, end, false)
470    } else {
471        (end, start, true)
472    };
473    let mut out = Vec::new();
474    let mut value = lo;
475    let eps = step.abs() * 1e-9 + 1e-12;
476    while value <= hi + eps {
477        out.push(value);
478        value += step.abs();
479    }
480    if reverse {
481        out.reverse();
482    }
483    out
484}
485
486#[inline(always)]
487fn validate_params_only(
488    source: &str,
489    avg_type: &str,
490    length: usize,
491    mult: f64,
492) -> Result<(SourceKind, AvgKind), EmdTrendError> {
493    let source_kind = parse_source_kind(source)?;
494    let avg_kind = parse_avg_kind(avg_type)?;
495    if length == 0 {
496        return Err(EmdTrendError::InvalidLength { length });
497    }
498    if !mult.is_finite() || mult < 0.05 {
499        return Err(EmdTrendError::InvalidMult { mult });
500    }
501    Ok((source_kind, avg_kind))
502}
503
504#[inline(always)]
505fn input_slices<'a>(
506    input: &'a EmdTrendInput<'a>,
507) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), EmdTrendError> {
508    match &input.data {
509        EmdTrendData::Candles { candles } => Ok((
510            candles.open.as_slice(),
511            candles.high.as_slice(),
512            candles.low.as_slice(),
513            candles.close.as_slice(),
514        )),
515        EmdTrendData::Slices {
516            open,
517            high,
518            low,
519            close,
520        } => Ok((open, high, low, close)),
521    }
522}
523
524#[inline(always)]
525fn build_source_series(
526    open: &[f64],
527    high: &[f64],
528    low: &[f64],
529    close: &[f64],
530    source_kind: SourceKind,
531) -> Result<Vec<f64>, EmdTrendError> {
532    if open.is_empty() || high.is_empty() || low.is_empty() || close.is_empty() {
533        return Err(EmdTrendError::EmptyInputData);
534    }
535    if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
536        return Err(EmdTrendError::InputLengthMismatch {
537            open_len: open.len(),
538            high_len: high.len(),
539            low_len: low.len(),
540            close_len: close.len(),
541        });
542    }
543
544    let mut out = Vec::with_capacity(close.len());
545    for i in 0..close.len() {
546        out.push(source_value(
547            source_kind,
548            open[i],
549            high[i],
550            low[i],
551            close[i],
552        ));
553    }
554    if longest_valid_run(&out) == 0 {
555        return Err(EmdTrendError::AllValuesNaN);
556    }
557    Ok(out)
558}
559
560#[inline(always)]
561fn parse_number_after(segment: &str) -> Option<usize> {
562    let eq = segment.find('=')?;
563    let number: String = segment[eq + 1..]
564        .chars()
565        .skip_while(|c| c.is_whitespace())
566        .take_while(|c| c.is_ascii_digit())
567        .collect();
568    number.parse().ok()
569}
570
571#[inline(always)]
572fn parse_needed_valid(msg: &str) -> Option<(usize, usize)> {
573    let needed_key = "needed";
574    let valid_key = "valid";
575    let needed_idx = msg.find(needed_key)?;
576    let valid_idx = msg.find(valid_key)?;
577    let needed = parse_number_after(&msg[needed_idx + needed_key.len()..])?;
578    let valid = parse_number_after(&msg[valid_idx + valid_key.len()..])?;
579    Some((needed, valid))
580}
581
582#[inline(always)]
583fn map_ma_error(err: Box<dyn Error>) -> EmdTrendError {
584    let msg = err.to_string();
585    if let Some((needed, valid)) = parse_needed_valid(&msg) {
586        return EmdTrendError::NotEnoughValidData { needed, valid };
587    }
588    EmdTrendError::InvalidInput { msg }
589}
590
591#[inline(always)]
592fn normalize_single_kernel(kernel: Kernel) -> Kernel {
593    match kernel {
594        Kernel::Auto => detect_best_kernel(),
595        Kernel::ScalarBatch => Kernel::Scalar,
596        Kernel::Avx2Batch => Kernel::Avx2,
597        Kernel::Avx512Batch => Kernel::Avx512,
598        other => other,
599    }
600}
601
602#[inline(always)]
603fn normalize_batch_kernel(kernel: Kernel) -> Result<Kernel, EmdTrendError> {
604    Ok(match kernel {
605        Kernel::Auto => match detect_best_batch_kernel() {
606            Kernel::ScalarBatch => Kernel::Scalar,
607            Kernel::Avx2Batch => Kernel::Avx2,
608            Kernel::Avx512Batch => Kernel::Avx512,
609            other => other,
610        },
611        Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
612        Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
613        Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
614        other => return Err(EmdTrendError::InvalidKernelForBatch(other)),
615    })
616}
617
618#[inline(always)]
619fn compute_average_series(
620    src: &[f64],
621    avg_kind: AvgKind,
622    length: usize,
623    kernel: Kernel,
624) -> Result<Vec<f64>, EmdTrendError> {
625    ma_with_kernel(
626        internal_avg_id(avg_kind),
627        MaData::Slice(src),
628        length,
629        kernel,
630    )
631    .map_err(map_ma_error)
632}
633
634#[inline(always)]
635fn compute_deviation_ema(
636    avg: &[f64],
637    src: &[f64],
638    length: usize,
639    kernel: Kernel,
640) -> Result<Vec<f64>, EmdTrendError> {
641    let mut abs_dev = alloc_with_nan_prefix(src.len(), 0);
642    abs_dev.fill(f64::NAN);
643    for i in 0..src.len() {
644        if avg[i].is_finite() && src[i].is_finite() {
645            abs_dev[i] = (src[i] - avg[i]).abs();
646        }
647    }
648    let input = EmaInput {
649        data: EmaData::Slice(&abs_dev),
650        params: EmaParams {
651            period: Some(length),
652        },
653    };
654    ema_with_kernel(&input, kernel)
655        .map(|out| out.values)
656        .map_err(|e| EmdTrendError::InvalidInput { msg: e.to_string() })
657}
658
659fn compute_from_source_into(
660    src: &[f64],
661    avg_kind: AvgKind,
662    length: usize,
663    mult: f64,
664    kernel: Kernel,
665    direction_out: &mut [f64],
666    average_out: &mut [f64],
667    upper_out: &mut [f64],
668    lower_out: &mut [f64],
669) -> Result<(), EmdTrendError> {
670    let len = src.len();
671    if direction_out.len() != len {
672        return Err(EmdTrendError::OutputLengthMismatch {
673            expected: len,
674            got: direction_out.len(),
675        });
676    }
677    if average_out.len() != len {
678        return Err(EmdTrendError::OutputLengthMismatch {
679            expected: len,
680            got: average_out.len(),
681        });
682    }
683    if upper_out.len() != len {
684        return Err(EmdTrendError::OutputLengthMismatch {
685            expected: len,
686            got: upper_out.len(),
687        });
688    }
689    if lower_out.len() != len {
690        return Err(EmdTrendError::OutputLengthMismatch {
691            expected: len,
692            got: lower_out.len(),
693        });
694    }
695
696    let average = compute_average_series(src, avg_kind, length, kernel)?;
697    let deviation = compute_deviation_ema(&average, src, length, kernel)?;
698
699    average_out.copy_from_slice(&average);
700    direction_out.fill(0.0);
701    upper_out.fill(f64::NAN);
702    lower_out.fill(f64::NAN);
703
704    let mut direction = 0.0f64;
705    for i in 0..len {
706        if average[i].is_finite() && deviation[i].is_finite() {
707            upper_out[i] = average[i] + deviation[i] * mult;
708            lower_out[i] = average[i] - deviation[i] * mult;
709        }
710        if i > 0
711            && src[i].is_finite()
712            && src[i - 1].is_finite()
713            && upper_out[i].is_finite()
714            && upper_out[i - 1].is_finite()
715            && src[i] > upper_out[i]
716            && src[i - 1] <= upper_out[i - 1]
717        {
718            direction = 1.0;
719        } else if i > 0
720            && src[i].is_finite()
721            && src[i - 1].is_finite()
722            && lower_out[i].is_finite()
723            && lower_out[i - 1].is_finite()
724            && src[i] < lower_out[i]
725            && src[i - 1] >= lower_out[i - 1]
726        {
727            direction = -1.0;
728        }
729        direction_out[i] = direction;
730    }
731
732    Ok(())
733}
734
735#[inline]
736pub fn emd_trend(input: &EmdTrendInput) -> Result<EmdTrendOutput, EmdTrendError> {
737    emd_trend_with_kernel(input, Kernel::Auto)
738}
739
740pub fn emd_trend_with_kernel(
741    input: &EmdTrendInput,
742    kernel: Kernel,
743) -> Result<EmdTrendOutput, EmdTrendError> {
744    let (open, high, low, close) = input_slices(input)?;
745    let (source_kind, avg_kind) = validate_params_only(
746        input.get_source(),
747        input.get_avg_type(),
748        input.get_length(),
749        input.get_mult(),
750    )?;
751    let src = build_source_series(open, high, low, close, source_kind)?;
752    let longest = longest_valid_run(&src);
753    let needed = if avg_kind == AvgKind::Frama && input.get_length() % 2 == 1 {
754        input.get_length() + 1
755    } else {
756        input.get_length()
757    };
758    if longest < needed {
759        return Err(EmdTrendError::NotEnoughValidData {
760            needed,
761            valid: longest,
762        });
763    }
764
765    let chosen = normalize_single_kernel(kernel);
766    let mut direction = alloc_with_nan_prefix(src.len(), 0);
767    let mut average = alloc_with_nan_prefix(src.len(), 0);
768    let mut upper = alloc_with_nan_prefix(src.len(), 0);
769    let mut lower = alloc_with_nan_prefix(src.len(), 0);
770    direction.fill(0.0);
771    average.fill(f64::NAN);
772    upper.fill(f64::NAN);
773    lower.fill(f64::NAN);
774
775    compute_from_source_into(
776        &src,
777        avg_kind,
778        input.get_length(),
779        input.get_mult(),
780        chosen,
781        &mut direction,
782        &mut average,
783        &mut upper,
784        &mut lower,
785    )?;
786
787    Ok(EmdTrendOutput {
788        direction,
789        average,
790        upper,
791        lower,
792    })
793}
794
795pub fn emd_trend_into_slice(
796    out_direction: &mut [f64],
797    out_average: &mut [f64],
798    out_upper: &mut [f64],
799    out_lower: &mut [f64],
800    input: &EmdTrendInput,
801    kernel: Kernel,
802) -> Result<(), EmdTrendError> {
803    let (open, high, low, close) = input_slices(input)?;
804    let (source_kind, avg_kind) = validate_params_only(
805        input.get_source(),
806        input.get_avg_type(),
807        input.get_length(),
808        input.get_mult(),
809    )?;
810    let src = build_source_series(open, high, low, close, source_kind)?;
811    let longest = longest_valid_run(&src);
812    let needed = if avg_kind == AvgKind::Frama && input.get_length() % 2 == 1 {
813        input.get_length() + 1
814    } else {
815        input.get_length()
816    };
817    if longest < needed {
818        return Err(EmdTrendError::NotEnoughValidData {
819            needed,
820            valid: longest,
821        });
822    }
823
824    let chosen = normalize_single_kernel(kernel);
825    compute_from_source_into(
826        &src,
827        avg_kind,
828        input.get_length(),
829        input.get_mult(),
830        chosen,
831        out_direction,
832        out_average,
833        out_upper,
834        out_lower,
835    )
836}
837
838#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
839pub fn emd_trend_into(
840    input: &EmdTrendInput,
841    out_direction: &mut [f64],
842    out_average: &mut [f64],
843    out_upper: &mut [f64],
844    out_lower: &mut [f64],
845) -> Result<(), EmdTrendError> {
846    emd_trend_into_slice(
847        out_direction,
848        out_average,
849        out_upper,
850        out_lower,
851        input,
852        Kernel::Auto,
853    )
854}
855
856#[derive(Debug, Clone)]
857pub struct EmdTrendBatchRange {
858    pub length: (usize, usize, usize),
859    pub mult: (f64, f64, f64),
860}
861
862impl Default for EmdTrendBatchRange {
863    fn default() -> Self {
864        Self {
865            length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
866            mult: (DEFAULT_MULT, DEFAULT_MULT, 0.0),
867        }
868    }
869}
870
871#[derive(Debug, Clone)]
872pub struct EmdTrendBatchOutput {
873    pub direction: Vec<f64>,
874    pub average: Vec<f64>,
875    pub upper: Vec<f64>,
876    pub lower: Vec<f64>,
877    pub combos: Vec<EmdTrendParams>,
878    pub rows: usize,
879    pub cols: usize,
880}
881
882#[derive(Clone, Debug)]
883pub struct EmdTrendBatchBuilder {
884    range: EmdTrendBatchRange,
885    source: Option<SourceKind>,
886    avg_type: Option<AvgKind>,
887    kernel: Kernel,
888}
889
890impl Default for EmdTrendBatchBuilder {
891    fn default() -> Self {
892        Self {
893            range: EmdTrendBatchRange::default(),
894            source: None,
895            avg_type: None,
896            kernel: Kernel::Auto,
897        }
898    }
899}
900
901impl EmdTrendBatchBuilder {
902    pub fn new() -> Self {
903        Self::default()
904    }
905
906    pub fn source(mut self, value: &str) -> Result<Self, EmdTrendError> {
907        self.source = Some(parse_source_kind(value)?);
908        Ok(self)
909    }
910
911    pub fn avg_type(mut self, value: &str) -> Result<Self, EmdTrendError> {
912        self.avg_type = Some(parse_avg_kind(value)?);
913        Ok(self)
914    }
915
916    pub fn kernel(mut self, value: Kernel) -> Self {
917        self.kernel = value;
918        self
919    }
920
921    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
922        self.range.length = (start, end, step);
923        self
924    }
925
926    pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
927        self.range.mult = (start, end, step);
928        self
929    }
930
931    pub fn apply_slices(
932        self,
933        open: &[f64],
934        high: &[f64],
935        low: &[f64],
936        close: &[f64],
937    ) -> Result<EmdTrendBatchOutput, EmdTrendError> {
938        emd_trend_batch_with_kernel(
939            open,
940            high,
941            low,
942            close,
943            &self.range,
944            source_kind_name(self.source.unwrap_or(SourceKind::Close)),
945            avg_kind_name(self.avg_type.unwrap_or(AvgKind::Sma)),
946            self.kernel,
947        )
948    }
949
950    pub fn apply_candles(self, candles: &Candles) -> Result<EmdTrendBatchOutput, EmdTrendError> {
951        self.apply_slices(
952            candles.open.as_slice(),
953            candles.high.as_slice(),
954            candles.low.as_slice(),
955            candles.close.as_slice(),
956        )
957    }
958}
959
960pub fn expand_grid_emd_trend(
961    sweep: &EmdTrendBatchRange,
962    source: &str,
963    avg_type: &str,
964) -> Result<Vec<EmdTrendParams>, EmdTrendError> {
965    let (source_kind, avg_kind) =
966        validate_params_only(source, avg_type, DEFAULT_LENGTH, DEFAULT_MULT)?;
967    let lengths = axis_usize(sweep.length);
968    let mults = axis_f64(sweep.mult);
969    let mut out = Vec::with_capacity(lengths.len().saturating_mul(mults.len()));
970    for &length in &lengths {
971        if length == 0 {
972            return Err(EmdTrendError::InvalidRange {
973                start: sweep.length.0.to_string(),
974                end: sweep.length.1.to_string(),
975                step: sweep.length.2.to_string(),
976            });
977        }
978        for &mult in &mults {
979            if !mult.is_finite() || mult < 0.05 {
980                return Err(EmdTrendError::InvalidRange {
981                    start: sweep.mult.0.to_string(),
982                    end: sweep.mult.1.to_string(),
983                    step: sweep.mult.2.to_string(),
984                });
985            }
986            out.push(EmdTrendParams {
987                source: Some(source_kind_name(source_kind).to_string()),
988                avg_type: Some(avg_kind_name(avg_kind).to_string()),
989                length: Some(length),
990                mult: Some(mult),
991            });
992        }
993    }
994    Ok(out)
995}
996
997pub fn emd_trend_batch_with_kernel(
998    open: &[f64],
999    high: &[f64],
1000    low: &[f64],
1001    close: &[f64],
1002    sweep: &EmdTrendBatchRange,
1003    source: &str,
1004    avg_type: &str,
1005    kernel: Kernel,
1006) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1007    let combos = expand_grid_emd_trend(sweep, source, avg_type)?;
1008    if combos.is_empty() {
1009        return Err(EmdTrendError::InvalidRange {
1010            start: sweep.length.0.to_string(),
1011            end: sweep.length.1.to_string(),
1012            step: sweep.length.2.to_string(),
1013        });
1014    }
1015    let rows = combos.len();
1016    let cols = close.len();
1017    let total = rows
1018        .checked_mul(cols)
1019        .ok_or_else(|| EmdTrendError::InvalidInput {
1020            msg: "emd_trend: rows*cols overflow in batch".to_string(),
1021        })?;
1022
1023    let mut direction_mu = make_uninit_matrix(rows, cols);
1024    let mut average_mu = make_uninit_matrix(rows, cols);
1025    let mut upper_mu = make_uninit_matrix(rows, cols);
1026    let mut lower_mu = make_uninit_matrix(rows, cols);
1027
1028    let mut direction_guard = core::mem::ManuallyDrop::new(direction_mu);
1029    let mut average_guard = core::mem::ManuallyDrop::new(average_mu);
1030    let mut upper_guard = core::mem::ManuallyDrop::new(upper_mu);
1031    let mut lower_guard = core::mem::ManuallyDrop::new(lower_mu);
1032
1033    let direction =
1034        unsafe { std::slice::from_raw_parts_mut(direction_guard.as_mut_ptr() as *mut f64, total) };
1035    let average =
1036        unsafe { std::slice::from_raw_parts_mut(average_guard.as_mut_ptr() as *mut f64, total) };
1037    let upper =
1038        unsafe { std::slice::from_raw_parts_mut(upper_guard.as_mut_ptr() as *mut f64, total) };
1039    let lower =
1040        unsafe { std::slice::from_raw_parts_mut(lower_guard.as_mut_ptr() as *mut f64, total) };
1041
1042    emd_trend_batch_inner_into(
1043        open, high, low, close, sweep, source, avg_type, kernel, false, direction, average, upper,
1044        lower,
1045    )?;
1046
1047    let direction = unsafe {
1048        Vec::from_raw_parts(
1049            direction_guard.as_mut_ptr() as *mut f64,
1050            direction_guard.len(),
1051            direction_guard.capacity(),
1052        )
1053    };
1054    let average = unsafe {
1055        Vec::from_raw_parts(
1056            average_guard.as_mut_ptr() as *mut f64,
1057            average_guard.len(),
1058            average_guard.capacity(),
1059        )
1060    };
1061    let upper = unsafe {
1062        Vec::from_raw_parts(
1063            upper_guard.as_mut_ptr() as *mut f64,
1064            upper_guard.len(),
1065            upper_guard.capacity(),
1066        )
1067    };
1068    let lower = unsafe {
1069        Vec::from_raw_parts(
1070            lower_guard.as_mut_ptr() as *mut f64,
1071            lower_guard.len(),
1072            lower_guard.capacity(),
1073        )
1074    };
1075
1076    Ok(EmdTrendBatchOutput {
1077        direction,
1078        average,
1079        upper,
1080        lower,
1081        combos,
1082        rows,
1083        cols,
1084    })
1085}
1086
1087pub fn emd_trend_batch_slice(
1088    open: &[f64],
1089    high: &[f64],
1090    low: &[f64],
1091    close: &[f64],
1092    sweep: &EmdTrendBatchRange,
1093    source: &str,
1094    avg_type: &str,
1095    kernel: Kernel,
1096) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1097    emd_trend_batch_with_kernel(open, high, low, close, sweep, source, avg_type, kernel)
1098}
1099
1100pub fn emd_trend_batch_par_slice(
1101    open: &[f64],
1102    high: &[f64],
1103    low: &[f64],
1104    close: &[f64],
1105    sweep: &EmdTrendBatchRange,
1106    source: &str,
1107    avg_type: &str,
1108    kernel: Kernel,
1109) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1110    emd_trend_batch_with_kernel(open, high, low, close, sweep, source, avg_type, kernel)
1111}
1112
1113fn emd_trend_batch_inner_into(
1114    open: &[f64],
1115    high: &[f64],
1116    low: &[f64],
1117    close: &[f64],
1118    sweep: &EmdTrendBatchRange,
1119    source: &str,
1120    avg_type: &str,
1121    kernel: Kernel,
1122    _parallel: bool,
1123    out_direction: &mut [f64],
1124    out_average: &mut [f64],
1125    out_upper: &mut [f64],
1126    out_lower: &mut [f64],
1127) -> Result<Vec<EmdTrendParams>, EmdTrendError> {
1128    let (source_kind, avg_kind) =
1129        validate_params_only(source, avg_type, DEFAULT_LENGTH, DEFAULT_MULT)?;
1130    let src = build_source_series(open, high, low, close, source_kind)?;
1131    let longest = longest_valid_run(&src);
1132    if longest == 0 {
1133        return Err(EmdTrendError::AllValuesNaN);
1134    }
1135
1136    let combos = expand_grid_emd_trend(sweep, source, avg_type)?;
1137    let rows = combos.len();
1138    let cols = close.len();
1139    let total = rows
1140        .checked_mul(cols)
1141        .ok_or_else(|| EmdTrendError::InvalidInput {
1142            msg: "emd_trend: rows*cols overflow in batch_into".to_string(),
1143        })?;
1144
1145    if out_direction.len() != total {
1146        return Err(EmdTrendError::MismatchedOutputLen {
1147            dst_len: out_direction.len(),
1148            expected_len: total,
1149        });
1150    }
1151    if out_average.len() != total {
1152        return Err(EmdTrendError::MismatchedOutputLen {
1153            dst_len: out_average.len(),
1154            expected_len: total,
1155        });
1156    }
1157    if out_upper.len() != total {
1158        return Err(EmdTrendError::MismatchedOutputLen {
1159            dst_len: out_upper.len(),
1160            expected_len: total,
1161        });
1162    }
1163    if out_lower.len() != total {
1164        return Err(EmdTrendError::MismatchedOutputLen {
1165            dst_len: out_lower.len(),
1166            expected_len: total,
1167        });
1168    }
1169
1170    let chosen = normalize_batch_kernel(kernel)?;
1171    for (row, combo) in combos.iter().enumerate() {
1172        let length = combo.length.unwrap_or(DEFAULT_LENGTH);
1173        let mult = combo.mult.unwrap_or(DEFAULT_MULT);
1174        let needed = if avg_kind == AvgKind::Frama && length % 2 == 1 {
1175            length + 1
1176        } else {
1177            length
1178        };
1179        if longest < needed {
1180            return Err(EmdTrendError::NotEnoughValidData {
1181                needed,
1182                valid: longest,
1183            });
1184        }
1185
1186        let start = row * cols;
1187        let end = start + cols;
1188        compute_from_source_into(
1189            &src,
1190            avg_kind,
1191            length,
1192            mult,
1193            chosen,
1194            &mut out_direction[start..end],
1195            &mut out_average[start..end],
1196            &mut out_upper[start..end],
1197            &mut out_lower[start..end],
1198        )?;
1199    }
1200
1201    Ok(combos)
1202}
1203
1204pub struct EmdTrendStream {
1205    source_kind: SourceKind,
1206    avg_kind: AvgKind,
1207    direction: f64,
1208    length: usize,
1209    mult: f64,
1210    src_history: Vec<f64>,
1211}
1212
1213impl EmdTrendStream {
1214    pub fn try_new(params: EmdTrendParams) -> Result<Self, EmdTrendError> {
1215        let source = params.source.as_deref().unwrap_or(DEFAULT_SOURCE);
1216        let avg_type = params.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE);
1217        let length = params.length.unwrap_or(DEFAULT_LENGTH);
1218        let mult = params.mult.unwrap_or(DEFAULT_MULT);
1219        let (source_kind, avg_kind) = validate_params_only(source, avg_type, length, mult)?;
1220
1221        Ok(Self {
1222            source_kind,
1223            avg_kind,
1224            direction: 0.0,
1225            length,
1226            mult,
1227            src_history: Vec::new(),
1228        })
1229    }
1230
1231    pub fn update(
1232        &mut self,
1233        open: f64,
1234        high: f64,
1235        low: f64,
1236        close: f64,
1237    ) -> Option<(f64, f64, f64, f64)> {
1238        let src = source_value(self.source_kind, open, high, low, close);
1239        if !src.is_finite() {
1240            return None;
1241        }
1242        self.src_history.push(src);
1243        let len = self.src_history.len();
1244        let mut direction = vec![0.0; len];
1245        let mut average = vec![f64::NAN; len];
1246        let mut upper = vec![f64::NAN; len];
1247        let mut lower = vec![f64::NAN; len];
1248
1249        match compute_from_source_into(
1250            &self.src_history,
1251            self.avg_kind,
1252            self.length,
1253            self.mult,
1254            Kernel::Scalar,
1255            &mut direction,
1256            &mut average,
1257            &mut upper,
1258            &mut lower,
1259        ) {
1260            Ok(()) => {
1261                self.direction = direction[len - 1];
1262                Some((
1263                    self.direction,
1264                    average[len - 1],
1265                    upper[len - 1],
1266                    lower[len - 1],
1267                ))
1268            }
1269            Err(EmdTrendError::NotEnoughValidData { .. }) => {
1270                Some((self.direction, f64::NAN, f64::NAN, f64::NAN))
1271            }
1272            Err(_) => Some((self.direction, f64::NAN, f64::NAN, f64::NAN)),
1273        }
1274    }
1275}
1276
1277#[cfg(feature = "python")]
1278#[pyclass(name = "EmdTrendStream")]
1279pub struct EmdTrendStreamPy {
1280    inner: EmdTrendStream,
1281}
1282
1283#[cfg(feature = "python")]
1284#[pymethods]
1285impl EmdTrendStreamPy {
1286    #[new]
1287    fn new(
1288        source: Option<&str>,
1289        avg_type: Option<&str>,
1290        length: usize,
1291        mult: f64,
1292    ) -> PyResult<Self> {
1293        let inner = EmdTrendStream::try_new(EmdTrendParams {
1294            source: source.map(str::to_string),
1295            avg_type: avg_type.map(str::to_string),
1296            length: Some(length),
1297            mult: Some(mult),
1298        })
1299        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1300        Ok(Self { inner })
1301    }
1302
1303    fn update(
1304        &mut self,
1305        open: f64,
1306        high: f64,
1307        low: f64,
1308        close: f64,
1309    ) -> Option<(f64, f64, f64, f64)> {
1310        self.inner.update(open, high, low, close)
1311    }
1312}
1313
1314#[cfg(feature = "python")]
1315#[pyfunction(
1316    name = "emd_trend",
1317    signature = (open, high, low, close, source="close", avg_type="SMA", length=28, mult=1.0, kernel=None)
1318)]
1319pub fn emd_trend_py<'py>(
1320    py: Python<'py>,
1321    open: PyReadonlyArray1<'py, f64>,
1322    high: PyReadonlyArray1<'py, f64>,
1323    low: PyReadonlyArray1<'py, f64>,
1324    close: PyReadonlyArray1<'py, f64>,
1325    source: &str,
1326    avg_type: &str,
1327    length: usize,
1328    mult: f64,
1329    kernel: Option<&str>,
1330) -> PyResult<(
1331    Bound<'py, PyArray1<f64>>,
1332    Bound<'py, PyArray1<f64>>,
1333    Bound<'py, PyArray1<f64>>,
1334    Bound<'py, PyArray1<f64>>,
1335)> {
1336    let open = open.as_slice()?;
1337    let high = high.as_slice()?;
1338    let low = low.as_slice()?;
1339    let close = close.as_slice()?;
1340    let input = EmdTrendInput::from_slices(
1341        open,
1342        high,
1343        low,
1344        close,
1345        EmdTrendParams {
1346            source: Some(source.to_string()),
1347            avg_type: Some(avg_type.to_string()),
1348            length: Some(length),
1349            mult: Some(mult),
1350        },
1351    );
1352    let kern = validate_kernel(kernel, false)?;
1353    let out = py
1354        .allow_threads(|| emd_trend_with_kernel(&input, kern))
1355        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1356    Ok((
1357        out.direction.into_pyarray(py),
1358        out.average.into_pyarray(py),
1359        out.upper.into_pyarray(py),
1360        out.lower.into_pyarray(py),
1361    ))
1362}
1363
1364#[cfg(feature = "python")]
1365#[pyfunction(
1366    name = "emd_trend_batch",
1367    signature = (open, high, low, close, length_range=(28, 28, 0), mult_range=(1.0, 1.0, 0.0), source="close", avg_type="SMA", kernel=None)
1368)]
1369pub fn emd_trend_batch_py<'py>(
1370    py: Python<'py>,
1371    open: PyReadonlyArray1<'py, f64>,
1372    high: PyReadonlyArray1<'py, f64>,
1373    low: PyReadonlyArray1<'py, f64>,
1374    close: PyReadonlyArray1<'py, f64>,
1375    length_range: (usize, usize, usize),
1376    mult_range: (f64, f64, f64),
1377    source: &str,
1378    avg_type: &str,
1379    kernel: Option<&str>,
1380) -> PyResult<Bound<'py, PyDict>> {
1381    let open = open.as_slice()?;
1382    let high = high.as_slice()?;
1383    let low = low.as_slice()?;
1384    let close = close.as_slice()?;
1385    let kern = validate_kernel(kernel, true)?;
1386    let batch = py
1387        .allow_threads(|| {
1388            emd_trend_batch_with_kernel(
1389                open,
1390                high,
1391                low,
1392                close,
1393                &EmdTrendBatchRange {
1394                    length: length_range,
1395                    mult: mult_range,
1396                },
1397                source,
1398                avg_type,
1399                kern,
1400            )
1401        })
1402        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1403
1404    let dict = PyDict::new(py);
1405    let rows = batch.rows;
1406    let cols = batch.cols;
1407    dict.set_item(
1408        "direction",
1409        PyArray1::from_vec(py, batch.direction).reshape((rows, cols))?,
1410    )?;
1411    dict.set_item(
1412        "average",
1413        PyArray1::from_vec(py, batch.average).reshape((rows, cols))?,
1414    )?;
1415    dict.set_item(
1416        "upper",
1417        PyArray1::from_vec(py, batch.upper).reshape((rows, cols))?,
1418    )?;
1419    dict.set_item(
1420        "lower",
1421        PyArray1::from_vec(py, batch.lower).reshape((rows, cols))?,
1422    )?;
1423    dict.set_item(
1424        "lengths",
1425        batch
1426            .combos
1427            .iter()
1428            .map(|params| params.length.unwrap_or(DEFAULT_LENGTH) as u64)
1429            .collect::<Vec<_>>()
1430            .into_pyarray(py),
1431    )?;
1432    dict.set_item(
1433        "mults",
1434        batch
1435            .combos
1436            .iter()
1437            .map(|params| params.mult.unwrap_or(DEFAULT_MULT))
1438            .collect::<Vec<_>>()
1439            .into_pyarray(py),
1440    )?;
1441    dict.set_item("rows", rows)?;
1442    dict.set_item("cols", cols)?;
1443    Ok(dict)
1444}
1445
1446#[cfg(feature = "python")]
1447pub fn register_emd_trend_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1448    m.add_function(wrap_pyfunction!(emd_trend_py, m)?)?;
1449    m.add_function(wrap_pyfunction!(emd_trend_batch_py, m)?)?;
1450    m.add_class::<EmdTrendStreamPy>()?;
1451    Ok(())
1452}
1453
1454#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1455#[derive(Serialize, Deserialize)]
1456pub struct EmdTrendBatchConfig {
1457    pub length_range: Vec<usize>,
1458    pub mult_range: Vec<f64>,
1459    pub source: Option<String>,
1460    pub avg_type: Option<String>,
1461}
1462
1463#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1464#[wasm_bindgen(js_name = emd_trend_js)]
1465pub fn emd_trend_js(
1466    open: &[f64],
1467    high: &[f64],
1468    low: &[f64],
1469    close: &[f64],
1470    source: &str,
1471    avg_type: &str,
1472    length: usize,
1473    mult: f64,
1474) -> Result<JsValue, JsValue> {
1475    let out = emd_trend_with_kernel(
1476        &EmdTrendInput::from_slices(
1477            open,
1478            high,
1479            low,
1480            close,
1481            EmdTrendParams {
1482                source: Some(source.to_string()),
1483                avg_type: Some(avg_type.to_string()),
1484                length: Some(length),
1485                mult: Some(mult),
1486            },
1487        ),
1488        Kernel::Auto,
1489    )
1490    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1491
1492    let obj = js_sys::Object::new();
1493    js_sys::Reflect::set(
1494        &obj,
1495        &JsValue::from_str("direction"),
1496        &serde_wasm_bindgen::to_value(&out.direction).unwrap(),
1497    )?;
1498    js_sys::Reflect::set(
1499        &obj,
1500        &JsValue::from_str("average"),
1501        &serde_wasm_bindgen::to_value(&out.average).unwrap(),
1502    )?;
1503    js_sys::Reflect::set(
1504        &obj,
1505        &JsValue::from_str("upper"),
1506        &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1507    )?;
1508    js_sys::Reflect::set(
1509        &obj,
1510        &JsValue::from_str("lower"),
1511        &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1512    )?;
1513    Ok(obj.into())
1514}
1515
1516#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1517#[wasm_bindgen(js_name = emd_trend_batch_js)]
1518pub fn emd_trend_batch_js(
1519    open: &[f64],
1520    high: &[f64],
1521    low: &[f64],
1522    close: &[f64],
1523    config: JsValue,
1524) -> Result<JsValue, JsValue> {
1525    let config: EmdTrendBatchConfig = serde_wasm_bindgen::from_value(config)
1526        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1527    if config.length_range.len() != 3 || config.mult_range.len() != 3 {
1528        return Err(JsValue::from_str(
1529            "Invalid config: every range must have exactly 3 elements [start, end, step]",
1530        ));
1531    }
1532    let source = config.source.as_deref().unwrap_or(DEFAULT_SOURCE);
1533    let avg_type = config.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE);
1534    let out = emd_trend_batch_with_kernel(
1535        open,
1536        high,
1537        low,
1538        close,
1539        &EmdTrendBatchRange {
1540            length: (
1541                config.length_range[0],
1542                config.length_range[1],
1543                config.length_range[2],
1544            ),
1545            mult: (
1546                config.mult_range[0],
1547                config.mult_range[1],
1548                config.mult_range[2],
1549            ),
1550        },
1551        source,
1552        avg_type,
1553        Kernel::Auto,
1554    )
1555    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1556
1557    let obj = js_sys::Object::new();
1558    js_sys::Reflect::set(
1559        &obj,
1560        &JsValue::from_str("direction"),
1561        &serde_wasm_bindgen::to_value(&out.direction).unwrap(),
1562    )?;
1563    js_sys::Reflect::set(
1564        &obj,
1565        &JsValue::from_str("average"),
1566        &serde_wasm_bindgen::to_value(&out.average).unwrap(),
1567    )?;
1568    js_sys::Reflect::set(
1569        &obj,
1570        &JsValue::from_str("upper"),
1571        &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1572    )?;
1573    js_sys::Reflect::set(
1574        &obj,
1575        &JsValue::from_str("lower"),
1576        &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1577    )?;
1578    js_sys::Reflect::set(
1579        &obj,
1580        &JsValue::from_str("rows"),
1581        &JsValue::from_f64(out.rows as f64),
1582    )?;
1583    js_sys::Reflect::set(
1584        &obj,
1585        &JsValue::from_str("cols"),
1586        &JsValue::from_f64(out.cols as f64),
1587    )?;
1588    js_sys::Reflect::set(
1589        &obj,
1590        &JsValue::from_str("combos"),
1591        &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1592    )?;
1593    Ok(obj.into())
1594}
1595
1596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1597#[wasm_bindgen]
1598pub fn emd_trend_alloc(len: usize) -> *mut f64 {
1599    let mut vec = Vec::<f64>::with_capacity(4 * len);
1600    let ptr = vec.as_mut_ptr();
1601    std::mem::forget(vec);
1602    ptr
1603}
1604
1605#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1606#[wasm_bindgen]
1607pub fn emd_trend_free(ptr: *mut f64, len: usize) {
1608    if !ptr.is_null() {
1609        unsafe {
1610            let _ = Vec::from_raw_parts(ptr, 4 * len, 4 * len);
1611        }
1612    }
1613}
1614
1615#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1616#[wasm_bindgen]
1617pub fn emd_trend_into(
1618    open_ptr: *const f64,
1619    high_ptr: *const f64,
1620    low_ptr: *const f64,
1621    close_ptr: *const f64,
1622    out_ptr: *mut f64,
1623    len: usize,
1624    source: &str,
1625    avg_type: &str,
1626    length: usize,
1627    mult: f64,
1628) -> Result<(), JsValue> {
1629    if open_ptr.is_null()
1630        || high_ptr.is_null()
1631        || low_ptr.is_null()
1632        || close_ptr.is_null()
1633        || out_ptr.is_null()
1634    {
1635        return Err(JsValue::from_str("null pointer passed to emd_trend_into"));
1636    }
1637    unsafe {
1638        let open = std::slice::from_raw_parts(open_ptr, len);
1639        let high = std::slice::from_raw_parts(high_ptr, len);
1640        let low = std::slice::from_raw_parts(low_ptr, len);
1641        let close = std::slice::from_raw_parts(close_ptr, len);
1642        let out = std::slice::from_raw_parts_mut(out_ptr, 4 * len);
1643        let (direction, tail) = out.split_at_mut(len);
1644        let (average, tail) = tail.split_at_mut(len);
1645        let (upper, lower) = tail.split_at_mut(len);
1646        emd_trend_into_slice(
1647            direction,
1648            average,
1649            upper,
1650            lower,
1651            &EmdTrendInput::from_slices(
1652                open,
1653                high,
1654                low,
1655                close,
1656                EmdTrendParams {
1657                    source: Some(source.to_string()),
1658                    avg_type: Some(avg_type.to_string()),
1659                    length: Some(length),
1660                    mult: Some(mult),
1661                },
1662            ),
1663            Kernel::Auto,
1664        )
1665        .map_err(|e| JsValue::from_str(&e.to_string()))
1666    }
1667}
1668
1669#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1670#[wasm_bindgen]
1671pub fn emd_trend_batch_into(
1672    open_ptr: *const f64,
1673    high_ptr: *const f64,
1674    low_ptr: *const f64,
1675    close_ptr: *const f64,
1676    out_ptr: *mut f64,
1677    len: usize,
1678    length_start: usize,
1679    length_end: usize,
1680    length_step: usize,
1681    mult_start: f64,
1682    mult_end: f64,
1683    mult_step: f64,
1684    source: &str,
1685    avg_type: &str,
1686) -> Result<usize, JsValue> {
1687    if open_ptr.is_null()
1688        || high_ptr.is_null()
1689        || low_ptr.is_null()
1690        || close_ptr.is_null()
1691        || out_ptr.is_null()
1692    {
1693        return Err(JsValue::from_str(
1694            "null pointer passed to emd_trend_batch_into",
1695        ));
1696    }
1697
1698    let sweep = EmdTrendBatchRange {
1699        length: (length_start, length_end, length_step),
1700        mult: (mult_start, mult_end, mult_step),
1701    };
1702    let combos = expand_grid_emd_trend(&sweep, source, avg_type)
1703        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1704    let rows = combos.len();
1705    let split = rows
1706        .checked_mul(len)
1707        .ok_or_else(|| JsValue::from_str("rows*cols overflow in emd_trend_batch_into"))?;
1708    let total = split
1709        .checked_mul(4)
1710        .ok_or_else(|| JsValue::from_str("4*rows*cols overflow in emd_trend_batch_into"))?;
1711
1712    unsafe {
1713        let open = std::slice::from_raw_parts(open_ptr, len);
1714        let high = std::slice::from_raw_parts(high_ptr, len);
1715        let low = std::slice::from_raw_parts(low_ptr, len);
1716        let close = std::slice::from_raw_parts(close_ptr, len);
1717        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1718        let (direction, tail) = out.split_at_mut(split);
1719        let (average, tail) = tail.split_at_mut(split);
1720        let (upper, lower) = tail.split_at_mut(split);
1721        emd_trend_batch_inner_into(
1722            open,
1723            high,
1724            low,
1725            close,
1726            &sweep,
1727            source,
1728            avg_type,
1729            Kernel::Auto,
1730            false,
1731            direction,
1732            average,
1733            upper,
1734            lower,
1735        )
1736        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1737    }
1738
1739    Ok(rows)
1740}
1741
1742#[cfg(test)]
1743mod tests {
1744    use super::*;
1745    use crate::indicators::dispatch::{
1746        compute_cpu, IndicatorComputeRequest, IndicatorDataRef, IndicatorSeries, ParamKV,
1747        ParamValue,
1748    };
1749
1750    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1751        let mut open = Vec::with_capacity(len);
1752        let mut high = Vec::with_capacity(len);
1753        let mut low = Vec::with_capacity(len);
1754        let mut close = Vec::with_capacity(len);
1755        for i in 0..len {
1756            let x = i as f64;
1757            let base = 100.0 + x * 0.04 + (x * 0.11).sin() * 1.8;
1758            let o = base + (x * 0.17).cos() * 0.6;
1759            let c = base + (x * 0.09).sin() * 0.9;
1760            let h = o.max(c) + 0.8 + (x * 0.05).cos().abs() * 0.4;
1761            let l = o.min(c) - 0.8 - (x * 0.07).sin().abs() * 0.3;
1762            open.push(o);
1763            high.push(h);
1764            low.push(l);
1765            close.push(c);
1766        }
1767        (open, high, low, close)
1768    }
1769
1770    fn assert_vec_eq_nan(lhs: &[f64], rhs: &[f64]) {
1771        assert_eq!(lhs.len(), rhs.len());
1772        for (idx, (&l, &r)) in lhs.iter().zip(rhs.iter()).enumerate() {
1773            if l.is_nan() && r.is_nan() {
1774                continue;
1775            }
1776            assert!(
1777                (l - r).abs() <= 1e-10,
1778                "mismatch at {idx}: lhs={l}, rhs={r}"
1779            );
1780        }
1781    }
1782
1783    #[test]
1784    fn emd_trend_output_contract() -> Result<(), Box<dyn Error>> {
1785        let (open, high, low, close) = sample_ohlc(256);
1786        let out = emd_trend_with_kernel(
1787            &EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default()),
1788            Kernel::Scalar,
1789        )?;
1790        assert_eq!(out.direction.len(), close.len());
1791        assert_eq!(out.average.len(), close.len());
1792        assert_eq!(out.upper.len(), close.len());
1793        assert_eq!(out.lower.len(), close.len());
1794        assert!(out.average.iter().any(|v| v.is_finite()));
1795        Ok(())
1796    }
1797
1798    #[test]
1799    fn emd_trend_exact_small_case() -> Result<(), Box<dyn Error>> {
1800        let open = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1801        let high = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1802        let low = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1803        let close = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1804        let out = emd_trend_with_kernel(
1805            &EmdTrendInput::from_slices(
1806                &open,
1807                &high,
1808                &low,
1809                &close,
1810                EmdTrendParams {
1811                    source: Some("close".to_string()),
1812                    avg_type: Some("SMA".to_string()),
1813                    length: Some(2),
1814                    mult: Some(1.0),
1815                },
1816            ),
1817            Kernel::Scalar,
1818        )?;
1819        assert_vec_eq_nan(&out.direction, &[0.0, 0.0, 0.0, 1.0, 1.0]);
1820        assert_vec_eq_nan(&out.average, &[f64::NAN, 1.0, 1.0, 5.5, 10.0]);
1821        assert_vec_eq_nan(&out.upper, &[f64::NAN, 1.0, 1.0, 8.5, 11.0]);
1822        assert_vec_eq_nan(&out.lower, &[f64::NAN, 1.0, 1.0, 2.5, 9.0]);
1823        Ok(())
1824    }
1825
1826    #[test]
1827    fn emd_trend_into_matches_api() -> Result<(), Box<dyn Error>> {
1828        let (open, high, low, close) = sample_ohlc(200);
1829        let input =
1830            EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default());
1831        let base = emd_trend(&input)?;
1832        let mut direction = vec![0.0; close.len()];
1833        let mut average = vec![f64::NAN; close.len()];
1834        let mut upper = vec![f64::NAN; close.len()];
1835        let mut lower = vec![f64::NAN; close.len()];
1836        emd_trend_into(&input, &mut direction, &mut average, &mut upper, &mut lower)?;
1837        assert_vec_eq_nan(&direction, &base.direction);
1838        assert_vec_eq_nan(&average, &base.average);
1839        assert_vec_eq_nan(&upper, &base.upper);
1840        assert_vec_eq_nan(&lower, &base.lower);
1841        Ok(())
1842    }
1843
1844    #[test]
1845    fn emd_trend_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1846        let (open, high, low, close) = sample_ohlc(180);
1847        let single = emd_trend_with_kernel(
1848            &EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default()),
1849            Kernel::Scalar,
1850        )?;
1851        let batch = emd_trend_batch_with_kernel(
1852            &open,
1853            &high,
1854            &low,
1855            &close,
1856            &EmdTrendBatchRange::default(),
1857            "close",
1858            "SMA",
1859            Kernel::ScalarBatch,
1860        )?;
1861        assert_eq!(batch.rows, 1);
1862        assert_eq!(batch.cols, close.len());
1863        assert_vec_eq_nan(&batch.direction[..close.len()], &single.direction);
1864        assert_vec_eq_nan(&batch.average[..close.len()], &single.average);
1865        assert_vec_eq_nan(&batch.upper[..close.len()], &single.upper);
1866        assert_vec_eq_nan(&batch.lower[..close.len()], &single.lower);
1867        Ok(())
1868    }
1869
1870    #[test]
1871    fn emd_trend_stream_contract() -> Result<(), Box<dyn Error>> {
1872        let (open, high, low, close) = sample_ohlc(220);
1873        let mut stream = EmdTrendStream::try_new(EmdTrendParams {
1874            source: Some("close".to_string()),
1875            avg_type: Some("SMA".to_string()),
1876            length: Some(28),
1877            mult: Some(1.0),
1878        })?;
1879        let mut points = Vec::with_capacity(close.len());
1880        for i in 0..close.len() {
1881            points.push(stream.update(open[i], high[i], low[i], close[i]).unwrap());
1882        }
1883        assert_eq!(points.len(), close.len());
1884        assert!(points.iter().any(|(_, avg, _, _)| avg.is_finite()));
1885        Ok(())
1886    }
1887
1888    #[test]
1889    fn emd_trend_rejects_invalid_params() {
1890        let (open, high, low, close) = sample_ohlc(64);
1891        let err = emd_trend(&EmdTrendInput::from_slices(
1892            &open,
1893            &high,
1894            &low,
1895            &close,
1896            EmdTrendParams {
1897                source: Some("bad".to_string()),
1898                avg_type: Some("SMA".to_string()),
1899                length: Some(28),
1900                mult: Some(1.0),
1901            },
1902        ))
1903        .unwrap_err();
1904        assert!(matches!(err, EmdTrendError::InvalidSource { .. }));
1905    }
1906
1907    #[test]
1908    fn emd_trend_dispatch_compute_returns_average() -> Result<(), Box<dyn Error>> {
1909        let (open, high, low, close) = sample_ohlc(192);
1910        let params = [
1911            ParamKV {
1912                key: "source",
1913                value: ParamValue::EnumString("close"),
1914            },
1915            ParamKV {
1916                key: "avg_type",
1917                value: ParamValue::EnumString("SMA"),
1918            },
1919            ParamKV {
1920                key: "length",
1921                value: ParamValue::Int(28),
1922            },
1923            ParamKV {
1924                key: "mult",
1925                value: ParamValue::Float(1.0),
1926            },
1927        ];
1928        let out = compute_cpu(IndicatorComputeRequest {
1929            indicator_id: "emd_trend",
1930            output_id: Some("average"),
1931            data: IndicatorDataRef::Ohlc {
1932                open: &open,
1933                high: &high,
1934                low: &low,
1935                close: &close,
1936            },
1937            params: &params,
1938            kernel: Kernel::Scalar,
1939        })?;
1940        assert_eq!(out.output_id, "average");
1941        match out.series {
1942            IndicatorSeries::F64(values) => assert_eq!(values.len(), close.len()),
1943            other => panic!("expected f64 series, got {:?}", other),
1944        }
1945        Ok(())
1946    }
1947}