Skip to main content

vector_ta/indicators/
trend_continuation_factor.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::convert::AsRef;
26use std::mem::{ManuallyDrop, MaybeUninit};
27use thiserror::Error;
28
29impl<'a> AsRef<[f64]> for TrendContinuationFactorInput<'a> {
30    #[inline(always)]
31    fn as_ref(&self) -> &[f64] {
32        match &self.data {
33            TrendContinuationFactorData::Candles { candles, source } => {
34                source_type(candles, source)
35            }
36            TrendContinuationFactorData::Slice(slice) => slice,
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub enum TrendContinuationFactorData<'a> {
43    Candles {
44        candles: &'a Candles,
45        source: &'a str,
46    },
47    Slice(&'a [f64]),
48}
49
50#[derive(Debug, Clone)]
51pub struct TrendContinuationFactorOutput {
52    pub plus_tcf: Vec<f64>,
53    pub minus_tcf: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58    all(target_arch = "wasm32", feature = "wasm"),
59    derive(Serialize, Deserialize)
60)]
61pub struct TrendContinuationFactorParams {
62    pub length: Option<usize>,
63}
64
65impl Default for TrendContinuationFactorParams {
66    fn default() -> Self {
67        Self { length: Some(35) }
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct TrendContinuationFactorInput<'a> {
73    pub data: TrendContinuationFactorData<'a>,
74    pub params: TrendContinuationFactorParams,
75}
76
77impl<'a> TrendContinuationFactorInput<'a> {
78    #[inline]
79    pub fn from_candles(
80        candles: &'a Candles,
81        source: &'a str,
82        params: TrendContinuationFactorParams,
83    ) -> Self {
84        Self {
85            data: TrendContinuationFactorData::Candles { candles, source },
86            params,
87        }
88    }
89
90    #[inline]
91    pub fn from_slice(slice: &'a [f64], params: TrendContinuationFactorParams) -> Self {
92        Self {
93            data: TrendContinuationFactorData::Slice(slice),
94            params,
95        }
96    }
97
98    #[inline]
99    pub fn with_default_candles(candles: &'a Candles) -> Self {
100        Self::from_candles(candles, "close", TrendContinuationFactorParams::default())
101    }
102
103    #[inline]
104    pub fn get_length(&self) -> usize {
105        self.params.length.unwrap_or(35)
106    }
107}
108
109#[derive(Copy, Clone, Debug)]
110pub struct TrendContinuationFactorBuilder {
111    length: Option<usize>,
112    kernel: Kernel,
113}
114
115impl Default for TrendContinuationFactorBuilder {
116    fn default() -> Self {
117        Self {
118            length: None,
119            kernel: Kernel::Auto,
120        }
121    }
122}
123
124impl TrendContinuationFactorBuilder {
125    #[inline(always)]
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    #[inline(always)]
131    pub fn length(mut self, value: usize) -> Self {
132        self.length = Some(value);
133        self
134    }
135
136    #[inline(always)]
137    pub fn kernel(mut self, value: Kernel) -> Self {
138        self.kernel = value;
139        self
140    }
141
142    #[inline(always)]
143    pub fn apply(
144        self,
145        candles: &Candles,
146    ) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
147        let input = TrendContinuationFactorInput::from_candles(
148            candles,
149            "close",
150            TrendContinuationFactorParams {
151                length: self.length,
152            },
153        );
154        trend_continuation_factor_with_kernel(&input, self.kernel)
155    }
156
157    #[inline(always)]
158    pub fn apply_slice(
159        self,
160        data: &[f64],
161    ) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
162        let input = TrendContinuationFactorInput::from_slice(
163            data,
164            TrendContinuationFactorParams {
165                length: self.length,
166            },
167        );
168        trend_continuation_factor_with_kernel(&input, self.kernel)
169    }
170
171    #[inline(always)]
172    pub fn into_stream(
173        self,
174    ) -> Result<TrendContinuationFactorStream, TrendContinuationFactorError> {
175        TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
176            length: self.length,
177        })
178    }
179}
180
181#[derive(Debug, Error)]
182pub enum TrendContinuationFactorError {
183    #[error("trend_continuation_factor: Input data slice is empty.")]
184    EmptyInputData,
185    #[error("trend_continuation_factor: All values are NaN.")]
186    AllValuesNaN,
187    #[error(
188        "trend_continuation_factor: Invalid length: length = {length}, data length = {data_len}"
189    )]
190    InvalidLength { length: usize, data_len: usize },
191    #[error(
192        "trend_continuation_factor: Not enough valid data: needed = {needed}, valid = {valid}"
193    )]
194    NotEnoughValidData { needed: usize, valid: usize },
195    #[error(
196        "trend_continuation_factor: Output length mismatch: expected = {expected}, got = {got}"
197    )]
198    OutputLengthMismatch { expected: usize, got: usize },
199    #[error("trend_continuation_factor: Invalid range: start={start}, end={end}, step={step}")]
200    InvalidRange {
201        start: usize,
202        end: usize,
203        step: usize,
204    },
205    #[error("trend_continuation_factor: Invalid kernel for batch: {0:?}")]
206    InvalidKernelForBatch(Kernel),
207}
208
209#[inline(always)]
210fn first_valid_index(data: &[f64]) -> Option<usize> {
211    data.iter().position(|x| x.is_finite())
212}
213
214#[inline(always)]
215fn trend_continuation_factor_prepare<'a>(
216    input: &'a TrendContinuationFactorInput,
217) -> Result<(&'a [f64], usize, usize), TrendContinuationFactorError> {
218    let data = input.as_ref();
219    let data_len = data.len();
220    if data_len == 0 {
221        return Err(TrendContinuationFactorError::EmptyInputData);
222    }
223
224    let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
225    let length = input.get_length();
226    if length == 0 || length > data_len {
227        return Err(TrendContinuationFactorError::InvalidLength { length, data_len });
228    }
229
230    let valid = data_len - first;
231    if valid <= length {
232        return Err(TrendContinuationFactorError::NotEnoughValidData {
233            needed: length + 1,
234            valid,
235        });
236    }
237
238    Ok((data, length, first))
239}
240
241#[derive(Clone, Debug)]
242pub struct TrendContinuationFactorStream {
243    length: usize,
244    prev: Option<f64>,
245    plus_cf: Option<f64>,
246    minus_cf: Option<f64>,
247    comparisons_seen: usize,
248    head: usize,
249    sum_plus: f64,
250    sum_minus: f64,
251    plus_buffer: Vec<f64>,
252    minus_buffer: Vec<f64>,
253}
254
255impl TrendContinuationFactorStream {
256    #[inline]
257    fn from_length(length: usize) -> Self {
258        Self {
259            length,
260            prev: None,
261            plus_cf: None,
262            minus_cf: None,
263            comparisons_seen: 0,
264            head: 0,
265            sum_plus: 0.0,
266            sum_minus: 0.0,
267            plus_buffer: vec![0.0; length.max(1)],
268            minus_buffer: vec![0.0; length.max(1)],
269        }
270    }
271
272    #[inline]
273    pub fn try_new(
274        params: TrendContinuationFactorParams,
275    ) -> Result<Self, TrendContinuationFactorError> {
276        let length = params.length.unwrap_or(35);
277        if length == 0 {
278            return Err(TrendContinuationFactorError::InvalidLength {
279                length,
280                data_len: 0,
281            });
282        }
283        Ok(Self::from_length(length))
284    }
285
286    #[inline(always)]
287    fn reset(&mut self) {
288        self.prev = None;
289        self.plus_cf = None;
290        self.minus_cf = None;
291        self.comparisons_seen = 0;
292        self.head = 0;
293        self.sum_plus = 0.0;
294        self.sum_minus = 0.0;
295        self.plus_buffer.fill(0.0);
296        self.minus_buffer.fill(0.0);
297    }
298
299    #[inline(always)]
300    pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
301        if !value.is_finite() {
302            return None;
303        }
304
305        let prev = match self.prev.replace(value) {
306            Some(prev) => prev,
307            None => return None,
308        };
309
310        let change = value - prev;
311        let plus_change = if change > 0.0 { change } else { 0.0 };
312        let minus_change = if change < 0.0 { -change } else { 0.0 };
313
314        let next_plus_cf = if plus_change == 0.0 {
315            0.0
316        } else {
317            plus_change + self.plus_cf.unwrap_or(1.0)
318        };
319        let next_minus_cf = if minus_change == 0.0 {
320            0.0
321        } else {
322            minus_change + self.minus_cf.unwrap_or(1.0)
323        };
324
325        self.plus_cf = Some(next_plus_cf);
326        self.minus_cf = Some(next_minus_cf);
327
328        let plus = plus_change - next_minus_cf;
329        let minus = minus_change - next_plus_cf;
330
331        if self.comparisons_seen < self.length {
332            self.plus_buffer[self.comparisons_seen] = plus;
333            self.minus_buffer[self.comparisons_seen] = minus;
334            self.sum_plus += plus;
335            self.sum_minus += minus;
336            self.comparisons_seen += 1;
337            if self.comparisons_seen < self.length {
338                return None;
339            }
340            return Some((self.sum_plus, self.sum_minus));
341        }
342
343        let old_plus = self.plus_buffer[self.head];
344        let old_minus = self.minus_buffer[self.head];
345        self.plus_buffer[self.head] = plus;
346        self.minus_buffer[self.head] = minus;
347        self.sum_plus += plus - old_plus;
348        self.sum_minus += minus - old_minus;
349        self.head += 1;
350        if self.head == self.length {
351            self.head = 0;
352        }
353
354        Some((self.sum_plus, self.sum_minus))
355    }
356
357    #[inline(always)]
358    pub fn update_reset_on_nan(&mut self, value: f64) -> Option<(f64, f64)> {
359        if !value.is_finite() {
360            self.reset();
361            return None;
362        }
363        self.update(value)
364    }
365}
366
367#[inline(always)]
368fn trend_continuation_factor_compute_into(
369    data: &[f64],
370    length: usize,
371    _first: usize,
372    _kernel: Kernel,
373    out_plus: &mut [f64],
374    out_minus: &mut [f64],
375) {
376    let mut stream = TrendContinuationFactorStream::from_length(length);
377    for i in 0..data.len() {
378        match stream.update_reset_on_nan(data[i]) {
379            Some((plus, minus)) => {
380                out_plus[i] = plus;
381                out_minus[i] = minus;
382            }
383            None => {
384                out_plus[i] = f64::NAN;
385                out_minus[i] = f64::NAN;
386            }
387        }
388    }
389}
390
391#[inline]
392pub fn trend_continuation_factor(
393    input: &TrendContinuationFactorInput,
394) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
395    trend_continuation_factor_with_kernel(input, Kernel::Auto)
396}
397
398pub fn trend_continuation_factor_with_kernel(
399    input: &TrendContinuationFactorInput,
400    kernel: Kernel,
401) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
402    let (data, length, first) = trend_continuation_factor_prepare(input)?;
403    let warmup = first + length;
404    let mut plus_tcf = alloc_with_nan_prefix(data.len(), warmup);
405    let mut minus_tcf = alloc_with_nan_prefix(data.len(), warmup);
406    trend_continuation_factor_compute_into(
407        data,
408        length,
409        first,
410        kernel,
411        &mut plus_tcf,
412        &mut minus_tcf,
413    );
414    Ok(TrendContinuationFactorOutput {
415        plus_tcf,
416        minus_tcf,
417    })
418}
419
420pub fn trend_continuation_factor_into_slice(
421    dst_plus_tcf: &mut [f64],
422    dst_minus_tcf: &mut [f64],
423    input: &TrendContinuationFactorInput,
424    kernel: Kernel,
425) -> Result<(), TrendContinuationFactorError> {
426    let (data, length, first) = trend_continuation_factor_prepare(input)?;
427    if dst_plus_tcf.len() != data.len() || dst_minus_tcf.len() != data.len() {
428        return Err(TrendContinuationFactorError::OutputLengthMismatch {
429            expected: data.len(),
430            got: dst_plus_tcf.len().max(dst_minus_tcf.len()),
431        });
432    }
433    trend_continuation_factor_compute_into(
434        data,
435        length,
436        first,
437        kernel,
438        dst_plus_tcf,
439        dst_minus_tcf,
440    );
441    Ok(())
442}
443
444#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
445#[inline]
446pub fn trend_continuation_factor_into(
447    input: &TrendContinuationFactorInput,
448    out_plus_tcf: &mut [f64],
449    out_minus_tcf: &mut [f64],
450) -> Result<(), TrendContinuationFactorError> {
451    trend_continuation_factor_into_slice(out_plus_tcf, out_minus_tcf, input, Kernel::Auto)
452}
453
454#[derive(Clone, Debug)]
455pub struct TrendContinuationFactorBatchRange {
456    pub length: (usize, usize, usize),
457}
458
459impl Default for TrendContinuationFactorBatchRange {
460    fn default() -> Self {
461        Self {
462            length: (35, 200, 1),
463        }
464    }
465}
466
467#[derive(Clone, Debug, Default)]
468pub struct TrendContinuationFactorBatchBuilder {
469    range: TrendContinuationFactorBatchRange,
470    kernel: Kernel,
471}
472
473impl TrendContinuationFactorBatchBuilder {
474    pub fn new() -> Self {
475        Self::default()
476    }
477
478    pub fn kernel(mut self, kernel: Kernel) -> Self {
479        self.kernel = kernel;
480        self
481    }
482
483    #[inline]
484    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
485        self.range.length = (start, end, step);
486        self
487    }
488
489    #[inline]
490    pub fn length_static(mut self, length: usize) -> Self {
491        self.range.length = (length, length, 0);
492        self
493    }
494
495    pub fn apply_slice(
496        self,
497        data: &[f64],
498    ) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
499        trend_continuation_factor_batch_with_kernel(data, &self.range, self.kernel)
500    }
501
502    pub fn apply_candles(
503        self,
504        candles: &Candles,
505        source: &str,
506    ) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
507        self.apply_slice(source_type(candles, source))
508    }
509}
510
511#[derive(Clone, Debug)]
512pub struct TrendContinuationFactorBatchOutput {
513    pub plus_tcf: Vec<f64>,
514    pub minus_tcf: Vec<f64>,
515    pub combos: Vec<TrendContinuationFactorParams>,
516    pub rows: usize,
517    pub cols: usize,
518}
519
520impl TrendContinuationFactorBatchOutput {
521    pub fn row_for_params(&self, params: &TrendContinuationFactorParams) -> Option<usize> {
522        self.combos
523            .iter()
524            .position(|combo| combo.length.unwrap_or(35) == params.length.unwrap_or(35))
525    }
526
527    pub fn plus_tcf_for(&self, params: &TrendContinuationFactorParams) -> Option<&[f64]> {
528        self.row_for_params(params).map(|row| {
529            let start = row * self.cols;
530            &self.plus_tcf[start..start + self.cols]
531        })
532    }
533
534    pub fn minus_tcf_for(&self, params: &TrendContinuationFactorParams) -> Option<&[f64]> {
535        self.row_for_params(params).map(|row| {
536            let start = row * self.cols;
537            &self.minus_tcf[start..start + self.cols]
538        })
539    }
540}
541
542fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, TrendContinuationFactorError> {
543    let (start, end, step) = range;
544    if step == 0 || start == end {
545        return Ok(vec![start]);
546    }
547
548    let mut out = Vec::new();
549    if start < end {
550        let mut value = start;
551        while value <= end {
552            out.push(value);
553            match value.checked_add(step) {
554                Some(next) if next > value => value = next,
555                _ => break,
556            }
557        }
558    } else {
559        let mut value = start;
560        while value >= end {
561            out.push(value);
562            if value < end + step {
563                break;
564            }
565            value = value.saturating_sub(step);
566            if value == 0 {
567                break;
568            }
569        }
570    }
571
572    if out.is_empty() {
573        return Err(TrendContinuationFactorError::InvalidRange { start, end, step });
574    }
575    Ok(out)
576}
577
578pub fn expand_grid_trend_continuation_factor(
579    sweep: &TrendContinuationFactorBatchRange,
580) -> Result<Vec<TrendContinuationFactorParams>, TrendContinuationFactorError> {
581    Ok(axis_usize(sweep.length)?
582        .into_iter()
583        .map(|length| TrendContinuationFactorParams {
584            length: Some(length),
585        })
586        .collect())
587}
588
589pub fn trend_continuation_factor_batch_with_kernel(
590    data: &[f64],
591    sweep: &TrendContinuationFactorBatchRange,
592    kernel: Kernel,
593) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
594    let batch_kernel = match kernel {
595        Kernel::Auto => Kernel::ScalarBatch,
596        other if other.is_batch() => other,
597        other => return Err(TrendContinuationFactorError::InvalidKernelForBatch(other)),
598    };
599    trend_continuation_factor_batch_impl(data, sweep, batch_kernel.to_non_batch(), true)
600}
601
602pub fn trend_continuation_factor_batch_slice(
603    data: &[f64],
604    sweep: &TrendContinuationFactorBatchRange,
605) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
606    trend_continuation_factor_batch_impl(data, sweep, Kernel::Scalar, false)
607}
608
609pub fn trend_continuation_factor_batch_par_slice(
610    data: &[f64],
611    sweep: &TrendContinuationFactorBatchRange,
612) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
613    trend_continuation_factor_batch_impl(data, sweep, Kernel::Scalar, true)
614}
615
616fn trend_continuation_factor_batch_impl(
617    data: &[f64],
618    sweep: &TrendContinuationFactorBatchRange,
619    kernel: Kernel,
620    parallel: bool,
621) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
622    let combos = expand_grid_trend_continuation_factor(sweep)?;
623    let rows = combos.len();
624    let cols = data.len();
625
626    if cols == 0 {
627        return Err(TrendContinuationFactorError::EmptyInputData);
628    }
629
630    let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
631    let max_length = combos
632        .iter()
633        .map(|params| params.length.unwrap_or(35))
634        .max()
635        .unwrap_or(35);
636    let valid = cols - first;
637    if valid <= max_length {
638        return Err(TrendContinuationFactorError::NotEnoughValidData {
639            needed: max_length + 1,
640            valid,
641        });
642    }
643
644    let warmups: Vec<usize> = combos
645        .iter()
646        .map(|params| first + params.length.unwrap_or(35))
647        .collect();
648
649    let mut plus_matrix = make_uninit_matrix(rows, cols);
650    init_matrix_prefixes(&mut plus_matrix, cols, &warmups);
651    let mut minus_matrix = make_uninit_matrix(rows, cols);
652    init_matrix_prefixes(&mut minus_matrix, cols, &warmups);
653
654    let mut plus_guard = ManuallyDrop::new(plus_matrix);
655    let mut minus_guard = ManuallyDrop::new(minus_matrix);
656
657    let plus_mu: &mut [MaybeUninit<f64>] =
658        unsafe { std::slice::from_raw_parts_mut(plus_guard.as_mut_ptr(), plus_guard.len()) };
659    let minus_mu: &mut [MaybeUninit<f64>] =
660        unsafe { std::slice::from_raw_parts_mut(minus_guard.as_mut_ptr(), minus_guard.len()) };
661
662    let do_row = |row: usize,
663                  row_plus_mu: &mut [MaybeUninit<f64>],
664                  row_minus_mu: &mut [MaybeUninit<f64>]| {
665        let length = combos[row].length.unwrap_or(35);
666        let dst_plus = unsafe {
667            std::slice::from_raw_parts_mut(row_plus_mu.as_mut_ptr() as *mut f64, row_plus_mu.len())
668        };
669        let dst_minus = unsafe {
670            std::slice::from_raw_parts_mut(
671                row_minus_mu.as_mut_ptr() as *mut f64,
672                row_minus_mu.len(),
673            )
674        };
675        trend_continuation_factor_compute_into(data, length, first, kernel, dst_plus, dst_minus);
676    };
677
678    if parallel {
679        #[cfg(not(target_arch = "wasm32"))]
680        plus_mu
681            .par_chunks_mut(cols)
682            .zip(minus_mu.par_chunks_mut(cols))
683            .enumerate()
684            .for_each(|(row, (row_plus_mu, row_minus_mu))| do_row(row, row_plus_mu, row_minus_mu));
685        #[cfg(target_arch = "wasm32")]
686        for (row, (row_plus_mu, row_minus_mu)) in plus_mu
687            .chunks_mut(cols)
688            .zip(minus_mu.chunks_mut(cols))
689            .enumerate()
690        {
691            do_row(row, row_plus_mu, row_minus_mu);
692        }
693    } else {
694        for (row, (row_plus_mu, row_minus_mu)) in plus_mu
695            .chunks_mut(cols)
696            .zip(minus_mu.chunks_mut(cols))
697            .enumerate()
698        {
699            do_row(row, row_plus_mu, row_minus_mu);
700        }
701    }
702
703    let plus_tcf = unsafe {
704        Vec::from_raw_parts(
705            plus_guard.as_mut_ptr() as *mut f64,
706            plus_guard.len(),
707            plus_guard.capacity(),
708        )
709    };
710    let minus_tcf = unsafe {
711        Vec::from_raw_parts(
712            minus_guard.as_mut_ptr() as *mut f64,
713            minus_guard.len(),
714            minus_guard.capacity(),
715        )
716    };
717
718    Ok(TrendContinuationFactorBatchOutput {
719        plus_tcf,
720        minus_tcf,
721        combos,
722        rows,
723        cols,
724    })
725}
726
727fn trend_continuation_factor_batch_inner_into(
728    data: &[f64],
729    sweep: &TrendContinuationFactorBatchRange,
730    kernel: Kernel,
731    parallel: bool,
732    out_plus: &mut [f64],
733    out_minus: &mut [f64],
734) -> Result<(), TrendContinuationFactorError> {
735    let combos = expand_grid_trend_continuation_factor(sweep)?;
736    let rows = combos.len();
737    let cols = data.len();
738    if rows.checked_mul(cols) != Some(out_plus.len()) || out_minus.len() != out_plus.len() {
739        return Err(TrendContinuationFactorError::OutputLengthMismatch {
740            expected: rows * cols,
741            got: out_plus.len().max(out_minus.len()),
742        });
743    }
744
745    let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
746    for (row, params) in combos.iter().enumerate() {
747        let length = params.length.unwrap_or(35);
748        let row_plus = &mut out_plus[row * cols..(row + 1) * cols];
749        let row_minus = &mut out_minus[row * cols..(row + 1) * cols];
750        row_plus.fill(f64::NAN);
751        row_minus.fill(f64::NAN);
752        if cols - first <= length {
753            return Err(TrendContinuationFactorError::NotEnoughValidData {
754                needed: length + 1,
755                valid: cols - first,
756            });
757        }
758    }
759
760    let do_row = |row: usize, row_plus: &mut [f64], row_minus: &mut [f64]| {
761        let length = combos[row].length.unwrap_or(35);
762        trend_continuation_factor_compute_into(data, length, first, kernel, row_plus, row_minus);
763    };
764
765    if parallel {
766        #[cfg(not(target_arch = "wasm32"))]
767        out_plus
768            .par_chunks_mut(cols)
769            .zip(out_minus.par_chunks_mut(cols))
770            .enumerate()
771            .for_each(|(row, (row_plus, row_minus))| do_row(row, row_plus, row_minus));
772        #[cfg(target_arch = "wasm32")]
773        for (row, (row_plus, row_minus)) in out_plus
774            .chunks_mut(cols)
775            .zip(out_minus.chunks_mut(cols))
776            .enumerate()
777        {
778            do_row(row, row_plus, row_minus);
779        }
780    } else {
781        for (row, (row_plus, row_minus)) in out_plus
782            .chunks_mut(cols)
783            .zip(out_minus.chunks_mut(cols))
784            .enumerate()
785        {
786            do_row(row, row_plus, row_minus);
787        }
788    }
789
790    Ok(())
791}
792
793#[cfg(feature = "python")]
794#[pyfunction(name = "trend_continuation_factor")]
795#[pyo3(signature = (data, length=35, kernel=None))]
796pub fn trend_continuation_factor_py<'py>(
797    py: Python<'py>,
798    data: PyReadonlyArray1<'py, f64>,
799    length: usize,
800    kernel: Option<&str>,
801) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
802    let data = data.as_slice()?;
803    let kernel = validate_kernel(kernel, false)?;
804    let input = TrendContinuationFactorInput::from_slice(
805        data,
806        TrendContinuationFactorParams {
807            length: Some(length),
808        },
809    );
810    let output = py
811        .allow_threads(|| trend_continuation_factor_with_kernel(&input, kernel))
812        .map_err(|e| PyValueError::new_err(e.to_string()))?;
813    Ok((
814        output.plus_tcf.into_pyarray(py),
815        output.minus_tcf.into_pyarray(py),
816    ))
817}
818
819#[cfg(feature = "python")]
820#[pyclass(name = "TrendContinuationFactorStream")]
821pub struct TrendContinuationFactorStreamPy {
822    stream: TrendContinuationFactorStream,
823}
824
825#[cfg(feature = "python")]
826#[pymethods]
827impl TrendContinuationFactorStreamPy {
828    #[new]
829    #[pyo3(signature = (length=35))]
830    fn new(length: usize) -> PyResult<Self> {
831        let stream = TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
832            length: Some(length),
833        })
834        .map_err(|e| PyValueError::new_err(e.to_string()))?;
835        Ok(Self { stream })
836    }
837
838    fn update(&mut self, value: f64) -> Option<(f64, f64)> {
839        self.stream.update_reset_on_nan(value)
840    }
841}
842
843#[cfg(feature = "python")]
844#[pyfunction(name = "trend_continuation_factor_batch")]
845#[pyo3(signature = (data, length_range, kernel=None))]
846pub fn trend_continuation_factor_batch_py<'py>(
847    py: Python<'py>,
848    data: PyReadonlyArray1<'py, f64>,
849    length_range: (usize, usize, usize),
850    kernel: Option<&str>,
851) -> PyResult<Bound<'py, PyDict>> {
852    let data = data.as_slice()?;
853    let sweep = TrendContinuationFactorBatchRange {
854        length: length_range,
855    };
856    let combos = expand_grid_trend_continuation_factor(&sweep)
857        .map_err(|e| PyValueError::new_err(e.to_string()))?;
858    let rows = combos.len();
859    let cols = data.len();
860    let total = rows
861        .checked_mul(cols)
862        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
863    let plus_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
864    let minus_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
865    let out_plus = unsafe { plus_arr.as_slice_mut()? };
866    let out_minus = unsafe { minus_arr.as_slice_mut()? };
867    let kernel = validate_kernel(kernel, true)?;
868
869    py.allow_threads(|| {
870        let batch_kernel = match kernel {
871            Kernel::Auto => detect_best_batch_kernel(),
872            other => other,
873        };
874        trend_continuation_factor_batch_inner_into(
875            data,
876            &sweep,
877            batch_kernel.to_non_batch(),
878            true,
879            out_plus,
880            out_minus,
881        )
882    })
883    .map_err(|e| PyValueError::new_err(e.to_string()))?;
884
885    let dict = PyDict::new(py);
886    dict.set_item("plus_tcf", plus_arr.reshape((rows, cols))?)?;
887    dict.set_item("minus_tcf", minus_arr.reshape((rows, cols))?)?;
888    dict.set_item(
889        "lengths",
890        combos
891            .iter()
892            .map(|params| params.length.unwrap_or(35) as u64)
893            .collect::<Vec<_>>()
894            .into_pyarray(py),
895    )?;
896    dict.set_item("rows", rows)?;
897    dict.set_item("cols", cols)?;
898    Ok(dict)
899}
900
901#[cfg(feature = "python")]
902pub fn register_trend_continuation_factor_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
903    m.add_function(wrap_pyfunction!(trend_continuation_factor_py, m)?)?;
904    m.add_function(wrap_pyfunction!(trend_continuation_factor_batch_py, m)?)?;
905    m.add_class::<TrendContinuationFactorStreamPy>()?;
906    Ok(())
907}
908
909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
910#[derive(Debug, Clone, Serialize, Deserialize)]
911struct TrendContinuationFactorBatchConfig {
912    length_range: Vec<usize>,
913}
914
915#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
916#[derive(Debug, Clone, Serialize, Deserialize)]
917struct TrendContinuationFactorBatchJsOutput {
918    plus_tcf: Vec<f64>,
919    minus_tcf: Vec<f64>,
920    rows: usize,
921    cols: usize,
922    combos: Vec<TrendContinuationFactorParams>,
923}
924
925#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
926#[derive(Debug, Clone, Serialize, Deserialize)]
927struct TrendContinuationFactorJsOutput {
928    plus_tcf: Vec<f64>,
929    minus_tcf: Vec<f64>,
930}
931
932#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
933#[wasm_bindgen(js_name = "trend_continuation_factor_js")]
934pub fn trend_continuation_factor_js(data: &[f64], length: usize) -> Result<JsValue, JsValue> {
935    let input = TrendContinuationFactorInput::from_slice(
936        data,
937        TrendContinuationFactorParams {
938            length: Some(length),
939        },
940    );
941    let output =
942        trend_continuation_factor(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
943    serde_wasm_bindgen::to_value(&TrendContinuationFactorJsOutput {
944        plus_tcf: output.plus_tcf,
945        minus_tcf: output.minus_tcf,
946    })
947    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
948}
949
950#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
951#[wasm_bindgen(js_name = "trend_continuation_factor_batch_js")]
952pub fn trend_continuation_factor_batch_js(
953    data: &[f64],
954    config: JsValue,
955) -> Result<JsValue, JsValue> {
956    let config: TrendContinuationFactorBatchConfig = serde_wasm_bindgen::from_value(config)
957        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
958    if config.length_range.len() != 3 {
959        return Err(JsValue::from_str(
960            "Invalid config: length_range must have exactly 3 elements [start, end, step]",
961        ));
962    }
963    let sweep = TrendContinuationFactorBatchRange {
964        length: (
965            config.length_range[0],
966            config.length_range[1],
967            config.length_range[2],
968        ),
969    };
970    let batch = trend_continuation_factor_batch_slice(data, &sweep)
971        .map_err(|e| JsValue::from_str(&e.to_string()))?;
972    serde_wasm_bindgen::to_value(&TrendContinuationFactorBatchJsOutput {
973        plus_tcf: batch.plus_tcf,
974        minus_tcf: batch.minus_tcf,
975        rows: batch.rows,
976        cols: batch.cols,
977        combos: batch.combos,
978    })
979    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
980}
981
982#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
983#[wasm_bindgen]
984pub fn trend_continuation_factor_alloc(len: usize) -> *mut f64 {
985    let mut vec = Vec::<f64>::with_capacity(len * 2);
986    let ptr = vec.as_mut_ptr();
987    std::mem::forget(vec);
988    ptr
989}
990
991#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
992#[wasm_bindgen]
993pub fn trend_continuation_factor_free(ptr: *mut f64, len: usize) {
994    unsafe {
995        let _ = Vec::from_raw_parts(ptr, len * 2, len * 2);
996    }
997}
998
999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1000#[wasm_bindgen]
1001pub fn trend_continuation_factor_into(
1002    in_ptr: *const f64,
1003    out_ptr: *mut f64,
1004    len: usize,
1005    length: usize,
1006) -> Result<(), JsValue> {
1007    if in_ptr.is_null() || out_ptr.is_null() {
1008        return Err(JsValue::from_str(
1009            "null pointer passed to trend_continuation_factor_into",
1010        ));
1011    }
1012    unsafe {
1013        let data = std::slice::from_raw_parts(in_ptr, len);
1014        let out = std::slice::from_raw_parts_mut(out_ptr, len * 2);
1015        let (out_plus, out_minus) = out.split_at_mut(len);
1016        let input = TrendContinuationFactorInput::from_slice(
1017            data,
1018            TrendContinuationFactorParams {
1019                length: Some(length),
1020            },
1021        );
1022        trend_continuation_factor_into_slice(out_plus, out_minus, &input, Kernel::Auto)
1023            .map_err(|e| JsValue::from_str(&e.to_string()))
1024    }
1025}
1026
1027#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1028#[wasm_bindgen(js_name = "trend_continuation_factor_into_host")]
1029pub fn trend_continuation_factor_into_host(
1030    data: &[f64],
1031    out_ptr: *mut f64,
1032    length: usize,
1033) -> Result<(), JsValue> {
1034    if out_ptr.is_null() {
1035        return Err(JsValue::from_str(
1036            "null pointer passed to trend_continuation_factor_into_host",
1037        ));
1038    }
1039    unsafe {
1040        let out = std::slice::from_raw_parts_mut(out_ptr, data.len() * 2);
1041        let (out_plus, out_minus) = out.split_at_mut(data.len());
1042        let input = TrendContinuationFactorInput::from_slice(
1043            data,
1044            TrendContinuationFactorParams {
1045                length: Some(length),
1046            },
1047        );
1048        trend_continuation_factor_into_slice(out_plus, out_minus, &input, Kernel::Auto)
1049            .map_err(|e| JsValue::from_str(&e.to_string()))
1050    }
1051}
1052
1053#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1054#[wasm_bindgen]
1055pub fn trend_continuation_factor_batch_into(
1056    in_ptr: *const f64,
1057    out_ptr: *mut f64,
1058    len: usize,
1059    length_start: usize,
1060    length_end: usize,
1061    length_step: usize,
1062) -> Result<usize, JsValue> {
1063    if in_ptr.is_null() || out_ptr.is_null() {
1064        return Err(JsValue::from_str(
1065            "null pointer passed to trend_continuation_factor_batch_into",
1066        ));
1067    }
1068    unsafe {
1069        let data = std::slice::from_raw_parts(in_ptr, len);
1070        let sweep = TrendContinuationFactorBatchRange {
1071            length: (length_start, length_end, length_step),
1072        };
1073        let combos = expand_grid_trend_continuation_factor(&sweep)
1074            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1075        let rows = combos.len();
1076        let out = std::slice::from_raw_parts_mut(out_ptr, rows * len * 2);
1077        let (out_plus, out_minus) = out.split_at_mut(rows * len);
1078        trend_continuation_factor_batch_inner_into(
1079            data,
1080            &sweep,
1081            Kernel::Scalar,
1082            false,
1083            out_plus,
1084            out_minus,
1085        )
1086        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1087        Ok(rows)
1088    }
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093    use super::*;
1094    use crate::indicators::dispatch::{
1095        compute_cpu_batch, IndicatorBatchRequest, IndicatorDataRef, IndicatorParamSet, ParamKV,
1096        ParamValue,
1097    };
1098
1099    fn sample_data(len: usize) -> Vec<f64> {
1100        let mut out = Vec::with_capacity(len);
1101        for i in 0..len {
1102            out.push(100.0 + (i as f64 * 0.17).sin() * 2.5 + i as f64 * 0.03);
1103        }
1104        out
1105    }
1106
1107    fn naive_tcf(data: &[f64], length: usize) -> (Vec<f64>, Vec<f64>) {
1108        let mut plus_out = vec![f64::NAN; data.len()];
1109        let mut minus_out = vec![f64::NAN; data.len()];
1110        if data.len() <= length {
1111            return (plus_out, minus_out);
1112        }
1113
1114        let mut plus_cf: Option<f64> = None;
1115        let mut minus_cf: Option<f64> = None;
1116        let mut plus_terms = vec![0.0; length];
1117        let mut minus_terms = vec![0.0; length];
1118        let mut head = 0usize;
1119        let mut seen = 0usize;
1120        let mut sum_plus = 0.0;
1121        let mut sum_minus = 0.0;
1122
1123        for i in 1..data.len() {
1124            let change = data[i] - data[i - 1];
1125            let plus_change = if change > 0.0 { change } else { 0.0 };
1126            let minus_change = if change < 0.0 { -change } else { 0.0 };
1127            let next_plus_cf = if plus_change == 0.0 {
1128                0.0
1129            } else {
1130                plus_change + plus_cf.unwrap_or(1.0)
1131            };
1132            let next_minus_cf = if minus_change == 0.0 {
1133                0.0
1134            } else {
1135                minus_change + minus_cf.unwrap_or(1.0)
1136            };
1137            plus_cf = Some(next_plus_cf);
1138            minus_cf = Some(next_minus_cf);
1139
1140            let plus = plus_change - next_minus_cf;
1141            let minus = minus_change - next_plus_cf;
1142
1143            if seen < length {
1144                plus_terms[seen] = plus;
1145                minus_terms[seen] = minus;
1146                sum_plus += plus;
1147                sum_minus += minus;
1148                seen += 1;
1149                if seen == length {
1150                    plus_out[i] = sum_plus;
1151                    minus_out[i] = sum_minus;
1152                }
1153            } else {
1154                sum_plus += plus - plus_terms[head];
1155                sum_minus += minus - minus_terms[head];
1156                plus_terms[head] = plus;
1157                minus_terms[head] = minus;
1158                head += 1;
1159                if head == length {
1160                    head = 0;
1161                }
1162                plus_out[i] = sum_plus;
1163                minus_out[i] = sum_minus;
1164            }
1165        }
1166
1167        (plus_out, minus_out)
1168    }
1169
1170    fn assert_close(a: &[f64], b: &[f64]) {
1171        assert_eq!(a.len(), b.len());
1172        for i in 0..a.len() {
1173            if a[i].is_nan() || b[i].is_nan() {
1174                assert!(
1175                    a[i].is_nan() && b[i].is_nan(),
1176                    "nan mismatch at {i}: {} vs {}",
1177                    a[i],
1178                    b[i]
1179                );
1180            } else {
1181                assert!(
1182                    (a[i] - b[i]).abs() <= 1e-10,
1183                    "mismatch at {i}: {} vs {}",
1184                    a[i],
1185                    b[i]
1186                );
1187            }
1188        }
1189    }
1190
1191    #[test]
1192    fn trend_continuation_factor_matches_naive() {
1193        let data = sample_data(256);
1194        let input = TrendContinuationFactorInput::from_slice(
1195            &data,
1196            TrendContinuationFactorParams { length: Some(35) },
1197        );
1198        let out = trend_continuation_factor(&input).expect("indicator");
1199        let (plus_ref, minus_ref) = naive_tcf(&data, 35);
1200        assert_close(&out.plus_tcf, &plus_ref);
1201        assert_close(&out.minus_tcf, &minus_ref);
1202    }
1203
1204    #[test]
1205    fn trend_continuation_factor_into_matches_api() {
1206        let data = sample_data(192);
1207        let input = TrendContinuationFactorInput::from_slice(
1208            &data,
1209            TrendContinuationFactorParams { length: Some(20) },
1210        );
1211        let baseline = trend_continuation_factor(&input).expect("baseline");
1212        let mut plus_out = vec![0.0; data.len()];
1213        let mut minus_out = vec![0.0; data.len()];
1214        trend_continuation_factor_into(&input, &mut plus_out, &mut minus_out).expect("into");
1215        assert_close(&baseline.plus_tcf, &plus_out);
1216        assert_close(&baseline.minus_tcf, &minus_out);
1217    }
1218
1219    #[test]
1220    fn trend_continuation_factor_stream_matches_batch() {
1221        let data = sample_data(192);
1222        let batch = trend_continuation_factor(&TrendContinuationFactorInput::from_slice(
1223            &data,
1224            TrendContinuationFactorParams { length: Some(18) },
1225        ))
1226        .expect("batch");
1227        let mut stream = TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
1228            length: Some(18),
1229        })
1230        .expect("stream");
1231        let mut plus = vec![f64::NAN; data.len()];
1232        let mut minus = vec![f64::NAN; data.len()];
1233        for (i, value) in data.iter().enumerate() {
1234            if let Some((p, m)) = stream.update_reset_on_nan(*value) {
1235                plus[i] = p;
1236                minus[i] = m;
1237            }
1238        }
1239        assert_close(&batch.plus_tcf, &plus);
1240        assert_close(&batch.minus_tcf, &minus);
1241    }
1242
1243    #[test]
1244    fn trend_continuation_factor_batch_single_param_matches_single() {
1245        let data = sample_data(192);
1246        let batch = trend_continuation_factor_batch_with_kernel(
1247            &data,
1248            &TrendContinuationFactorBatchRange {
1249                length: (12, 12, 0),
1250            },
1251            Kernel::ScalarBatch,
1252        )
1253        .expect("batch");
1254        let input = TrendContinuationFactorInput::from_slice(
1255            &data,
1256            TrendContinuationFactorParams { length: Some(12) },
1257        );
1258        let direct = trend_continuation_factor_with_kernel(&input, Kernel::Scalar).expect("direct");
1259        assert_eq!(batch.rows, 1);
1260        assert_eq!(batch.cols, data.len());
1261        assert_close(&batch.plus_tcf[..data.len()], &direct.plus_tcf);
1262        assert_close(&batch.minus_tcf[..data.len()], &direct.minus_tcf);
1263    }
1264
1265    #[test]
1266    fn trend_continuation_factor_rejects_invalid_length() {
1267        let data = sample_data(32);
1268        let input = TrendContinuationFactorInput::from_slice(
1269            &data,
1270            TrendContinuationFactorParams { length: Some(0) },
1271        );
1272        let err = trend_continuation_factor(&input).unwrap_err();
1273        assert!(matches!(
1274            err,
1275            TrendContinuationFactorError::InvalidLength { .. }
1276        ));
1277    }
1278
1279    #[test]
1280    fn trend_continuation_factor_dispatch_matches_direct() {
1281        let data = sample_data(160);
1282        let combo = [ParamKV {
1283            key: "length",
1284            value: ParamValue::Int(16),
1285        }];
1286        let combos = [IndicatorParamSet { params: &combo }];
1287
1288        let req = IndicatorBatchRequest {
1289            indicator_id: "trend_continuation_factor",
1290            output_id: Some("plus_tcf"),
1291            data: IndicatorDataRef::Slice { values: &data },
1292            combos: &combos,
1293            kernel: Kernel::Auto,
1294        };
1295        let out = compute_cpu_batch(req).expect("dispatch");
1296
1297        let input = TrendContinuationFactorInput::from_slice(
1298            &data,
1299            TrendContinuationFactorParams { length: Some(16) },
1300        );
1301        let direct = trend_continuation_factor(&input).expect("direct");
1302        assert_eq!(out.rows, 1);
1303        assert_eq!(out.cols, data.len());
1304        let values = out.values_f64.expect("values");
1305        assert_close(&values, &direct.plus_tcf);
1306    }
1307}