1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::convert::AsRef;
26use std::mem::{ManuallyDrop, MaybeUninit};
27use thiserror::Error;
28
29impl<'a> AsRef<[f64]> for TrendContinuationFactorInput<'a> {
30 #[inline(always)]
31 fn as_ref(&self) -> &[f64] {
32 match &self.data {
33 TrendContinuationFactorData::Candles { candles, source } => {
34 source_type(candles, source)
35 }
36 TrendContinuationFactorData::Slice(slice) => slice,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub enum TrendContinuationFactorData<'a> {
43 Candles {
44 candles: &'a Candles,
45 source: &'a str,
46 },
47 Slice(&'a [f64]),
48}
49
50#[derive(Debug, Clone)]
51pub struct TrendContinuationFactorOutput {
52 pub plus_tcf: Vec<f64>,
53 pub minus_tcf: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58 all(target_arch = "wasm32", feature = "wasm"),
59 derive(Serialize, Deserialize)
60)]
61pub struct TrendContinuationFactorParams {
62 pub length: Option<usize>,
63}
64
65impl Default for TrendContinuationFactorParams {
66 fn default() -> Self {
67 Self { length: Some(35) }
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct TrendContinuationFactorInput<'a> {
73 pub data: TrendContinuationFactorData<'a>,
74 pub params: TrendContinuationFactorParams,
75}
76
77impl<'a> TrendContinuationFactorInput<'a> {
78 #[inline]
79 pub fn from_candles(
80 candles: &'a Candles,
81 source: &'a str,
82 params: TrendContinuationFactorParams,
83 ) -> Self {
84 Self {
85 data: TrendContinuationFactorData::Candles { candles, source },
86 params,
87 }
88 }
89
90 #[inline]
91 pub fn from_slice(slice: &'a [f64], params: TrendContinuationFactorParams) -> Self {
92 Self {
93 data: TrendContinuationFactorData::Slice(slice),
94 params,
95 }
96 }
97
98 #[inline]
99 pub fn with_default_candles(candles: &'a Candles) -> Self {
100 Self::from_candles(candles, "close", TrendContinuationFactorParams::default())
101 }
102
103 #[inline]
104 pub fn get_length(&self) -> usize {
105 self.params.length.unwrap_or(35)
106 }
107}
108
109#[derive(Copy, Clone, Debug)]
110pub struct TrendContinuationFactorBuilder {
111 length: Option<usize>,
112 kernel: Kernel,
113}
114
115impl Default for TrendContinuationFactorBuilder {
116 fn default() -> Self {
117 Self {
118 length: None,
119 kernel: Kernel::Auto,
120 }
121 }
122}
123
124impl TrendContinuationFactorBuilder {
125 #[inline(always)]
126 pub fn new() -> Self {
127 Self::default()
128 }
129
130 #[inline(always)]
131 pub fn length(mut self, value: usize) -> Self {
132 self.length = Some(value);
133 self
134 }
135
136 #[inline(always)]
137 pub fn kernel(mut self, value: Kernel) -> Self {
138 self.kernel = value;
139 self
140 }
141
142 #[inline(always)]
143 pub fn apply(
144 self,
145 candles: &Candles,
146 ) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
147 let input = TrendContinuationFactorInput::from_candles(
148 candles,
149 "close",
150 TrendContinuationFactorParams {
151 length: self.length,
152 },
153 );
154 trend_continuation_factor_with_kernel(&input, self.kernel)
155 }
156
157 #[inline(always)]
158 pub fn apply_slice(
159 self,
160 data: &[f64],
161 ) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
162 let input = TrendContinuationFactorInput::from_slice(
163 data,
164 TrendContinuationFactorParams {
165 length: self.length,
166 },
167 );
168 trend_continuation_factor_with_kernel(&input, self.kernel)
169 }
170
171 #[inline(always)]
172 pub fn into_stream(
173 self,
174 ) -> Result<TrendContinuationFactorStream, TrendContinuationFactorError> {
175 TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
176 length: self.length,
177 })
178 }
179}
180
181#[derive(Debug, Error)]
182pub enum TrendContinuationFactorError {
183 #[error("trend_continuation_factor: Input data slice is empty.")]
184 EmptyInputData,
185 #[error("trend_continuation_factor: All values are NaN.")]
186 AllValuesNaN,
187 #[error(
188 "trend_continuation_factor: Invalid length: length = {length}, data length = {data_len}"
189 )]
190 InvalidLength { length: usize, data_len: usize },
191 #[error(
192 "trend_continuation_factor: Not enough valid data: needed = {needed}, valid = {valid}"
193 )]
194 NotEnoughValidData { needed: usize, valid: usize },
195 #[error(
196 "trend_continuation_factor: Output length mismatch: expected = {expected}, got = {got}"
197 )]
198 OutputLengthMismatch { expected: usize, got: usize },
199 #[error("trend_continuation_factor: Invalid range: start={start}, end={end}, step={step}")]
200 InvalidRange {
201 start: usize,
202 end: usize,
203 step: usize,
204 },
205 #[error("trend_continuation_factor: Invalid kernel for batch: {0:?}")]
206 InvalidKernelForBatch(Kernel),
207}
208
209#[inline(always)]
210fn first_valid_index(data: &[f64]) -> Option<usize> {
211 data.iter().position(|x| x.is_finite())
212}
213
214#[inline(always)]
215fn trend_continuation_factor_prepare<'a>(
216 input: &'a TrendContinuationFactorInput,
217) -> Result<(&'a [f64], usize, usize), TrendContinuationFactorError> {
218 let data = input.as_ref();
219 let data_len = data.len();
220 if data_len == 0 {
221 return Err(TrendContinuationFactorError::EmptyInputData);
222 }
223
224 let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
225 let length = input.get_length();
226 if length == 0 || length > data_len {
227 return Err(TrendContinuationFactorError::InvalidLength { length, data_len });
228 }
229
230 let valid = data_len - first;
231 if valid <= length {
232 return Err(TrendContinuationFactorError::NotEnoughValidData {
233 needed: length + 1,
234 valid,
235 });
236 }
237
238 Ok((data, length, first))
239}
240
241#[derive(Clone, Debug)]
242pub struct TrendContinuationFactorStream {
243 length: usize,
244 prev: Option<f64>,
245 plus_cf: Option<f64>,
246 minus_cf: Option<f64>,
247 comparisons_seen: usize,
248 head: usize,
249 sum_plus: f64,
250 sum_minus: f64,
251 plus_buffer: Vec<f64>,
252 minus_buffer: Vec<f64>,
253}
254
255impl TrendContinuationFactorStream {
256 #[inline]
257 fn from_length(length: usize) -> Self {
258 Self {
259 length,
260 prev: None,
261 plus_cf: None,
262 minus_cf: None,
263 comparisons_seen: 0,
264 head: 0,
265 sum_plus: 0.0,
266 sum_minus: 0.0,
267 plus_buffer: vec![0.0; length.max(1)],
268 minus_buffer: vec![0.0; length.max(1)],
269 }
270 }
271
272 #[inline]
273 pub fn try_new(
274 params: TrendContinuationFactorParams,
275 ) -> Result<Self, TrendContinuationFactorError> {
276 let length = params.length.unwrap_or(35);
277 if length == 0 {
278 return Err(TrendContinuationFactorError::InvalidLength {
279 length,
280 data_len: 0,
281 });
282 }
283 Ok(Self::from_length(length))
284 }
285
286 #[inline(always)]
287 fn reset(&mut self) {
288 self.prev = None;
289 self.plus_cf = None;
290 self.minus_cf = None;
291 self.comparisons_seen = 0;
292 self.head = 0;
293 self.sum_plus = 0.0;
294 self.sum_minus = 0.0;
295 self.plus_buffer.fill(0.0);
296 self.minus_buffer.fill(0.0);
297 }
298
299 #[inline(always)]
300 pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
301 if !value.is_finite() {
302 return None;
303 }
304
305 let prev = match self.prev.replace(value) {
306 Some(prev) => prev,
307 None => return None,
308 };
309
310 let change = value - prev;
311 let plus_change = if change > 0.0 { change } else { 0.0 };
312 let minus_change = if change < 0.0 { -change } else { 0.0 };
313
314 let next_plus_cf = if plus_change == 0.0 {
315 0.0
316 } else {
317 plus_change + self.plus_cf.unwrap_or(1.0)
318 };
319 let next_minus_cf = if minus_change == 0.0 {
320 0.0
321 } else {
322 minus_change + self.minus_cf.unwrap_or(1.0)
323 };
324
325 self.plus_cf = Some(next_plus_cf);
326 self.minus_cf = Some(next_minus_cf);
327
328 let plus = plus_change - next_minus_cf;
329 let minus = minus_change - next_plus_cf;
330
331 if self.comparisons_seen < self.length {
332 self.plus_buffer[self.comparisons_seen] = plus;
333 self.minus_buffer[self.comparisons_seen] = minus;
334 self.sum_plus += plus;
335 self.sum_minus += minus;
336 self.comparisons_seen += 1;
337 if self.comparisons_seen < self.length {
338 return None;
339 }
340 return Some((self.sum_plus, self.sum_minus));
341 }
342
343 let old_plus = self.plus_buffer[self.head];
344 let old_minus = self.minus_buffer[self.head];
345 self.plus_buffer[self.head] = plus;
346 self.minus_buffer[self.head] = minus;
347 self.sum_plus += plus - old_plus;
348 self.sum_minus += minus - old_minus;
349 self.head += 1;
350 if self.head == self.length {
351 self.head = 0;
352 }
353
354 Some((self.sum_plus, self.sum_minus))
355 }
356
357 #[inline(always)]
358 pub fn update_reset_on_nan(&mut self, value: f64) -> Option<(f64, f64)> {
359 if !value.is_finite() {
360 self.reset();
361 return None;
362 }
363 self.update(value)
364 }
365}
366
367#[inline(always)]
368fn trend_continuation_factor_compute_into(
369 data: &[f64],
370 length: usize,
371 _first: usize,
372 _kernel: Kernel,
373 out_plus: &mut [f64],
374 out_minus: &mut [f64],
375) {
376 let mut stream = TrendContinuationFactorStream::from_length(length);
377 for i in 0..data.len() {
378 match stream.update_reset_on_nan(data[i]) {
379 Some((plus, minus)) => {
380 out_plus[i] = plus;
381 out_minus[i] = minus;
382 }
383 None => {
384 out_plus[i] = f64::NAN;
385 out_minus[i] = f64::NAN;
386 }
387 }
388 }
389}
390
391#[inline]
392pub fn trend_continuation_factor(
393 input: &TrendContinuationFactorInput,
394) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
395 trend_continuation_factor_with_kernel(input, Kernel::Auto)
396}
397
398pub fn trend_continuation_factor_with_kernel(
399 input: &TrendContinuationFactorInput,
400 kernel: Kernel,
401) -> Result<TrendContinuationFactorOutput, TrendContinuationFactorError> {
402 let (data, length, first) = trend_continuation_factor_prepare(input)?;
403 let warmup = first + length;
404 let mut plus_tcf = alloc_with_nan_prefix(data.len(), warmup);
405 let mut minus_tcf = alloc_with_nan_prefix(data.len(), warmup);
406 trend_continuation_factor_compute_into(
407 data,
408 length,
409 first,
410 kernel,
411 &mut plus_tcf,
412 &mut minus_tcf,
413 );
414 Ok(TrendContinuationFactorOutput {
415 plus_tcf,
416 minus_tcf,
417 })
418}
419
420pub fn trend_continuation_factor_into_slice(
421 dst_plus_tcf: &mut [f64],
422 dst_minus_tcf: &mut [f64],
423 input: &TrendContinuationFactorInput,
424 kernel: Kernel,
425) -> Result<(), TrendContinuationFactorError> {
426 let (data, length, first) = trend_continuation_factor_prepare(input)?;
427 if dst_plus_tcf.len() != data.len() || dst_minus_tcf.len() != data.len() {
428 return Err(TrendContinuationFactorError::OutputLengthMismatch {
429 expected: data.len(),
430 got: dst_plus_tcf.len().max(dst_minus_tcf.len()),
431 });
432 }
433 trend_continuation_factor_compute_into(
434 data,
435 length,
436 first,
437 kernel,
438 dst_plus_tcf,
439 dst_minus_tcf,
440 );
441 Ok(())
442}
443
444#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
445#[inline]
446pub fn trend_continuation_factor_into(
447 input: &TrendContinuationFactorInput,
448 out_plus_tcf: &mut [f64],
449 out_minus_tcf: &mut [f64],
450) -> Result<(), TrendContinuationFactorError> {
451 trend_continuation_factor_into_slice(out_plus_tcf, out_minus_tcf, input, Kernel::Auto)
452}
453
454#[derive(Clone, Debug)]
455pub struct TrendContinuationFactorBatchRange {
456 pub length: (usize, usize, usize),
457}
458
459impl Default for TrendContinuationFactorBatchRange {
460 fn default() -> Self {
461 Self {
462 length: (35, 200, 1),
463 }
464 }
465}
466
467#[derive(Clone, Debug, Default)]
468pub struct TrendContinuationFactorBatchBuilder {
469 range: TrendContinuationFactorBatchRange,
470 kernel: Kernel,
471}
472
473impl TrendContinuationFactorBatchBuilder {
474 pub fn new() -> Self {
475 Self::default()
476 }
477
478 pub fn kernel(mut self, kernel: Kernel) -> Self {
479 self.kernel = kernel;
480 self
481 }
482
483 #[inline]
484 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
485 self.range.length = (start, end, step);
486 self
487 }
488
489 #[inline]
490 pub fn length_static(mut self, length: usize) -> Self {
491 self.range.length = (length, length, 0);
492 self
493 }
494
495 pub fn apply_slice(
496 self,
497 data: &[f64],
498 ) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
499 trend_continuation_factor_batch_with_kernel(data, &self.range, self.kernel)
500 }
501
502 pub fn apply_candles(
503 self,
504 candles: &Candles,
505 source: &str,
506 ) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
507 self.apply_slice(source_type(candles, source))
508 }
509}
510
511#[derive(Clone, Debug)]
512pub struct TrendContinuationFactorBatchOutput {
513 pub plus_tcf: Vec<f64>,
514 pub minus_tcf: Vec<f64>,
515 pub combos: Vec<TrendContinuationFactorParams>,
516 pub rows: usize,
517 pub cols: usize,
518}
519
520impl TrendContinuationFactorBatchOutput {
521 pub fn row_for_params(&self, params: &TrendContinuationFactorParams) -> Option<usize> {
522 self.combos
523 .iter()
524 .position(|combo| combo.length.unwrap_or(35) == params.length.unwrap_or(35))
525 }
526
527 pub fn plus_tcf_for(&self, params: &TrendContinuationFactorParams) -> Option<&[f64]> {
528 self.row_for_params(params).map(|row| {
529 let start = row * self.cols;
530 &self.plus_tcf[start..start + self.cols]
531 })
532 }
533
534 pub fn minus_tcf_for(&self, params: &TrendContinuationFactorParams) -> Option<&[f64]> {
535 self.row_for_params(params).map(|row| {
536 let start = row * self.cols;
537 &self.minus_tcf[start..start + self.cols]
538 })
539 }
540}
541
542fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, TrendContinuationFactorError> {
543 let (start, end, step) = range;
544 if step == 0 || start == end {
545 return Ok(vec![start]);
546 }
547
548 let mut out = Vec::new();
549 if start < end {
550 let mut value = start;
551 while value <= end {
552 out.push(value);
553 match value.checked_add(step) {
554 Some(next) if next > value => value = next,
555 _ => break,
556 }
557 }
558 } else {
559 let mut value = start;
560 while value >= end {
561 out.push(value);
562 if value < end + step {
563 break;
564 }
565 value = value.saturating_sub(step);
566 if value == 0 {
567 break;
568 }
569 }
570 }
571
572 if out.is_empty() {
573 return Err(TrendContinuationFactorError::InvalidRange { start, end, step });
574 }
575 Ok(out)
576}
577
578pub fn expand_grid_trend_continuation_factor(
579 sweep: &TrendContinuationFactorBatchRange,
580) -> Result<Vec<TrendContinuationFactorParams>, TrendContinuationFactorError> {
581 Ok(axis_usize(sweep.length)?
582 .into_iter()
583 .map(|length| TrendContinuationFactorParams {
584 length: Some(length),
585 })
586 .collect())
587}
588
589pub fn trend_continuation_factor_batch_with_kernel(
590 data: &[f64],
591 sweep: &TrendContinuationFactorBatchRange,
592 kernel: Kernel,
593) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
594 let batch_kernel = match kernel {
595 Kernel::Auto => Kernel::ScalarBatch,
596 other if other.is_batch() => other,
597 other => return Err(TrendContinuationFactorError::InvalidKernelForBatch(other)),
598 };
599 trend_continuation_factor_batch_impl(data, sweep, batch_kernel.to_non_batch(), true)
600}
601
602pub fn trend_continuation_factor_batch_slice(
603 data: &[f64],
604 sweep: &TrendContinuationFactorBatchRange,
605) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
606 trend_continuation_factor_batch_impl(data, sweep, Kernel::Scalar, false)
607}
608
609pub fn trend_continuation_factor_batch_par_slice(
610 data: &[f64],
611 sweep: &TrendContinuationFactorBatchRange,
612) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
613 trend_continuation_factor_batch_impl(data, sweep, Kernel::Scalar, true)
614}
615
616fn trend_continuation_factor_batch_impl(
617 data: &[f64],
618 sweep: &TrendContinuationFactorBatchRange,
619 kernel: Kernel,
620 parallel: bool,
621) -> Result<TrendContinuationFactorBatchOutput, TrendContinuationFactorError> {
622 let combos = expand_grid_trend_continuation_factor(sweep)?;
623 let rows = combos.len();
624 let cols = data.len();
625
626 if cols == 0 {
627 return Err(TrendContinuationFactorError::EmptyInputData);
628 }
629
630 let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
631 let max_length = combos
632 .iter()
633 .map(|params| params.length.unwrap_or(35))
634 .max()
635 .unwrap_or(35);
636 let valid = cols - first;
637 if valid <= max_length {
638 return Err(TrendContinuationFactorError::NotEnoughValidData {
639 needed: max_length + 1,
640 valid,
641 });
642 }
643
644 let warmups: Vec<usize> = combos
645 .iter()
646 .map(|params| first + params.length.unwrap_or(35))
647 .collect();
648
649 let mut plus_matrix = make_uninit_matrix(rows, cols);
650 init_matrix_prefixes(&mut plus_matrix, cols, &warmups);
651 let mut minus_matrix = make_uninit_matrix(rows, cols);
652 init_matrix_prefixes(&mut minus_matrix, cols, &warmups);
653
654 let mut plus_guard = ManuallyDrop::new(plus_matrix);
655 let mut minus_guard = ManuallyDrop::new(minus_matrix);
656
657 let plus_mu: &mut [MaybeUninit<f64>] =
658 unsafe { std::slice::from_raw_parts_mut(plus_guard.as_mut_ptr(), plus_guard.len()) };
659 let minus_mu: &mut [MaybeUninit<f64>] =
660 unsafe { std::slice::from_raw_parts_mut(minus_guard.as_mut_ptr(), minus_guard.len()) };
661
662 let do_row = |row: usize,
663 row_plus_mu: &mut [MaybeUninit<f64>],
664 row_minus_mu: &mut [MaybeUninit<f64>]| {
665 let length = combos[row].length.unwrap_or(35);
666 let dst_plus = unsafe {
667 std::slice::from_raw_parts_mut(row_plus_mu.as_mut_ptr() as *mut f64, row_plus_mu.len())
668 };
669 let dst_minus = unsafe {
670 std::slice::from_raw_parts_mut(
671 row_minus_mu.as_mut_ptr() as *mut f64,
672 row_minus_mu.len(),
673 )
674 };
675 trend_continuation_factor_compute_into(data, length, first, kernel, dst_plus, dst_minus);
676 };
677
678 if parallel {
679 #[cfg(not(target_arch = "wasm32"))]
680 plus_mu
681 .par_chunks_mut(cols)
682 .zip(minus_mu.par_chunks_mut(cols))
683 .enumerate()
684 .for_each(|(row, (row_plus_mu, row_minus_mu))| do_row(row, row_plus_mu, row_minus_mu));
685 #[cfg(target_arch = "wasm32")]
686 for (row, (row_plus_mu, row_minus_mu)) in plus_mu
687 .chunks_mut(cols)
688 .zip(minus_mu.chunks_mut(cols))
689 .enumerate()
690 {
691 do_row(row, row_plus_mu, row_minus_mu);
692 }
693 } else {
694 for (row, (row_plus_mu, row_minus_mu)) in plus_mu
695 .chunks_mut(cols)
696 .zip(minus_mu.chunks_mut(cols))
697 .enumerate()
698 {
699 do_row(row, row_plus_mu, row_minus_mu);
700 }
701 }
702
703 let plus_tcf = unsafe {
704 Vec::from_raw_parts(
705 plus_guard.as_mut_ptr() as *mut f64,
706 plus_guard.len(),
707 plus_guard.capacity(),
708 )
709 };
710 let minus_tcf = unsafe {
711 Vec::from_raw_parts(
712 minus_guard.as_mut_ptr() as *mut f64,
713 minus_guard.len(),
714 minus_guard.capacity(),
715 )
716 };
717
718 Ok(TrendContinuationFactorBatchOutput {
719 plus_tcf,
720 minus_tcf,
721 combos,
722 rows,
723 cols,
724 })
725}
726
727fn trend_continuation_factor_batch_inner_into(
728 data: &[f64],
729 sweep: &TrendContinuationFactorBatchRange,
730 kernel: Kernel,
731 parallel: bool,
732 out_plus: &mut [f64],
733 out_minus: &mut [f64],
734) -> Result<(), TrendContinuationFactorError> {
735 let combos = expand_grid_trend_continuation_factor(sweep)?;
736 let rows = combos.len();
737 let cols = data.len();
738 if rows.checked_mul(cols) != Some(out_plus.len()) || out_minus.len() != out_plus.len() {
739 return Err(TrendContinuationFactorError::OutputLengthMismatch {
740 expected: rows * cols,
741 got: out_plus.len().max(out_minus.len()),
742 });
743 }
744
745 let first = first_valid_index(data).ok_or(TrendContinuationFactorError::AllValuesNaN)?;
746 for (row, params) in combos.iter().enumerate() {
747 let length = params.length.unwrap_or(35);
748 let row_plus = &mut out_plus[row * cols..(row + 1) * cols];
749 let row_minus = &mut out_minus[row * cols..(row + 1) * cols];
750 row_plus.fill(f64::NAN);
751 row_minus.fill(f64::NAN);
752 if cols - first <= length {
753 return Err(TrendContinuationFactorError::NotEnoughValidData {
754 needed: length + 1,
755 valid: cols - first,
756 });
757 }
758 }
759
760 let do_row = |row: usize, row_plus: &mut [f64], row_minus: &mut [f64]| {
761 let length = combos[row].length.unwrap_or(35);
762 trend_continuation_factor_compute_into(data, length, first, kernel, row_plus, row_minus);
763 };
764
765 if parallel {
766 #[cfg(not(target_arch = "wasm32"))]
767 out_plus
768 .par_chunks_mut(cols)
769 .zip(out_minus.par_chunks_mut(cols))
770 .enumerate()
771 .for_each(|(row, (row_plus, row_minus))| do_row(row, row_plus, row_minus));
772 #[cfg(target_arch = "wasm32")]
773 for (row, (row_plus, row_minus)) in out_plus
774 .chunks_mut(cols)
775 .zip(out_minus.chunks_mut(cols))
776 .enumerate()
777 {
778 do_row(row, row_plus, row_minus);
779 }
780 } else {
781 for (row, (row_plus, row_minus)) in out_plus
782 .chunks_mut(cols)
783 .zip(out_minus.chunks_mut(cols))
784 .enumerate()
785 {
786 do_row(row, row_plus, row_minus);
787 }
788 }
789
790 Ok(())
791}
792
793#[cfg(feature = "python")]
794#[pyfunction(name = "trend_continuation_factor")]
795#[pyo3(signature = (data, length=35, kernel=None))]
796pub fn trend_continuation_factor_py<'py>(
797 py: Python<'py>,
798 data: PyReadonlyArray1<'py, f64>,
799 length: usize,
800 kernel: Option<&str>,
801) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
802 let data = data.as_slice()?;
803 let kernel = validate_kernel(kernel, false)?;
804 let input = TrendContinuationFactorInput::from_slice(
805 data,
806 TrendContinuationFactorParams {
807 length: Some(length),
808 },
809 );
810 let output = py
811 .allow_threads(|| trend_continuation_factor_with_kernel(&input, kernel))
812 .map_err(|e| PyValueError::new_err(e.to_string()))?;
813 Ok((
814 output.plus_tcf.into_pyarray(py),
815 output.minus_tcf.into_pyarray(py),
816 ))
817}
818
819#[cfg(feature = "python")]
820#[pyclass(name = "TrendContinuationFactorStream")]
821pub struct TrendContinuationFactorStreamPy {
822 stream: TrendContinuationFactorStream,
823}
824
825#[cfg(feature = "python")]
826#[pymethods]
827impl TrendContinuationFactorStreamPy {
828 #[new]
829 #[pyo3(signature = (length=35))]
830 fn new(length: usize) -> PyResult<Self> {
831 let stream = TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
832 length: Some(length),
833 })
834 .map_err(|e| PyValueError::new_err(e.to_string()))?;
835 Ok(Self { stream })
836 }
837
838 fn update(&mut self, value: f64) -> Option<(f64, f64)> {
839 self.stream.update_reset_on_nan(value)
840 }
841}
842
843#[cfg(feature = "python")]
844#[pyfunction(name = "trend_continuation_factor_batch")]
845#[pyo3(signature = (data, length_range, kernel=None))]
846pub fn trend_continuation_factor_batch_py<'py>(
847 py: Python<'py>,
848 data: PyReadonlyArray1<'py, f64>,
849 length_range: (usize, usize, usize),
850 kernel: Option<&str>,
851) -> PyResult<Bound<'py, PyDict>> {
852 let data = data.as_slice()?;
853 let sweep = TrendContinuationFactorBatchRange {
854 length: length_range,
855 };
856 let combos = expand_grid_trend_continuation_factor(&sweep)
857 .map_err(|e| PyValueError::new_err(e.to_string()))?;
858 let rows = combos.len();
859 let cols = data.len();
860 let total = rows
861 .checked_mul(cols)
862 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
863 let plus_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
864 let minus_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
865 let out_plus = unsafe { plus_arr.as_slice_mut()? };
866 let out_minus = unsafe { minus_arr.as_slice_mut()? };
867 let kernel = validate_kernel(kernel, true)?;
868
869 py.allow_threads(|| {
870 let batch_kernel = match kernel {
871 Kernel::Auto => detect_best_batch_kernel(),
872 other => other,
873 };
874 trend_continuation_factor_batch_inner_into(
875 data,
876 &sweep,
877 batch_kernel.to_non_batch(),
878 true,
879 out_plus,
880 out_minus,
881 )
882 })
883 .map_err(|e| PyValueError::new_err(e.to_string()))?;
884
885 let dict = PyDict::new(py);
886 dict.set_item("plus_tcf", plus_arr.reshape((rows, cols))?)?;
887 dict.set_item("minus_tcf", minus_arr.reshape((rows, cols))?)?;
888 dict.set_item(
889 "lengths",
890 combos
891 .iter()
892 .map(|params| params.length.unwrap_or(35) as u64)
893 .collect::<Vec<_>>()
894 .into_pyarray(py),
895 )?;
896 dict.set_item("rows", rows)?;
897 dict.set_item("cols", cols)?;
898 Ok(dict)
899}
900
901#[cfg(feature = "python")]
902pub fn register_trend_continuation_factor_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
903 m.add_function(wrap_pyfunction!(trend_continuation_factor_py, m)?)?;
904 m.add_function(wrap_pyfunction!(trend_continuation_factor_batch_py, m)?)?;
905 m.add_class::<TrendContinuationFactorStreamPy>()?;
906 Ok(())
907}
908
909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
910#[derive(Debug, Clone, Serialize, Deserialize)]
911struct TrendContinuationFactorBatchConfig {
912 length_range: Vec<usize>,
913}
914
915#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
916#[derive(Debug, Clone, Serialize, Deserialize)]
917struct TrendContinuationFactorBatchJsOutput {
918 plus_tcf: Vec<f64>,
919 minus_tcf: Vec<f64>,
920 rows: usize,
921 cols: usize,
922 combos: Vec<TrendContinuationFactorParams>,
923}
924
925#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
926#[derive(Debug, Clone, Serialize, Deserialize)]
927struct TrendContinuationFactorJsOutput {
928 plus_tcf: Vec<f64>,
929 minus_tcf: Vec<f64>,
930}
931
932#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
933#[wasm_bindgen(js_name = "trend_continuation_factor_js")]
934pub fn trend_continuation_factor_js(data: &[f64], length: usize) -> Result<JsValue, JsValue> {
935 let input = TrendContinuationFactorInput::from_slice(
936 data,
937 TrendContinuationFactorParams {
938 length: Some(length),
939 },
940 );
941 let output =
942 trend_continuation_factor(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
943 serde_wasm_bindgen::to_value(&TrendContinuationFactorJsOutput {
944 plus_tcf: output.plus_tcf,
945 minus_tcf: output.minus_tcf,
946 })
947 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
948}
949
950#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
951#[wasm_bindgen(js_name = "trend_continuation_factor_batch_js")]
952pub fn trend_continuation_factor_batch_js(
953 data: &[f64],
954 config: JsValue,
955) -> Result<JsValue, JsValue> {
956 let config: TrendContinuationFactorBatchConfig = serde_wasm_bindgen::from_value(config)
957 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
958 if config.length_range.len() != 3 {
959 return Err(JsValue::from_str(
960 "Invalid config: length_range must have exactly 3 elements [start, end, step]",
961 ));
962 }
963 let sweep = TrendContinuationFactorBatchRange {
964 length: (
965 config.length_range[0],
966 config.length_range[1],
967 config.length_range[2],
968 ),
969 };
970 let batch = trend_continuation_factor_batch_slice(data, &sweep)
971 .map_err(|e| JsValue::from_str(&e.to_string()))?;
972 serde_wasm_bindgen::to_value(&TrendContinuationFactorBatchJsOutput {
973 plus_tcf: batch.plus_tcf,
974 minus_tcf: batch.minus_tcf,
975 rows: batch.rows,
976 cols: batch.cols,
977 combos: batch.combos,
978 })
979 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
980}
981
982#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
983#[wasm_bindgen]
984pub fn trend_continuation_factor_alloc(len: usize) -> *mut f64 {
985 let mut vec = Vec::<f64>::with_capacity(len * 2);
986 let ptr = vec.as_mut_ptr();
987 std::mem::forget(vec);
988 ptr
989}
990
991#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
992#[wasm_bindgen]
993pub fn trend_continuation_factor_free(ptr: *mut f64, len: usize) {
994 unsafe {
995 let _ = Vec::from_raw_parts(ptr, len * 2, len * 2);
996 }
997}
998
999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1000#[wasm_bindgen]
1001pub fn trend_continuation_factor_into(
1002 in_ptr: *const f64,
1003 out_ptr: *mut f64,
1004 len: usize,
1005 length: usize,
1006) -> Result<(), JsValue> {
1007 if in_ptr.is_null() || out_ptr.is_null() {
1008 return Err(JsValue::from_str(
1009 "null pointer passed to trend_continuation_factor_into",
1010 ));
1011 }
1012 unsafe {
1013 let data = std::slice::from_raw_parts(in_ptr, len);
1014 let out = std::slice::from_raw_parts_mut(out_ptr, len * 2);
1015 let (out_plus, out_minus) = out.split_at_mut(len);
1016 let input = TrendContinuationFactorInput::from_slice(
1017 data,
1018 TrendContinuationFactorParams {
1019 length: Some(length),
1020 },
1021 );
1022 trend_continuation_factor_into_slice(out_plus, out_minus, &input, Kernel::Auto)
1023 .map_err(|e| JsValue::from_str(&e.to_string()))
1024 }
1025}
1026
1027#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1028#[wasm_bindgen(js_name = "trend_continuation_factor_into_host")]
1029pub fn trend_continuation_factor_into_host(
1030 data: &[f64],
1031 out_ptr: *mut f64,
1032 length: usize,
1033) -> Result<(), JsValue> {
1034 if out_ptr.is_null() {
1035 return Err(JsValue::from_str(
1036 "null pointer passed to trend_continuation_factor_into_host",
1037 ));
1038 }
1039 unsafe {
1040 let out = std::slice::from_raw_parts_mut(out_ptr, data.len() * 2);
1041 let (out_plus, out_minus) = out.split_at_mut(data.len());
1042 let input = TrendContinuationFactorInput::from_slice(
1043 data,
1044 TrendContinuationFactorParams {
1045 length: Some(length),
1046 },
1047 );
1048 trend_continuation_factor_into_slice(out_plus, out_minus, &input, Kernel::Auto)
1049 .map_err(|e| JsValue::from_str(&e.to_string()))
1050 }
1051}
1052
1053#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1054#[wasm_bindgen]
1055pub fn trend_continuation_factor_batch_into(
1056 in_ptr: *const f64,
1057 out_ptr: *mut f64,
1058 len: usize,
1059 length_start: usize,
1060 length_end: usize,
1061 length_step: usize,
1062) -> Result<usize, JsValue> {
1063 if in_ptr.is_null() || out_ptr.is_null() {
1064 return Err(JsValue::from_str(
1065 "null pointer passed to trend_continuation_factor_batch_into",
1066 ));
1067 }
1068 unsafe {
1069 let data = std::slice::from_raw_parts(in_ptr, len);
1070 let sweep = TrendContinuationFactorBatchRange {
1071 length: (length_start, length_end, length_step),
1072 };
1073 let combos = expand_grid_trend_continuation_factor(&sweep)
1074 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1075 let rows = combos.len();
1076 let out = std::slice::from_raw_parts_mut(out_ptr, rows * len * 2);
1077 let (out_plus, out_minus) = out.split_at_mut(rows * len);
1078 trend_continuation_factor_batch_inner_into(
1079 data,
1080 &sweep,
1081 Kernel::Scalar,
1082 false,
1083 out_plus,
1084 out_minus,
1085 )
1086 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1087 Ok(rows)
1088 }
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093 use super::*;
1094 use crate::indicators::dispatch::{
1095 compute_cpu_batch, IndicatorBatchRequest, IndicatorDataRef, IndicatorParamSet, ParamKV,
1096 ParamValue,
1097 };
1098
1099 fn sample_data(len: usize) -> Vec<f64> {
1100 let mut out = Vec::with_capacity(len);
1101 for i in 0..len {
1102 out.push(100.0 + (i as f64 * 0.17).sin() * 2.5 + i as f64 * 0.03);
1103 }
1104 out
1105 }
1106
1107 fn naive_tcf(data: &[f64], length: usize) -> (Vec<f64>, Vec<f64>) {
1108 let mut plus_out = vec![f64::NAN; data.len()];
1109 let mut minus_out = vec![f64::NAN; data.len()];
1110 if data.len() <= length {
1111 return (plus_out, minus_out);
1112 }
1113
1114 let mut plus_cf: Option<f64> = None;
1115 let mut minus_cf: Option<f64> = None;
1116 let mut plus_terms = vec![0.0; length];
1117 let mut minus_terms = vec![0.0; length];
1118 let mut head = 0usize;
1119 let mut seen = 0usize;
1120 let mut sum_plus = 0.0;
1121 let mut sum_minus = 0.0;
1122
1123 for i in 1..data.len() {
1124 let change = data[i] - data[i - 1];
1125 let plus_change = if change > 0.0 { change } else { 0.0 };
1126 let minus_change = if change < 0.0 { -change } else { 0.0 };
1127 let next_plus_cf = if plus_change == 0.0 {
1128 0.0
1129 } else {
1130 plus_change + plus_cf.unwrap_or(1.0)
1131 };
1132 let next_minus_cf = if minus_change == 0.0 {
1133 0.0
1134 } else {
1135 minus_change + minus_cf.unwrap_or(1.0)
1136 };
1137 plus_cf = Some(next_plus_cf);
1138 minus_cf = Some(next_minus_cf);
1139
1140 let plus = plus_change - next_minus_cf;
1141 let minus = minus_change - next_plus_cf;
1142
1143 if seen < length {
1144 plus_terms[seen] = plus;
1145 minus_terms[seen] = minus;
1146 sum_plus += plus;
1147 sum_minus += minus;
1148 seen += 1;
1149 if seen == length {
1150 plus_out[i] = sum_plus;
1151 minus_out[i] = sum_minus;
1152 }
1153 } else {
1154 sum_plus += plus - plus_terms[head];
1155 sum_minus += minus - minus_terms[head];
1156 plus_terms[head] = plus;
1157 minus_terms[head] = minus;
1158 head += 1;
1159 if head == length {
1160 head = 0;
1161 }
1162 plus_out[i] = sum_plus;
1163 minus_out[i] = sum_minus;
1164 }
1165 }
1166
1167 (plus_out, minus_out)
1168 }
1169
1170 fn assert_close(a: &[f64], b: &[f64]) {
1171 assert_eq!(a.len(), b.len());
1172 for i in 0..a.len() {
1173 if a[i].is_nan() || b[i].is_nan() {
1174 assert!(
1175 a[i].is_nan() && b[i].is_nan(),
1176 "nan mismatch at {i}: {} vs {}",
1177 a[i],
1178 b[i]
1179 );
1180 } else {
1181 assert!(
1182 (a[i] - b[i]).abs() <= 1e-10,
1183 "mismatch at {i}: {} vs {}",
1184 a[i],
1185 b[i]
1186 );
1187 }
1188 }
1189 }
1190
1191 #[test]
1192 fn trend_continuation_factor_matches_naive() {
1193 let data = sample_data(256);
1194 let input = TrendContinuationFactorInput::from_slice(
1195 &data,
1196 TrendContinuationFactorParams { length: Some(35) },
1197 );
1198 let out = trend_continuation_factor(&input).expect("indicator");
1199 let (plus_ref, minus_ref) = naive_tcf(&data, 35);
1200 assert_close(&out.plus_tcf, &plus_ref);
1201 assert_close(&out.minus_tcf, &minus_ref);
1202 }
1203
1204 #[test]
1205 fn trend_continuation_factor_into_matches_api() {
1206 let data = sample_data(192);
1207 let input = TrendContinuationFactorInput::from_slice(
1208 &data,
1209 TrendContinuationFactorParams { length: Some(20) },
1210 );
1211 let baseline = trend_continuation_factor(&input).expect("baseline");
1212 let mut plus_out = vec![0.0; data.len()];
1213 let mut minus_out = vec![0.0; data.len()];
1214 trend_continuation_factor_into(&input, &mut plus_out, &mut minus_out).expect("into");
1215 assert_close(&baseline.plus_tcf, &plus_out);
1216 assert_close(&baseline.minus_tcf, &minus_out);
1217 }
1218
1219 #[test]
1220 fn trend_continuation_factor_stream_matches_batch() {
1221 let data = sample_data(192);
1222 let batch = trend_continuation_factor(&TrendContinuationFactorInput::from_slice(
1223 &data,
1224 TrendContinuationFactorParams { length: Some(18) },
1225 ))
1226 .expect("batch");
1227 let mut stream = TrendContinuationFactorStream::try_new(TrendContinuationFactorParams {
1228 length: Some(18),
1229 })
1230 .expect("stream");
1231 let mut plus = vec![f64::NAN; data.len()];
1232 let mut minus = vec![f64::NAN; data.len()];
1233 for (i, value) in data.iter().enumerate() {
1234 if let Some((p, m)) = stream.update_reset_on_nan(*value) {
1235 plus[i] = p;
1236 minus[i] = m;
1237 }
1238 }
1239 assert_close(&batch.plus_tcf, &plus);
1240 assert_close(&batch.minus_tcf, &minus);
1241 }
1242
1243 #[test]
1244 fn trend_continuation_factor_batch_single_param_matches_single() {
1245 let data = sample_data(192);
1246 let batch = trend_continuation_factor_batch_with_kernel(
1247 &data,
1248 &TrendContinuationFactorBatchRange {
1249 length: (12, 12, 0),
1250 },
1251 Kernel::ScalarBatch,
1252 )
1253 .expect("batch");
1254 let input = TrendContinuationFactorInput::from_slice(
1255 &data,
1256 TrendContinuationFactorParams { length: Some(12) },
1257 );
1258 let direct = trend_continuation_factor_with_kernel(&input, Kernel::Scalar).expect("direct");
1259 assert_eq!(batch.rows, 1);
1260 assert_eq!(batch.cols, data.len());
1261 assert_close(&batch.plus_tcf[..data.len()], &direct.plus_tcf);
1262 assert_close(&batch.minus_tcf[..data.len()], &direct.minus_tcf);
1263 }
1264
1265 #[test]
1266 fn trend_continuation_factor_rejects_invalid_length() {
1267 let data = sample_data(32);
1268 let input = TrendContinuationFactorInput::from_slice(
1269 &data,
1270 TrendContinuationFactorParams { length: Some(0) },
1271 );
1272 let err = trend_continuation_factor(&input).unwrap_err();
1273 assert!(matches!(
1274 err,
1275 TrendContinuationFactorError::InvalidLength { .. }
1276 ));
1277 }
1278
1279 #[test]
1280 fn trend_continuation_factor_dispatch_matches_direct() {
1281 let data = sample_data(160);
1282 let combo = [ParamKV {
1283 key: "length",
1284 value: ParamValue::Int(16),
1285 }];
1286 let combos = [IndicatorParamSet { params: &combo }];
1287
1288 let req = IndicatorBatchRequest {
1289 indicator_id: "trend_continuation_factor",
1290 output_id: Some("plus_tcf"),
1291 data: IndicatorDataRef::Slice { values: &data },
1292 combos: &combos,
1293 kernel: Kernel::Auto,
1294 };
1295 let out = compute_cpu_batch(req).expect("dispatch");
1296
1297 let input = TrendContinuationFactorInput::from_slice(
1298 &data,
1299 TrendContinuationFactorParams { length: Some(16) },
1300 );
1301 let direct = trend_continuation_factor(&input).expect("direct");
1302 assert_eq!(out.rows, 1);
1303 assert_eq!(out.cols, data.len());
1304 let values = out.values_f64.expect("values");
1305 assert_close(&values, &direct.plus_tcf);
1306 }
1307}