Skip to main content

vector_ta/indicators/
trend_follower.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::moving_averages::ema::{ema_with_kernel, EmaInput, EmaParams};
16use crate::indicators::moving_averages::linreg::{
17    linreg_with_kernel, LinRegInput, LinRegParams, LinRegStream,
18};
19use crate::indicators::moving_averages::sma::{sma_with_kernel, SmaInput, SmaParams};
20use crate::indicators::moving_averages::vwma::{vwma_with_kernel, VwmaInput, VwmaParams};
21use crate::indicators::moving_averages::wilders::{
22    wilders_with_kernel, WildersInput, WildersParams,
23};
24use crate::indicators::moving_averages::wma::{wma_with_kernel, WmaInput, WmaParams};
25use crate::utilities::data_loader::Candles;
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28    alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
29};
30#[cfg(feature = "python")]
31use crate::utilities::kernel_validation::validate_kernel;
32
33#[cfg(not(target_arch = "wasm32"))]
34use rayon::prelude::*;
35use std::collections::VecDeque;
36use std::mem::{ManuallyDrop, MaybeUninit};
37use thiserror::Error;
38
39const CHANNEL_WINDOW: usize = 280;
40
41#[derive(Debug, Clone)]
42pub struct TrendFollowerOutput {
43    pub values: Vec<f64>,
44}
45
46#[derive(Debug, Clone)]
47#[cfg_attr(
48    all(target_arch = "wasm32", feature = "wasm"),
49    derive(Serialize, Deserialize)
50)]
51pub struct TrendFollowerParams {
52    pub matype: Option<String>,
53    pub trend_period: Option<usize>,
54    pub ma_period: Option<usize>,
55    pub channel_rate_percent: Option<f64>,
56    pub use_linear_regression: Option<bool>,
57    pub linear_regression_period: Option<usize>,
58}
59
60impl Default for TrendFollowerParams {
61    fn default() -> Self {
62        Self {
63            matype: Some("ema".to_string()),
64            trend_period: Some(20),
65            ma_period: Some(20),
66            channel_rate_percent: Some(1.0),
67            use_linear_regression: Some(true),
68            linear_regression_period: Some(5),
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub enum TrendFollowerData<'a> {
75    Candles(&'a Candles),
76    Slices {
77        high: &'a [f64],
78        low: &'a [f64],
79        close: &'a [f64],
80        volume: &'a [f64],
81    },
82}
83
84#[derive(Debug, Clone)]
85pub struct TrendFollowerInput<'a> {
86    pub data: TrendFollowerData<'a>,
87    pub params: TrendFollowerParams,
88}
89
90impl<'a> TrendFollowerInput<'a> {
91    #[inline]
92    pub fn from_candles(candles: &'a Candles, params: TrendFollowerParams) -> Self {
93        Self {
94            data: TrendFollowerData::Candles(candles),
95            params,
96        }
97    }
98
99    #[inline]
100    pub fn from_slices(
101        high: &'a [f64],
102        low: &'a [f64],
103        close: &'a [f64],
104        volume: &'a [f64],
105        params: TrendFollowerParams,
106    ) -> Self {
107        Self {
108            data: TrendFollowerData::Slices {
109                high,
110                low,
111                close,
112                volume,
113            },
114            params,
115        }
116    }
117
118    #[inline]
119    pub fn with_default_candles(candles: &'a Candles) -> Self {
120        Self::from_candles(candles, TrendFollowerParams::default())
121    }
122
123    #[inline]
124    pub fn as_slices(&self) -> (&[f64], &[f64], &[f64], &[f64]) {
125        match &self.data {
126            TrendFollowerData::Candles(candles) => (
127                candles.high.as_slice(),
128                candles.low.as_slice(),
129                candles.close.as_slice(),
130                candles.volume.as_slice(),
131            ),
132            TrendFollowerData::Slices {
133                high,
134                low,
135                close,
136                volume,
137            } => (*high, *low, *close, *volume),
138        }
139    }
140
141    #[inline]
142    pub fn get_matype(&self) -> &str {
143        self.params.matype.as_deref().unwrap_or("ema")
144    }
145
146    #[inline]
147    pub fn get_trend_period(&self) -> usize {
148        self.params.trend_period.unwrap_or(20)
149    }
150
151    #[inline]
152    pub fn get_ma_period(&self) -> usize {
153        self.params.ma_period.unwrap_or(20)
154    }
155
156    #[inline]
157    pub fn get_channel_rate_percent(&self) -> f64 {
158        self.params.channel_rate_percent.unwrap_or(1.0)
159    }
160
161    #[inline]
162    pub fn get_use_linear_regression(&self) -> bool {
163        self.params.use_linear_regression.unwrap_or(true)
164    }
165
166    #[inline]
167    pub fn get_linear_regression_period(&self) -> usize {
168        self.params.linear_regression_period.unwrap_or(5)
169    }
170}
171
172#[derive(Copy, Clone, Debug, Eq, PartialEq)]
173enum TrendFollowerMaType {
174    Ema,
175    Sma,
176    Rma,
177    Wma,
178    Vwma,
179}
180
181impl TrendFollowerMaType {
182    #[inline]
183    fn as_str(self) -> &'static str {
184        match self {
185            Self::Ema => "ema",
186            Self::Sma => "sma",
187            Self::Rma => "rma",
188            Self::Wma => "wma",
189            Self::Vwma => "vwma",
190        }
191    }
192}
193
194#[derive(Copy, Clone, Debug)]
195struct TrendFollowerResolvedParams {
196    matype: TrendFollowerMaType,
197    trend_period: usize,
198    ma_period: usize,
199    channel_rate_fraction: f64,
200    use_linear_regression: bool,
201    linear_regression_period: usize,
202}
203
204#[derive(Clone, Debug)]
205enum TrendFollowerBaseMaStream {
206    Ema(EmaState),
207    Sma(SmaState),
208    Rma(RmaState),
209    Wma(WmaState),
210    Vwma(VwmaState),
211}
212
213impl TrendFollowerBaseMaStream {
214    fn new(matype: TrendFollowerMaType, period: usize) -> Self {
215        match matype {
216            TrendFollowerMaType::Ema => Self::Ema(EmaState::new(period)),
217            TrendFollowerMaType::Sma => Self::Sma(SmaState::new(period)),
218            TrendFollowerMaType::Rma => Self::Rma(RmaState::new(period)),
219            TrendFollowerMaType::Wma => Self::Wma(WmaState::new(period)),
220            TrendFollowerMaType::Vwma => Self::Vwma(VwmaState::new(period)),
221        }
222    }
223
224    fn update(&mut self, value: f64, volume: f64) -> Option<f64> {
225        match self {
226            Self::Ema(state) => state.update(value),
227            Self::Sma(state) => state.update(value),
228            Self::Rma(state) => state.update(value),
229            Self::Wma(state) => state.update(value),
230            Self::Vwma(state) => state.update(value, volume),
231        }
232    }
233}
234
235#[derive(Clone, Debug)]
236struct EmaState {
237    period: usize,
238    alpha: f64,
239    beta: f64,
240    value: Option<f64>,
241    valid_count: usize,
242}
243
244impl EmaState {
245    fn new(period: usize) -> Self {
246        Self {
247            period,
248            alpha: 2.0 / (period as f64 + 1.0),
249            beta: 1.0 - 2.0 / (period as f64 + 1.0),
250            value: None,
251            valid_count: 0,
252        }
253    }
254
255    fn update(&mut self, value: f64) -> Option<f64> {
256        if !value.is_finite() {
257            return None;
258        }
259        let next = match self.value {
260            None => {
261                self.valid_count = 1;
262                value
263            }
264            Some(prev) if self.valid_count < self.period => {
265                self.valid_count += 1;
266                let vc = self.valid_count as f64;
267                ((vc - 1.0) * prev + value) / vc
268            }
269            Some(prev) => self.beta.mul_add(prev, self.alpha * value),
270        };
271        self.value = Some(next);
272        Some(next)
273    }
274}
275
276#[derive(Clone, Debug)]
277struct SmaState {
278    period: usize,
279    buffer: Vec<f64>,
280    head: usize,
281    filled: usize,
282    sum: f64,
283}
284
285impl SmaState {
286    fn new(period: usize) -> Self {
287        Self {
288            period,
289            buffer: vec![0.0; period],
290            head: 0,
291            filled: 0,
292            sum: 0.0,
293        }
294    }
295
296    fn update(&mut self, value: f64) -> Option<f64> {
297        if !value.is_finite() {
298            return None;
299        }
300        if self.filled == self.period {
301            self.sum -= self.buffer[self.head];
302        } else {
303            self.filled += 1;
304        }
305        self.buffer[self.head] = value;
306        self.sum += value;
307        self.head = (self.head + 1) % self.period;
308        if self.filled == self.period {
309            Some(self.sum / self.period as f64)
310        } else {
311            None
312        }
313    }
314}
315
316#[derive(Clone, Debug)]
317struct RmaState {
318    period: usize,
319    buffer: Vec<f64>,
320    head: usize,
321    filled: usize,
322    sum: f64,
323    value: Option<f64>,
324}
325
326impl RmaState {
327    fn new(period: usize) -> Self {
328        Self {
329            period,
330            buffer: vec![0.0; period],
331            head: 0,
332            filled: 0,
333            sum: 0.0,
334            value: None,
335        }
336    }
337
338    fn update(&mut self, value: f64) -> Option<f64> {
339        if !value.is_finite() {
340            return None;
341        }
342        if let Some(prev) = self.value {
343            let next = prev + (value - prev) / self.period as f64;
344            self.value = Some(next);
345            return Some(next);
346        }
347        self.buffer[self.head] = value;
348        self.sum += value;
349        self.head = (self.head + 1) % self.period;
350        self.filled += 1;
351        if self.filled == self.period {
352            let next = self.sum / self.period as f64;
353            self.value = Some(next);
354            Some(next)
355        } else {
356            None
357        }
358    }
359}
360
361#[derive(Clone, Debug)]
362struct WmaState {
363    period: usize,
364    buffer: Vec<f64>,
365    head: usize,
366    filled: usize,
367}
368
369impl WmaState {
370    fn new(period: usize) -> Self {
371        Self {
372            period,
373            buffer: vec![0.0; period],
374            head: 0,
375            filled: 0,
376        }
377    }
378
379    fn update(&mut self, value: f64) -> Option<f64> {
380        if !value.is_finite() {
381            return None;
382        }
383        self.buffer[self.head] = value;
384        self.head = (self.head + 1) % self.period;
385        if self.filled < self.period {
386            self.filled += 1;
387        }
388        if self.filled < self.period {
389            return None;
390        }
391        let mut acc = 0.0;
392        let mut weight_sum = 0.0;
393        for i in 0..self.period {
394            let idx = (self.head + i) % self.period;
395            let weight = (i + 1) as f64;
396            acc += self.buffer[idx] * weight;
397            weight_sum += weight;
398        }
399        Some(acc / weight_sum)
400    }
401}
402
403#[derive(Clone, Debug)]
404struct VwmaState {
405    period: usize,
406    prices: Vec<f64>,
407    volumes: Vec<f64>,
408    head: usize,
409    filled: usize,
410    sum_pv: f64,
411    sum_v: f64,
412}
413
414impl VwmaState {
415    fn new(period: usize) -> Self {
416        Self {
417            period,
418            prices: vec![0.0; period],
419            volumes: vec![0.0; period],
420            head: 0,
421            filled: 0,
422            sum_pv: 0.0,
423            sum_v: 0.0,
424        }
425    }
426
427    fn update(&mut self, price: f64, volume: f64) -> Option<f64> {
428        if !(price.is_finite() && volume.is_finite()) {
429            return None;
430        }
431        if self.filled == self.period {
432            self.sum_pv -= self.prices[self.head] * self.volumes[self.head];
433            self.sum_v -= self.volumes[self.head];
434        } else {
435            self.filled += 1;
436        }
437        self.prices[self.head] = price;
438        self.volumes[self.head] = volume;
439        self.sum_pv += price * volume;
440        self.sum_v += volume;
441        self.head = (self.head + 1) % self.period;
442        if self.filled == self.period && self.sum_v != 0.0 {
443            Some(self.sum_pv / self.sum_v)
444        } else {
445            None
446        }
447    }
448}
449
450#[derive(Copy, Clone, Debug)]
451pub struct TrendFollowerBuilder {
452    trend_period: Option<usize>,
453    ma_period: Option<usize>,
454    channel_rate_percent: Option<f64>,
455    use_linear_regression: Option<bool>,
456    linear_regression_period: Option<usize>,
457    kernel: Kernel,
458}
459
460impl Default for TrendFollowerBuilder {
461    fn default() -> Self {
462        Self {
463            trend_period: None,
464            ma_period: None,
465            channel_rate_percent: None,
466            use_linear_regression: None,
467            linear_regression_period: None,
468            kernel: Kernel::Auto,
469        }
470    }
471}
472
473impl TrendFollowerBuilder {
474    #[inline]
475    pub fn new() -> Self {
476        Self::default()
477    }
478
479    #[inline]
480    pub fn trend_period(mut self, value: usize) -> Self {
481        self.trend_period = Some(value);
482        self
483    }
484
485    #[inline]
486    pub fn ma_period(mut self, value: usize) -> Self {
487        self.ma_period = Some(value);
488        self
489    }
490
491    #[inline]
492    pub fn channel_rate_percent(mut self, value: f64) -> Self {
493        self.channel_rate_percent = Some(value);
494        self
495    }
496
497    #[inline]
498    pub fn use_linear_regression(mut self, value: bool) -> Self {
499        self.use_linear_regression = Some(value);
500        self
501    }
502
503    #[inline]
504    pub fn linear_regression_period(mut self, value: usize) -> Self {
505        self.linear_regression_period = Some(value);
506        self
507    }
508
509    #[inline]
510    pub fn kernel(mut self, value: Kernel) -> Self {
511        self.kernel = value;
512        self
513    }
514
515    #[inline]
516    fn params(self, matype: &str) -> TrendFollowerParams {
517        TrendFollowerParams {
518            matype: Some(matype.to_string()),
519            trend_period: self.trend_period,
520            ma_period: self.ma_period,
521            channel_rate_percent: self.channel_rate_percent,
522            use_linear_regression: self.use_linear_regression,
523            linear_regression_period: self.linear_regression_period,
524        }
525    }
526
527    #[inline]
528    pub fn apply(self, candles: &Candles) -> Result<TrendFollowerOutput, TrendFollowerError> {
529        let input = TrendFollowerInput::from_candles(candles, self.params("ema"));
530        trend_follower_with_kernel(&input, self.kernel)
531    }
532
533    #[inline]
534    pub fn apply_with_matype(
535        self,
536        candles: &Candles,
537        matype: &str,
538    ) -> Result<TrendFollowerOutput, TrendFollowerError> {
539        let input = TrendFollowerInput::from_candles(candles, self.params(matype));
540        trend_follower_with_kernel(&input, self.kernel)
541    }
542
543    #[inline]
544    pub fn apply_slices(
545        self,
546        high: &[f64],
547        low: &[f64],
548        close: &[f64],
549        volume: &[f64],
550        matype: &str,
551    ) -> Result<TrendFollowerOutput, TrendFollowerError> {
552        let input = TrendFollowerInput::from_slices(high, low, close, volume, self.params(matype));
553        trend_follower_with_kernel(&input, self.kernel)
554    }
555
556    #[inline]
557    pub fn into_stream(self, matype: &str) -> Result<TrendFollowerStream, TrendFollowerError> {
558        TrendFollowerStream::try_new(self.params(matype))
559    }
560}
561
562#[derive(Debug, Error)]
563pub enum TrendFollowerError {
564    #[error("trend_follower: Empty input data.")]
565    EmptyInputData,
566    #[error(
567        "trend_follower: Data length mismatch: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
568    )]
569    DataLengthMismatch {
570        high_len: usize,
571        low_len: usize,
572        close_len: usize,
573        volume_len: usize,
574    },
575    #[error("trend_follower: All values are invalid.")]
576    AllValuesNaN,
577    #[error("trend_follower: Invalid MA type: {matype}")]
578    InvalidMaType { matype: String },
579    #[error("trend_follower: Invalid trend period: {trend_period}")]
580    InvalidTrendPeriod { trend_period: usize },
581    #[error("trend_follower: Invalid MA period: {ma_period}, data length = {data_len}")]
582    InvalidMaPeriod { ma_period: usize, data_len: usize },
583    #[error(
584        "trend_follower: Invalid linear regression period: {linear_regression_period}, data length = {data_len}"
585    )]
586    InvalidLinearRegressionPeriod {
587        linear_regression_period: usize,
588        data_len: usize,
589    },
590    #[error("trend_follower: Invalid channel rate percent: {channel_rate_percent}")]
591    InvalidChannelRatePercent { channel_rate_percent: f64 },
592    #[error("trend_follower: Moving average computation failed: {0}")]
593    MovingAverageError(String),
594    #[error("trend_follower: Linear regression computation failed: {0}")]
595    LinearRegressionError(String),
596    #[error("trend_follower: Output length mismatch: expected = {expected}, got = {got}")]
597    OutputLengthMismatch { expected: usize, got: usize },
598    #[error("trend_follower: Invalid integer range: start={start}, end={end}, step={step}")]
599    InvalidRangeUsize {
600        start: usize,
601        end: usize,
602        step: usize,
603    },
604    #[error("trend_follower: Invalid float range: start={start}, end={end}, step={step}")]
605    InvalidRangeF64 { start: f64, end: f64, step: f64 },
606    #[error("trend_follower: Invalid kernel for batch path: {0:?}")]
607    InvalidKernelForBatch(Kernel),
608}
609
610#[inline]
611fn parse_matype(matype: &str) -> Result<TrendFollowerMaType, TrendFollowerError> {
612    if matype.eq_ignore_ascii_case("ema") {
613        return Ok(TrendFollowerMaType::Ema);
614    }
615    if matype.eq_ignore_ascii_case("sma") {
616        return Ok(TrendFollowerMaType::Sma);
617    }
618    if matype.eq_ignore_ascii_case("rma") {
619        return Ok(TrendFollowerMaType::Rma);
620    }
621    if matype.eq_ignore_ascii_case("wma") {
622        return Ok(TrendFollowerMaType::Wma);
623    }
624    if matype.eq_ignore_ascii_case("vwma") {
625        return Ok(TrendFollowerMaType::Vwma);
626    }
627    Err(TrendFollowerError::InvalidMaType {
628        matype: matype.to_string(),
629    })
630}
631
632#[inline]
633fn resolve_params(
634    input: &TrendFollowerInput<'_>,
635    data_len: usize,
636) -> Result<TrendFollowerResolvedParams, TrendFollowerError> {
637    let trend_period = input.get_trend_period();
638    if trend_period < 1 {
639        return Err(TrendFollowerError::InvalidTrendPeriod { trend_period });
640    }
641
642    let ma_period = input.get_ma_period();
643    if ma_period == 0 || ma_period > data_len {
644        return Err(TrendFollowerError::InvalidMaPeriod {
645            ma_period,
646            data_len,
647        });
648    }
649
650    let linear_regression_period = input.get_linear_regression_period();
651    if input.get_use_linear_regression()
652        && (linear_regression_period < 2 || linear_regression_period > data_len)
653    {
654        return Err(TrendFollowerError::InvalidLinearRegressionPeriod {
655            linear_regression_period,
656            data_len,
657        });
658    }
659
660    let channel_rate_percent = input.get_channel_rate_percent();
661    if !channel_rate_percent.is_finite() || channel_rate_percent <= 0.0 {
662        return Err(TrendFollowerError::InvalidChannelRatePercent {
663            channel_rate_percent,
664        });
665    }
666
667    Ok(TrendFollowerResolvedParams {
668        matype: parse_matype(input.get_matype())?,
669        trend_period,
670        ma_period,
671        channel_rate_fraction: channel_rate_percent * 0.01,
672        use_linear_regression: input.get_use_linear_regression(),
673        linear_regression_period,
674    })
675}
676
677#[inline]
678fn first_valid_bar(
679    high: &[f64],
680    low: &[f64],
681    close: &[f64],
682    volume: &[f64],
683    needs_volume: bool,
684) -> Option<usize> {
685    (0..high.len()).find(|&i| {
686        high[i].is_finite()
687            && low[i].is_finite()
688            && close[i].is_finite()
689            && (!needs_volume || volume[i].is_finite())
690    })
691}
692
693#[inline]
694fn data_is_clean(
695    high: &[f64],
696    low: &[f64],
697    close: &[f64],
698    volume: &[f64],
699    first: usize,
700    needs_volume: bool,
701) -> bool {
702    for i in first..high.len() {
703        if !(high[i].is_finite() && low[i].is_finite() && close[i].is_finite()) {
704            return false;
705        }
706        if needs_volume && !volume[i].is_finite() {
707            return false;
708        }
709    }
710    true
711}
712
713#[inline]
714fn trend_follower_prepare<'a>(
715    input: &'a TrendFollowerInput<'a>,
716) -> Result<
717    (
718        &'a [f64],
719        &'a [f64],
720        &'a [f64],
721        &'a [f64],
722        TrendFollowerResolvedParams,
723        usize,
724    ),
725    TrendFollowerError,
726> {
727    let (high, low, close, volume) = input.as_slices();
728    if high.is_empty() {
729        return Err(TrendFollowerError::EmptyInputData);
730    }
731    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
732        return Err(TrendFollowerError::DataLengthMismatch {
733            high_len: high.len(),
734            low_len: low.len(),
735            close_len: close.len(),
736            volume_len: volume.len(),
737        });
738    }
739    let params = resolve_params(input, high.len())?;
740    let first = first_valid_bar(
741        high,
742        low,
743        close,
744        volume,
745        params.matype == TrendFollowerMaType::Vwma,
746    )
747    .ok_or(TrendFollowerError::AllValuesNaN)?;
748    Ok((high, low, close, volume, params, first))
749}
750
751#[inline]
752fn compute_ma_series(
753    close: &[f64],
754    volume: &[f64],
755    params: TrendFollowerResolvedParams,
756    kernel: Kernel,
757) -> Result<Vec<f64>, TrendFollowerError> {
758    match params.matype {
759        TrendFollowerMaType::Ema => ema_with_kernel(
760            &EmaInput::from_slice(
761                close,
762                EmaParams {
763                    period: Some(params.ma_period),
764                },
765            ),
766            kernel,
767        )
768        .map(|out| out.values)
769        .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
770        TrendFollowerMaType::Sma => sma_with_kernel(
771            &SmaInput::from_slice(
772                close,
773                SmaParams {
774                    period: Some(params.ma_period),
775                },
776            ),
777            kernel,
778        )
779        .map(|out| out.values)
780        .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
781        TrendFollowerMaType::Rma => wilders_with_kernel(
782            &WildersInput::from_slice(
783                close,
784                WildersParams {
785                    period: Some(params.ma_period),
786                },
787            ),
788            kernel,
789        )
790        .map(|out| out.values)
791        .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
792        TrendFollowerMaType::Wma => wma_with_kernel(
793            &WmaInput::from_slice(
794                close,
795                WmaParams {
796                    period: Some(params.ma_period),
797                },
798            ),
799            kernel,
800        )
801        .map(|out| out.values)
802        .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
803        TrendFollowerMaType::Vwma => vwma_with_kernel(
804            &VwmaInput::from_slice(
805                close,
806                volume,
807                VwmaParams {
808                    period: Some(params.ma_period),
809                },
810            ),
811            kernel,
812        )
813        .map(|out| out.values)
814        .map_err(|e| TrendFollowerError::MovingAverageError(e.to_string())),
815    }
816}
817
818#[inline]
819fn push_max(queue: &mut VecDeque<(usize, f64)>, idx: usize, value: f64, window: usize) {
820    let min_idx = idx.saturating_add(1).saturating_sub(window);
821    while let Some(&(old_idx, _)) = queue.front() {
822        if old_idx < min_idx {
823            queue.pop_front();
824        } else {
825            break;
826        }
827    }
828    while let Some(&(_, old_value)) = queue.back() {
829        if old_value <= value {
830            queue.pop_back();
831        } else {
832            break;
833        }
834    }
835    queue.push_back((idx, value));
836}
837
838#[inline]
839fn push_min(queue: &mut VecDeque<(usize, f64)>, idx: usize, value: f64, window: usize) {
840    let min_idx = idx.saturating_add(1).saturating_sub(window);
841    while let Some(&(old_idx, _)) = queue.front() {
842        if old_idx < min_idx {
843            queue.pop_front();
844        } else {
845            break;
846        }
847    }
848    while let Some(&(_, old_value)) = queue.back() {
849        if old_value >= value {
850            queue.pop_back();
851        } else {
852            break;
853        }
854    }
855    queue.push_back((idx, value));
856}
857
858#[inline]
859fn evict_front(queue: &mut VecDeque<(usize, f64)>, idx: usize, window: usize) {
860    let min_idx = idx.saturating_add(1).saturating_sub(window);
861    while let Some(&(old_idx, _)) = queue.front() {
862        if old_idx < min_idx {
863            queue.pop_front();
864        } else {
865            break;
866        }
867    }
868}
869
870fn trend_follower_compute_clean_into(
871    high: &[f64],
872    low: &[f64],
873    close: &[f64],
874    volume: &[f64],
875    params: TrendFollowerResolvedParams,
876    first: usize,
877    kernel: Kernel,
878    out: &mut [f64],
879) -> Result<(), TrendFollowerError> {
880    let base_ma = compute_ma_series(close, volume, params, kernel)?;
881    let trend_ma = if params.use_linear_regression {
882        linreg_with_kernel(
883            &LinRegInput::from_slice(
884                &base_ma,
885                LinRegParams {
886                    period: Some(params.linear_regression_period),
887                },
888            ),
889            kernel,
890        )
891        .map(|series| series.values)
892        .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?
893    } else {
894        base_ma
895    };
896
897    let mut high_max = VecDeque::with_capacity(CHANNEL_WINDOW);
898    let mut low_min = VecDeque::with_capacity(CHANNEL_WINDOW);
899    let mut ma_max = VecDeque::with_capacity(params.trend_period.max(1));
900    let mut ma_min = VecDeque::with_capacity(params.trend_period.max(1));
901
902    for i in first..high.len() {
903        evict_front(&mut high_max, i, CHANNEL_WINDOW);
904        evict_front(&mut low_min, i, CHANNEL_WINDOW);
905        evict_front(&mut ma_max, i, params.trend_period);
906        evict_front(&mut ma_min, i, params.trend_period);
907
908        push_max(&mut high_max, i, high[i], CHANNEL_WINDOW);
909        push_min(&mut low_min, i, low[i], CHANNEL_WINDOW);
910
911        let ma = trend_ma[i];
912        if ma.is_finite() {
913            push_max(&mut ma_max, i, ma, params.trend_period);
914            push_min(&mut ma_min, i, ma, params.trend_period);
915        }
916
917        let (hh, ll) = match (ma_max.front(), ma_min.front()) {
918            (Some((_, hh)), Some((_, ll))) => (*hh, *ll),
919            _ => continue,
920        };
921        let (channel_high, channel_low) = match (high_max.front(), low_min.front()) {
922            (Some((_, hi)), Some((_, lo))) => (*hi, *lo),
923            _ => continue,
924        };
925        let chan = (channel_high - channel_low) * params.channel_rate_fraction;
926        if !ma.is_finite() || !chan.is_finite() || chan == 0.0 {
927            out[i] = f64::NAN;
928            continue;
929        }
930
931        let diff = (hh - ll).abs();
932        let trend = if diff > chan {
933            if ma > ll + chan {
934                1.0
935            } else if ma < hh - chan {
936                -1.0
937            } else {
938                0.0
939            }
940        } else {
941            0.0
942        };
943        out[i] = trend * diff / chan;
944    }
945
946    Ok(())
947}
948
949fn trend_follower_compute_fallback_into(
950    high: &[f64],
951    low: &[f64],
952    close: &[f64],
953    volume: &[f64],
954    input: &TrendFollowerInput<'_>,
955    out: &mut [f64],
956) -> Result<(), TrendFollowerError> {
957    let mut stream = TrendFollowerStream::try_new(input.params.clone())?;
958    for i in 0..high.len() {
959        out[i] = stream
960            .update_reset_on_nan(high[i], low[i], close[i], volume[i])
961            .unwrap_or(f64::NAN);
962    }
963    Ok(())
964}
965
966fn trend_follower_compute_into(
967    high: &[f64],
968    low: &[f64],
969    close: &[f64],
970    volume: &[f64],
971    input: &TrendFollowerInput<'_>,
972    params: TrendFollowerResolvedParams,
973    first: usize,
974    kernel: Kernel,
975    out: &mut [f64],
976) -> Result<(), TrendFollowerError> {
977    if data_is_clean(
978        high,
979        low,
980        close,
981        volume,
982        first,
983        params.matype == TrendFollowerMaType::Vwma,
984    ) {
985        trend_follower_compute_clean_into(high, low, close, volume, params, first, kernel, out)
986    } else {
987        trend_follower_compute_fallback_into(high, low, close, volume, input, out)
988    }
989}
990
991#[inline]
992pub fn trend_follower(
993    input: &TrendFollowerInput<'_>,
994) -> Result<TrendFollowerOutput, TrendFollowerError> {
995    trend_follower_with_kernel(input, Kernel::Auto)
996}
997
998pub fn trend_follower_with_kernel(
999    input: &TrendFollowerInput<'_>,
1000    kernel: Kernel,
1001) -> Result<TrendFollowerOutput, TrendFollowerError> {
1002    let (high, low, close, volume, params, first) = trend_follower_prepare(input)?;
1003    let mut out = alloc_with_nan_prefix(close.len(), close.len());
1004    trend_follower_compute_into(
1005        high, low, close, volume, input, params, first, kernel, &mut out,
1006    )?;
1007    Ok(TrendFollowerOutput { values: out })
1008}
1009
1010#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1011#[inline]
1012pub fn trend_follower_into(
1013    input: &TrendFollowerInput<'_>,
1014    out: &mut [f64],
1015) -> Result<(), TrendFollowerError> {
1016    trend_follower_into_slice(out, input, Kernel::Auto)
1017}
1018
1019pub fn trend_follower_into_slice(
1020    out: &mut [f64],
1021    input: &TrendFollowerInput<'_>,
1022    kernel: Kernel,
1023) -> Result<(), TrendFollowerError> {
1024    let (high, low, close, volume, params, first) = trend_follower_prepare(input)?;
1025    if out.len() != close.len() {
1026        return Err(TrendFollowerError::OutputLengthMismatch {
1027            expected: close.len(),
1028            got: out.len(),
1029        });
1030    }
1031    out.fill(f64::NAN);
1032    trend_follower_compute_into(high, low, close, volume, input, params, first, kernel, out)
1033}
1034
1035#[derive(Clone, Debug)]
1036pub struct TrendFollowerStream {
1037    matype: TrendFollowerMaType,
1038    trend_period: usize,
1039    ma_period: usize,
1040    channel_rate_fraction: f64,
1041    use_linear_regression: bool,
1042    linear_regression_period: usize,
1043    ma_stream: TrendFollowerBaseMaStream,
1044    linreg_stream: Option<LinRegStream>,
1045    index: usize,
1046    high_max: VecDeque<(usize, f64)>,
1047    low_min: VecDeque<(usize, f64)>,
1048    ma_max: VecDeque<(usize, f64)>,
1049    ma_min: VecDeque<(usize, f64)>,
1050}
1051
1052impl TrendFollowerStream {
1053    #[inline]
1054    pub fn try_new(params: TrendFollowerParams) -> Result<Self, TrendFollowerError> {
1055        let input = TrendFollowerInput::from_slices(&[1.0], &[1.0], &[1.0], &[1.0], params);
1056        let resolved = resolve_params(&input, usize::MAX)?;
1057        let ma_stream = TrendFollowerBaseMaStream::new(resolved.matype, resolved.ma_period);
1058        let linreg_stream = if resolved.use_linear_regression {
1059            Some(
1060                LinRegStream::try_new(LinRegParams {
1061                    period: Some(resolved.linear_regression_period),
1062                })
1063                .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?,
1064            )
1065        } else {
1066            None
1067        };
1068        Ok(Self {
1069            matype: resolved.matype,
1070            trend_period: resolved.trend_period,
1071            ma_period: resolved.ma_period,
1072            channel_rate_fraction: resolved.channel_rate_fraction,
1073            use_linear_regression: resolved.use_linear_regression,
1074            linear_regression_period: resolved.linear_regression_period,
1075            ma_stream,
1076            linreg_stream,
1077            index: 0,
1078            high_max: VecDeque::with_capacity(CHANNEL_WINDOW),
1079            low_min: VecDeque::with_capacity(CHANNEL_WINDOW),
1080            ma_max: VecDeque::with_capacity(resolved.trend_period.max(1)),
1081            ma_min: VecDeque::with_capacity(resolved.trend_period.max(1)),
1082        })
1083    }
1084
1085    #[inline]
1086    pub fn reset(&mut self) -> Result<(), TrendFollowerError> {
1087        self.ma_stream = TrendFollowerBaseMaStream::new(self.matype, self.ma_period);
1088        self.linreg_stream = if self.use_linear_regression {
1089            Some(
1090                LinRegStream::try_new(LinRegParams {
1091                    period: Some(self.linear_regression_period),
1092                })
1093                .map_err(|e| TrendFollowerError::LinearRegressionError(e.to_string()))?,
1094            )
1095        } else {
1096            None
1097        };
1098        self.index = 0;
1099        self.high_max.clear();
1100        self.low_min.clear();
1101        self.ma_max.clear();
1102        self.ma_min.clear();
1103        Ok(())
1104    }
1105
1106    #[inline]
1107    pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
1108        let needs_volume = self.matype == TrendFollowerMaType::Vwma;
1109        if !(high.is_finite() && low.is_finite() && close.is_finite())
1110            || (needs_volume && !volume.is_finite())
1111        {
1112            return None;
1113        }
1114
1115        let idx = self.index;
1116        evict_front(&mut self.high_max, idx, CHANNEL_WINDOW);
1117        evict_front(&mut self.low_min, idx, CHANNEL_WINDOW);
1118        evict_front(&mut self.ma_max, idx, self.trend_period);
1119        evict_front(&mut self.ma_min, idx, self.trend_period);
1120
1121        push_max(&mut self.high_max, idx, high, CHANNEL_WINDOW);
1122        push_min(&mut self.low_min, idx, low, CHANNEL_WINDOW);
1123
1124        let base_ma = self.ma_stream.update(close, volume);
1125        let ma = if self.use_linear_regression {
1126            match (base_ma, self.linreg_stream.as_mut()) {
1127                (Some(value), Some(stream)) => stream.update(value),
1128                _ => None,
1129            }
1130        } else {
1131            base_ma
1132        };
1133
1134        self.index = idx + 1;
1135
1136        let Some(ma) = ma else {
1137            return None;
1138        };
1139        if ma.is_finite() {
1140            push_max(&mut self.ma_max, idx, ma, self.trend_period);
1141            push_min(&mut self.ma_min, idx, ma, self.trend_period);
1142        } else {
1143            return Some(f64::NAN);
1144        }
1145
1146        let (hh, ll) = match (self.ma_max.front(), self.ma_min.front()) {
1147            (Some((_, hh)), Some((_, ll))) => (*hh, *ll),
1148            _ => return None,
1149        };
1150        let (channel_high, channel_low) = match (self.high_max.front(), self.low_min.front()) {
1151            (Some((_, hi)), Some((_, lo))) => (*hi, *lo),
1152            _ => return None,
1153        };
1154        let chan = (channel_high - channel_low) * self.channel_rate_fraction;
1155        if !chan.is_finite() || chan == 0.0 {
1156            return Some(f64::NAN);
1157        }
1158
1159        let diff = (hh - ll).abs();
1160        let trend = if diff > chan {
1161            if ma > ll + chan {
1162                1.0
1163            } else if ma < hh - chan {
1164                -1.0
1165            } else {
1166                0.0
1167            }
1168        } else {
1169            0.0
1170        };
1171        Some(trend * diff / chan)
1172    }
1173
1174    #[inline]
1175    pub fn update_reset_on_nan(
1176        &mut self,
1177        high: f64,
1178        low: f64,
1179        close: f64,
1180        volume: f64,
1181    ) -> Option<f64> {
1182        let needs_volume = self.matype == TrendFollowerMaType::Vwma;
1183        if !(high.is_finite() && low.is_finite() && close.is_finite())
1184            || (needs_volume && !volume.is_finite())
1185        {
1186            let _ = self.reset();
1187            return None;
1188        }
1189        self.update(high, low, close, volume)
1190    }
1191}
1192
1193#[derive(Clone, Debug)]
1194pub struct TrendFollowerBatchRange {
1195    pub trend_period: (usize, usize, usize),
1196    pub ma_period: (usize, usize, usize),
1197    pub channel_rate_percent: (f64, f64, f64),
1198    pub linear_regression_period: (usize, usize, usize),
1199    pub matype: (String, String, String),
1200    pub use_linear_regression: bool,
1201}
1202
1203impl Default for TrendFollowerBatchRange {
1204    fn default() -> Self {
1205        Self {
1206            trend_period: (20, 20, 0),
1207            ma_period: (20, 20, 0),
1208            channel_rate_percent: (1.0, 1.0, 0.0),
1209            linear_regression_period: (5, 5, 0),
1210            matype: ("ema".to_string(), "ema".to_string(), String::new()),
1211            use_linear_regression: true,
1212        }
1213    }
1214}
1215
1216#[derive(Clone, Debug, Default)]
1217pub struct TrendFollowerBatchBuilder {
1218    range: TrendFollowerBatchRange,
1219    kernel: Kernel,
1220}
1221
1222impl TrendFollowerBatchBuilder {
1223    pub fn new() -> Self {
1224        Self::default()
1225    }
1226
1227    pub fn kernel(mut self, kernel: Kernel) -> Self {
1228        self.kernel = kernel;
1229        self
1230    }
1231
1232    pub fn trend_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1233        self.range.trend_period = (start, end, step);
1234        self
1235    }
1236
1237    pub fn ma_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1238        self.range.ma_period = (start, end, step);
1239        self
1240    }
1241
1242    pub fn channel_rate_percent_range(mut self, start: f64, end: f64, step: f64) -> Self {
1243        self.range.channel_rate_percent = (start, end, step);
1244        self
1245    }
1246
1247    pub fn linear_regression_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1248        self.range.linear_regression_period = (start, end, step);
1249        self
1250    }
1251
1252    pub fn matype_static<S: Into<String>>(mut self, value: S) -> Self {
1253        let value = value.into();
1254        self.range.matype = (value.clone(), value, String::new());
1255        self
1256    }
1257
1258    pub fn use_linear_regression(mut self, value: bool) -> Self {
1259        self.range.use_linear_regression = value;
1260        self
1261    }
1262
1263    pub fn apply_slices(
1264        self,
1265        high: &[f64],
1266        low: &[f64],
1267        close: &[f64],
1268        volume: &[f64],
1269    ) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1270        trend_follower_batch_with_kernel(high, low, close, volume, &self.range, self.kernel)
1271    }
1272
1273    pub fn apply_candles(
1274        self,
1275        candles: &Candles,
1276    ) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1277        self.apply_slices(&candles.high, &candles.low, &candles.close, &candles.volume)
1278    }
1279}
1280
1281#[derive(Clone, Debug)]
1282pub struct TrendFollowerBatchOutput {
1283    pub values: Vec<f64>,
1284    pub combos: Vec<TrendFollowerParams>,
1285    pub rows: usize,
1286    pub cols: usize,
1287}
1288
1289impl TrendFollowerBatchOutput {
1290    pub fn row_for_params(&self, params: &TrendFollowerParams) -> Option<usize> {
1291        let matype = params
1292            .matype
1293            .as_deref()
1294            .unwrap_or("ema")
1295            .to_ascii_lowercase();
1296        self.combos.iter().position(|combo| {
1297            combo.trend_period.unwrap_or(20) == params.trend_period.unwrap_or(20)
1298                && combo.ma_period.unwrap_or(20) == params.ma_period.unwrap_or(20)
1299                && (combo.channel_rate_percent.unwrap_or(1.0)
1300                    - params.channel_rate_percent.unwrap_or(1.0))
1301                .abs()
1302                    <= 1e-12
1303                && combo.use_linear_regression.unwrap_or(true)
1304                    == params.use_linear_regression.unwrap_or(true)
1305                && combo.linear_regression_period.unwrap_or(5)
1306                    == params.linear_regression_period.unwrap_or(5)
1307                && combo
1308                    .matype
1309                    .as_deref()
1310                    .unwrap_or("ema")
1311                    .eq_ignore_ascii_case(&matype)
1312        })
1313    }
1314
1315    pub fn values_for(&self, params: &TrendFollowerParams) -> Option<&[f64]> {
1316        self.row_for_params(params).map(|row| {
1317            let start = row * self.cols;
1318            &self.values[start..start + self.cols]
1319        })
1320    }
1321}
1322
1323#[inline]
1324fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, TrendFollowerError> {
1325    let (start, end, step) = range;
1326    if start == 0 || end == 0 {
1327        return Err(TrendFollowerError::InvalidRangeUsize { start, end, step });
1328    }
1329    if step == 0 || start == end {
1330        return Ok(vec![start]);
1331    }
1332    let mut out = Vec::new();
1333    if start < end {
1334        let mut value = start;
1335        while value <= end {
1336            out.push(value);
1337            match value.checked_add(step) {
1338                Some(next) if next > value => value = next,
1339                _ => break,
1340            }
1341        }
1342    } else {
1343        let mut value = start;
1344        while value >= end {
1345            out.push(value);
1346            if value < end + step {
1347                break;
1348            }
1349            value = value.saturating_sub(step);
1350            if value == 0 {
1351                break;
1352            }
1353        }
1354    }
1355    if out.is_empty() {
1356        return Err(TrendFollowerError::InvalidRangeUsize { start, end, step });
1357    }
1358    Ok(out)
1359}
1360
1361#[inline]
1362fn axis_f64(range: (f64, f64, f64)) -> Result<Vec<f64>, TrendFollowerError> {
1363    let (start, end, step) = range;
1364    if !start.is_finite() || !end.is_finite() || !step.is_finite() {
1365        return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1366    }
1367    if step == 0.0 || (start - end).abs() <= 1e-12 {
1368        return Ok(vec![start]);
1369    }
1370    if step < 0.0 {
1371        return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1372    }
1373    let mut out = Vec::new();
1374    if start < end {
1375        let mut value = start;
1376        while value <= end + 1e-12 {
1377            out.push(value);
1378            value += step;
1379        }
1380    } else {
1381        let mut value = start;
1382        while value >= end - 1e-12 {
1383            out.push(value);
1384            value -= step;
1385        }
1386    }
1387    if out.is_empty() {
1388        return Err(TrendFollowerError::InvalidRangeF64 { start, end, step });
1389    }
1390    Ok(out)
1391}
1392
1393#[inline]
1394fn axis_string(range: (String, String, String)) -> Vec<String> {
1395    if range.0.eq_ignore_ascii_case(&range.1) {
1396        vec![range.0]
1397    } else {
1398        vec![range.0, range.1]
1399    }
1400}
1401
1402pub fn expand_grid_trend_follower(
1403    range: &TrendFollowerBatchRange,
1404) -> Result<Vec<TrendFollowerParams>, TrendFollowerError> {
1405    let trend_periods = axis_usize(range.trend_period)?;
1406    let ma_periods = axis_usize(range.ma_period)?;
1407    let channel_rates = axis_f64(range.channel_rate_percent)?;
1408    let linear_regression_periods = axis_usize(range.linear_regression_period)?;
1409    let matypes = axis_string(range.matype.clone());
1410
1411    let mut out = Vec::new();
1412    for trend_period in &trend_periods {
1413        for ma_period in &ma_periods {
1414            for channel_rate_percent in &channel_rates {
1415                for linear_regression_period in &linear_regression_periods {
1416                    for matype in &matypes {
1417                        out.push(TrendFollowerParams {
1418                            matype: Some(matype.to_ascii_lowercase()),
1419                            trend_period: Some(*trend_period),
1420                            ma_period: Some(*ma_period),
1421                            channel_rate_percent: Some(*channel_rate_percent),
1422                            use_linear_regression: Some(range.use_linear_regression),
1423                            linear_regression_period: Some(*linear_regression_period),
1424                        });
1425                    }
1426                }
1427            }
1428        }
1429    }
1430    Ok(out)
1431}
1432
1433pub fn trend_follower_batch_with_kernel(
1434    high: &[f64],
1435    low: &[f64],
1436    close: &[f64],
1437    volume: &[f64],
1438    range: &TrendFollowerBatchRange,
1439    kernel: Kernel,
1440) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1441    let batch_kernel = match kernel {
1442        Kernel::Auto => Kernel::ScalarBatch,
1443        other if other.is_batch() => other,
1444        other => return Err(TrendFollowerError::InvalidKernelForBatch(other)),
1445    };
1446    trend_follower_batch_impl(
1447        high,
1448        low,
1449        close,
1450        volume,
1451        range,
1452        batch_kernel.to_non_batch(),
1453        true,
1454    )
1455}
1456
1457pub fn trend_follower_batch_slice(
1458    high: &[f64],
1459    low: &[f64],
1460    close: &[f64],
1461    volume: &[f64],
1462    range: &TrendFollowerBatchRange,
1463) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1464    trend_follower_batch_impl(high, low, close, volume, range, Kernel::Scalar, false)
1465}
1466
1467pub fn trend_follower_batch_par_slice(
1468    high: &[f64],
1469    low: &[f64],
1470    close: &[f64],
1471    volume: &[f64],
1472    range: &TrendFollowerBatchRange,
1473) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1474    trend_follower_batch_impl(high, low, close, volume, range, Kernel::Scalar, true)
1475}
1476
1477fn trend_follower_batch_impl(
1478    high: &[f64],
1479    low: &[f64],
1480    close: &[f64],
1481    volume: &[f64],
1482    range: &TrendFollowerBatchRange,
1483    kernel: Kernel,
1484    parallel: bool,
1485) -> Result<TrendFollowerBatchOutput, TrendFollowerError> {
1486    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1487        return Err(TrendFollowerError::DataLengthMismatch {
1488            high_len: high.len(),
1489            low_len: low.len(),
1490            close_len: close.len(),
1491            volume_len: volume.len(),
1492        });
1493    }
1494    if high.is_empty() {
1495        return Err(TrendFollowerError::EmptyInputData);
1496    }
1497
1498    let combos = expand_grid_trend_follower(range)?;
1499    let rows = combos.len();
1500    let cols = close.len();
1501    let mut matrix = make_uninit_matrix(rows, cols);
1502    init_matrix_prefixes(&mut matrix, cols, &vec![cols; rows]);
1503
1504    let mut guard = ManuallyDrop::new(matrix);
1505    let out_mu: &mut [MaybeUninit<f64>] =
1506        unsafe { std::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
1507
1508    let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| {
1509        let out = unsafe {
1510            std::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
1511        };
1512        let input = TrendFollowerInput::from_slices(high, low, close, volume, combos[row].clone());
1513        let _ = trend_follower_into_slice(out, &input, kernel);
1514    };
1515
1516    if parallel {
1517        #[cfg(not(target_arch = "wasm32"))]
1518        out_mu
1519            .par_chunks_mut(cols)
1520            .enumerate()
1521            .for_each(|(row, row_mu)| do_row(row, row_mu));
1522        #[cfg(target_arch = "wasm32")]
1523        for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
1524            do_row(row, row_mu);
1525        }
1526    } else {
1527        for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
1528            do_row(row, row_mu);
1529        }
1530    }
1531
1532    let values = unsafe {
1533        Vec::from_raw_parts(
1534            guard.as_mut_ptr() as *mut f64,
1535            guard.len(),
1536            guard.capacity(),
1537        )
1538    };
1539
1540    Ok(TrendFollowerBatchOutput {
1541        values,
1542        combos,
1543        rows,
1544        cols,
1545    })
1546}
1547
1548fn trend_follower_batch_inner_into(
1549    high: &[f64],
1550    low: &[f64],
1551    close: &[f64],
1552    volume: &[f64],
1553    range: &TrendFollowerBatchRange,
1554    kernel: Kernel,
1555    parallel: bool,
1556    out: &mut [f64],
1557) -> Result<(), TrendFollowerError> {
1558    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1559        return Err(TrendFollowerError::DataLengthMismatch {
1560            high_len: high.len(),
1561            low_len: low.len(),
1562            close_len: close.len(),
1563            volume_len: volume.len(),
1564        });
1565    }
1566    let combos = expand_grid_trend_follower(range)?;
1567    let rows = combos.len();
1568    let cols = close.len();
1569    if rows.checked_mul(cols) != Some(out.len()) {
1570        return Err(TrendFollowerError::OutputLengthMismatch {
1571            expected: rows * cols,
1572            got: out.len(),
1573        });
1574    }
1575
1576    for row_out in out.chunks_mut(cols) {
1577        row_out.fill(f64::NAN);
1578    }
1579
1580    let do_row = |row: usize, row_out: &mut [f64]| {
1581        let input = TrendFollowerInput::from_slices(high, low, close, volume, combos[row].clone());
1582        let _ = trend_follower_into_slice(row_out, &input, kernel);
1583    };
1584
1585    if parallel {
1586        #[cfg(not(target_arch = "wasm32"))]
1587        out.par_chunks_mut(cols)
1588            .enumerate()
1589            .for_each(|(row, row_out)| do_row(row, row_out));
1590        #[cfg(target_arch = "wasm32")]
1591        for (row, row_out) in out.chunks_mut(cols).enumerate() {
1592            do_row(row, row_out);
1593        }
1594    } else {
1595        for (row, row_out) in out.chunks_mut(cols).enumerate() {
1596            do_row(row, row_out);
1597        }
1598    }
1599
1600    Ok(())
1601}
1602
1603#[cfg(feature = "python")]
1604#[pyfunction(name = "trend_follower")]
1605#[pyo3(signature = (high, low, close, volume, matype="ema", trend_period=20, ma_period=20, channel_rate_percent=1.0, use_linear_regression=true, linear_regression_period=5, kernel=None))]
1606pub fn trend_follower_py<'py>(
1607    py: Python<'py>,
1608    high: PyReadonlyArray1<'py, f64>,
1609    low: PyReadonlyArray1<'py, f64>,
1610    close: PyReadonlyArray1<'py, f64>,
1611    volume: PyReadonlyArray1<'py, f64>,
1612    matype: &str,
1613    trend_period: usize,
1614    ma_period: usize,
1615    channel_rate_percent: f64,
1616    use_linear_regression: bool,
1617    linear_regression_period: usize,
1618    kernel: Option<&str>,
1619) -> PyResult<Bound<'py, PyArray1<f64>>> {
1620    let high = high.as_slice()?;
1621    let low = low.as_slice()?;
1622    let close = close.as_slice()?;
1623    let volume = volume.as_slice()?;
1624    let kernel = validate_kernel(kernel, false)?;
1625    let input = TrendFollowerInput::from_slices(
1626        high,
1627        low,
1628        close,
1629        volume,
1630        TrendFollowerParams {
1631            matype: Some(matype.to_string()),
1632            trend_period: Some(trend_period),
1633            ma_period: Some(ma_period),
1634            channel_rate_percent: Some(channel_rate_percent),
1635            use_linear_regression: Some(use_linear_regression),
1636            linear_regression_period: Some(linear_regression_period),
1637        },
1638    );
1639    let output = py
1640        .allow_threads(|| trend_follower_with_kernel(&input, kernel))
1641        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1642    Ok(output.values.into_pyarray(py))
1643}
1644
1645#[cfg(feature = "python")]
1646#[pyclass(name = "TrendFollowerStream")]
1647pub struct TrendFollowerStreamPy {
1648    stream: TrendFollowerStream,
1649}
1650
1651#[cfg(feature = "python")]
1652#[pymethods]
1653impl TrendFollowerStreamPy {
1654    #[new]
1655    #[pyo3(signature = (matype="ema", trend_period=20, ma_period=20, channel_rate_percent=1.0, use_linear_regression=true, linear_regression_period=5))]
1656    fn new(
1657        matype: &str,
1658        trend_period: usize,
1659        ma_period: usize,
1660        channel_rate_percent: f64,
1661        use_linear_regression: bool,
1662        linear_regression_period: usize,
1663    ) -> PyResult<Self> {
1664        let stream = TrendFollowerStream::try_new(TrendFollowerParams {
1665            matype: Some(matype.to_string()),
1666            trend_period: Some(trend_period),
1667            ma_period: Some(ma_period),
1668            channel_rate_percent: Some(channel_rate_percent),
1669            use_linear_regression: Some(use_linear_regression),
1670            linear_regression_period: Some(linear_regression_period),
1671        })
1672        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1673        Ok(Self { stream })
1674    }
1675
1676    fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
1677        self.stream.update_reset_on_nan(high, low, close, volume)
1678    }
1679}
1680
1681#[cfg(feature = "python")]
1682#[pyfunction(name = "trend_follower_batch")]
1683#[pyo3(signature = (high, low, close, volume, trend_period_range=(20, 20, 0), ma_period_range=(20, 20, 0), channel_rate_percent_range=(1.0, 1.0, 0.0), linear_regression_period_range=(5, 5, 0), matype="ema", use_linear_regression=true, kernel=None))]
1684pub fn trend_follower_batch_py<'py>(
1685    py: Python<'py>,
1686    high: PyReadonlyArray1<'py, f64>,
1687    low: PyReadonlyArray1<'py, f64>,
1688    close: PyReadonlyArray1<'py, f64>,
1689    volume: PyReadonlyArray1<'py, f64>,
1690    trend_period_range: (usize, usize, usize),
1691    ma_period_range: (usize, usize, usize),
1692    channel_rate_percent_range: (f64, f64, f64),
1693    linear_regression_period_range: (usize, usize, usize),
1694    matype: &str,
1695    use_linear_regression: bool,
1696    kernel: Option<&str>,
1697) -> PyResult<Bound<'py, PyDict>> {
1698    let high = high.as_slice()?;
1699    let low = low.as_slice()?;
1700    let close = close.as_slice()?;
1701    let volume = volume.as_slice()?;
1702    let range = TrendFollowerBatchRange {
1703        trend_period: trend_period_range,
1704        ma_period: ma_period_range,
1705        channel_rate_percent: channel_rate_percent_range,
1706        linear_regression_period: linear_regression_period_range,
1707        matype: (matype.to_string(), matype.to_string(), String::new()),
1708        use_linear_regression,
1709    };
1710    let combos =
1711        expand_grid_trend_follower(&range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1712    let rows = combos.len();
1713    let cols = close.len();
1714    let total = rows
1715        .checked_mul(cols)
1716        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1717    let arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1718    let out = unsafe { arr.as_slice_mut()? };
1719    let kernel = validate_kernel(kernel, true)?;
1720
1721    py.allow_threads(|| {
1722        let batch_kernel = match kernel {
1723            Kernel::Auto => detect_best_batch_kernel(),
1724            other => other,
1725        };
1726        trend_follower_batch_inner_into(
1727            high,
1728            low,
1729            close,
1730            volume,
1731            &range,
1732            batch_kernel.to_non_batch(),
1733            true,
1734            out,
1735        )
1736    })
1737    .map_err(|e| PyValueError::new_err(e.to_string()))?;
1738
1739    let dict = PyDict::new(py);
1740    dict.set_item("values", arr.reshape((rows, cols))?)?;
1741    dict.set_item(
1742        "trend_periods",
1743        combos
1744            .iter()
1745            .map(|params| params.trend_period.unwrap_or(20) as u64)
1746            .collect::<Vec<_>>()
1747            .into_pyarray(py),
1748    )?;
1749    dict.set_item(
1750        "ma_periods",
1751        combos
1752            .iter()
1753            .map(|params| params.ma_period.unwrap_or(20) as u64)
1754            .collect::<Vec<_>>()
1755            .into_pyarray(py),
1756    )?;
1757    dict.set_item(
1758        "channel_rate_percents",
1759        combos
1760            .iter()
1761            .map(|params| params.channel_rate_percent.unwrap_or(1.0))
1762            .collect::<Vec<_>>()
1763            .into_pyarray(py),
1764    )?;
1765    dict.set_item(
1766        "linear_regression_periods",
1767        combos
1768            .iter()
1769            .map(|params| params.linear_regression_period.unwrap_or(5) as u64)
1770            .collect::<Vec<_>>()
1771            .into_pyarray(py),
1772    )?;
1773    dict.set_item(
1774        "matypes",
1775        combos
1776            .iter()
1777            .map(|params| params.matype.as_deref().unwrap_or("ema").to_string())
1778            .collect::<Vec<_>>(),
1779    )?;
1780    dict.set_item(
1781        "use_linear_regression",
1782        combos
1783            .iter()
1784            .map(|params| params.use_linear_regression.unwrap_or(true))
1785            .collect::<Vec<_>>()
1786            .into_pyarray(py),
1787    )?;
1788    dict.set_item("rows", rows)?;
1789    dict.set_item("cols", cols)?;
1790    Ok(dict)
1791}
1792
1793#[cfg(feature = "python")]
1794pub fn register_trend_follower_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1795    m.add_function(wrap_pyfunction!(trend_follower_py, m)?)?;
1796    m.add_function(wrap_pyfunction!(trend_follower_batch_py, m)?)?;
1797    m.add_class::<TrendFollowerStreamPy>()?;
1798    Ok(())
1799}
1800
1801#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1802#[derive(Debug, Clone, Serialize, Deserialize)]
1803struct TrendFollowerBatchConfig {
1804    trend_period_range: Vec<usize>,
1805    ma_period_range: Vec<usize>,
1806    channel_rate_percent_range: Vec<f64>,
1807    linear_regression_period_range: Vec<usize>,
1808    matype: String,
1809    use_linear_regression: bool,
1810}
1811
1812#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1813#[derive(Debug, Clone, Serialize, Deserialize)]
1814struct TrendFollowerBatchJsOutput {
1815    values: Vec<f64>,
1816    rows: usize,
1817    cols: usize,
1818    combos: Vec<TrendFollowerParams>,
1819}
1820
1821#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1822#[wasm_bindgen(js_name = "trend_follower_js")]
1823pub fn trend_follower_js(
1824    high: &[f64],
1825    low: &[f64],
1826    close: &[f64],
1827    volume: &[f64],
1828    matype: &str,
1829    trend_period: usize,
1830    ma_period: usize,
1831    channel_rate_percent: f64,
1832    use_linear_regression: bool,
1833    linear_regression_period: usize,
1834) -> Result<Vec<f64>, JsValue> {
1835    let input = TrendFollowerInput::from_slices(
1836        high,
1837        low,
1838        close,
1839        volume,
1840        TrendFollowerParams {
1841            matype: Some(matype.to_string()),
1842            trend_period: Some(trend_period),
1843            ma_period: Some(ma_period),
1844            channel_rate_percent: Some(channel_rate_percent),
1845            use_linear_regression: Some(use_linear_regression),
1846            linear_regression_period: Some(linear_regression_period),
1847        },
1848    );
1849    trend_follower(&input)
1850        .map(|out| out.values)
1851        .map_err(|e| JsValue::from_str(&e.to_string()))
1852}
1853
1854#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1855#[wasm_bindgen(js_name = "trend_follower_batch_js")]
1856pub fn trend_follower_batch_js(
1857    high: &[f64],
1858    low: &[f64],
1859    close: &[f64],
1860    volume: &[f64],
1861    config: JsValue,
1862) -> Result<JsValue, JsValue> {
1863    let config: TrendFollowerBatchConfig = serde_wasm_bindgen::from_value(config)
1864        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1865    if config.trend_period_range.len() != 3
1866        || config.ma_period_range.len() != 3
1867        || config.channel_rate_percent_range.len() != 3
1868        || config.linear_regression_period_range.len() != 3
1869    {
1870        return Err(JsValue::from_str(
1871            "Invalid config: all *_range fields must have exactly 3 elements",
1872        ));
1873    }
1874    let range = TrendFollowerBatchRange {
1875        trend_period: (
1876            config.trend_period_range[0],
1877            config.trend_period_range[1],
1878            config.trend_period_range[2],
1879        ),
1880        ma_period: (
1881            config.ma_period_range[0],
1882            config.ma_period_range[1],
1883            config.ma_period_range[2],
1884        ),
1885        channel_rate_percent: (
1886            config.channel_rate_percent_range[0],
1887            config.channel_rate_percent_range[1],
1888            config.channel_rate_percent_range[2],
1889        ),
1890        linear_regression_period: (
1891            config.linear_regression_period_range[0],
1892            config.linear_regression_period_range[1],
1893            config.linear_regression_period_range[2],
1894        ),
1895        matype: (config.matype.clone(), config.matype, String::new()),
1896        use_linear_regression: config.use_linear_regression,
1897    };
1898    let batch = trend_follower_batch_slice(high, low, close, volume, &range)
1899        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1900    serde_wasm_bindgen::to_value(&TrendFollowerBatchJsOutput {
1901        values: batch.values,
1902        rows: batch.rows,
1903        cols: batch.cols,
1904        combos: batch.combos,
1905    })
1906    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1907}
1908
1909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1910#[wasm_bindgen]
1911pub fn trend_follower_alloc(len: usize) -> *mut f64 {
1912    let mut vec = Vec::<f64>::with_capacity(len);
1913    let ptr = vec.as_mut_ptr();
1914    std::mem::forget(vec);
1915    ptr
1916}
1917
1918#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1919#[wasm_bindgen]
1920pub fn trend_follower_free(ptr: *mut f64, len: usize) {
1921    unsafe {
1922        let _ = Vec::from_raw_parts(ptr, len, len);
1923    }
1924}
1925
1926#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1927#[wasm_bindgen]
1928pub fn trend_follower_into(
1929    high_ptr: *const f64,
1930    low_ptr: *const f64,
1931    close_ptr: *const f64,
1932    volume_ptr: *const f64,
1933    out_ptr: *mut f64,
1934    len: usize,
1935    matype: &str,
1936    trend_period: usize,
1937    ma_period: usize,
1938    channel_rate_percent: f64,
1939    use_linear_regression: bool,
1940    linear_regression_period: usize,
1941) -> Result<(), JsValue> {
1942    if high_ptr.is_null()
1943        || low_ptr.is_null()
1944        || close_ptr.is_null()
1945        || volume_ptr.is_null()
1946        || out_ptr.is_null()
1947    {
1948        return Err(JsValue::from_str(
1949            "null pointer passed to trend_follower_into",
1950        ));
1951    }
1952    unsafe {
1953        let high = std::slice::from_raw_parts(high_ptr, len);
1954        let low = std::slice::from_raw_parts(low_ptr, len);
1955        let close = std::slice::from_raw_parts(close_ptr, len);
1956        let volume = std::slice::from_raw_parts(volume_ptr, len);
1957        let out = std::slice::from_raw_parts_mut(out_ptr, len);
1958        let input = TrendFollowerInput::from_slices(
1959            high,
1960            low,
1961            close,
1962            volume,
1963            TrendFollowerParams {
1964                matype: Some(matype.to_string()),
1965                trend_period: Some(trend_period),
1966                ma_period: Some(ma_period),
1967                channel_rate_percent: Some(channel_rate_percent),
1968                use_linear_regression: Some(use_linear_regression),
1969                linear_regression_period: Some(linear_regression_period),
1970            },
1971        );
1972        trend_follower_into_slice(out, &input, Kernel::Auto)
1973            .map_err(|e| JsValue::from_str(&e.to_string()))
1974    }
1975}
1976
1977#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1978#[wasm_bindgen(js_name = "trend_follower_into_host")]
1979pub fn trend_follower_into_host(
1980    high: &[f64],
1981    low: &[f64],
1982    close: &[f64],
1983    volume: &[f64],
1984    out_ptr: *mut f64,
1985    matype: &str,
1986    trend_period: usize,
1987    ma_period: usize,
1988    channel_rate_percent: f64,
1989    use_linear_regression: bool,
1990    linear_regression_period: usize,
1991) -> Result<(), JsValue> {
1992    if out_ptr.is_null() {
1993        return Err(JsValue::from_str(
1994            "null pointer passed to trend_follower_into_host",
1995        ));
1996    }
1997    unsafe {
1998        let out = std::slice::from_raw_parts_mut(out_ptr, close.len());
1999        let input = TrendFollowerInput::from_slices(
2000            high,
2001            low,
2002            close,
2003            volume,
2004            TrendFollowerParams {
2005                matype: Some(matype.to_string()),
2006                trend_period: Some(trend_period),
2007                ma_period: Some(ma_period),
2008                channel_rate_percent: Some(channel_rate_percent),
2009                use_linear_regression: Some(use_linear_regression),
2010                linear_regression_period: Some(linear_regression_period),
2011            },
2012        );
2013        trend_follower_into_slice(out, &input, Kernel::Auto)
2014            .map_err(|e| JsValue::from_str(&e.to_string()))
2015    }
2016}
2017
2018#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2019#[wasm_bindgen]
2020pub fn trend_follower_batch_into(
2021    high_ptr: *const f64,
2022    low_ptr: *const f64,
2023    close_ptr: *const f64,
2024    volume_ptr: *const f64,
2025    out_ptr: *mut f64,
2026    len: usize,
2027    trend_period_start: usize,
2028    trend_period_end: usize,
2029    trend_period_step: usize,
2030    ma_period_start: usize,
2031    ma_period_end: usize,
2032    ma_period_step: usize,
2033    channel_rate_percent_start: f64,
2034    channel_rate_percent_end: f64,
2035    channel_rate_percent_step: f64,
2036    linear_regression_period_start: usize,
2037    linear_regression_period_end: usize,
2038    linear_regression_period_step: usize,
2039    matype: &str,
2040    use_linear_regression: bool,
2041) -> Result<usize, JsValue> {
2042    if high_ptr.is_null()
2043        || low_ptr.is_null()
2044        || close_ptr.is_null()
2045        || volume_ptr.is_null()
2046        || out_ptr.is_null()
2047    {
2048        return Err(JsValue::from_str(
2049            "null pointer passed to trend_follower_batch_into",
2050        ));
2051    }
2052    unsafe {
2053        let high = std::slice::from_raw_parts(high_ptr, len);
2054        let low = std::slice::from_raw_parts(low_ptr, len);
2055        let close = std::slice::from_raw_parts(close_ptr, len);
2056        let volume = std::slice::from_raw_parts(volume_ptr, len);
2057        let range = TrendFollowerBatchRange {
2058            trend_period: (trend_period_start, trend_period_end, trend_period_step),
2059            ma_period: (ma_period_start, ma_period_end, ma_period_step),
2060            channel_rate_percent: (
2061                channel_rate_percent_start,
2062                channel_rate_percent_end,
2063                channel_rate_percent_step,
2064            ),
2065            linear_regression_period: (
2066                linear_regression_period_start,
2067                linear_regression_period_end,
2068                linear_regression_period_step,
2069            ),
2070            matype: (matype.to_string(), matype.to_string(), String::new()),
2071            use_linear_regression,
2072        };
2073        let combos =
2074            expand_grid_trend_follower(&range).map_err(|e| JsValue::from_str(&e.to_string()))?;
2075        let rows = combos.len();
2076        let out = std::slice::from_raw_parts_mut(out_ptr, rows * len);
2077        trend_follower_batch_inner_into(
2078            high,
2079            low,
2080            close,
2081            volume,
2082            &range,
2083            Kernel::Scalar,
2084            false,
2085            out,
2086        )
2087        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2088        Ok(rows)
2089    }
2090}
2091
2092#[cfg(test)]
2093mod tests {
2094    use super::*;
2095
2096    fn sample_ohlcv(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
2097        let mut high = Vec::with_capacity(len);
2098        let mut low = Vec::with_capacity(len);
2099        let mut close = Vec::with_capacity(len);
2100        let mut volume = Vec::with_capacity(len);
2101        for i in 0..len {
2102            let base = 100.0 + i as f64 * 0.19 + (i as f64 * 0.13).sin() * 1.8;
2103            let c = base + (i as f64 * 0.027).cos() * 0.6;
2104            high.push(c + 1.2 + (i as f64 * 0.05).sin().abs());
2105            low.push(c - 1.1 - (i as f64 * 0.04).cos().abs());
2106            close.push(c);
2107            volume.push(1000.0 + i as f64 * 11.0 + (i % 9) as f64 * 17.0);
2108        }
2109        (high, low, close, volume)
2110    }
2111
2112    fn assert_close(a: &[f64], b: &[f64]) {
2113        assert_eq!(a.len(), b.len());
2114        for i in 0..a.len() {
2115            if a[i].is_nan() || b[i].is_nan() {
2116                assert!(a[i].is_nan() && b[i].is_nan(), "nan mismatch at {i}");
2117            } else {
2118                assert!(
2119                    (a[i] - b[i]).abs() <= 1e-9,
2120                    "value mismatch at {i}: {} vs {}",
2121                    a[i],
2122                    b[i]
2123                );
2124            }
2125        }
2126    }
2127
2128    #[test]
2129    fn trend_follower_into_matches_api() {
2130        let (high, low, close, volume) = sample_ohlcv(128);
2131        let input = TrendFollowerInput::from_slices(
2132            &high,
2133            &low,
2134            &close,
2135            &volume,
2136            TrendFollowerParams::default(),
2137        );
2138        let direct = trend_follower(&input).unwrap();
2139        let mut out = vec![f64::NAN; close.len()];
2140        trend_follower_into_slice(&mut out, &input, Kernel::Auto).unwrap();
2141        assert_close(&direct.values, &out);
2142    }
2143
2144    #[test]
2145    fn trend_follower_stream_matches_batch_with_nan_gap() {
2146        let (mut high, mut low, mut close, mut volume) = sample_ohlcv(128);
2147        high[48] = f64::NAN;
2148        low[48] = f64::NAN;
2149        close[48] = f64::NAN;
2150        volume[48] = f64::NAN;
2151        let input = TrendFollowerInput::from_slices(
2152            &high,
2153            &low,
2154            &close,
2155            &volume,
2156            TrendFollowerParams::default(),
2157        );
2158        let batch = trend_follower(&input).unwrap();
2159        let mut stream = TrendFollowerStream::try_new(TrendFollowerParams::default()).unwrap();
2160        let mut collected = Vec::with_capacity(close.len());
2161        for i in 0..close.len() {
2162            collected.push(
2163                stream
2164                    .update_reset_on_nan(high[i], low[i], close[i], volume[i])
2165                    .unwrap_or(f64::NAN),
2166            );
2167        }
2168        assert_close(&batch.values, &collected);
2169    }
2170
2171    #[test]
2172    fn trend_follower_batch_single_param_matches_single() {
2173        let (high, low, close, volume) = sample_ohlcv(128);
2174        let params = TrendFollowerParams {
2175            matype: Some("wma".to_string()),
2176            trend_period: Some(14),
2177            ma_period: Some(9),
2178            channel_rate_percent: Some(1.1),
2179            use_linear_regression: Some(false),
2180            linear_regression_period: Some(5),
2181        };
2182        let single = trend_follower(&TrendFollowerInput::from_slices(
2183            &high,
2184            &low,
2185            &close,
2186            &volume,
2187            params.clone(),
2188        ))
2189        .unwrap();
2190        let batch = trend_follower_batch_with_kernel(
2191            &high,
2192            &low,
2193            &close,
2194            &volume,
2195            &TrendFollowerBatchRange {
2196                trend_period: (14, 14, 0),
2197                ma_period: (9, 9, 0),
2198                channel_rate_percent: (1.1, 1.1, 0.0),
2199                linear_regression_period: (5, 5, 0),
2200                matype: ("wma".to_string(), "wma".to_string(), String::new()),
2201                use_linear_regression: false,
2202            },
2203            Kernel::Auto,
2204        )
2205        .unwrap();
2206        assert_eq!(batch.rows, 1);
2207        assert_close(&single.values, &batch.values[..close.len()]);
2208    }
2209
2210    #[test]
2211    fn trend_follower_vwma_depends_on_volume() {
2212        let (high, low, close, volume) = sample_ohlcv(96);
2213        let mut volume_b = volume.clone();
2214        volume_b.reverse();
2215        let params = TrendFollowerParams {
2216            matype: Some("vwma".to_string()),
2217            trend_period: Some(20),
2218            ma_period: Some(12),
2219            channel_rate_percent: Some(1.0),
2220            use_linear_regression: Some(false),
2221            linear_regression_period: Some(5),
2222        };
2223        let a = trend_follower(&TrendFollowerInput::from_slices(
2224            &high,
2225            &low,
2226            &close,
2227            &volume,
2228            params.clone(),
2229        ))
2230        .unwrap();
2231        let b = trend_follower(&TrendFollowerInput::from_slices(
2232            &high, &low, &close, &volume_b, params,
2233        ))
2234        .unwrap();
2235        assert!(a
2236            .values
2237            .iter()
2238            .zip(&b.values)
2239            .any(|(x, y)| x.is_finite() && y.is_finite() && (x - y).abs() > 1e-9));
2240    }
2241
2242    #[test]
2243    fn trend_follower_invalid_matype_rejected() {
2244        let (high, low, close, volume) = sample_ohlcv(64);
2245        let err = trend_follower(&TrendFollowerInput::from_slices(
2246            &high,
2247            &low,
2248            &close,
2249            &volume,
2250            TrendFollowerParams {
2251                matype: Some("hma".to_string()),
2252                ..TrendFollowerParams::default()
2253            },
2254        ))
2255        .unwrap_err();
2256        assert!(matches!(err, TrendFollowerError::InvalidMaType { .. }));
2257    }
2258}