1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::ema::{ema_with_kernel, EmaData, EmaInput, EmaParams, EmaStream};
16use crate::indicators::moving_averages::frama::{FramaParams, FramaStream};
17use crate::indicators::moving_averages::ma::{ma_with_kernel, MaData};
18use crate::indicators::moving_averages::ma_stream::{ma_stream, MaStream};
19use crate::utilities::data_loader::Candles;
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26use std::error::Error;
27use thiserror::Error;
28
29const DEFAULT_SOURCE: &str = "close";
30const DEFAULT_AVG_TYPE: &str = "SMA";
31const DEFAULT_LENGTH: usize = 28;
32const DEFAULT_MULT: f64 = 1.0;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35enum SourceKind {
36 Open,
37 High,
38 Low,
39 Close,
40 Oc2,
41 Hl2,
42 Occ3,
43 Hlc3,
44 Ohlc4,
45 Hlcc4,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49enum AvgKind {
50 Sma,
51 Ema,
52 Hma,
53 Dema,
54 Tema,
55 Rma,
56 Frama,
57}
58
59#[derive(Debug, Clone)]
60pub enum EmdTrendData<'a> {
61 Candles {
62 candles: &'a Candles,
63 },
64 Slices {
65 open: &'a [f64],
66 high: &'a [f64],
67 low: &'a [f64],
68 close: &'a [f64],
69 },
70}
71
72#[derive(Debug, Clone)]
73pub struct EmdTrendOutput {
74 pub direction: Vec<f64>,
75 pub average: Vec<f64>,
76 pub upper: Vec<f64>,
77 pub lower: Vec<f64>,
78}
79
80#[derive(Debug, Clone)]
81#[cfg_attr(
82 all(target_arch = "wasm32", feature = "wasm"),
83 derive(Serialize, Deserialize)
84)]
85pub struct EmdTrendParams {
86 pub source: Option<String>,
87 pub avg_type: Option<String>,
88 pub length: Option<usize>,
89 pub mult: Option<f64>,
90}
91
92impl Default for EmdTrendParams {
93 fn default() -> Self {
94 Self {
95 source: Some(DEFAULT_SOURCE.to_string()),
96 avg_type: Some(DEFAULT_AVG_TYPE.to_string()),
97 length: Some(DEFAULT_LENGTH),
98 mult: Some(DEFAULT_MULT),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct EmdTrendInput<'a> {
105 pub data: EmdTrendData<'a>,
106 pub params: EmdTrendParams,
107}
108
109impl<'a> EmdTrendInput<'a> {
110 #[inline]
111 pub fn from_candles(candles: &'a Candles, params: EmdTrendParams) -> Self {
112 Self {
113 data: EmdTrendData::Candles { candles },
114 params,
115 }
116 }
117
118 #[inline]
119 pub fn from_slices(
120 open: &'a [f64],
121 high: &'a [f64],
122 low: &'a [f64],
123 close: &'a [f64],
124 params: EmdTrendParams,
125 ) -> Self {
126 Self {
127 data: EmdTrendData::Slices {
128 open,
129 high,
130 low,
131 close,
132 },
133 params,
134 }
135 }
136
137 #[inline]
138 pub fn with_default_candles(candles: &'a Candles) -> Self {
139 Self::from_candles(candles, EmdTrendParams::default())
140 }
141
142 #[inline]
143 pub fn get_source(&self) -> &str {
144 self.params.source.as_deref().unwrap_or(DEFAULT_SOURCE)
145 }
146
147 #[inline]
148 pub fn get_avg_type(&self) -> &str {
149 self.params.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE)
150 }
151
152 #[inline]
153 pub fn get_length(&self) -> usize {
154 self.params.length.unwrap_or(DEFAULT_LENGTH)
155 }
156
157 #[inline]
158 pub fn get_mult(&self) -> f64 {
159 self.params.mult.unwrap_or(DEFAULT_MULT)
160 }
161}
162
163#[derive(Copy, Clone, Debug)]
164pub struct EmdTrendBuilder {
165 source: Option<SourceKind>,
166 avg_type: Option<AvgKind>,
167 length: Option<usize>,
168 mult: Option<f64>,
169 kernel: Kernel,
170}
171
172impl Default for EmdTrendBuilder {
173 fn default() -> Self {
174 Self {
175 source: None,
176 avg_type: None,
177 length: None,
178 mult: None,
179 kernel: Kernel::Auto,
180 }
181 }
182}
183
184impl EmdTrendBuilder {
185 #[inline(always)]
186 pub fn new() -> Self {
187 Self::default()
188 }
189
190 #[inline(always)]
191 pub fn source(mut self, value: &str) -> Result<Self, EmdTrendError> {
192 self.source = Some(parse_source_kind(value)?);
193 Ok(self)
194 }
195
196 #[inline(always)]
197 pub fn avg_type(mut self, value: &str) -> Result<Self, EmdTrendError> {
198 self.avg_type = Some(parse_avg_kind(value)?);
199 Ok(self)
200 }
201
202 #[inline(always)]
203 pub fn length(mut self, value: usize) -> Self {
204 self.length = Some(value);
205 self
206 }
207
208 #[inline(always)]
209 pub fn mult(mut self, value: f64) -> Self {
210 self.mult = Some(value);
211 self
212 }
213
214 #[inline(always)]
215 pub fn kernel(mut self, value: Kernel) -> Self {
216 self.kernel = value;
217 self
218 }
219}
220
221#[derive(Debug, Error)]
222pub enum EmdTrendError {
223 #[error("emd_trend: Input data slice is empty.")]
224 EmptyInputData,
225 #[error(
226 "emd_trend: Input length mismatch: open = {open_len}, high = {high_len}, low = {low_len}, close = {close_len}"
227 )]
228 InputLengthMismatch {
229 open_len: usize,
230 high_len: usize,
231 low_len: usize,
232 close_len: usize,
233 },
234 #[error("emd_trend: All values are NaN.")]
235 AllValuesNaN,
236 #[error("emd_trend: Invalid source: {value}")]
237 InvalidSource { value: String },
238 #[error("emd_trend: Invalid avg_type: {value}")]
239 InvalidAvgType { value: String },
240 #[error("emd_trend: Invalid length: {length}")]
241 InvalidLength { length: usize },
242 #[error("emd_trend: Invalid mult: {mult}")]
243 InvalidMult { mult: f64 },
244 #[error("emd_trend: Not enough valid data: needed = {needed}, valid = {valid}")]
245 NotEnoughValidData { needed: usize, valid: usize },
246 #[error("emd_trend: Output length mismatch: expected = {expected}, got = {got}")]
247 OutputLengthMismatch { expected: usize, got: usize },
248 #[error("emd_trend: Output length mismatch: dst = {dst_len}, expected = {expected_len}")]
249 MismatchedOutputLen { dst_len: usize, expected_len: usize },
250 #[error("emd_trend: Invalid range: start={start}, end={end}, step={step}")]
251 InvalidRange {
252 start: String,
253 end: String,
254 step: String,
255 },
256 #[error("emd_trend: Invalid kernel for batch: {0:?}")]
257 InvalidKernelForBatch(Kernel),
258 #[error("emd_trend: Invalid input: {msg}")]
259 InvalidInput { msg: String },
260}
261
262#[inline(always)]
263fn source_kind_name(kind: SourceKind) -> &'static str {
264 match kind {
265 SourceKind::Open => "open",
266 SourceKind::High => "high",
267 SourceKind::Low => "low",
268 SourceKind::Close => "close",
269 SourceKind::Oc2 => "oc2",
270 SourceKind::Hl2 => "hl2",
271 SourceKind::Occ3 => "occ3",
272 SourceKind::Hlc3 => "hlc3",
273 SourceKind::Ohlc4 => "ohlc4",
274 SourceKind::Hlcc4 => "hlcc4",
275 }
276}
277
278#[inline(always)]
279fn avg_kind_name(kind: AvgKind) -> &'static str {
280 match kind {
281 AvgKind::Sma => "SMA",
282 AvgKind::Ema => "EMA",
283 AvgKind::Hma => "HMA",
284 AvgKind::Dema => "DEMA",
285 AvgKind::Tema => "TEMA",
286 AvgKind::Rma => "RMA",
287 AvgKind::Frama => "FRAMA",
288 }
289}
290
291#[inline(always)]
292fn parse_source_kind(value: &str) -> Result<SourceKind, EmdTrendError> {
293 if value.eq_ignore_ascii_case("open") {
294 return Ok(SourceKind::Open);
295 }
296 if value.eq_ignore_ascii_case("high") {
297 return Ok(SourceKind::High);
298 }
299 if value.eq_ignore_ascii_case("low") {
300 return Ok(SourceKind::Low);
301 }
302 if value.eq_ignore_ascii_case("close") {
303 return Ok(SourceKind::Close);
304 }
305 if value.eq_ignore_ascii_case("oc2") {
306 return Ok(SourceKind::Oc2);
307 }
308 if value.eq_ignore_ascii_case("hl2") {
309 return Ok(SourceKind::Hl2);
310 }
311 if value.eq_ignore_ascii_case("occ3") {
312 return Ok(SourceKind::Occ3);
313 }
314 if value.eq_ignore_ascii_case("hlc3") {
315 return Ok(SourceKind::Hlc3);
316 }
317 if value.eq_ignore_ascii_case("ohlc4") {
318 return Ok(SourceKind::Ohlc4);
319 }
320 if value.eq_ignore_ascii_case("hlcc4") {
321 return Ok(SourceKind::Hlcc4);
322 }
323 Err(EmdTrendError::InvalidSource {
324 value: value.to_string(),
325 })
326}
327
328#[inline(always)]
329fn parse_avg_kind(value: &str) -> Result<AvgKind, EmdTrendError> {
330 if value.eq_ignore_ascii_case("SMA") {
331 return Ok(AvgKind::Sma);
332 }
333 if value.eq_ignore_ascii_case("EMA") {
334 return Ok(AvgKind::Ema);
335 }
336 if value.eq_ignore_ascii_case("HMA") {
337 return Ok(AvgKind::Hma);
338 }
339 if value.eq_ignore_ascii_case("DEMA") {
340 return Ok(AvgKind::Dema);
341 }
342 if value.eq_ignore_ascii_case("TEMA") {
343 return Ok(AvgKind::Tema);
344 }
345 if value.eq_ignore_ascii_case("RMA") {
346 return Ok(AvgKind::Rma);
347 }
348 if value.eq_ignore_ascii_case("FRAMA") {
349 return Ok(AvgKind::Frama);
350 }
351 Err(EmdTrendError::InvalidAvgType {
352 value: value.to_string(),
353 })
354}
355
356#[inline(always)]
357fn internal_avg_id(kind: AvgKind) -> &'static str {
358 match kind {
359 AvgKind::Sma => "sma",
360 AvgKind::Ema => "ema",
361 AvgKind::Hma => "hma",
362 AvgKind::Dema => "dema",
363 AvgKind::Tema => "tema",
364 AvgKind::Rma => "wilders",
365 AvgKind::Frama => "frama",
366 }
367}
368
369#[inline(always)]
370fn source_value(kind: SourceKind, open: f64, high: f64, low: f64, close: f64) -> f64 {
371 match kind {
372 SourceKind::Open => open,
373 SourceKind::High => high,
374 SourceKind::Low => low,
375 SourceKind::Close => close,
376 SourceKind::Oc2 => {
377 if open.is_finite() && close.is_finite() {
378 0.5 * (open + close)
379 } else {
380 f64::NAN
381 }
382 }
383 SourceKind::Hl2 => {
384 if high.is_finite() && low.is_finite() {
385 0.5 * (high + low)
386 } else {
387 f64::NAN
388 }
389 }
390 SourceKind::Occ3 => {
391 if open.is_finite() && close.is_finite() {
392 (open + close + close) / 3.0
393 } else {
394 f64::NAN
395 }
396 }
397 SourceKind::Hlc3 => {
398 if high.is_finite() && low.is_finite() && close.is_finite() {
399 (high + low + close) / 3.0
400 } else {
401 f64::NAN
402 }
403 }
404 SourceKind::Ohlc4 => {
405 if open.is_finite() && high.is_finite() && low.is_finite() && close.is_finite() {
406 (open + high + low + close) * 0.25
407 } else {
408 f64::NAN
409 }
410 }
411 SourceKind::Hlcc4 => {
412 if high.is_finite() && low.is_finite() && close.is_finite() {
413 (high + low + close + close) * 0.25
414 } else {
415 f64::NAN
416 }
417 }
418 }
419}
420
421#[inline(always)]
422fn longest_valid_run(values: &[f64]) -> usize {
423 let mut best = 0usize;
424 let mut current = 0usize;
425 for &value in values {
426 if value.is_finite() {
427 current += 1;
428 best = best.max(current);
429 } else {
430 current = 0;
431 }
432 }
433 best
434}
435
436#[inline(always)]
437fn axis_usize(range: (usize, usize, usize)) -> Vec<usize> {
438 let (start, end, step) = range;
439 if step == 0 || start == end {
440 return vec![start];
441 }
442 let (lo, hi) = if start <= end {
443 (start, end)
444 } else {
445 (end, start)
446 };
447 let mut out = Vec::new();
448 let mut value = lo;
449 loop {
450 out.push(value);
451 match value.checked_add(step) {
452 Some(next) if next <= hi => value = next,
453 _ => break,
454 }
455 }
456 if start > end {
457 out.reverse();
458 }
459 out
460}
461
462#[inline(always)]
463fn axis_f64(range: (f64, f64, f64)) -> Vec<f64> {
464 let (start, end, step) = range;
465 if step == 0.0 || (start - end).abs() <= f64::EPSILON {
466 return vec![start];
467 }
468 let (lo, hi, reverse) = if start <= end {
469 (start, end, false)
470 } else {
471 (end, start, true)
472 };
473 let mut out = Vec::new();
474 let mut value = lo;
475 let eps = step.abs() * 1e-9 + 1e-12;
476 while value <= hi + eps {
477 out.push(value);
478 value += step.abs();
479 }
480 if reverse {
481 out.reverse();
482 }
483 out
484}
485
486#[inline(always)]
487fn validate_params_only(
488 source: &str,
489 avg_type: &str,
490 length: usize,
491 mult: f64,
492) -> Result<(SourceKind, AvgKind), EmdTrendError> {
493 let source_kind = parse_source_kind(source)?;
494 let avg_kind = parse_avg_kind(avg_type)?;
495 if length == 0 {
496 return Err(EmdTrendError::InvalidLength { length });
497 }
498 if !mult.is_finite() || mult < 0.05 {
499 return Err(EmdTrendError::InvalidMult { mult });
500 }
501 Ok((source_kind, avg_kind))
502}
503
504#[inline(always)]
505fn input_slices<'a>(
506 input: &'a EmdTrendInput<'a>,
507) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), EmdTrendError> {
508 match &input.data {
509 EmdTrendData::Candles { candles } => Ok((
510 candles.open.as_slice(),
511 candles.high.as_slice(),
512 candles.low.as_slice(),
513 candles.close.as_slice(),
514 )),
515 EmdTrendData::Slices {
516 open,
517 high,
518 low,
519 close,
520 } => Ok((open, high, low, close)),
521 }
522}
523
524#[inline(always)]
525fn build_source_series(
526 open: &[f64],
527 high: &[f64],
528 low: &[f64],
529 close: &[f64],
530 source_kind: SourceKind,
531) -> Result<Vec<f64>, EmdTrendError> {
532 if open.is_empty() || high.is_empty() || low.is_empty() || close.is_empty() {
533 return Err(EmdTrendError::EmptyInputData);
534 }
535 if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
536 return Err(EmdTrendError::InputLengthMismatch {
537 open_len: open.len(),
538 high_len: high.len(),
539 low_len: low.len(),
540 close_len: close.len(),
541 });
542 }
543
544 let mut out = Vec::with_capacity(close.len());
545 for i in 0..close.len() {
546 out.push(source_value(
547 source_kind,
548 open[i],
549 high[i],
550 low[i],
551 close[i],
552 ));
553 }
554 if longest_valid_run(&out) == 0 {
555 return Err(EmdTrendError::AllValuesNaN);
556 }
557 Ok(out)
558}
559
560#[inline(always)]
561fn parse_number_after(segment: &str) -> Option<usize> {
562 let eq = segment.find('=')?;
563 let number: String = segment[eq + 1..]
564 .chars()
565 .skip_while(|c| c.is_whitespace())
566 .take_while(|c| c.is_ascii_digit())
567 .collect();
568 number.parse().ok()
569}
570
571#[inline(always)]
572fn parse_needed_valid(msg: &str) -> Option<(usize, usize)> {
573 let needed_key = "needed";
574 let valid_key = "valid";
575 let needed_idx = msg.find(needed_key)?;
576 let valid_idx = msg.find(valid_key)?;
577 let needed = parse_number_after(&msg[needed_idx + needed_key.len()..])?;
578 let valid = parse_number_after(&msg[valid_idx + valid_key.len()..])?;
579 Some((needed, valid))
580}
581
582#[inline(always)]
583fn map_ma_error(err: Box<dyn Error>) -> EmdTrendError {
584 let msg = err.to_string();
585 if let Some((needed, valid)) = parse_needed_valid(&msg) {
586 return EmdTrendError::NotEnoughValidData { needed, valid };
587 }
588 EmdTrendError::InvalidInput { msg }
589}
590
591#[inline(always)]
592fn normalize_single_kernel(kernel: Kernel) -> Kernel {
593 match kernel {
594 Kernel::Auto => detect_best_kernel(),
595 Kernel::ScalarBatch => Kernel::Scalar,
596 Kernel::Avx2Batch => Kernel::Avx2,
597 Kernel::Avx512Batch => Kernel::Avx512,
598 other => other,
599 }
600}
601
602#[inline(always)]
603fn normalize_batch_kernel(kernel: Kernel) -> Result<Kernel, EmdTrendError> {
604 Ok(match kernel {
605 Kernel::Auto => match detect_best_batch_kernel() {
606 Kernel::ScalarBatch => Kernel::Scalar,
607 Kernel::Avx2Batch => Kernel::Avx2,
608 Kernel::Avx512Batch => Kernel::Avx512,
609 other => other,
610 },
611 Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
612 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
613 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
614 other => return Err(EmdTrendError::InvalidKernelForBatch(other)),
615 })
616}
617
618#[inline(always)]
619fn compute_average_series(
620 src: &[f64],
621 avg_kind: AvgKind,
622 length: usize,
623 kernel: Kernel,
624) -> Result<Vec<f64>, EmdTrendError> {
625 ma_with_kernel(
626 internal_avg_id(avg_kind),
627 MaData::Slice(src),
628 length,
629 kernel,
630 )
631 .map_err(map_ma_error)
632}
633
634#[inline(always)]
635fn compute_deviation_ema(
636 avg: &[f64],
637 src: &[f64],
638 length: usize,
639 kernel: Kernel,
640) -> Result<Vec<f64>, EmdTrendError> {
641 let mut abs_dev = alloc_with_nan_prefix(src.len(), 0);
642 abs_dev.fill(f64::NAN);
643 for i in 0..src.len() {
644 if avg[i].is_finite() && src[i].is_finite() {
645 abs_dev[i] = (src[i] - avg[i]).abs();
646 }
647 }
648 let input = EmaInput {
649 data: EmaData::Slice(&abs_dev),
650 params: EmaParams {
651 period: Some(length),
652 },
653 };
654 ema_with_kernel(&input, kernel)
655 .map(|out| out.values)
656 .map_err(|e| EmdTrendError::InvalidInput { msg: e.to_string() })
657}
658
659fn compute_from_source_into(
660 src: &[f64],
661 avg_kind: AvgKind,
662 length: usize,
663 mult: f64,
664 kernel: Kernel,
665 direction_out: &mut [f64],
666 average_out: &mut [f64],
667 upper_out: &mut [f64],
668 lower_out: &mut [f64],
669) -> Result<(), EmdTrendError> {
670 let len = src.len();
671 if direction_out.len() != len {
672 return Err(EmdTrendError::OutputLengthMismatch {
673 expected: len,
674 got: direction_out.len(),
675 });
676 }
677 if average_out.len() != len {
678 return Err(EmdTrendError::OutputLengthMismatch {
679 expected: len,
680 got: average_out.len(),
681 });
682 }
683 if upper_out.len() != len {
684 return Err(EmdTrendError::OutputLengthMismatch {
685 expected: len,
686 got: upper_out.len(),
687 });
688 }
689 if lower_out.len() != len {
690 return Err(EmdTrendError::OutputLengthMismatch {
691 expected: len,
692 got: lower_out.len(),
693 });
694 }
695
696 let average = compute_average_series(src, avg_kind, length, kernel)?;
697 let deviation = compute_deviation_ema(&average, src, length, kernel)?;
698
699 average_out.copy_from_slice(&average);
700 direction_out.fill(0.0);
701 upper_out.fill(f64::NAN);
702 lower_out.fill(f64::NAN);
703
704 let mut direction = 0.0f64;
705 for i in 0..len {
706 if average[i].is_finite() && deviation[i].is_finite() {
707 upper_out[i] = average[i] + deviation[i] * mult;
708 lower_out[i] = average[i] - deviation[i] * mult;
709 }
710 if i > 0
711 && src[i].is_finite()
712 && src[i - 1].is_finite()
713 && upper_out[i].is_finite()
714 && upper_out[i - 1].is_finite()
715 && src[i] > upper_out[i]
716 && src[i - 1] <= upper_out[i - 1]
717 {
718 direction = 1.0;
719 } else if i > 0
720 && src[i].is_finite()
721 && src[i - 1].is_finite()
722 && lower_out[i].is_finite()
723 && lower_out[i - 1].is_finite()
724 && src[i] < lower_out[i]
725 && src[i - 1] >= lower_out[i - 1]
726 {
727 direction = -1.0;
728 }
729 direction_out[i] = direction;
730 }
731
732 Ok(())
733}
734
735#[inline]
736pub fn emd_trend(input: &EmdTrendInput) -> Result<EmdTrendOutput, EmdTrendError> {
737 emd_trend_with_kernel(input, Kernel::Auto)
738}
739
740pub fn emd_trend_with_kernel(
741 input: &EmdTrendInput,
742 kernel: Kernel,
743) -> Result<EmdTrendOutput, EmdTrendError> {
744 let (open, high, low, close) = input_slices(input)?;
745 let (source_kind, avg_kind) = validate_params_only(
746 input.get_source(),
747 input.get_avg_type(),
748 input.get_length(),
749 input.get_mult(),
750 )?;
751 let src = build_source_series(open, high, low, close, source_kind)?;
752 let longest = longest_valid_run(&src);
753 let needed = if avg_kind == AvgKind::Frama && input.get_length() % 2 == 1 {
754 input.get_length() + 1
755 } else {
756 input.get_length()
757 };
758 if longest < needed {
759 return Err(EmdTrendError::NotEnoughValidData {
760 needed,
761 valid: longest,
762 });
763 }
764
765 let chosen = normalize_single_kernel(kernel);
766 let mut direction = alloc_with_nan_prefix(src.len(), 0);
767 let mut average = alloc_with_nan_prefix(src.len(), 0);
768 let mut upper = alloc_with_nan_prefix(src.len(), 0);
769 let mut lower = alloc_with_nan_prefix(src.len(), 0);
770 direction.fill(0.0);
771 average.fill(f64::NAN);
772 upper.fill(f64::NAN);
773 lower.fill(f64::NAN);
774
775 compute_from_source_into(
776 &src,
777 avg_kind,
778 input.get_length(),
779 input.get_mult(),
780 chosen,
781 &mut direction,
782 &mut average,
783 &mut upper,
784 &mut lower,
785 )?;
786
787 Ok(EmdTrendOutput {
788 direction,
789 average,
790 upper,
791 lower,
792 })
793}
794
795pub fn emd_trend_into_slice(
796 out_direction: &mut [f64],
797 out_average: &mut [f64],
798 out_upper: &mut [f64],
799 out_lower: &mut [f64],
800 input: &EmdTrendInput,
801 kernel: Kernel,
802) -> Result<(), EmdTrendError> {
803 let (open, high, low, close) = input_slices(input)?;
804 let (source_kind, avg_kind) = validate_params_only(
805 input.get_source(),
806 input.get_avg_type(),
807 input.get_length(),
808 input.get_mult(),
809 )?;
810 let src = build_source_series(open, high, low, close, source_kind)?;
811 let longest = longest_valid_run(&src);
812 let needed = if avg_kind == AvgKind::Frama && input.get_length() % 2 == 1 {
813 input.get_length() + 1
814 } else {
815 input.get_length()
816 };
817 if longest < needed {
818 return Err(EmdTrendError::NotEnoughValidData {
819 needed,
820 valid: longest,
821 });
822 }
823
824 let chosen = normalize_single_kernel(kernel);
825 compute_from_source_into(
826 &src,
827 avg_kind,
828 input.get_length(),
829 input.get_mult(),
830 chosen,
831 out_direction,
832 out_average,
833 out_upper,
834 out_lower,
835 )
836}
837
838#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
839pub fn emd_trend_into(
840 input: &EmdTrendInput,
841 out_direction: &mut [f64],
842 out_average: &mut [f64],
843 out_upper: &mut [f64],
844 out_lower: &mut [f64],
845) -> Result<(), EmdTrendError> {
846 emd_trend_into_slice(
847 out_direction,
848 out_average,
849 out_upper,
850 out_lower,
851 input,
852 Kernel::Auto,
853 )
854}
855
856#[derive(Debug, Clone)]
857pub struct EmdTrendBatchRange {
858 pub length: (usize, usize, usize),
859 pub mult: (f64, f64, f64),
860}
861
862impl Default for EmdTrendBatchRange {
863 fn default() -> Self {
864 Self {
865 length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
866 mult: (DEFAULT_MULT, DEFAULT_MULT, 0.0),
867 }
868 }
869}
870
871#[derive(Debug, Clone)]
872pub struct EmdTrendBatchOutput {
873 pub direction: Vec<f64>,
874 pub average: Vec<f64>,
875 pub upper: Vec<f64>,
876 pub lower: Vec<f64>,
877 pub combos: Vec<EmdTrendParams>,
878 pub rows: usize,
879 pub cols: usize,
880}
881
882#[derive(Clone, Debug)]
883pub struct EmdTrendBatchBuilder {
884 range: EmdTrendBatchRange,
885 source: Option<SourceKind>,
886 avg_type: Option<AvgKind>,
887 kernel: Kernel,
888}
889
890impl Default for EmdTrendBatchBuilder {
891 fn default() -> Self {
892 Self {
893 range: EmdTrendBatchRange::default(),
894 source: None,
895 avg_type: None,
896 kernel: Kernel::Auto,
897 }
898 }
899}
900
901impl EmdTrendBatchBuilder {
902 pub fn new() -> Self {
903 Self::default()
904 }
905
906 pub fn source(mut self, value: &str) -> Result<Self, EmdTrendError> {
907 self.source = Some(parse_source_kind(value)?);
908 Ok(self)
909 }
910
911 pub fn avg_type(mut self, value: &str) -> Result<Self, EmdTrendError> {
912 self.avg_type = Some(parse_avg_kind(value)?);
913 Ok(self)
914 }
915
916 pub fn kernel(mut self, value: Kernel) -> Self {
917 self.kernel = value;
918 self
919 }
920
921 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
922 self.range.length = (start, end, step);
923 self
924 }
925
926 pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
927 self.range.mult = (start, end, step);
928 self
929 }
930
931 pub fn apply_slices(
932 self,
933 open: &[f64],
934 high: &[f64],
935 low: &[f64],
936 close: &[f64],
937 ) -> Result<EmdTrendBatchOutput, EmdTrendError> {
938 emd_trend_batch_with_kernel(
939 open,
940 high,
941 low,
942 close,
943 &self.range,
944 source_kind_name(self.source.unwrap_or(SourceKind::Close)),
945 avg_kind_name(self.avg_type.unwrap_or(AvgKind::Sma)),
946 self.kernel,
947 )
948 }
949
950 pub fn apply_candles(self, candles: &Candles) -> Result<EmdTrendBatchOutput, EmdTrendError> {
951 self.apply_slices(
952 candles.open.as_slice(),
953 candles.high.as_slice(),
954 candles.low.as_slice(),
955 candles.close.as_slice(),
956 )
957 }
958}
959
960pub fn expand_grid_emd_trend(
961 sweep: &EmdTrendBatchRange,
962 source: &str,
963 avg_type: &str,
964) -> Result<Vec<EmdTrendParams>, EmdTrendError> {
965 let (source_kind, avg_kind) =
966 validate_params_only(source, avg_type, DEFAULT_LENGTH, DEFAULT_MULT)?;
967 let lengths = axis_usize(sweep.length);
968 let mults = axis_f64(sweep.mult);
969 let mut out = Vec::with_capacity(lengths.len().saturating_mul(mults.len()));
970 for &length in &lengths {
971 if length == 0 {
972 return Err(EmdTrendError::InvalidRange {
973 start: sweep.length.0.to_string(),
974 end: sweep.length.1.to_string(),
975 step: sweep.length.2.to_string(),
976 });
977 }
978 for &mult in &mults {
979 if !mult.is_finite() || mult < 0.05 {
980 return Err(EmdTrendError::InvalidRange {
981 start: sweep.mult.0.to_string(),
982 end: sweep.mult.1.to_string(),
983 step: sweep.mult.2.to_string(),
984 });
985 }
986 out.push(EmdTrendParams {
987 source: Some(source_kind_name(source_kind).to_string()),
988 avg_type: Some(avg_kind_name(avg_kind).to_string()),
989 length: Some(length),
990 mult: Some(mult),
991 });
992 }
993 }
994 Ok(out)
995}
996
997pub fn emd_trend_batch_with_kernel(
998 open: &[f64],
999 high: &[f64],
1000 low: &[f64],
1001 close: &[f64],
1002 sweep: &EmdTrendBatchRange,
1003 source: &str,
1004 avg_type: &str,
1005 kernel: Kernel,
1006) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1007 let combos = expand_grid_emd_trend(sweep, source, avg_type)?;
1008 if combos.is_empty() {
1009 return Err(EmdTrendError::InvalidRange {
1010 start: sweep.length.0.to_string(),
1011 end: sweep.length.1.to_string(),
1012 step: sweep.length.2.to_string(),
1013 });
1014 }
1015 let rows = combos.len();
1016 let cols = close.len();
1017 let total = rows
1018 .checked_mul(cols)
1019 .ok_or_else(|| EmdTrendError::InvalidInput {
1020 msg: "emd_trend: rows*cols overflow in batch".to_string(),
1021 })?;
1022
1023 let mut direction_mu = make_uninit_matrix(rows, cols);
1024 let mut average_mu = make_uninit_matrix(rows, cols);
1025 let mut upper_mu = make_uninit_matrix(rows, cols);
1026 let mut lower_mu = make_uninit_matrix(rows, cols);
1027
1028 let mut direction_guard = core::mem::ManuallyDrop::new(direction_mu);
1029 let mut average_guard = core::mem::ManuallyDrop::new(average_mu);
1030 let mut upper_guard = core::mem::ManuallyDrop::new(upper_mu);
1031 let mut lower_guard = core::mem::ManuallyDrop::new(lower_mu);
1032
1033 let direction =
1034 unsafe { std::slice::from_raw_parts_mut(direction_guard.as_mut_ptr() as *mut f64, total) };
1035 let average =
1036 unsafe { std::slice::from_raw_parts_mut(average_guard.as_mut_ptr() as *mut f64, total) };
1037 let upper =
1038 unsafe { std::slice::from_raw_parts_mut(upper_guard.as_mut_ptr() as *mut f64, total) };
1039 let lower =
1040 unsafe { std::slice::from_raw_parts_mut(lower_guard.as_mut_ptr() as *mut f64, total) };
1041
1042 emd_trend_batch_inner_into(
1043 open, high, low, close, sweep, source, avg_type, kernel, false, direction, average, upper,
1044 lower,
1045 )?;
1046
1047 let direction = unsafe {
1048 Vec::from_raw_parts(
1049 direction_guard.as_mut_ptr() as *mut f64,
1050 direction_guard.len(),
1051 direction_guard.capacity(),
1052 )
1053 };
1054 let average = unsafe {
1055 Vec::from_raw_parts(
1056 average_guard.as_mut_ptr() as *mut f64,
1057 average_guard.len(),
1058 average_guard.capacity(),
1059 )
1060 };
1061 let upper = unsafe {
1062 Vec::from_raw_parts(
1063 upper_guard.as_mut_ptr() as *mut f64,
1064 upper_guard.len(),
1065 upper_guard.capacity(),
1066 )
1067 };
1068 let lower = unsafe {
1069 Vec::from_raw_parts(
1070 lower_guard.as_mut_ptr() as *mut f64,
1071 lower_guard.len(),
1072 lower_guard.capacity(),
1073 )
1074 };
1075
1076 Ok(EmdTrendBatchOutput {
1077 direction,
1078 average,
1079 upper,
1080 lower,
1081 combos,
1082 rows,
1083 cols,
1084 })
1085}
1086
1087pub fn emd_trend_batch_slice(
1088 open: &[f64],
1089 high: &[f64],
1090 low: &[f64],
1091 close: &[f64],
1092 sweep: &EmdTrendBatchRange,
1093 source: &str,
1094 avg_type: &str,
1095 kernel: Kernel,
1096) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1097 emd_trend_batch_with_kernel(open, high, low, close, sweep, source, avg_type, kernel)
1098}
1099
1100pub fn emd_trend_batch_par_slice(
1101 open: &[f64],
1102 high: &[f64],
1103 low: &[f64],
1104 close: &[f64],
1105 sweep: &EmdTrendBatchRange,
1106 source: &str,
1107 avg_type: &str,
1108 kernel: Kernel,
1109) -> Result<EmdTrendBatchOutput, EmdTrendError> {
1110 emd_trend_batch_with_kernel(open, high, low, close, sweep, source, avg_type, kernel)
1111}
1112
1113fn emd_trend_batch_inner_into(
1114 open: &[f64],
1115 high: &[f64],
1116 low: &[f64],
1117 close: &[f64],
1118 sweep: &EmdTrendBatchRange,
1119 source: &str,
1120 avg_type: &str,
1121 kernel: Kernel,
1122 _parallel: bool,
1123 out_direction: &mut [f64],
1124 out_average: &mut [f64],
1125 out_upper: &mut [f64],
1126 out_lower: &mut [f64],
1127) -> Result<Vec<EmdTrendParams>, EmdTrendError> {
1128 let (source_kind, avg_kind) =
1129 validate_params_only(source, avg_type, DEFAULT_LENGTH, DEFAULT_MULT)?;
1130 let src = build_source_series(open, high, low, close, source_kind)?;
1131 let longest = longest_valid_run(&src);
1132 if longest == 0 {
1133 return Err(EmdTrendError::AllValuesNaN);
1134 }
1135
1136 let combos = expand_grid_emd_trend(sweep, source, avg_type)?;
1137 let rows = combos.len();
1138 let cols = close.len();
1139 let total = rows
1140 .checked_mul(cols)
1141 .ok_or_else(|| EmdTrendError::InvalidInput {
1142 msg: "emd_trend: rows*cols overflow in batch_into".to_string(),
1143 })?;
1144
1145 if out_direction.len() != total {
1146 return Err(EmdTrendError::MismatchedOutputLen {
1147 dst_len: out_direction.len(),
1148 expected_len: total,
1149 });
1150 }
1151 if out_average.len() != total {
1152 return Err(EmdTrendError::MismatchedOutputLen {
1153 dst_len: out_average.len(),
1154 expected_len: total,
1155 });
1156 }
1157 if out_upper.len() != total {
1158 return Err(EmdTrendError::MismatchedOutputLen {
1159 dst_len: out_upper.len(),
1160 expected_len: total,
1161 });
1162 }
1163 if out_lower.len() != total {
1164 return Err(EmdTrendError::MismatchedOutputLen {
1165 dst_len: out_lower.len(),
1166 expected_len: total,
1167 });
1168 }
1169
1170 let chosen = normalize_batch_kernel(kernel)?;
1171 for (row, combo) in combos.iter().enumerate() {
1172 let length = combo.length.unwrap_or(DEFAULT_LENGTH);
1173 let mult = combo.mult.unwrap_or(DEFAULT_MULT);
1174 let needed = if avg_kind == AvgKind::Frama && length % 2 == 1 {
1175 length + 1
1176 } else {
1177 length
1178 };
1179 if longest < needed {
1180 return Err(EmdTrendError::NotEnoughValidData {
1181 needed,
1182 valid: longest,
1183 });
1184 }
1185
1186 let start = row * cols;
1187 let end = start + cols;
1188 compute_from_source_into(
1189 &src,
1190 avg_kind,
1191 length,
1192 mult,
1193 chosen,
1194 &mut out_direction[start..end],
1195 &mut out_average[start..end],
1196 &mut out_upper[start..end],
1197 &mut out_lower[start..end],
1198 )?;
1199 }
1200
1201 Ok(combos)
1202}
1203
1204pub struct EmdTrendStream {
1205 source_kind: SourceKind,
1206 avg_kind: AvgKind,
1207 direction: f64,
1208 length: usize,
1209 mult: f64,
1210 src_history: Vec<f64>,
1211}
1212
1213impl EmdTrendStream {
1214 pub fn try_new(params: EmdTrendParams) -> Result<Self, EmdTrendError> {
1215 let source = params.source.as_deref().unwrap_or(DEFAULT_SOURCE);
1216 let avg_type = params.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE);
1217 let length = params.length.unwrap_or(DEFAULT_LENGTH);
1218 let mult = params.mult.unwrap_or(DEFAULT_MULT);
1219 let (source_kind, avg_kind) = validate_params_only(source, avg_type, length, mult)?;
1220
1221 Ok(Self {
1222 source_kind,
1223 avg_kind,
1224 direction: 0.0,
1225 length,
1226 mult,
1227 src_history: Vec::new(),
1228 })
1229 }
1230
1231 pub fn update(
1232 &mut self,
1233 open: f64,
1234 high: f64,
1235 low: f64,
1236 close: f64,
1237 ) -> Option<(f64, f64, f64, f64)> {
1238 let src = source_value(self.source_kind, open, high, low, close);
1239 if !src.is_finite() {
1240 return None;
1241 }
1242 self.src_history.push(src);
1243 let len = self.src_history.len();
1244 let mut direction = vec![0.0; len];
1245 let mut average = vec![f64::NAN; len];
1246 let mut upper = vec![f64::NAN; len];
1247 let mut lower = vec![f64::NAN; len];
1248
1249 match compute_from_source_into(
1250 &self.src_history,
1251 self.avg_kind,
1252 self.length,
1253 self.mult,
1254 Kernel::Scalar,
1255 &mut direction,
1256 &mut average,
1257 &mut upper,
1258 &mut lower,
1259 ) {
1260 Ok(()) => {
1261 self.direction = direction[len - 1];
1262 Some((
1263 self.direction,
1264 average[len - 1],
1265 upper[len - 1],
1266 lower[len - 1],
1267 ))
1268 }
1269 Err(EmdTrendError::NotEnoughValidData { .. }) => {
1270 Some((self.direction, f64::NAN, f64::NAN, f64::NAN))
1271 }
1272 Err(_) => Some((self.direction, f64::NAN, f64::NAN, f64::NAN)),
1273 }
1274 }
1275}
1276
1277#[cfg(feature = "python")]
1278#[pyclass(name = "EmdTrendStream")]
1279pub struct EmdTrendStreamPy {
1280 inner: EmdTrendStream,
1281}
1282
1283#[cfg(feature = "python")]
1284#[pymethods]
1285impl EmdTrendStreamPy {
1286 #[new]
1287 fn new(
1288 source: Option<&str>,
1289 avg_type: Option<&str>,
1290 length: usize,
1291 mult: f64,
1292 ) -> PyResult<Self> {
1293 let inner = EmdTrendStream::try_new(EmdTrendParams {
1294 source: source.map(str::to_string),
1295 avg_type: avg_type.map(str::to_string),
1296 length: Some(length),
1297 mult: Some(mult),
1298 })
1299 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1300 Ok(Self { inner })
1301 }
1302
1303 fn update(
1304 &mut self,
1305 open: f64,
1306 high: f64,
1307 low: f64,
1308 close: f64,
1309 ) -> Option<(f64, f64, f64, f64)> {
1310 self.inner.update(open, high, low, close)
1311 }
1312}
1313
1314#[cfg(feature = "python")]
1315#[pyfunction(
1316 name = "emd_trend",
1317 signature = (open, high, low, close, source="close", avg_type="SMA", length=28, mult=1.0, kernel=None)
1318)]
1319pub fn emd_trend_py<'py>(
1320 py: Python<'py>,
1321 open: PyReadonlyArray1<'py, f64>,
1322 high: PyReadonlyArray1<'py, f64>,
1323 low: PyReadonlyArray1<'py, f64>,
1324 close: PyReadonlyArray1<'py, f64>,
1325 source: &str,
1326 avg_type: &str,
1327 length: usize,
1328 mult: f64,
1329 kernel: Option<&str>,
1330) -> PyResult<(
1331 Bound<'py, PyArray1<f64>>,
1332 Bound<'py, PyArray1<f64>>,
1333 Bound<'py, PyArray1<f64>>,
1334 Bound<'py, PyArray1<f64>>,
1335)> {
1336 let open = open.as_slice()?;
1337 let high = high.as_slice()?;
1338 let low = low.as_slice()?;
1339 let close = close.as_slice()?;
1340 let input = EmdTrendInput::from_slices(
1341 open,
1342 high,
1343 low,
1344 close,
1345 EmdTrendParams {
1346 source: Some(source.to_string()),
1347 avg_type: Some(avg_type.to_string()),
1348 length: Some(length),
1349 mult: Some(mult),
1350 },
1351 );
1352 let kern = validate_kernel(kernel, false)?;
1353 let out = py
1354 .allow_threads(|| emd_trend_with_kernel(&input, kern))
1355 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1356 Ok((
1357 out.direction.into_pyarray(py),
1358 out.average.into_pyarray(py),
1359 out.upper.into_pyarray(py),
1360 out.lower.into_pyarray(py),
1361 ))
1362}
1363
1364#[cfg(feature = "python")]
1365#[pyfunction(
1366 name = "emd_trend_batch",
1367 signature = (open, high, low, close, length_range=(28, 28, 0), mult_range=(1.0, 1.0, 0.0), source="close", avg_type="SMA", kernel=None)
1368)]
1369pub fn emd_trend_batch_py<'py>(
1370 py: Python<'py>,
1371 open: PyReadonlyArray1<'py, f64>,
1372 high: PyReadonlyArray1<'py, f64>,
1373 low: PyReadonlyArray1<'py, f64>,
1374 close: PyReadonlyArray1<'py, f64>,
1375 length_range: (usize, usize, usize),
1376 mult_range: (f64, f64, f64),
1377 source: &str,
1378 avg_type: &str,
1379 kernel: Option<&str>,
1380) -> PyResult<Bound<'py, PyDict>> {
1381 let open = open.as_slice()?;
1382 let high = high.as_slice()?;
1383 let low = low.as_slice()?;
1384 let close = close.as_slice()?;
1385 let kern = validate_kernel(kernel, true)?;
1386 let batch = py
1387 .allow_threads(|| {
1388 emd_trend_batch_with_kernel(
1389 open,
1390 high,
1391 low,
1392 close,
1393 &EmdTrendBatchRange {
1394 length: length_range,
1395 mult: mult_range,
1396 },
1397 source,
1398 avg_type,
1399 kern,
1400 )
1401 })
1402 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1403
1404 let dict = PyDict::new(py);
1405 let rows = batch.rows;
1406 let cols = batch.cols;
1407 dict.set_item(
1408 "direction",
1409 PyArray1::from_vec(py, batch.direction).reshape((rows, cols))?,
1410 )?;
1411 dict.set_item(
1412 "average",
1413 PyArray1::from_vec(py, batch.average).reshape((rows, cols))?,
1414 )?;
1415 dict.set_item(
1416 "upper",
1417 PyArray1::from_vec(py, batch.upper).reshape((rows, cols))?,
1418 )?;
1419 dict.set_item(
1420 "lower",
1421 PyArray1::from_vec(py, batch.lower).reshape((rows, cols))?,
1422 )?;
1423 dict.set_item(
1424 "lengths",
1425 batch
1426 .combos
1427 .iter()
1428 .map(|params| params.length.unwrap_or(DEFAULT_LENGTH) as u64)
1429 .collect::<Vec<_>>()
1430 .into_pyarray(py),
1431 )?;
1432 dict.set_item(
1433 "mults",
1434 batch
1435 .combos
1436 .iter()
1437 .map(|params| params.mult.unwrap_or(DEFAULT_MULT))
1438 .collect::<Vec<_>>()
1439 .into_pyarray(py),
1440 )?;
1441 dict.set_item("rows", rows)?;
1442 dict.set_item("cols", cols)?;
1443 Ok(dict)
1444}
1445
1446#[cfg(feature = "python")]
1447pub fn register_emd_trend_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1448 m.add_function(wrap_pyfunction!(emd_trend_py, m)?)?;
1449 m.add_function(wrap_pyfunction!(emd_trend_batch_py, m)?)?;
1450 m.add_class::<EmdTrendStreamPy>()?;
1451 Ok(())
1452}
1453
1454#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1455#[derive(Serialize, Deserialize)]
1456pub struct EmdTrendBatchConfig {
1457 pub length_range: Vec<usize>,
1458 pub mult_range: Vec<f64>,
1459 pub source: Option<String>,
1460 pub avg_type: Option<String>,
1461}
1462
1463#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1464#[wasm_bindgen(js_name = emd_trend_js)]
1465pub fn emd_trend_js(
1466 open: &[f64],
1467 high: &[f64],
1468 low: &[f64],
1469 close: &[f64],
1470 source: &str,
1471 avg_type: &str,
1472 length: usize,
1473 mult: f64,
1474) -> Result<JsValue, JsValue> {
1475 let out = emd_trend_with_kernel(
1476 &EmdTrendInput::from_slices(
1477 open,
1478 high,
1479 low,
1480 close,
1481 EmdTrendParams {
1482 source: Some(source.to_string()),
1483 avg_type: Some(avg_type.to_string()),
1484 length: Some(length),
1485 mult: Some(mult),
1486 },
1487 ),
1488 Kernel::Auto,
1489 )
1490 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1491
1492 let obj = js_sys::Object::new();
1493 js_sys::Reflect::set(
1494 &obj,
1495 &JsValue::from_str("direction"),
1496 &serde_wasm_bindgen::to_value(&out.direction).unwrap(),
1497 )?;
1498 js_sys::Reflect::set(
1499 &obj,
1500 &JsValue::from_str("average"),
1501 &serde_wasm_bindgen::to_value(&out.average).unwrap(),
1502 )?;
1503 js_sys::Reflect::set(
1504 &obj,
1505 &JsValue::from_str("upper"),
1506 &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1507 )?;
1508 js_sys::Reflect::set(
1509 &obj,
1510 &JsValue::from_str("lower"),
1511 &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1512 )?;
1513 Ok(obj.into())
1514}
1515
1516#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1517#[wasm_bindgen(js_name = emd_trend_batch_js)]
1518pub fn emd_trend_batch_js(
1519 open: &[f64],
1520 high: &[f64],
1521 low: &[f64],
1522 close: &[f64],
1523 config: JsValue,
1524) -> Result<JsValue, JsValue> {
1525 let config: EmdTrendBatchConfig = serde_wasm_bindgen::from_value(config)
1526 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1527 if config.length_range.len() != 3 || config.mult_range.len() != 3 {
1528 return Err(JsValue::from_str(
1529 "Invalid config: every range must have exactly 3 elements [start, end, step]",
1530 ));
1531 }
1532 let source = config.source.as_deref().unwrap_or(DEFAULT_SOURCE);
1533 let avg_type = config.avg_type.as_deref().unwrap_or(DEFAULT_AVG_TYPE);
1534 let out = emd_trend_batch_with_kernel(
1535 open,
1536 high,
1537 low,
1538 close,
1539 &EmdTrendBatchRange {
1540 length: (
1541 config.length_range[0],
1542 config.length_range[1],
1543 config.length_range[2],
1544 ),
1545 mult: (
1546 config.mult_range[0],
1547 config.mult_range[1],
1548 config.mult_range[2],
1549 ),
1550 },
1551 source,
1552 avg_type,
1553 Kernel::Auto,
1554 )
1555 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1556
1557 let obj = js_sys::Object::new();
1558 js_sys::Reflect::set(
1559 &obj,
1560 &JsValue::from_str("direction"),
1561 &serde_wasm_bindgen::to_value(&out.direction).unwrap(),
1562 )?;
1563 js_sys::Reflect::set(
1564 &obj,
1565 &JsValue::from_str("average"),
1566 &serde_wasm_bindgen::to_value(&out.average).unwrap(),
1567 )?;
1568 js_sys::Reflect::set(
1569 &obj,
1570 &JsValue::from_str("upper"),
1571 &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1572 )?;
1573 js_sys::Reflect::set(
1574 &obj,
1575 &JsValue::from_str("lower"),
1576 &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1577 )?;
1578 js_sys::Reflect::set(
1579 &obj,
1580 &JsValue::from_str("rows"),
1581 &JsValue::from_f64(out.rows as f64),
1582 )?;
1583 js_sys::Reflect::set(
1584 &obj,
1585 &JsValue::from_str("cols"),
1586 &JsValue::from_f64(out.cols as f64),
1587 )?;
1588 js_sys::Reflect::set(
1589 &obj,
1590 &JsValue::from_str("combos"),
1591 &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1592 )?;
1593 Ok(obj.into())
1594}
1595
1596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1597#[wasm_bindgen]
1598pub fn emd_trend_alloc(len: usize) -> *mut f64 {
1599 let mut vec = Vec::<f64>::with_capacity(4 * len);
1600 let ptr = vec.as_mut_ptr();
1601 std::mem::forget(vec);
1602 ptr
1603}
1604
1605#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1606#[wasm_bindgen]
1607pub fn emd_trend_free(ptr: *mut f64, len: usize) {
1608 if !ptr.is_null() {
1609 unsafe {
1610 let _ = Vec::from_raw_parts(ptr, 4 * len, 4 * len);
1611 }
1612 }
1613}
1614
1615#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1616#[wasm_bindgen]
1617pub fn emd_trend_into(
1618 open_ptr: *const f64,
1619 high_ptr: *const f64,
1620 low_ptr: *const f64,
1621 close_ptr: *const f64,
1622 out_ptr: *mut f64,
1623 len: usize,
1624 source: &str,
1625 avg_type: &str,
1626 length: usize,
1627 mult: f64,
1628) -> Result<(), JsValue> {
1629 if open_ptr.is_null()
1630 || high_ptr.is_null()
1631 || low_ptr.is_null()
1632 || close_ptr.is_null()
1633 || out_ptr.is_null()
1634 {
1635 return Err(JsValue::from_str("null pointer passed to emd_trend_into"));
1636 }
1637 unsafe {
1638 let open = std::slice::from_raw_parts(open_ptr, len);
1639 let high = std::slice::from_raw_parts(high_ptr, len);
1640 let low = std::slice::from_raw_parts(low_ptr, len);
1641 let close = std::slice::from_raw_parts(close_ptr, len);
1642 let out = std::slice::from_raw_parts_mut(out_ptr, 4 * len);
1643 let (direction, tail) = out.split_at_mut(len);
1644 let (average, tail) = tail.split_at_mut(len);
1645 let (upper, lower) = tail.split_at_mut(len);
1646 emd_trend_into_slice(
1647 direction,
1648 average,
1649 upper,
1650 lower,
1651 &EmdTrendInput::from_slices(
1652 open,
1653 high,
1654 low,
1655 close,
1656 EmdTrendParams {
1657 source: Some(source.to_string()),
1658 avg_type: Some(avg_type.to_string()),
1659 length: Some(length),
1660 mult: Some(mult),
1661 },
1662 ),
1663 Kernel::Auto,
1664 )
1665 .map_err(|e| JsValue::from_str(&e.to_string()))
1666 }
1667}
1668
1669#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1670#[wasm_bindgen]
1671pub fn emd_trend_batch_into(
1672 open_ptr: *const f64,
1673 high_ptr: *const f64,
1674 low_ptr: *const f64,
1675 close_ptr: *const f64,
1676 out_ptr: *mut f64,
1677 len: usize,
1678 length_start: usize,
1679 length_end: usize,
1680 length_step: usize,
1681 mult_start: f64,
1682 mult_end: f64,
1683 mult_step: f64,
1684 source: &str,
1685 avg_type: &str,
1686) -> Result<usize, JsValue> {
1687 if open_ptr.is_null()
1688 || high_ptr.is_null()
1689 || low_ptr.is_null()
1690 || close_ptr.is_null()
1691 || out_ptr.is_null()
1692 {
1693 return Err(JsValue::from_str(
1694 "null pointer passed to emd_trend_batch_into",
1695 ));
1696 }
1697
1698 let sweep = EmdTrendBatchRange {
1699 length: (length_start, length_end, length_step),
1700 mult: (mult_start, mult_end, mult_step),
1701 };
1702 let combos = expand_grid_emd_trend(&sweep, source, avg_type)
1703 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1704 let rows = combos.len();
1705 let split = rows
1706 .checked_mul(len)
1707 .ok_or_else(|| JsValue::from_str("rows*cols overflow in emd_trend_batch_into"))?;
1708 let total = split
1709 .checked_mul(4)
1710 .ok_or_else(|| JsValue::from_str("4*rows*cols overflow in emd_trend_batch_into"))?;
1711
1712 unsafe {
1713 let open = std::slice::from_raw_parts(open_ptr, len);
1714 let high = std::slice::from_raw_parts(high_ptr, len);
1715 let low = std::slice::from_raw_parts(low_ptr, len);
1716 let close = std::slice::from_raw_parts(close_ptr, len);
1717 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1718 let (direction, tail) = out.split_at_mut(split);
1719 let (average, tail) = tail.split_at_mut(split);
1720 let (upper, lower) = tail.split_at_mut(split);
1721 emd_trend_batch_inner_into(
1722 open,
1723 high,
1724 low,
1725 close,
1726 &sweep,
1727 source,
1728 avg_type,
1729 Kernel::Auto,
1730 false,
1731 direction,
1732 average,
1733 upper,
1734 lower,
1735 )
1736 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1737 }
1738
1739 Ok(rows)
1740}
1741
1742#[cfg(test)]
1743mod tests {
1744 use super::*;
1745 use crate::indicators::dispatch::{
1746 compute_cpu, IndicatorComputeRequest, IndicatorDataRef, IndicatorSeries, ParamKV,
1747 ParamValue,
1748 };
1749
1750 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1751 let mut open = Vec::with_capacity(len);
1752 let mut high = Vec::with_capacity(len);
1753 let mut low = Vec::with_capacity(len);
1754 let mut close = Vec::with_capacity(len);
1755 for i in 0..len {
1756 let x = i as f64;
1757 let base = 100.0 + x * 0.04 + (x * 0.11).sin() * 1.8;
1758 let o = base + (x * 0.17).cos() * 0.6;
1759 let c = base + (x * 0.09).sin() * 0.9;
1760 let h = o.max(c) + 0.8 + (x * 0.05).cos().abs() * 0.4;
1761 let l = o.min(c) - 0.8 - (x * 0.07).sin().abs() * 0.3;
1762 open.push(o);
1763 high.push(h);
1764 low.push(l);
1765 close.push(c);
1766 }
1767 (open, high, low, close)
1768 }
1769
1770 fn assert_vec_eq_nan(lhs: &[f64], rhs: &[f64]) {
1771 assert_eq!(lhs.len(), rhs.len());
1772 for (idx, (&l, &r)) in lhs.iter().zip(rhs.iter()).enumerate() {
1773 if l.is_nan() && r.is_nan() {
1774 continue;
1775 }
1776 assert!(
1777 (l - r).abs() <= 1e-10,
1778 "mismatch at {idx}: lhs={l}, rhs={r}"
1779 );
1780 }
1781 }
1782
1783 #[test]
1784 fn emd_trend_output_contract() -> Result<(), Box<dyn Error>> {
1785 let (open, high, low, close) = sample_ohlc(256);
1786 let out = emd_trend_with_kernel(
1787 &EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default()),
1788 Kernel::Scalar,
1789 )?;
1790 assert_eq!(out.direction.len(), close.len());
1791 assert_eq!(out.average.len(), close.len());
1792 assert_eq!(out.upper.len(), close.len());
1793 assert_eq!(out.lower.len(), close.len());
1794 assert!(out.average.iter().any(|v| v.is_finite()));
1795 Ok(())
1796 }
1797
1798 #[test]
1799 fn emd_trend_exact_small_case() -> Result<(), Box<dyn Error>> {
1800 let open = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1801 let high = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1802 let low = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1803 let close = vec![1.0, 1.0, 1.0, 10.0, 10.0];
1804 let out = emd_trend_with_kernel(
1805 &EmdTrendInput::from_slices(
1806 &open,
1807 &high,
1808 &low,
1809 &close,
1810 EmdTrendParams {
1811 source: Some("close".to_string()),
1812 avg_type: Some("SMA".to_string()),
1813 length: Some(2),
1814 mult: Some(1.0),
1815 },
1816 ),
1817 Kernel::Scalar,
1818 )?;
1819 assert_vec_eq_nan(&out.direction, &[0.0, 0.0, 0.0, 1.0, 1.0]);
1820 assert_vec_eq_nan(&out.average, &[f64::NAN, 1.0, 1.0, 5.5, 10.0]);
1821 assert_vec_eq_nan(&out.upper, &[f64::NAN, 1.0, 1.0, 8.5, 11.0]);
1822 assert_vec_eq_nan(&out.lower, &[f64::NAN, 1.0, 1.0, 2.5, 9.0]);
1823 Ok(())
1824 }
1825
1826 #[test]
1827 fn emd_trend_into_matches_api() -> Result<(), Box<dyn Error>> {
1828 let (open, high, low, close) = sample_ohlc(200);
1829 let input =
1830 EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default());
1831 let base = emd_trend(&input)?;
1832 let mut direction = vec![0.0; close.len()];
1833 let mut average = vec![f64::NAN; close.len()];
1834 let mut upper = vec![f64::NAN; close.len()];
1835 let mut lower = vec![f64::NAN; close.len()];
1836 emd_trend_into(&input, &mut direction, &mut average, &mut upper, &mut lower)?;
1837 assert_vec_eq_nan(&direction, &base.direction);
1838 assert_vec_eq_nan(&average, &base.average);
1839 assert_vec_eq_nan(&upper, &base.upper);
1840 assert_vec_eq_nan(&lower, &base.lower);
1841 Ok(())
1842 }
1843
1844 #[test]
1845 fn emd_trend_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1846 let (open, high, low, close) = sample_ohlc(180);
1847 let single = emd_trend_with_kernel(
1848 &EmdTrendInput::from_slices(&open, &high, &low, &close, EmdTrendParams::default()),
1849 Kernel::Scalar,
1850 )?;
1851 let batch = emd_trend_batch_with_kernel(
1852 &open,
1853 &high,
1854 &low,
1855 &close,
1856 &EmdTrendBatchRange::default(),
1857 "close",
1858 "SMA",
1859 Kernel::ScalarBatch,
1860 )?;
1861 assert_eq!(batch.rows, 1);
1862 assert_eq!(batch.cols, close.len());
1863 assert_vec_eq_nan(&batch.direction[..close.len()], &single.direction);
1864 assert_vec_eq_nan(&batch.average[..close.len()], &single.average);
1865 assert_vec_eq_nan(&batch.upper[..close.len()], &single.upper);
1866 assert_vec_eq_nan(&batch.lower[..close.len()], &single.lower);
1867 Ok(())
1868 }
1869
1870 #[test]
1871 fn emd_trend_stream_contract() -> Result<(), Box<dyn Error>> {
1872 let (open, high, low, close) = sample_ohlc(220);
1873 let mut stream = EmdTrendStream::try_new(EmdTrendParams {
1874 source: Some("close".to_string()),
1875 avg_type: Some("SMA".to_string()),
1876 length: Some(28),
1877 mult: Some(1.0),
1878 })?;
1879 let mut points = Vec::with_capacity(close.len());
1880 for i in 0..close.len() {
1881 points.push(stream.update(open[i], high[i], low[i], close[i]).unwrap());
1882 }
1883 assert_eq!(points.len(), close.len());
1884 assert!(points.iter().any(|(_, avg, _, _)| avg.is_finite()));
1885 Ok(())
1886 }
1887
1888 #[test]
1889 fn emd_trend_rejects_invalid_params() {
1890 let (open, high, low, close) = sample_ohlc(64);
1891 let err = emd_trend(&EmdTrendInput::from_slices(
1892 &open,
1893 &high,
1894 &low,
1895 &close,
1896 EmdTrendParams {
1897 source: Some("bad".to_string()),
1898 avg_type: Some("SMA".to_string()),
1899 length: Some(28),
1900 mult: Some(1.0),
1901 },
1902 ))
1903 .unwrap_err();
1904 assert!(matches!(err, EmdTrendError::InvalidSource { .. }));
1905 }
1906
1907 #[test]
1908 fn emd_trend_dispatch_compute_returns_average() -> Result<(), Box<dyn Error>> {
1909 let (open, high, low, close) = sample_ohlc(192);
1910 let params = [
1911 ParamKV {
1912 key: "source",
1913 value: ParamValue::EnumString("close"),
1914 },
1915 ParamKV {
1916 key: "avg_type",
1917 value: ParamValue::EnumString("SMA"),
1918 },
1919 ParamKV {
1920 key: "length",
1921 value: ParamValue::Int(28),
1922 },
1923 ParamKV {
1924 key: "mult",
1925 value: ParamValue::Float(1.0),
1926 },
1927 ];
1928 let out = compute_cpu(IndicatorComputeRequest {
1929 indicator_id: "emd_trend",
1930 output_id: Some("average"),
1931 data: IndicatorDataRef::Ohlc {
1932 open: &open,
1933 high: &high,
1934 low: &low,
1935 close: &close,
1936 },
1937 params: ¶ms,
1938 kernel: Kernel::Scalar,
1939 })?;
1940 assert_eq!(out.output_id, "average");
1941 match out.series {
1942 IndicatorSeries::F64(values) => assert_eq!(values.len(), close.len()),
1943 other => panic!("expected f64 series, got {:?}", other),
1944 }
1945 Ok(())
1946 }
1947}