1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::moving_averages::ema::{ema_with_kernel, EmaInput, EmaParams};
16use crate::indicators::moving_averages::linreg::{
17 linreg_with_kernel, LinRegInput, LinRegParams, LinRegStream,
18};
19use crate::indicators::moving_averages::sma::{sma_with_kernel, SmaInput, SmaParams};
20use crate::indicators::moving_averages::vwma::{vwma_with_kernel, VwmaInput, VwmaParams};
21use crate::indicators::moving_averages::wilders::{
22 wilders_with_kernel, WildersInput, WildersParams,
23};
24use crate::indicators::moving_averages::wma::{wma_with_kernel, WmaInput, WmaParams};
25use crate::utilities::data_loader::Candles;
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28 alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
29};
30#[cfg(feature = "python")]
31use crate::utilities::kernel_validation::validate_kernel;
32
33#[cfg(not(target_arch = "wasm32"))]
34use rayon::prelude::*;
35use std::collections::VecDeque;
36use std::mem::{ManuallyDrop, MaybeUninit};
37use thiserror::Error;
38
39const CHANNEL_WINDOW: usize = 280;
40
41#[derive(Debug, Clone)]
42pub struct TrendFollowerOutput {
43 pub values: Vec<f64>,
44}
45
46#[derive(Debug, Clone)]
47#[cfg_attr(
48 all(target_arch = "wasm32", feature = "wasm"),
49 derive(Serialize, Deserialize)
50)]
51pub struct TrendFollowerParams {
52 pub matype: Option<String>,
53 pub trend_period: Option<usize>,
54 pub ma_period: Option<usize>,
55 pub channel_rate_percent: Option<f64>,
56 pub use_linear_regression: Option<bool>,
57 pub linear_regression_period: Option<usize>,
58}
59
60impl Default for TrendFollowerParams {
61 fn default() -> Self {
62 Self {
63 matype: Some("ema".to_string()),
64 trend_period: Some(20),
65 ma_period: Some(20),
66 channel_rate_percent: Some(1.0),
67 use_linear_regression: Some(true),
68 linear_regression_period: Some(5),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub enum TrendFollowerData<'a> {
75 Candles(&'a Candles),
76 Slices {
77 high: &'a [f64],
78 low: &'a [f64],
79 close: &'a [f64],
80 volume: &'a [f64],
81 },
82}
83
84#[derive(Debug, Clone)]
85pub struct TrendFollowerInput<'a> {
86 pub data: TrendFollowerData<'a>,
87 pub params: TrendFollowerParams,
88}
89
90impl<'a> TrendFollowerInput<'a> {
91 #[inline]
92 pub fn from_candles(candles: &'a Candles, params: TrendFollowerParams) -> Self {
93 Self {
94 data: TrendFollowerData::Candles(candles),
95 params,
96 }
97 }
98
99 #[inline]
100 pub fn from_slices(
101 high: &'a [f64],
102 low: &'a [f64],
103 close: &'a [f64],
104 volume: &'a [f64],
105 params: TrendFollowerParams,
106 ) -> Self {
107 Self {
108 data: TrendFollowerData::Slices {
109 high,
110 low,
111 close,
112 volume,
113 },
114 params,
115 }
116 }
117
118 #[inline]
119 pub fn with_default_candles(candles: &'a Candles) -> Self {
120 Self::from_candles(candles, TrendFollowerParams::default())
121 }
122
123 #[inline]
124 pub fn as_slices(&self) -> (&[f64], &[f64], &[f64], &[f64]) {
125 match &self.data {
126 TrendFollowerData::Candles(candles) => (
127 candles.high.as_slice(),
128 candles.low.as_slice(),
129 candles.close.as_slice(),
130 candles.volume.as_slice(),
131 ),
132 TrendFollowerData::Slices {
133 high,
134 low,
135 close,
136 volume,
137 } => (*high, *low, *close, *volume),
138 }
139 }
140
141 #[inline]
142 pub fn get_matype(&self) -> &str {
143 self.params.matype.as_deref().unwrap_or("ema")
144 }
145
146 #[inline]
147 pub fn get_trend_period(&self) -> usize {
148 self.params.trend_period.unwrap_or(20)
149 }
150
151 #[inline]
152 pub fn get_ma_period(&self) -> usize {
153 self.params.ma_period.unwrap_or(20)
154 }
155
156 #[inline]
157 pub fn get_channel_rate_percent(&self) -> f64 {
158 self.params.channel_rate_percent.unwrap_or(1.0)
159 }
160
161 #[inline]
162 pub fn get_use_linear_regression(&self) -> bool {
163 self.params.use_linear_regression.unwrap_or(true)
164 }
165
166 #[inline]
167 pub fn get_linear_regression_period(&self) -> usize {
168 self.params.linear_regression_period.unwrap_or(5)
169 }
170}
171
172#[derive(Copy, Clone, Debug, Eq, PartialEq)]
173enum TrendFollowerMaType {
174 Ema,
175 Sma,
176 Rma,
177 Wma,
178 Vwma,
179}
180
181impl TrendFollowerMaType {
182 #[inline]
183 fn as_str(self) -> &'static str {
184 match self {
185 Self::Ema => "ema",
186 Self::Sma => "sma",
187 Self::Rma => "rma",
188 Self::Wma => "wma",
189 Self::Vwma => "vwma",
190 }
191 }
192}
193
194#[derive(Copy, Clone, Debug)]
195struct TrendFollowerResolvedParams {
196 matype: TrendFollowerMaType,
197 trend_period: usize,
198 ma_period: usize,
199 channel_rate_fraction: f64,
200 use_linear_regression: bool,
201 linear_regression_period: usize,
202}
203
204#[derive(Clone, Debug)]
205enum TrendFollowerBaseMaStream {
206 Ema(EmaState),
207 Sma(SmaState),
208 Rma(RmaState),
209 Wma(WmaState),
210 Vwma(VwmaState),
211}
212
213impl TrendFollowerBaseMaStream {
214 fn new(matype: TrendFollowerMaType, period: usize) -> Self {
215 match matype {
216 TrendFollowerMaType::Ema => Self::Ema(EmaState::new(period)),
217 TrendFollowerMaType::Sma => Self::Sma(SmaState::new(period)),
218 TrendFollowerMaType::Rma => Self::Rma(RmaState::new(period)),
219 TrendFollowerMaType::Wma => Self::Wma(WmaState::new(period)),
220 TrendFollowerMaType::Vwma => Self::Vwma(VwmaState::new(period)),
221 }
222 }
223
224 fn update(&mut self, value: f64, volume: f64) -> Option<f64> {
225 match self {
226 Self::Ema(state) => state.update(value),
227 Self::Sma(state) => state.update(value),
228 Self::Rma(state) => state.update(value),
229 Self::Wma(state) => state.update(value),
230 Self::Vwma(state) => state.update(value, volume),
231 }
232 }
233}
234
235#[derive(Clone, Debug)]
236struct EmaState {
237 period: usize,
238 alpha: f64,
239 beta: f64,
240 value: Option<f64>,
241 valid_count: usize,
242}
243
244impl EmaState {
245 fn new(period: usize) -> Self {
246 Self {
247 period,
248 alpha: 2.0 / (period as f64 + 1.0),
249 beta: 1.0 - 2.0 / (period as f64 + 1.0),
250 value: None,
251 valid_count: 0,
252 }
253 }
254
255 fn update(&mut self, value: f64) -> Option<f64> {
256 if !value.is_finite() {
257 return None;
258 }
259 let next = match self.value {
260 None => {
261 self.valid_count = 1;
262 value
263 }
264 Some(prev) if self.valid_count < self.period => {
265 self.valid_count += 1;
266 let vc = self.valid_count as f64;
267 ((vc - 1.0) * prev + value) / vc
268 }
269 Some(prev) => self.beta.mul_add(prev, self.alpha * value),
270 };
271 self.value = Some(next);
272 Some(next)
273 }
274}
275
276#[derive(Clone, Debug)]
277struct SmaState {
278 period: usize,
279 buffer: Vec<f64>,
280 head: usize,
281 filled: usize,
282 sum: f64,
283}
284
285impl SmaState {
286 fn new(period: usize) -> Self {
287 Self {
288 period,
289 buffer: vec![0.0; period],
290 head: 0,
291 filled: 0,
292 sum: 0.0,
293 }
294 }
295
296 fn update(&mut self, value: f64) -> Option<f64> {
297 if !value.is_finite() {
298 return None;
299 }
300 if self.filled == self.period {
301 self.sum -= self.buffer[self.head];
302 } else {
303 self.filled += 1;
304 }
305 self.buffer[self.head] = value;
306 self.sum += value;
307 self.head = (self.head + 1) % self.period;
308 if self.filled == self.period {
309 Some(self.sum / self.period as f64)
310 } else {
311 None
312 }
313 }
314}
315
316#[derive(Clone, Debug)]
317struct RmaState {
318 period: usize,
319 buffer: Vec<f64>,
320 head: usize,
321 filled: usize,
322 sum: f64,
323 value: Option<f64>,
324}
325
326impl RmaState {
327 fn new(period: usize) -> Self {
328 Self {
329 period,
330 buffer: vec![0.0; period],
331 head: 0,
332 filled: 0,
333 sum: 0.0,
334 value: None,
335 }
336 }
337
338 fn update(&mut self, value: f64) -> Option<f64> {
339 if !value.is_finite() {
340 return None;
341 }
342 if let Some(prev) = self.value {
343 let next = prev + (value - prev) / self.period as f64;
344 self.value = Some(next);
345 return Some(next);
346 }
347 self.buffer[self.head] = value;
348 self.sum += value;
349 self.head = (self.head + 1) % self.period;
350 self.filled += 1;
351 if self.filled == self.period {
352 let next = self.sum / self.period as f64;
353 self.value = Some(next);
354 Some(next)
355 } else {
356 None
357 }
358 }
359}
360
361#[derive(Clone, Debug)]
362struct WmaState {
363 period: usize,
364 buffer: Vec<f64>,
365 head: usize,
366 filled: usize,
367}
368
369impl WmaState {
370 fn new(period: usize) -> Self {
371 Self {
372 period,
373 buffer: vec![0.0; period],
374 head: 0,
375 filled: 0,
376 }
377 }
378
379 fn update(&mut self, value: f64) -> Option<f64> {
380 if !value.is_finite() {
381 return None;
382 }
383 self.buffer[self.head] = value;
384 self.head = (self.head + 1) % self.period;
385 if self.filled < self.period {
386 self.filled += 1;
387 }
388 if self.filled < self.period {
389 return None;
390 }
391 let mut acc = 0.0;
392 let mut weight_sum = 0.0;
393 for i in 0..self.period {
394 let idx = (self.head + i) % self.period;
395 let weight = (i + 1) as f64;
396 acc += self.buffer[idx] * weight;
397 weight_sum += weight;
398 }
399 Some(acc / weight_sum)
400 }
401}
402
403#[derive(Clone, Debug)]
404struct VwmaState {
405 period: usize,
406 prices: Vec<f64>,
407 volumes: Vec<f64>,
408 head: usize,
409 filled: usize,
410 sum_pv: f64,
411 sum_v: f64,
412}
413
414impl VwmaState {
415 fn new(period: usize) -> Self {
416 Self {
417 period,
418 prices: vec![0.0; period],
419 volumes: vec![0.0; period],
420 head: 0,
421 filled: 0,
422 sum_pv: 0.0,
423 sum_v: 0.0,
424 }
425 }
426
427 fn update(&mut self, price: f64, volume: f64) -> Option<f64> {
428 if !(price.is_finite() && volume.is_finite()) {
429 return None;
430 }
431 if self.filled == self.period {
432 self.sum_pv -= self.prices[self.head] * self.volumes[self.head];
433 self.sum_v -= self.volumes[self.head];
434 } else {
435 self.filled += 1;
436 }
437 self.prices[self.head] = price;
438 self.volumes[self.head] = volume;
439 self.sum_pv += price * volume;
440 self.sum_v += volume;
441 self.head = (self.head + 1) % self.period;
442 if self.filled == self.period && self.sum_v != 0.0 {
443 Some(self.sum_pv / self.sum_v)
444 } else {
445 None
446 }
447 }
448}
449
450#[derive(Copy, Clone, Debug)]
451pub struct TrendFollowerBuilder {
452 trend_period: Option<usize>,
453 ma_period: Option<usize>,
454 channel_rate_percent: Option<f64>,
455 use_linear_regression: Option<bool>,
456 linear_regression_period: Option<usize>,
457 kernel: Kernel,
458}
459
460impl Default for TrendFollowerBuilder {
461 fn default() -> Self {
462 Self {
463 trend_period: None,
464 ma_period: None,
465 channel_rate_percent: None,
466 use_linear_regression: None,
467 linear_regression_period: None,
468 kernel: Kernel::Auto,
469 }
470 }
471}
472
473impl TrendFollowerBuilder {
474 #[inline]
475 pub fn new() -> Self {
476 Self::default()
477 }
478
479 #[inline]
480 pub fn trend_period(mut self, value: usize) -> Self {
481 self.trend_period = Some(value);
482 self
483 }
484
485 #[inline]
486 pub fn ma_period(mut self, value: usize) -> Self {
487 self.ma_period = Some(value);
488 self
489 }
490
491 #[inline]
492 pub fn channel_rate_percent(mut self, value: f64) -> Self {
493 self.channel_rate_percent = Some(value);
494 self
495 }
496
497 #[inline]
498 pub fn use_linear_regression(mut self, value: bool) -> Self {
499 self.use_linear_regression = Some(value);
500 self
501 }
502
503 #[inline]
504 pub fn linear_regression_period(mut self, value: usize) -> Self {
505 self.linear_regression_period = Some(value);
506 self
507 }
508
509 #[inline]
510 pub fn kernel(mut self, value: Kernel) -> Self {
511 self.kernel = value;
512 self
513 }
514
515 #[inline]
516 fn params(self, matype: &str) -> TrendFollowerParams {
517 TrendFollowerParams {
518 matype: Some(matype.to_string()),
519 trend_period: self.trend_period,
520 ma_period: self.ma_period,
521 channel_rate_percent: self.channel_rate_percent,
522 use_linear_regression: self.use_linear_regression,
523 linear_regression_period: self.linear_regression_period,
524 }
525 }
526
527 #[inline]
528 pub fn apply(self, candles: &Candles) -> Result<TrendFollowerOutput, TrendFollowerError> {
529 let input = TrendFollowerInput::from_candles(candles, self.params("ema"));
530 trend_follower_with_kernel(&input, self.kernel)
531 }
532
533 #[inline]
534 pub fn apply_with_matype(
535 self,
536 candles: &Candles,
537 matype: &str,
538 ) -> Result<TrendFollowerOutput, TrendFollowerError> {
539 let input = TrendFollowerInput::from_candles(candles, self.params(matype));
540 trend_follower_with_kernel(&input, self.kernel)
541 }
542
543 #[inline]
544 pub fn apply_slices(
545 self,
546 high: &[f64],
547 low: &[f64],
548 close: &[f64],
549 volume: &[f64],
550 matype: &str,
551 ) -> Result<TrendFollowerOutput, TrendFollowerError> {
552 let input = TrendFollowerInput::from_slices(high, low, close, volume, self.params(matype));
553 trend_follower_with_kernel(&input, self.kernel)
554 }
555
556 #[inline]
557 pub fn into_stream(self, matype: &str) -> Result<TrendFollowerStream, TrendFollowerError> {
558 TrendFollowerStream::try_new(self.params(matype))
559 }
560}
561
562#[derive(Debug, Error)]
563pub enum TrendFollowerError {
564 #[error("trend_follower: Empty input data.")]
565 EmptyInputData,
566 #[error(
567 "trend_follower: Data length mismatch: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
568 )]
569 DataLengthMismatch {
570 high_len: usize,
571 low_len: usize,
572 close_len: usize,
573 volume_len: usize,
574 },
575 #[error("trend_follower: All values are invalid.")]
576 AllValuesNaN,
577 #[error("trend_follower: Invalid MA type: {matype}")]
578 InvalidMaType { matype: String },
579 #[error("trend_follower: Invalid trend period: {trend_period}")]
580 InvalidTrendPeriod { trend_period: usize },
581 #[error("trend_follower: Invalid MA period: {ma_period}, data length = {data_len}")]
582 InvalidMaPeriod { ma_period: usize, data_len: usize },
583 #[error(
584 "trend_follower: Invalid linear regression period: {linear_regression_period}, data length = {data_len}"
585 )]
586 InvalidLinearRegressionPeriod {
587 linear_regression_period: usize,
588 data_len: usize,
589 },
590 #[error("trend_follower: Invalid channel rate percent: {channel_rate_percent}")]
591 InvalidChannelRatePercent { channel_rate_percent: f64 },
592 #[error("trend_follower: Moving average computation failed: {0}")]
593 MovingAverageError(String),
594 #[error("trend_follower: Linear regression computation failed: {0}")]
595 LinearRegressionError(String),
596 #[error("trend_follower: Output length mismatch: expected = {expected}, got = {got}")]
597 OutputLengthMismatch { expected: usize, got: usize },
598 #[error("trend_follower: Invalid integer range: start={start}, end={end}, step={step}")]
599 InvalidRangeUsize {
600 start: usize,
601 end: usize,
602 step: usize,
603 },
604 #[error("trend_follower: Invalid float range: start={start}, end={end}, step={step}")]
605 InvalidRangeF64 { start: f64, end: f64, step: f64 },
606 #[error("trend_follower: Invalid kernel for batch path: {0:?}")]
607 InvalidKernelForBatch(Kernel),
608}
609
610#[inline]
611fn parse_matype(matype: &str) -> Result<TrendFollowerMaType, TrendFollowerError> {
612 if matype.eq_ignore_ascii_case("ema") {
613 return Ok(TrendFollowerMaType::Ema);
614 }
615 if matype.eq_ignore_ascii_case("sma") {
616 return Ok(TrendFollowerMaType::Sma);
617 }
618 if matype.eq_ignore_ascii_case("rma") {
619 return Ok(TrendFollowerMaType::Rma);
620 }
621 if matype.eq_ignore_ascii_case("wma") {
622 return Ok(TrendFollowerMaType::Wma);
623 }
624 if matype.eq_ignore_ascii_case("vwma") {
625 return Ok(TrendFollowerMaType::Vwma);
626 }
627 Err(TrendFollowerError::InvalidMaType {
628 matype: matype.to_string(),
629 })
630}
631
632#[inline]
633fn resolve_params(
634 input: &TrendFollowerInput<'_>,
635 data_len: usize,
636) -> Result<TrendFollowerResolvedParams, TrendFollowerError> {
637 let trend_period = input.get_trend_period();
638 if trend_period < 1 {
639 return Err(TrendFollowerError::InvalidTrendPeriod { trend_period });
640 }
641
642 let ma_period = input.get_ma_period();
643 if ma_period == 0 || ma_period > data_len {
644 return Err(TrendFollowerError::InvalidMaPeriod {
645 ma_period,
646 data_len,
647 });
648 }
649
650 let linear_regression_period = input.get_linear_regression_period();
651 if input.get_use_linear_regression()
652 && (linear_regression_period < 2 || linear_regression_period > data_len)
653 {
654 return Err(TrendFollowerError::InvalidLinearRegressionPeriod {
655 linear_regression_period,
656 data_len,
657 });
658 }
659
660 let channel_rate_percent = input.get_channel_rate_percent();
661 if !channel_rate_percent.is_finite() || channel_rate_percent <= 0.0 {
662 return Err(TrendFollowerError::InvalidChannelRatePercent {
663 channel_rate_percent,
664 });
665 }
666
667 Ok(TrendFollowerResolvedParams {
668 matype: parse_matype(input.get_matype())?,
669 trend_period,
670 ma_period,
671 channel_rate_fraction: channel_rate_percent * 0.01,
672 use_linear_regression: input.get_use_linear_regression(),
673 linear_regression_period,
674 })
675}
676
677#[inline]
678fn first_valid_bar(
679 high: &[f64],
680 low: &[f64],
681 close: &[f64],
682 volume: &[f64],
683 needs_volume: bool,
684) -> Option<usize> {
685 (0..high.len()).find(|&i| {
686 high[i].is_finite()
687 && low[i].is_finite()
688 && close[i].is_finite()
689 && (!needs_volume || volume[i].is_finite())
690 })
691}
692
693#[inline]
694fn data_is_clean(
695 high: &[f64],
696 low: &[f64],
697 close: &[f64],
698 volume: &[f64],
699 first: usize,
700 needs_volume: bool,
701) -> bool {
702 for i in first..high.len() {
703 if !(high[i].is_finite() && low[i].is_finite() && close[i].is_finite()) {
704 return false;
705 }
706 if needs_volume && !volume[i].is_finite() {
707 return false;
708 }
709 }
710 true
711}
712
713#[inline]
714fn trend_follower_prepare<'a>(
715 input: &'a TrendFollowerInput<'a>,
716) -> Result<
717 (
718 &'a [f64],
719 &'a [f64],
720 &'a [f64],
721 &'a [f64],
722 TrendFollowerResolvedParams,
723 usize,
724 ),
725 TrendFollowerError,
726> {
727 let (high, low, close, volume) = input.as_slices();
728 if high.is_empty() {
729 return Err(TrendFollowerError::EmptyInputData);
730 }
731 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
732 return Err(TrendFollowerError::DataLengthMismatch {
733 high_len: high.len(),
734 low_len: low.len(),
735 close_len: close.len(),
736 volume_len: volume.len(),
737 });
738 }
739 let params = resolve_params(input, high.len())?;
740 let first = first_valid_bar(
741 high,
742 low,
743 close,
744 volume,
745 params.matype == TrendFollowerMaType::Vwma,
746 )
747 .ok_or(TrendFollowerError::AllValuesNaN)?;
748 Ok((high, low, close, volume, params, first))
749}
750
751#[inline]
752fn compute_ma_series(
753 close: &[f64],
754 volume: &[f64],
755 params: TrendFollowerResolvedParams,
756 kernel: Kernel,
757) -> Result<Vec<f64>, TrendFollowerError> {
758 match params.matype {
759 TrendFollowerMaType::Ema => ema_with_kernel(
760 &EmaInput::from_slice(
761 close,
762 EmaParams {
763 period: Some(params.ma_period),
764 },
765 ),
766 kernel,
767 )
768 .map(|out| out.values)
769 .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
770 TrendFollowerMaType::Sma => sma_with_kernel(
771 &SmaInput::from_slice(
772 close,
773 SmaParams {
774 period: Some(params.ma_period),
775 },
776 ),
777 kernel,
778 )
779 .map(|out| out.values)
780 .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
781 TrendFollowerMaType::Rma => wilders_with_kernel(
782 &WildersInput::from_slice(
783 close,
784 WildersParams {
785 period: Some(params.ma_period),
786 },
787 ),
788 kernel,
789 )
790 .map(|out| out.values)
791 .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
792 TrendFollowerMaType::Wma => wma_with_kernel(
793 &WmaInput::from_slice(
794 close,
795 WmaParams {
796 period: Some(params.ma_period),
797 },
798 ),
799 kernel,
800 )
801 .map(|out| out.values)
802 .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
803 TrendFollowerMaType::Vwma => vwma_with_kernel(
804 &VwmaInput::from_slice(
805 close,
806 volume,
807 VwmaParams {
808 period: Some(params.ma_period),
809 },
810 ),
811 kernel,
812 )
813 .map(|out| out.values)
814 .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
815 }
816}
817
818#[inline]
819fn push_max(queue: &mut VecDeque<(usize, f64)>, idx: usize, value: f64, window: usize) {
820 let min_idx = idx.saturating_add(1).saturating_sub(window);
821 while let Some(&(old_idx, _)) = queue.front() {
822 if old_idx < min_idx {
823 queue.pop_front();
824 } else {
825 break;
826 }
827 }
828 while let Some(&(_, old_value)) = queue.back() {
829 if old_value <= value {
830 queue.pop_back();
831 } else {
832 break;
833 }
834 }
835 queue.push_back((idx, value));
836}
837
838#[inline]
839fn push_min(queue: &mut VecDeque<(usize, f64)>, idx: usize, value: f64, window: usize) {
840 let min_idx = idx.saturating_add(1).saturating_sub(window);
841 while let Some(&(old_idx, _)) = queue.front() {
842 if old_idx < min_idx {
843 queue.pop_front();
844 } else {
845 break;
846 }
847 }
848 while let Some(&(_, old_value)) = queue.back() {
849 if old_value >= value {
850 queue.pop_back();
851 } else {
852 break;
853 }
854 }
855 queue.push_back((idx, value));
856}
857
858#[inline]
859fn evict_front(queue: &mut VecDeque<(usize, f64)>, idx: usize, window: usize) {
860 let min_idx = idx.saturating_add(1).saturating_sub(window);
861 while let Some(&(old_idx, _)) = queue.front() {
862 if old_idx < min_idx {
863 queue.pop_front();
864 } else {
865 break;
866 }
867 }
868}
869
870fn trend_follower_compute_clean_into(
871 high: &[f64],
872 low: &[f64],
873 close: &[f64],
874 volume: &[f64],
875 params: TrendFollowerResolvedParams,
876 first: usize,
877 kernel: Kernel,
878 out: &mut [f64],
879) -> Result<(), TrendFollowerError> {
880 let base_ma = compute_ma_series(close, volume, params, kernel)?;
881 let trend_ma = if params.use_linear_regression {
882 linreg_with_kernel(
883 &LinRegInput::from_slice(
884 &base_ma,
885 LinRegParams {
886 period: Some(params.linear_regression_period),
887 },
888 ),
889 kernel,
890 )
891 .map(|series| series.values)
892 .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?
893 } else {
894 base_ma
895 };
896
897 let mut high_max = VecDeque::with_capacity(CHANNEL_WINDOW);
898 let mut low_min = VecDeque::with_capacity(CHANNEL_WINDOW);
899 let mut ma_max = VecDeque::with_capacity(params.trend_period.max(1));
900 let mut ma_min = VecDeque::with_capacity(params.trend_period.max(1));
901
902 for i in first..high.len() {
903 evict_front(&mut high_max, i, CHANNEL_WINDOW);
904 evict_front(&mut low_min, i, CHANNEL_WINDOW);
905 evict_front(&mut ma_max, i, params.trend_period);
906 evict_front(&mut ma_min, i, params.trend_period);
907
908 push_max(&mut high_max, i, high[i], CHANNEL_WINDOW);
909 push_min(&mut low_min, i, low[i], CHANNEL_WINDOW);
910
911 let ma = trend_ma[i];
912 if ma.is_finite() {
913 push_max(&mut ma_max, i, ma, params.trend_period);
914 push_min(&mut ma_min, i, ma, params.trend_period);
915 }
916
917 let (hh, ll) = match (ma_max.front(), ma_min.front()) {
918 (Some((_, hh)), Some((_, ll))) => (*hh, *ll),
919 _ => continue,
920 };
921 let (channel_high, channel_low) = match (high_max.front(), low_min.front()) {
922 (Some((_, hi)), Some((_, lo))) => (*hi, *lo),
923 _ => continue,
924 };
925 let chan = (channel_high - channel_low) * params.channel_rate_fraction;
926 if !ma.is_finite() || !chan.is_finite() || chan == 0.0 {
927 out[i] = f64::NAN;
928 continue;
929 }
930
931 let diff = (hh - ll).abs();
932 let trend = if diff > chan {
933 if ma > ll + chan {
934 1.0
935 } else if ma < hh - chan {
936 -1.0
937 } else {
938 0.0
939 }
940 } else {
941 0.0
942 };
943 out[i] = trend * diff / chan;
944 }
945
946 Ok(())
947}
948
949fn trend_follower_compute_fallback_into(
950 high: &[f64],
951 low: &[f64],
952 close: &[f64],
953 volume: &[f64],
954 input: &TrendFollowerInput<'_>,
955 out: &mut [f64],
956) -> Result<(), TrendFollowerError> {
957 let mut stream = TrendFollowerStream::try_new(input.params.clone())?;
958 for i in 0..high.len() {
959 out[i] = stream
960 .update_reset_on_nan(high[i], low[i], close[i], volume[i])
961 .unwrap_or(f64::NAN);
962 }
963 Ok(())
964}
965
966fn trend_follower_compute_into(
967 high: &[f64],
968 low: &[f64],
969 close: &[f64],
970 volume: &[f64],
971 input: &TrendFollowerInput<'_>,
972 params: TrendFollowerResolvedParams,
973 first: usize,
974 kernel: Kernel,
975 out: &mut [f64],
976) -> Result<(), TrendFollowerError> {
977 if data_is_clean(
978 high,
979 low,
980 close,
981 volume,
982 first,
983 params.matype == TrendFollowerMaType::Vwma,
984 ) {
985 trend_follower_compute_clean_into(high, low, close, volume, params, first, kernel, out)
986 } else {
987 trend_follower_compute_fallback_into(high, low, close, volume, input, out)
988 }
989}
990
991#[inline]
992pub fn trend_follower(
993 input: &TrendFollowerInput<'_>,
994) -> Result<TrendFollowerOutput, TrendFollowerError> {
995 trend_follower_with_kernel(input, Kernel::Auto)
996}
997
998pub fn trend_follower_with_kernel(
999 input: &TrendFollowerInput<'_>,
1000 kernel: Kernel,
1001) -> Result<TrendFollowerOutput, TrendFollowerError> {
1002 let (high, low, close, volume, params, first) = trend_follower_prepare(input)?;
1003 let mut out = alloc_with_nan_prefix(close.len(), close.len());
1004 trend_follower_compute_into(
1005 high, low, close, volume, input, params, first, kernel, &mut out,
1006 )?;
1007 Ok(TrendFollowerOutput { values: out })
1008}
1009
1010#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1011#[inline]
1012pub fn trend_follower_into(
1013 input: &TrendFollowerInput<'_>,
1014 out: &mut [f64],
1015) -> Result<(), TrendFollowerError> {
1016 trend_follower_into_slice(out, input, Kernel::Auto)
1017}
1018
1019pub fn trend_follower_into_slice(
1020 out: &mut [f64],
1021 input: &TrendFollowerInput<'_>,
1022 kernel: Kernel,
1023) -> Result<(), TrendFollowerError> {
1024 let (high, low, close, volume, params, first) = trend_follower_prepare(input)?;
1025 if out.len() != close.len() {
1026 return Err(TrendFollowerError::OutputLengthMismatch {
1027 expected: close.len(),
1028 got: out.len(),
1029 });
1030 }
1031 out.fill(f64::NAN);
1032 trend_follower_compute_into(high, low, close, volume, input, params, first, kernel, out)
1033}
1034
1035#[derive(Clone, Debug)]
1036pub struct TrendFollowerStream {
1037 matype: TrendFollowerMaType,
1038 trend_period: usize,
1039 ma_period: usize,
1040 channel_rate_fraction: f64,
1041 use_linear_regression: bool,
1042 linear_regression_period: usize,
1043 ma_stream: TrendFollowerBaseMaStream,
1044 linreg_stream: Option<LinRegStream>,
1045 index: usize,
1046 high_max: VecDeque<(usize, f64)>,
1047 low_min: VecDeque<(usize, f64)>,
1048 ma_max: VecDeque<(usize, f64)>,
1049 ma_min: VecDeque<(usize, f64)>,
1050}
1051
1052impl TrendFollowerStream {
1053 #[inline]
1054 pub fn try_new(params: TrendFollowerParams) -> Result<Self, TrendFollowerError> {
1055 let input = TrendFollowerInput::from_slices(&[1.0], &[1.0], &[1.0], &[1.0], params);
1056 let resolved = resolve_params(&input, usize::MAX)?;
1057 let ma_stream = TrendFollowerBaseMaStream::new(resolved.matype, resolved.ma_period);
1058 let linreg_stream = if resolved.use_linear_regression {
1059 Some(
1060 LinRegStream::try_new(LinRegParams {
1061 period: Some(resolved.linear_regression_period),
1062 })
1063 .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?,
1064 )
1065 } else {
1066 None
1067 };
1068 Ok(Self {
1069 matype: resolved.matype,
1070 trend_period: resolved.trend_period,
1071 ma_period: resolved.ma_period,
1072 channel_rate_fraction: resolved.channel_rate_fraction,
1073 use_linear_regression: resolved.use_linear_regression,
1074 linear_regression_period: resolved.linear_regression_period,
1075 ma_stream,
1076 linreg_stream,
1077 index: 0,
1078 high_max: VecDeque::with_capacity(CHANNEL_WINDOW),
1079 low_min: VecDeque::with_capacity(CHANNEL_WINDOW),
1080 ma_max: VecDeque::with_capacity(resolved.trend_period.max(1)),
1081 ma_min: VecDeque::with_capacity(resolved.trend_period.max(1)),
1082 })
1083 }
1084
1085 #[inline]
1086 pub fn reset(&mut self) -> Result<(), TrendFollowerError> {
1087 self.ma_stream = TrendFollowerBaseMaStream::new(self.matype, self.ma_period);
1088 self.linreg_stream = if self.use_linear_regression {
1089 Some(
1090 LinRegStream::try_new(LinRegParams {
1091 period: Some(self.linear_regression_period),
1092 })
1093 .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?,
1094 )
1095 } else {
1096 None
1097 };
1098 self.index = 0;
1099 self.high_max.clear();
1100 self.low_min.clear();
1101 self.ma_max.clear();
1102 self.ma_min.clear();
1103 Ok(())
1104 }
1105
1106 #[inline]
1107 pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
1108 let needs_volume = self.matype == TrendFollowerMaType::Vwma;
1109 if !(high.is_finite() && low.is_finite() && close.is_finite())
1110 || (needs_volume && !volume.is_finite())
1111 {
1112 return None;
1113 }
1114
1115 let idx = self.index;
1116 evict_front(&mut self.high_max, idx, CHANNEL_WINDOW);
1117 evict_front(&mut self.low_min, idx, CHANNEL_WINDOW);
1118 evict_front(&mut self.ma_max, idx, self.trend_period);
1119 evict_front(&mut self.ma_min, idx, self.trend_period);
1120
1121 push_max(&mut self.high_max, idx, high, CHANNEL_WINDOW);
1122 push_min(&mut self.low_min, idx, low, CHANNEL_WINDOW);
1123
1124 let base_ma = self.ma_stream.update(close, volume);
1125 let ma = if self.use_linear_regression {
1126 match (base_ma, self.linreg_stream.as_mut()) {
1127 (Some(value), Some(stream)) => stream.update(value),
1128 _ => None,
1129 }
1130 } else {
1131 base_ma
1132 };
1133
1134 self.index = idx + 1;
1135
1136 let Some(ma) = ma else {
1137 return None;
1138 };
1139 if ma.is_finite() {
1140 push_max(&mut self.ma_max, idx, ma, self.trend_period);
1141 push_min(&mut self.ma_min, idx, ma, self.trend_period);
1142 } else {
1143 return Some(f64::NAN);
1144 }
1145
1146 let (hh, ll) = match (self.ma_max.front(), self.ma_min.front()) {
1147 (Some((_, hh)), Some((_, ll))) => (*hh, *ll),
1148 _ => return None,
1149 };
1150 let (channel_high, channel_low) = match (self.high_max.front(), self.low_min.front()) {
1151 (Some((_, hi)), Some((_, lo))) => (*hi, *lo),
1152 _ => return None,
1153 };
1154 let chan = (channel_high - channel_low) * self.channel_rate_fraction;
1155 if !chan.is_finite() || chan == 0.0 {
1156 return Some(f64::NAN);
1157 }
1158
1159 let diff = (hh - ll).abs();
1160 let trend = if diff > chan {
1161 if ma > ll + chan {
1162 1.0
1163 } else if ma < hh - chan {
1164 -1.0
1165 } else {
1166 0.0
1167 }
1168 } else {
1169 0.0
1170 };
1171 Some(trend * diff / chan)
1172 }
1173
1174 #[inline]
1175 pub fn update_reset_on_nan(
1176 &mut self,
1177 high: f64,
1178 low: f64,
1179 close: f64,
1180 volume: f64,
1181 ) -> Option<f64> {
1182 let needs_volume = self.matype == TrendFollowerMaType::Vwma;
1183 if !(high.is_finite() && low.is_finite() && close.is_finite())
1184 || (needs_volume && !volume.is_finite())
1185 {
1186 let _ = self.reset();
1187 return None;
1188 }
1189 self.update(high, low, close, volume)
1190 }
1191}
1192
1193#[derive(Clone, Debug)]
1194pub struct TrendFollowerBatchRange {
1195 pub trend_period: (usize, usize, usize),
1196 pub ma_period: (usize, usize, usize),
1197 pub channel_rate_percent: (f64, f64, f64),
1198 pub linear_regression_period: (usize, usize, usize),
1199 pub matype: (String, String, String),
1200 pub use_linear_regression: bool,
1201}
1202
1203impl Default for TrendFollowerBatchRange {
1204 fn default() -> Self {
1205 Self {
1206 trend_period: (20, 20, 0),
1207 ma_period: (20, 20, 0),
1208 channel_rate_percent: (1.0, 1.0, 0.0),
1209 linear_regression_period: (5, 5, 0),
1210 matype: ("ema".to_string(), "ema".to_string(), String::new()),
1211 use_linear_regression: true,
1212 }
1213 }
1214}
1215
1216#[derive(Clone, Debug, Default)]
1217pub struct TrendFollowerBatchBuilder {
1218 range: TrendFollowerBatchRange,
1219 kernel: Kernel,
1220}
1221
1222impl TrendFollowerBatchBuilder {
1223 pub fn new() -> Self {
1224 Self::default()
1225 }
1226
1227 pub fn kernel(mut self, kernel: Kernel) -> Self {
1228 self.kernel = kernel;
1229 self
1230 }
1231
1232 pub fn trend_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1233 self.range.trend_period = (start, end, step);
1234 self
1235 }
1236
1237 pub fn ma_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1238 self.range.ma_period = (start, end, step);
1239 self
1240 }
1241
1242 pub fn channel_rate_percent_range(mut self, start: f64, end: f64, step: f64) -> Self {
1243 self.range.channel_rate_percent = (start, end, step);
1244 self
1245 }
1246
1247 pub fn linear_regression_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1248 self.range.linear_regression_period = (start, end, step);
1249 self
1250 }
1251
1252 pub fn matype_static<S: Into<String>>(mut self, value: S) -> Self {
1253 let value = value.into();
1254 self.range.matype = (value.clone(), value, String::new());
1255 self
1256 }
1257
1258 pub fn use_linear_regression(mut self, value: bool) -> Self {
1259 self.range.use_linear_regression = value;
1260 self
1261 }
1262
1263 pub fn apply_slices(
1264 self,
1265 high: &[f64],
1266 low: &[f64],
1267 close: &[f64],
1268 volume: &[f64],
1269 ) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1270 trend_follower_batch_with_kernel(high, low, close, volume, &self.range, self.kernel)
1271 }
1272
1273 pub fn apply_candles(
1274 self,
1275 candles: &Candles,
1276 ) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1277 self.apply_slices(&candles.high, &candles.low, &candles.close, &candles.volume)
1278 }
1279}
1280
1281#[derive(Clone, Debug)]
1282pub struct TrendFollowerBatchOutput {
1283 pub values: Vec<f64>,
1284 pub combos: Vec<TrendFollowerParams>,
1285 pub rows: usize,
1286 pub cols: usize,
1287}
1288
1289impl TrendFollowerBatchOutput {
1290 pub fn row_for_params(&self, params: &TrendFollowerParams) -> Option<usize> {
1291 let matype = params
1292 .matype
1293 .as_deref()
1294 .unwrap_or("ema")
1295 .to_ascii_lowercase();
1296 self.combos.iter().position(|combo| {
1297 combo.trend_period.unwrap_or(20) == params.trend_period.unwrap_or(20)
1298 && combo.ma_period.unwrap_or(20) == params.ma_period.unwrap_or(20)
1299 && (combo.channel_rate_percent.unwrap_or(1.0)
1300 - params.channel_rate_percent.unwrap_or(1.0))
1301 .abs()
1302 <= 1e-12
1303 && combo.use_linear_regression.unwrap_or(true)
1304 == params.use_linear_regression.unwrap_or(true)
1305 && combo.linear_regression_period.unwrap_or(5)
1306 == params.linear_regression_period.unwrap_or(5)
1307 && combo
1308 .matype
1309 .as_deref()
1310 .unwrap_or("ema")
1311 .eq_ignore_ascii_case(&matype)
1312 })
1313 }
1314
1315 pub fn values_for(&self, params: &TrendFollowerParams) -> Option<&[f64]> {
1316 self.row_for_params(params).map(|row| {
1317 let start = row * self.cols;
1318 &self.values[start..start + self.cols]
1319 })
1320 }
1321}
1322
1323#[inline]
1324fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, TrendFollowerError> {
1325 let (start, end, step) = range;
1326 if start == 0 || end == 0 {
1327 return Err(TrendFollowerError::InvalidRangeUsize { start, end, step });
1328 }
1329 if step == 0 || start == end {
1330 return Ok(vec![start]);
1331 }
1332 let mut out = Vec::new();
1333 if start < end {
1334 let mut value = start;
1335 while value <= end {
1336 out.push(value);
1337 match value.checked_add(step) {
1338 Some(next) if next > value => value = next,
1339 _ => break,
1340 }
1341 }
1342 } else {
1343 let mut value = start;
1344 while value >= end {
1345 out.push(value);
1346 if value < end + step {
1347 break;
1348 }
1349 value = value.saturating_sub(step);
1350 if value == 0 {
1351 break;
1352 }
1353 }
1354 }
1355 if out.is_empty() {
1356 return Err(TrendFollowerError::InvalidRangeUsize { start, end, step });
1357 }
1358 Ok(out)
1359}
1360
1361#[inline]
1362fn axis_f64(range: (f64, f64, f64)) -> Result<Vec<f64>, TrendFollowerError> {
1363 let (start, end, step) = range;
1364 if !start.is_finite() || !end.is_finite() || !step.is_finite() {
1365 return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1366 }
1367 if step == 0.0 || (start - end).abs() <= 1e-12 {
1368 return Ok(vec![start]);
1369 }
1370 if step < 0.0 {
1371 return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1372 }
1373 let mut out = Vec::new();
1374 if start < end {
1375 let mut value = start;
1376 while value <= end + 1e-12 {
1377 out.push(value);
1378 value += step;
1379 }
1380 } else {
1381 let mut value = start;
1382 while value >= end - 1e-12 {
1383 out.push(value);
1384 value -= step;
1385 }
1386 }
1387 if out.is_empty() {
1388 return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1389 }
1390 Ok(out)
1391}
1392
1393#[inline]
1394fn axis_string(range: (String, String, String)) -> Vec<String> {
1395 if range.0.eq_ignore_ascii_case(&range.1) {
1396 vec![range.0]
1397 } else {
1398 vec![range.0, range.1]
1399 }
1400}
1401
1402pub fn expand_grid_trend_follower(
1403 range: &TrendFollowerBatchRange,
1404) -> Result<Vec<TrendFollowerParams>, TrendFollowerError> {
1405 let trend_periods = axis_usize(range.trend_period)?;
1406 let ma_periods = axis_usize(range.ma_period)?;
1407 let channel_rates = axis_f64(range.channel_rate_percent)?;
1408 let linear_regression_periods = axis_usize(range.linear_regression_period)?;
1409 let matypes = axis_string(range.matype.clone());
1410
1411 let mut out = Vec::new();
1412 for trend_period in &trend_periods {
1413 for ma_period in &ma_periods {
1414 for channel_rate_percent in &channel_rates {
1415 for linear_regression_period in &linear_regression_periods {
1416 for matype in &matypes {
1417 out.push(TrendFollowerParams {
1418 matype: Some(matype.to_ascii_lowercase()),
1419 trend_period: Some(*trend_period),
1420 ma_period: Some(*ma_period),
1421 channel_rate_percent: Some(*channel_rate_percent),
1422 use_linear_regression: Some(range.use_linear_regression),
1423 linear_regression_period: Some(*linear_regression_period),
1424 });
1425 }
1426 }
1427 }
1428 }
1429 }
1430 Ok(out)
1431}
1432
1433pub fn trend_follower_batch_with_kernel(
1434 high: &[f64],
1435 low: &[f64],
1436 close: &[f64],
1437 volume: &[f64],
1438 range: &TrendFollowerBatchRange,
1439 kernel: Kernel,
1440) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1441 let batch_kernel = match kernel {
1442 Kernel::Auto => Kernel::ScalarBatch,
1443 other if other.is_batch() => other,
1444 other => return Err(TrendFollowerError::InvalidKernelForBatch(other)),
1445 };
1446 trend_follower_batch_impl(
1447 high,
1448 low,
1449 close,
1450 volume,
1451 range,
1452 batch_kernel.to_non_batch(),
1453 true,
1454 )
1455}
1456
1457pub fn trend_follower_batch_slice(
1458 high: &[f64],
1459 low: &[f64],
1460 close: &[f64],
1461 volume: &[f64],
1462 range: &TrendFollowerBatchRange,
1463) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1464 trend_follower_batch_impl(high, low, close, volume, range, Kernel::Scalar, false)
1465}
1466
1467pub fn trend_follower_batch_par_slice(
1468 high: &[f64],
1469 low: &[f64],
1470 close: &[f64],
1471 volume: &[f64],
1472 range: &TrendFollowerBatchRange,
1473) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1474 trend_follower_batch_impl(high, low, close, volume, range, Kernel::Scalar, true)
1475}
1476
1477fn trend_follower_batch_impl(
1478 high: &[f64],
1479 low: &[f64],
1480 close: &[f64],
1481 volume: &[f64],
1482 range: &TrendFollowerBatchRange,
1483 kernel: Kernel,
1484 parallel: bool,
1485) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1486 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1487 return Err(TrendFollowerError::DataLengthMismatch {
1488 high_len: high.len(),
1489 low_len: low.len(),
1490 close_len: close.len(),
1491 volume_len: volume.len(),
1492 });
1493 }
1494 if high.is_empty() {
1495 return Err(TrendFollowerError::EmptyInputData);
1496 }
1497
1498 let combos = expand_grid_trend_follower(range)?;
1499 let rows = combos.len();
1500 let cols = close.len();
1501 let mut matrix = make_uninit_matrix(rows, cols);
1502 init_matrix_prefixes(&mut matrix, cols, &vec![cols; rows]);
1503
1504 let mut guard = ManuallyDrop::new(matrix);
1505 let out_mu: &mut [MaybeUninit<f64>] =
1506 unsafe { std::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
1507
1508 let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| {
1509 let out = unsafe {
1510 std::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
1511 };
1512 let input = TrendFollowerInput::from_slices(high, low, close, volume, combos[row].clone());
1513 let _ = trend_follower_into_slice(out, &input, kernel);
1514 };
1515
1516 if parallel {
1517 #[cfg(not(target_arch = "wasm32"))]
1518 out_mu
1519 .par_chunks_mut(cols)
1520 .enumerate()
1521 .for_each(|(row, row_mu)| do_row(row, row_mu));
1522 #[cfg(target_arch = "wasm32")]
1523 for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
1524 do_row(row, row_mu);
1525 }
1526 } else {
1527 for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
1528 do_row(row, row_mu);
1529 }
1530 }
1531
1532 let values = unsafe {
1533 Vec::from_raw_parts(
1534 guard.as_mut_ptr() as *mut f64,
1535 guard.len(),
1536 guard.capacity(),
1537 )
1538 };
1539
1540 Ok(TrendFollowerBatchOutput {
1541 values,
1542 combos,
1543 rows,
1544 cols,
1545 })
1546}
1547
1548fn trend_follower_batch_inner_into(
1549 high: &[f64],
1550 low: &[f64],
1551 close: &[f64],
1552 volume: &[f64],
1553 range: &TrendFollowerBatchRange,
1554 kernel: Kernel,
1555 parallel: bool,
1556 out: &mut [f64],
1557) -> Result<(), TrendFollowerError> {
1558 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1559 return Err(TrendFollowerError::DataLengthMismatch {
1560 high_len: high.len(),
1561 low_len: low.len(),
1562 close_len: close.len(),
1563 volume_len: volume.len(),
1564 });
1565 }
1566 let combos = expand_grid_trend_follower(range)?;
1567 let rows = combos.len();
1568 let cols = close.len();
1569 if rows.checked_mul(cols) != Some(out.len()) {
1570 return Err(TrendFollowerError::OutputLengthMismatch {
1571 expected: rows * cols,
1572 got: out.len(),
1573 });
1574 }
1575
1576 for row_out in out.chunks_mut(cols) {
1577 row_out.fill(f64::NAN);
1578 }
1579
1580 let do_row = |row: usize, row_out: &mut [f64]| {
1581 let input = TrendFollowerInput::from_slices(high, low, close, volume, combos[row].clone());
1582 let _ = trend_follower_into_slice(row_out, &input, kernel);
1583 };
1584
1585 if parallel {
1586 #[cfg(not(target_arch = "wasm32"))]
1587 out.par_chunks_mut(cols)
1588 .enumerate()
1589 .for_each(|(row, row_out)| do_row(row, row_out));
1590 #[cfg(target_arch = "wasm32")]
1591 for (row, row_out) in out.chunks_mut(cols).enumerate() {
1592 do_row(row, row_out);
1593 }
1594 } else {
1595 for (row, row_out) in out.chunks_mut(cols).enumerate() {
1596 do_row(row, row_out);
1597 }
1598 }
1599
1600 Ok(())
1601}
1602
1603#[cfg(feature = "python")]
1604#[pyfunction(name = "trend_follower")]
1605#[pyo3(signature = (high, low, close, volume, matype="ema", trend_period=20, ma_period=20, channel_rate_percent=1.0, use_linear_regression=true, linear_regression_period=5, kernel=None))]
1606pub fn trend_follower_py<'py>(
1607 py: Python<'py>,
1608 high: PyReadonlyArray1<'py, f64>,
1609 low: PyReadonlyArray1<'py, f64>,
1610 close: PyReadonlyArray1<'py, f64>,
1611 volume: PyReadonlyArray1<'py, f64>,
1612 matype: &str,
1613 trend_period: usize,
1614 ma_period: usize,
1615 channel_rate_percent: f64,
1616 use_linear_regression: bool,
1617 linear_regression_period: usize,
1618 kernel: Option<&str>,
1619) -> PyResult<Bound<'py, PyArray1<f64>>> {
1620 let high = high.as_slice()?;
1621 let low = low.as_slice()?;
1622 let close = close.as_slice()?;
1623 let volume = volume.as_slice()?;
1624 let kernel = validate_kernel(kernel, false)?;
1625 let input = TrendFollowerInput::from_slices(
1626 high,
1627 low,
1628 close,
1629 volume,
1630 TrendFollowerParams {
1631 matype: Some(matype.to_string()),
1632 trend_period: Some(trend_period),
1633 ma_period: Some(ma_period),
1634 channel_rate_percent: Some(channel_rate_percent),
1635 use_linear_regression: Some(use_linear_regression),
1636 linear_regression_period: Some(linear_regression_period),
1637 },
1638 );
1639 let output = py
1640 .allow_threads(|| trend_follower_with_kernel(&input, kernel))
1641 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1642 Ok(output.values.into_pyarray(py))
1643}
1644
1645#[cfg(feature = "python")]
1646#[pyclass(name = "TrendFollowerStream")]
1647pub struct TrendFollowerStreamPy {
1648 stream: TrendFollowerStream,
1649}
1650
1651#[cfg(feature = "python")]
1652#[pymethods]
1653impl TrendFollowerStreamPy {
1654 #[new]
1655 #[pyo3(signature = (matype="ema", trend_period=20, ma_period=20, channel_rate_percent=1.0, use_linear_regression=true, linear_regression_period=5))]
1656 fn new(
1657 matype: &str,
1658 trend_period: usize,
1659 ma_period: usize,
1660 channel_rate_percent: f64,
1661 use_linear_regression: bool,
1662 linear_regression_period: usize,
1663 ) -> PyResult<Self> {
1664 let stream = TrendFollowerStream::try_new(TrendFollowerParams {
1665 matype: Some(matype.to_string()),
1666 trend_period: Some(trend_period),
1667 ma_period: Some(ma_period),
1668 channel_rate_percent: Some(channel_rate_percent),
1669 use_linear_regression: Some(use_linear_regression),
1670 linear_regression_period: Some(linear_regression_period),
1671 })
1672 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1673 Ok(Self { stream })
1674 }
1675
1676 fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
1677 self.stream.update_reset_on_nan(high, low, close, volume)
1678 }
1679}
1680
1681#[cfg(feature = "python")]
1682#[pyfunction(name = "trend_follower_batch")]
1683#[pyo3(signature = (high, low, close, volume, trend_period_range=(20, 20, 0), ma_period_range=(20, 20, 0), channel_rate_percent_range=(1.0, 1.0, 0.0), linear_regression_period_range=(5, 5, 0), matype="ema", use_linear_regression=true, kernel=None))]
1684pub fn trend_follower_batch_py<'py>(
1685 py: Python<'py>,
1686 high: PyReadonlyArray1<'py, f64>,
1687 low: PyReadonlyArray1<'py, f64>,
1688 close: PyReadonlyArray1<'py, f64>,
1689 volume: PyReadonlyArray1<'py, f64>,
1690 trend_period_range: (usize, usize, usize),
1691 ma_period_range: (usize, usize, usize),
1692 channel_rate_percent_range: (f64, f64, f64),
1693 linear_regression_period_range: (usize, usize, usize),
1694 matype: &str,
1695 use_linear_regression: bool,
1696 kernel: Option<&str>,
1697) -> PyResult<Bound<'py, PyDict>> {
1698 let high = high.as_slice()?;
1699 let low = low.as_slice()?;
1700 let close = close.as_slice()?;
1701 let volume = volume.as_slice()?;
1702 let range = TrendFollowerBatchRange {
1703 trend_period: trend_period_range,
1704 ma_period: ma_period_range,
1705 channel_rate_percent: channel_rate_percent_range,
1706 linear_regression_period: linear_regression_period_range,
1707 matype: (matype.to_string(), matype.to_string(), String::new()),
1708 use_linear_regression,
1709 };
1710 let combos =
1711 expand_grid_trend_follower(&range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1712 let rows = combos.len();
1713 let cols = close.len();
1714 let total = rows
1715 .checked_mul(cols)
1716 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1717 let arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1718 let out = unsafe { arr.as_slice_mut()? };
1719 let kernel = validate_kernel(kernel, true)?;
1720
1721 py.allow_threads(|| {
1722 let batch_kernel = match kernel {
1723 Kernel::Auto => detect_best_batch_kernel(),
1724 other => other,
1725 };
1726 trend_follower_batch_inner_into(
1727 high,
1728 low,
1729 close,
1730 volume,
1731 &range,
1732 batch_kernel.to_non_batch(),
1733 true,
1734 out,
1735 )
1736 })
1737 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1738
1739 let dict = PyDict::new(py);
1740 dict.set_item("values", arr.reshape((rows, cols))?)?;
1741 dict.set_item(
1742 "trend_periods",
1743 combos
1744 .iter()
1745 .map(|params| params.trend_period.unwrap_or(20) as u64)
1746 .collect::<Vec<_>>()
1747 .into_pyarray(py),
1748 )?;
1749 dict.set_item(
1750 "ma_periods",
1751 combos
1752 .iter()
1753 .map(|params| params.ma_period.unwrap_or(20) as u64)
1754 .collect::<Vec<_>>()
1755 .into_pyarray(py),
1756 )?;
1757 dict.set_item(
1758 "channel_rate_percents",
1759 combos
1760 .iter()
1761 .map(|params| params.channel_rate_percent.unwrap_or(1.0))
1762 .collect::<Vec<_>>()
1763 .into_pyarray(py),
1764 )?;
1765 dict.set_item(
1766 "linear_regression_periods",
1767 combos
1768 .iter()
1769 .map(|params| params.linear_regression_period.unwrap_or(5) as u64)
1770 .collect::<Vec<_>>()
1771 .into_pyarray(py),
1772 )?;
1773 dict.set_item(
1774 "matypes",
1775 combos
1776 .iter()
1777 .map(|params| params.matype.as_deref().unwrap_or("ema").to_string())
1778 .collect::<Vec<_>>(),
1779 )?;
1780 dict.set_item(
1781 "use_linear_regression",
1782 combos
1783 .iter()
1784 .map(|params| params.use_linear_regression.unwrap_or(true))
1785 .collect::<Vec<_>>()
1786 .into_pyarray(py),
1787 )?;
1788 dict.set_item("rows", rows)?;
1789 dict.set_item("cols", cols)?;
1790 Ok(dict)
1791}
1792
1793#[cfg(feature = "python")]
1794pub fn register_trend_follower_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1795 m.add_function(wrap_pyfunction!(trend_follower_py, m)?)?;
1796 m.add_function(wrap_pyfunction!(trend_follower_batch_py, m)?)?;
1797 m.add_class::<TrendFollowerStreamPy>()?;
1798 Ok(())
1799}
1800
1801#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1802#[derive(Debug, Clone, Serialize, Deserialize)]
1803struct TrendFollowerBatchConfig {
1804 trend_period_range: Vec<usize>,
1805 ma_period_range: Vec<usize>,
1806 channel_rate_percent_range: Vec<f64>,
1807 linear_regression_period_range: Vec<usize>,
1808 matype: String,
1809 use_linear_regression: bool,
1810}
1811
1812#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1813#[derive(Debug, Clone, Serialize, Deserialize)]
1814struct TrendFollowerBatchJsOutput {
1815 values: Vec<f64>,
1816 rows: usize,
1817 cols: usize,
1818 combos: Vec<TrendFollowerParams>,
1819}
1820
1821#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1822#[wasm_bindgen(js_name = "trend_follower_js")]
1823pub fn trend_follower_js(
1824 high: &[f64],
1825 low: &[f64],
1826 close: &[f64],
1827 volume: &[f64],
1828 matype: &str,
1829 trend_period: usize,
1830 ma_period: usize,
1831 channel_rate_percent: f64,
1832 use_linear_regression: bool,
1833 linear_regression_period: usize,
1834) -> Result<Vec<f64>, JsValue> {
1835 let input = TrendFollowerInput::from_slices(
1836 high,
1837 low,
1838 close,
1839 volume,
1840 TrendFollowerParams {
1841 matype: Some(matype.to_string()),
1842 trend_period: Some(trend_period),
1843 ma_period: Some(ma_period),
1844 channel_rate_percent: Some(channel_rate_percent),
1845 use_linear_regression: Some(use_linear_regression),
1846 linear_regression_period: Some(linear_regression_period),
1847 },
1848 );
1849 trend_follower(&input)
1850 .map(|out| out.values)
1851 .map_err(|e| JsValue::from_str(&e.to_string()))
1852}
1853
1854#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1855#[wasm_bindgen(js_name = "trend_follower_batch_js")]
1856pub fn trend_follower_batch_js(
1857 high: &[f64],
1858 low: &[f64],
1859 close: &[f64],
1860 volume: &[f64],
1861 config: JsValue,
1862) -> Result<JsValue, JsValue> {
1863 let config: TrendFollowerBatchConfig = serde_wasm_bindgen::from_value(config)
1864 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1865 if config.trend_period_range.len() != 3
1866 || config.ma_period_range.len() != 3
1867 || config.channel_rate_percent_range.len() != 3
1868 || config.linear_regression_period_range.len() != 3
1869 {
1870 return Err(JsValue::from_str(
1871 "Invalid config: all *_range fields must have exactly 3 elements",
1872 ));
1873 }
1874 let range = TrendFollowerBatchRange {
1875 trend_period: (
1876 config.trend_period_range[0],
1877 config.trend_period_range[1],
1878 config.trend_period_range[2],
1879 ),
1880 ma_period: (
1881 config.ma_period_range[0],
1882 config.ma_period_range[1],
1883 config.ma_period_range[2],
1884 ),
1885 channel_rate_percent: (
1886 config.channel_rate_percent_range[0],
1887 config.channel_rate_percent_range[1],
1888 config.channel_rate_percent_range[2],
1889 ),
1890 linear_regression_period: (
1891 config.linear_regression_period_range[0],
1892 config.linear_regression_period_range[1],
1893 config.linear_regression_period_range[2],
1894 ),
1895 matype: (config.matype.clone(), config.matype, String::new()),
1896 use_linear_regression: config.use_linear_regression,
1897 };
1898 let batch = trend_follower_batch_slice(high, low, close, volume, &range)
1899 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1900 serde_wasm_bindgen::to_value(&TrendFollowerBatchJsOutput {
1901 values: batch.values,
1902 rows: batch.rows,
1903 cols: batch.cols,
1904 combos: batch.combos,
1905 })
1906 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1907}
1908
1909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1910#[wasm_bindgen]
1911pub fn trend_follower_alloc(len: usize) -> *mut f64 {
1912 let mut vec = Vec::<f64>::with_capacity(len);
1913 let ptr = vec.as_mut_ptr();
1914 std::mem::forget(vec);
1915 ptr
1916}
1917
1918#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1919#[wasm_bindgen]
1920pub fn trend_follower_free(ptr: *mut f64, len: usize) {
1921 unsafe {
1922 let _ = Vec::from_raw_parts(ptr, len, len);
1923 }
1924}
1925
1926#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1927#[wasm_bindgen]
1928pub fn trend_follower_into(
1929 high_ptr: *const f64,
1930 low_ptr: *const f64,
1931 close_ptr: *const f64,
1932 volume_ptr: *const f64,
1933 out_ptr: *mut f64,
1934 len: usize,
1935 matype: &str,
1936 trend_period: usize,
1937 ma_period: usize,
1938 channel_rate_percent: f64,
1939 use_linear_regression: bool,
1940 linear_regression_period: usize,
1941) -> Result<(), JsValue> {
1942 if high_ptr.is_null()
1943 || low_ptr.is_null()
1944 || close_ptr.is_null()
1945 || volume_ptr.is_null()
1946 || out_ptr.is_null()
1947 {
1948 return Err(JsValue::from_str(
1949 "null pointer passed to trend_follower_into",
1950 ));
1951 }
1952 unsafe {
1953 let high = std::slice::from_raw_parts(high_ptr, len);
1954 let low = std::slice::from_raw_parts(low_ptr, len);
1955 let close = std::slice::from_raw_parts(close_ptr, len);
1956 let volume = std::slice::from_raw_parts(volume_ptr, len);
1957 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1958 let input = TrendFollowerInput::from_slices(
1959 high,
1960 low,
1961 close,
1962 volume,
1963 TrendFollowerParams {
1964 matype: Some(matype.to_string()),
1965 trend_period: Some(trend_period),
1966 ma_period: Some(ma_period),
1967 channel_rate_percent: Some(channel_rate_percent),
1968 use_linear_regression: Some(use_linear_regression),
1969 linear_regression_period: Some(linear_regression_period),
1970 },
1971 );
1972 trend_follower_into_slice(out, &input, Kernel::Auto)
1973 .map_err(|e| JsValue::from_str(&e.to_string()))
1974 }
1975}
1976
1977#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1978#[wasm_bindgen(js_name = "trend_follower_into_host")]
1979pub fn trend_follower_into_host(
1980 high: &[f64],
1981 low: &[f64],
1982 close: &[f64],
1983 volume: &[f64],
1984 out_ptr: *mut f64,
1985 matype: &str,
1986 trend_period: usize,
1987 ma_period: usize,
1988 channel_rate_percent: f64,
1989 use_linear_regression: bool,
1990 linear_regression_period: usize,
1991) -> Result<(), JsValue> {
1992 if out_ptr.is_null() {
1993 return Err(JsValue::from_str(
1994 "null pointer passed to trend_follower_into_host",
1995 ));
1996 }
1997 unsafe {
1998 let out = std::slice::from_raw_parts_mut(out_ptr, close.len());
1999 let input = TrendFollowerInput::from_slices(
2000 high,
2001 low,
2002 close,
2003 volume,
2004 TrendFollowerParams {
2005 matype: Some(matype.to_string()),
2006 trend_period: Some(trend_period),
2007 ma_period: Some(ma_period),
2008 channel_rate_percent: Some(channel_rate_percent),
2009 use_linear_regression: Some(use_linear_regression),
2010 linear_regression_period: Some(linear_regression_period),
2011 },
2012 );
2013 trend_follower_into_slice(out, &input, Kernel::Auto)
2014 .map_err(|e| JsValue::from_str(&e.to_string()))
2015 }
2016}
2017
2018#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2019#[wasm_bindgen]
2020pub fn trend_follower_batch_into(
2021 high_ptr: *const f64,
2022 low_ptr: *const f64,
2023 close_ptr: *const f64,
2024 volume_ptr: *const f64,
2025 out_ptr: *mut f64,
2026 len: usize,
2027 trend_period_start: usize,
2028 trend_period_end: usize,
2029 trend_period_step: usize,
2030 ma_period_start: usize,
2031 ma_period_end: usize,
2032 ma_period_step: usize,
2033 channel_rate_percent_start: f64,
2034 channel_rate_percent_end: f64,
2035 channel_rate_percent_step: f64,
2036 linear_regression_period_start: usize,
2037 linear_regression_period_end: usize,
2038 linear_regression_period_step: usize,
2039 matype: &str,
2040 use_linear_regression: bool,
2041) -> Result<usize, JsValue> {
2042 if high_ptr.is_null()
2043 || low_ptr.is_null()
2044 || close_ptr.is_null()
2045 || volume_ptr.is_null()
2046 || out_ptr.is_null()
2047 {
2048 return Err(JsValue::from_str(
2049 "null pointer passed to trend_follower_batch_into",
2050 ));
2051 }
2052 unsafe {
2053 let high = std::slice::from_raw_parts(high_ptr, len);
2054 let low = std::slice::from_raw_parts(low_ptr, len);
2055 let close = std::slice::from_raw_parts(close_ptr, len);
2056 let volume = std::slice::from_raw_parts(volume_ptr, len);
2057 let range = TrendFollowerBatchRange {
2058 trend_period: (trend_period_start, trend_period_end, trend_period_step),
2059 ma_period: (ma_period_start, ma_period_end, ma_period_step),
2060 channel_rate_percent: (
2061 channel_rate_percent_start,
2062 channel_rate_percent_end,
2063 channel_rate_percent_step,
2064 ),
2065 linear_regression_period: (
2066 linear_regression_period_start,
2067 linear_regression_period_end,
2068 linear_regression_period_step,
2069 ),
2070 matype: (matype.to_string(), matype.to_string(), String::new()),
2071 use_linear_regression,
2072 };
2073 let combos =
2074 expand_grid_trend_follower(&range).map_err(|e| JsValue::from_str(&e.to_string()))?;
2075 let rows = combos.len();
2076 let out = std::slice::from_raw_parts_mut(out_ptr, rows * len);
2077 trend_follower_batch_inner_into(
2078 high,
2079 low,
2080 close,
2081 volume,
2082 &range,
2083 Kernel::Scalar,
2084 false,
2085 out,
2086 )
2087 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2088 Ok(rows)
2089 }
2090}
2091
2092#[cfg(test)]
2093mod tests {
2094 use super::*;
2095
2096 fn sample_ohlcv(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
2097 let mut high = Vec::with_capacity(len);
2098 let mut low = Vec::with_capacity(len);
2099 let mut close = Vec::with_capacity(len);
2100 let mut volume = Vec::with_capacity(len);
2101 for i in 0..len {
2102 let base = 100.0 + i as f64 * 0.19 + (i as f64 * 0.13).sin() * 1.8;
2103 let c = base + (i as f64 * 0.027).cos() * 0.6;
2104 high.push(c + 1.2 + (i as f64 * 0.05).sin().abs());
2105 low.push(c - 1.1 - (i as f64 * 0.04).cos().abs());
2106 close.push(c);
2107 volume.push(1000.0 + i as f64 * 11.0 + (i % 9) as f64 * 17.0);
2108 }
2109 (high, low, close, volume)
2110 }
2111
2112 fn assert_close(a: &[f64], b: &[f64]) {
2113 assert_eq!(a.len(), b.len());
2114 for i in 0..a.len() {
2115 if a[i].is_nan() || b[i].is_nan() {
2116 assert!(a[i].is_nan() && b[i].is_nan(), "nan mismatch at {i}");
2117 } else {
2118 assert!(
2119 (a[i] - b[i]).abs() <= 1e-9,
2120 "value mismatch at {i}: {} vs {}",
2121 a[i],
2122 b[i]
2123 );
2124 }
2125 }
2126 }
2127
2128 #[test]
2129 fn trend_follower_into_matches_api() {
2130 let (high, low, close, volume) = sample_ohlcv(128);
2131 let input = TrendFollowerInput::from_slices(
2132 &high,
2133 &low,
2134 &close,
2135 &volume,
2136 TrendFollowerParams::default(),
2137 );
2138 let direct = trend_follower(&input).unwrap();
2139 let mut out = vec![f64::NAN; close.len()];
2140 trend_follower_into_slice(&mut out, &input, Kernel::Auto).unwrap();
2141 assert_close(&direct.values, &out);
2142 }
2143
2144 #[test]
2145 fn trend_follower_stream_matches_batch_with_nan_gap() {
2146 let (mut high, mut low, mut close, mut volume) = sample_ohlcv(128);
2147 high[48] = f64::NAN;
2148 low[48] = f64::NAN;
2149 close[48] = f64::NAN;
2150 volume[48] = f64::NAN;
2151 let input = TrendFollowerInput::from_slices(
2152 &high,
2153 &low,
2154 &close,
2155 &volume,
2156 TrendFollowerParams::default(),
2157 );
2158 let batch = trend_follower(&input).unwrap();
2159 let mut stream = TrendFollowerStream::try_new(TrendFollowerParams::default()).unwrap();
2160 let mut collected = Vec::with_capacity(close.len());
2161 for i in 0..close.len() {
2162 collected.push(
2163 stream
2164 .update_reset_on_nan(high[i], low[i], close[i], volume[i])
2165 .unwrap_or(f64::NAN),
2166 );
2167 }
2168 assert_close(&batch.values, &collected);
2169 }
2170
2171 #[test]
2172 fn trend_follower_batch_single_param_matches_single() {
2173 let (high, low, close, volume) = sample_ohlcv(128);
2174 let params = TrendFollowerParams {
2175 matype: Some("wma".to_string()),
2176 trend_period: Some(14),
2177 ma_period: Some(9),
2178 channel_rate_percent: Some(1.1),
2179 use_linear_regression: Some(false),
2180 linear_regression_period: Some(5),
2181 };
2182 let single = trend_follower(&TrendFollowerInput::from_slices(
2183 &high,
2184 &low,
2185 &close,
2186 &volume,
2187 params.clone(),
2188 ))
2189 .unwrap();
2190 let batch = trend_follower_batch_with_kernel(
2191 &high,
2192 &low,
2193 &close,
2194 &volume,
2195 &TrendFollowerBatchRange {
2196 trend_period: (14, 14, 0),
2197 ma_period: (9, 9, 0),
2198 channel_rate_percent: (1.1, 1.1, 0.0),
2199 linear_regression_period: (5, 5, 0),
2200 matype: ("wma".to_string(), "wma".to_string(), String::new()),
2201 use_linear_regression: false,
2202 },
2203 Kernel::Auto,
2204 )
2205 .unwrap();
2206 assert_eq!(batch.rows, 1);
2207 assert_close(&single.values, &batch.values[..close.len()]);
2208 }
2209
2210 #[test]
2211 fn trend_follower_vwma_depends_on_volume() {
2212 let (high, low, close, volume) = sample_ohlcv(96);
2213 let mut volume_b = volume.clone();
2214 volume_b.reverse();
2215 let params = TrendFollowerParams {
2216 matype: Some("vwma".to_string()),
2217 trend_period: Some(20),
2218 ma_period: Some(12),
2219 channel_rate_percent: Some(1.0),
2220 use_linear_regression: Some(false),
2221 linear_regression_period: Some(5),
2222 };
2223 let a = trend_follower(&TrendFollowerInput::from_slices(
2224 &high,
2225 &low,
2226 &close,
2227 &volume,
2228 params.clone(),
2229 ))
2230 .unwrap();
2231 let b = trend_follower(&TrendFollowerInput::from_slices(
2232 &high, &low, &close, &volume_b, params,
2233 ))
2234 .unwrap();
2235 assert!(a
2236 .values
2237 .iter()
2238 .zip(&b.values)
2239 .any(|(x, y)| x.is_finite() && y.is_finite() && (x - y).abs() > 1e-9));
2240 }
2241
2242 #[test]
2243 fn trend_follower_invalid_matype_rejected() {
2244 let (high, low, close, volume) = sample_ohlcv(64);
2245 let err = trend_follower(&TrendFollowerInput::from_slices(
2246 &high,
2247 &low,
2248 &close,
2249 &volume,
2250 TrendFollowerParams {
2251 matype: Some("hma".to_string()),
2252 ..TrendFollowerParams::default()
2253 },
2254 ))
2255 .unwrap_err();
2256 assert!(matches!(err, TrendFollowerError::InvalidMaType { .. }));
2257 }
2258}