1use crate::utilities::data_loader::{source_type, Candles};
2#[cfg(all(feature = "python", feature = "cuda"))]
3use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(all(feature = "python", feature = "cuda"))]
12use cust::memory::DeviceBuffer;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::error::Error;
16use thiserror::Error;
17
18#[inline(always)]
19fn ttm_numeric_compute_into(
20 source: &[f64],
21 close: &[f64],
22 period: usize,
23 first: usize,
24 out: &mut [f64],
25) {
26 let mut sum = 0.0;
27 for &v in &source[first..first + period] {
28 sum += v;
29 }
30 let inv_p = 1.0 / (period as f64);
31 let mut i = first + period - 1;
32 out[i] = if close[i] > sum * inv_p { 1.0 } else { 0.0 };
33 i += 1;
34 let n = source.len().min(close.len());
35 while i < n {
36 sum += source[i] - source[i - period];
37 out[i] = if close[i] > sum * inv_p { 1.0 } else { 0.0 };
38 i += 1;
39 }
40}
41
42#[inline(always)]
43fn ttm_prepare<'a>(
44 input: &'a TtmTrendInput<'a>,
45 kernel: Kernel,
46) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), TtmTrendError> {
47 let (source, close) = input.as_slices();
48 let len = source.len().min(close.len());
49 if len == 0 {
50 return Err(TtmTrendError::EmptyInputData);
51 }
52 let first = source
53 .iter()
54 .zip(close)
55 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
56 .ok_or(TtmTrendError::AllValuesNaN)?;
57 let period = input.get_period();
58 if period == 0 || period > len {
59 return Err(TtmTrendError::InvalidPeriod {
60 period,
61 data_len: len,
62 });
63 }
64 if len - first < period {
65 return Err(TtmTrendError::NotEnoughValidData {
66 needed: period,
67 valid: len - first,
68 });
69 }
70 let chosen = match kernel {
71 Kernel::Auto => detect_best_kernel(),
72 k => k,
73 };
74 Ok((source, close, period, first, chosen))
75}
76
77#[inline(always)]
78fn ttm_numeric_compute_with_params(
79 dst: &mut [f64],
80 source: &[f64],
81 close: &[f64],
82 period: usize,
83 first: usize,
84 _kernel: Kernel,
85) -> Result<(), TtmTrendError> {
86 let len = source.len().min(close.len());
87 if dst.len() != len {
88 return Err(TtmTrendError::OutputLengthMismatch {
89 expected: len,
90 got: dst.len(),
91 });
92 }
93
94 let warmup = first + period - 1;
95 for v in &mut dst[..warmup] {
96 *v = f64::NAN;
97 }
98 ttm_numeric_compute_into(source, close, period, first, dst);
99 Ok(())
100}
101
102#[inline(always)]
103fn ttm_numeric_into_slice(
104 dst: &mut [f64],
105 input: &TtmTrendInput,
106 kern: Kernel,
107) -> Result<(), TtmTrendError> {
108 let (source, close, period, first, chosen) = ttm_prepare(input, kern)?;
109 ttm_numeric_compute_with_params(dst, source, close, period, first, chosen)
110}
111
112#[derive(Debug, Clone)]
113pub enum TtmTrendData<'a> {
114 Candles {
115 candles: &'a Candles,
116 source: &'a str,
117 },
118 Slices {
119 source: &'a [f64],
120 close: &'a [f64],
121 },
122}
123
124#[derive(Debug, Clone)]
125pub struct TtmTrendOutput {
126 pub values: Vec<bool>,
127}
128
129#[derive(Debug, Clone)]
130pub struct TtmTrendParams {
131 pub period: Option<usize>,
132}
133
134impl Default for TtmTrendParams {
135 fn default() -> Self {
136 Self { period: Some(5) }
137 }
138}
139
140#[derive(Debug, Clone)]
141pub struct TtmTrendInput<'a> {
142 pub data: TtmTrendData<'a>,
143 pub params: TtmTrendParams,
144}
145
146impl<'a> TtmTrendInput<'a> {
147 #[inline]
148 pub fn from_candles(c: &'a Candles, s: &'a str, p: TtmTrendParams) -> Self {
149 Self {
150 data: TtmTrendData::Candles {
151 candles: c,
152 source: s,
153 },
154 params: p,
155 }
156 }
157 #[inline]
158 pub fn from_slices(source: &'a [f64], close: &'a [f64], p: TtmTrendParams) -> Self {
159 Self {
160 data: TtmTrendData::Slices { source, close },
161 params: p,
162 }
163 }
164 #[inline]
165 pub fn with_default_candles(c: &'a Candles) -> Self {
166 Self::from_candles(c, "hl2", TtmTrendParams::default())
167 }
168 #[inline]
169 pub fn get_period(&self) -> usize {
170 self.params.period.unwrap_or(5)
171 }
172 #[inline(always)]
173 pub fn as_slices(&self) -> (&[f64], &[f64]) {
174 match &self.data {
175 TtmTrendData::Slices { source, close } => (source, close),
176 TtmTrendData::Candles { candles, source } => {
177 (source_type(candles, source), source_type(candles, "close"))
178 }
179 }
180 }
181 #[inline(always)]
182 pub fn as_ref(&self) -> (&[f64], &[f64]) {
183 match &self.data {
184 TtmTrendData::Slices { source, close } => (*source, *close),
185 TtmTrendData::Candles { candles, source } => {
186 (source_type(candles, source), source_type(candles, "close"))
187 }
188 }
189 }
190}
191
192#[derive(Copy, Clone, Debug)]
193pub struct TtmTrendBuilder {
194 period: Option<usize>,
195 kernel: Kernel,
196}
197
198impl Default for TtmTrendBuilder {
199 fn default() -> Self {
200 Self {
201 period: None,
202 kernel: Kernel::Auto,
203 }
204 }
205}
206
207impl TtmTrendBuilder {
208 #[inline(always)]
209 pub fn new() -> Self {
210 Self::default()
211 }
212 #[inline(always)]
213 pub fn period(mut self, n: usize) -> Self {
214 self.period = Some(n);
215 self
216 }
217 #[inline(always)]
218 pub fn kernel(mut self, k: Kernel) -> Self {
219 self.kernel = k;
220 self
221 }
222 #[inline(always)]
223 pub fn apply(self, c: &Candles) -> Result<TtmTrendOutput, TtmTrendError> {
224 let p = TtmTrendParams {
225 period: self.period,
226 };
227 let i = TtmTrendInput::from_candles(c, "hl2", p);
228 ttm_trend_with_kernel(&i, self.kernel)
229 }
230 #[inline(always)]
231 pub fn apply_slices(self, src: &[f64], close: &[f64]) -> Result<TtmTrendOutput, TtmTrendError> {
232 let p = TtmTrendParams {
233 period: self.period,
234 };
235 let i = TtmTrendInput::from_slices(src, close, p);
236 ttm_trend_with_kernel(&i, self.kernel)
237 }
238 #[inline(always)]
239 pub fn into_stream(self) -> Result<TtmTrendStream, TtmTrendError> {
240 let p = TtmTrendParams {
241 period: self.period,
242 };
243 TtmTrendStream::try_new(p)
244 }
245}
246
247#[derive(Debug, Error)]
248pub enum TtmTrendError {
249 #[error("ttm_trend: Input data slice is empty.")]
250 EmptyInputData,
251 #[error("ttm_trend: All values are NaN.")]
252 AllValuesNaN,
253 #[error("ttm_trend: Invalid period: period = {period}, data length = {data_len}")]
254 InvalidPeriod { period: usize, data_len: usize },
255 #[error("ttm_trend: Not enough valid data: needed = {needed}, valid = {valid}")]
256 NotEnoughValidData { needed: usize, valid: usize },
257 #[error("ttm_trend: Output length mismatch: expected = {expected}, got = {got}")]
258 OutputLengthMismatch { expected: usize, got: usize },
259 #[error("ttm_trend: Invalid range: start={start}, end={end}, step={step}")]
260 InvalidRange {
261 start: usize,
262 end: usize,
263 step: usize,
264 },
265 #[error("ttm_trend: Invalid kernel type for batch operation: {0:?}")]
266 InvalidKernelForBatch(Kernel),
267}
268
269#[inline]
270pub fn ttm_trend(input: &TtmTrendInput) -> Result<TtmTrendOutput, TtmTrendError> {
271 ttm_trend_with_kernel(input, Kernel::Auto)
272}
273
274pub fn ttm_trend_with_kernel(
275 input: &TtmTrendInput,
276 kernel: Kernel,
277) -> Result<TtmTrendOutput, TtmTrendError> {
278 let (source, close, period, first, chosen) = ttm_prepare(input, kernel)?;
279 let len = source.len().min(close.len());
280
281 let mut values = vec![false; len];
282
283 #[allow(unused_variables)]
284 {
285 match chosen {
286 Kernel::Scalar | Kernel::ScalarBatch => {
287 ttm_trend_scalar(source, close, period, first, &mut values)
288 }
289 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
290 Kernel::Avx2 | Kernel::Avx2Batch => {
291 ttm_trend_avx2(source, close, period, first, &mut values)
292 }
293 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294 Kernel::Avx512 | Kernel::Avx512Batch => {
295 ttm_trend_avx512(source, close, period, first, &mut values)
296 }
297 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
298 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
299 ttm_trend_scalar(source, close, period, first, &mut values)
300 }
301 _ => unreachable!(),
302 }
303 }
304
305 Ok(TtmTrendOutput { values })
306}
307
308#[inline]
309pub fn ttm_trend_scalar(
310 source: &[f64],
311 close: &[f64],
312 period: usize,
313 first: usize,
314 out: &mut [bool],
315) {
316 let n = source.len().min(close.len()).min(out.len());
317 if n == 0 {
318 return;
319 }
320
321 let warmup_end = first + period - 1;
322 if warmup_end >= n {
323 for i in 0..n {
324 out[i] = false;
325 }
326 return;
327 }
328
329 for i in 0..first {
330 out[i] = false;
331 }
332
333 if period == 1 {
334 let mut i = first;
335 while i < n {
336 out[i] = close[i] > source[i];
337 i += 1;
338 }
339 return;
340 }
341
342 let mut sum = 0.0;
343 let mut k = first;
344 while k < warmup_end {
345 sum += source[k];
346 out[k] = false;
347 k += 1;
348 }
349 sum += source[warmup_end];
350
351 let inv_period = 1.0 / (period as f64);
352 let mut idx = warmup_end;
353 out[idx] = close[idx] > sum * inv_period;
354 idx += 1;
355 while idx < n {
356 sum += source[idx] - source[idx - period];
357 out[idx] = close[idx] > sum * inv_period;
358 idx += 1;
359 }
360}
361
362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
363#[inline]
364pub fn ttm_trend_avx512(
365 source: &[f64],
366 close: &[f64],
367 period: usize,
368 first: usize,
369 out: &mut [bool],
370) {
371 if period == 1 {
372 unsafe {
373 let n = source.len().min(close.len()).min(out.len());
374 if first >= n {
375 return;
376 }
377
378 for i in 0..first {
379 *out.get_unchecked_mut(i) = false;
380 }
381 let mut i = first;
382 const W: usize = 8;
383 while i + W <= n {
384 let s = _mm512_loadu_pd(source.as_ptr().add(i));
385 let c = _mm512_loadu_pd(close.as_ptr().add(i));
386 let m: u8 = _mm512_cmp_pd_mask(c, s, _CMP_GT_OQ);
387
388 (*out.get_unchecked_mut(i + 0)) = (m & (1 << 0)) != 0;
389 (*out.get_unchecked_mut(i + 1)) = (m & (1 << 1)) != 0;
390 (*out.get_unchecked_mut(i + 2)) = (m & (1 << 2)) != 0;
391 (*out.get_unchecked_mut(i + 3)) = (m & (1 << 3)) != 0;
392 (*out.get_unchecked_mut(i + 4)) = (m & (1 << 4)) != 0;
393 (*out.get_unchecked_mut(i + 5)) = (m & (1 << 5)) != 0;
394 (*out.get_unchecked_mut(i + 6)) = (m & (1 << 6)) != 0;
395 (*out.get_unchecked_mut(i + 7)) = (m & (1 << 7)) != 0;
396 i += W;
397 }
398 while i < n {
399 *out.get_unchecked_mut(i) = *close.get_unchecked(i) > *source.get_unchecked(i);
400 i += 1;
401 }
402 }
403 return;
404 }
405
406 ttm_trend_scalar(source, close, period, first, out)
407}
408
409#[inline]
410pub fn ttm_trend_avx2(
411 source: &[f64],
412 close: &[f64],
413 period: usize,
414 first: usize,
415 out: &mut [bool],
416) {
417 if period == 1 {
418 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419 unsafe {
420 let n = source.len().min(close.len()).min(out.len());
421 if first >= n {
422 return;
423 }
424
425 for i in 0..first {
426 *out.get_unchecked_mut(i) = false;
427 }
428 let mut i = first;
429 const W: usize = 4;
430 while i + W <= n {
431 let s = _mm256_loadu_pd(source.as_ptr().add(i));
432 let c = _mm256_loadu_pd(close.as_ptr().add(i));
433 let m = _mm256_cmp_pd(c, s, _CMP_GT_OQ);
434 let bits = _mm256_movemask_pd(m) as i32;
435 (*out.get_unchecked_mut(i + 0)) = (bits & (1 << 0)) != 0;
436 (*out.get_unchecked_mut(i + 1)) = (bits & (1 << 1)) != 0;
437 (*out.get_unchecked_mut(i + 2)) = (bits & (1 << 2)) != 0;
438 (*out.get_unchecked_mut(i + 3)) = (bits & (1 << 3)) != 0;
439 i += W;
440 }
441 while i < n {
442 *out.get_unchecked_mut(i) = *close.get_unchecked(i) > *source.get_unchecked(i);
443 i += 1;
444 }
445 }
446 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
447 {
448 ttm_trend_scalar(source, close, period, first, out)
449 }
450 return;
451 }
452 ttm_trend_scalar(source, close, period, first, out)
453}
454
455#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
456#[inline]
457pub unsafe fn ttm_trend_avx512_short(
458 source: &[f64],
459 close: &[f64],
460 period: usize,
461 first: usize,
462 out: &mut [bool],
463) {
464 ttm_trend_scalar(source, close, period, first, out)
465}
466
467#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
468#[inline]
469pub unsafe fn ttm_trend_avx512_long(
470 source: &[f64],
471 close: &[f64],
472 period: usize,
473 first: usize,
474 out: &mut [bool],
475) {
476 ttm_trend_scalar(source, close, period, first, out)
477}
478
479#[derive(Debug, Clone)]
480pub struct TtmTrendStream {
481 period: usize,
482 inv_period: f64,
483 buffer: Vec<f64>,
484 sum: f64,
485 head: usize,
486 len: usize,
487 pow2_mask: usize,
488}
489
490impl TtmTrendStream {
491 #[inline(always)]
492 pub fn try_new(params: TtmTrendParams) -> Result<Self, TtmTrendError> {
493 let period = params.period.unwrap_or(5);
494 if period == 0 {
495 return Err(TtmTrendError::InvalidPeriod {
496 period,
497 data_len: 0,
498 });
499 }
500 let pow2_mask = if period.is_power_of_two() {
501 period - 1
502 } else {
503 0
504 };
505 Ok(Self {
506 period,
507 inv_period: 1.0 / (period as f64),
508 buffer: vec![0.0; period],
509 sum: 0.0,
510 head: 0,
511 len: 0,
512 pow2_mask,
513 })
514 }
515
516 #[inline(always)]
517 fn bump(&mut self) {
518 if self.pow2_mask != 0 {
519 self.head = (self.head + 1) & self.pow2_mask;
520 } else {
521 let h = self.head + 1;
522 self.head = if h == self.period { 0 } else { h };
523 }
524 }
525
526 #[inline(always)]
527 pub fn update(&mut self, src_val: f64, close_val: f64) -> Option<bool> {
528 if self.period == 1 {
529 let old = self.buffer[self.head];
530 self.buffer[self.head] = src_val;
531 self.sum += src_val - old;
532 self.len = 1;
533
534 return Some(close_val > src_val);
535 }
536
537 let old = self.buffer[self.head];
538 self.buffer[self.head] = src_val;
539
540 self.sum += src_val - old;
541 self.bump();
542
543 if self.len < self.period {
544 self.len += 1;
545 if self.len < self.period {
546 return None;
547 }
548 }
549
550 let avg = self.sum * self.inv_period;
551 Some(close_val > avg)
552 }
553}
554
555#[derive(Clone, Debug)]
556pub struct TtmTrendBatchRange {
557 pub period: (usize, usize, usize),
558}
559
560impl Default for TtmTrendBatchRange {
561 fn default() -> Self {
562 Self {
563 period: (5, 254, 1),
564 }
565 }
566}
567
568#[derive(Clone, Debug, Default)]
569pub struct TtmTrendBatchBuilder {
570 range: TtmTrendBatchRange,
571 kernel: Kernel,
572}
573
574impl TtmTrendBatchBuilder {
575 pub fn new() -> Self {
576 Self::default()
577 }
578 pub fn kernel(mut self, k: Kernel) -> Self {
579 self.kernel = k;
580 self
581 }
582 #[inline]
583 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
584 self.range.period = (start, end, step);
585 self
586 }
587 #[inline]
588 pub fn period_static(mut self, p: usize) -> Self {
589 self.range.period = (p, p, 0);
590 self
591 }
592 pub fn apply_slices(
593 self,
594 source: &[f64],
595 close: &[f64],
596 ) -> Result<TtmTrendBatchOutput, TtmTrendError> {
597 ttm_trend_batch_with_kernel(source, close, &self.range, self.kernel)
598 }
599 pub fn with_default_slices(
600 source: &[f64],
601 close: &[f64],
602 k: Kernel,
603 ) -> Result<TtmTrendBatchOutput, TtmTrendError> {
604 TtmTrendBatchBuilder::new()
605 .kernel(k)
606 .apply_slices(source, close)
607 }
608 pub fn apply_candles(
609 self,
610 c: &Candles,
611 src: &str,
612 ) -> Result<TtmTrendBatchOutput, TtmTrendError> {
613 let source = source_type(c, src);
614 let close = source_type(c, "close");
615 self.apply_slices(source, close)
616 }
617 pub fn with_default_candles(c: &Candles) -> Result<TtmTrendBatchOutput, TtmTrendError> {
618 TtmTrendBatchBuilder::new()
619 .kernel(Kernel::Auto)
620 .apply_candles(c, "hl2")
621 }
622}
623
624pub fn ttm_trend_batch_with_kernel(
625 source: &[f64],
626 close: &[f64],
627 sweep: &TtmTrendBatchRange,
628 k: Kernel,
629) -> Result<TtmTrendBatchOutput, TtmTrendError> {
630 let kernel = match k {
631 Kernel::Auto => detect_best_batch_kernel(),
632 other if other.is_batch() => other,
633 other => return Err(TtmTrendError::InvalidKernelForBatch(other)),
634 };
635 let simd = match kernel {
636 Kernel::Avx512Batch => Kernel::Avx512,
637 Kernel::Avx2Batch => Kernel::Avx2,
638 Kernel::ScalarBatch => Kernel::Scalar,
639 _ => Kernel::Scalar,
640 };
641 ttm_trend_batch_par_slice(source, close, sweep, simd)
642}
643
644#[derive(Clone, Debug)]
645pub struct TtmTrendBatchOutput {
646 pub values: Vec<bool>,
647 pub combos: Vec<TtmTrendParams>,
648 pub rows: usize,
649 pub cols: usize,
650}
651
652impl TtmTrendBatchOutput {
653 pub fn row_for_params(&self, p: &TtmTrendParams) -> Option<usize> {
654 self.combos
655 .iter()
656 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
657 }
658 pub fn values_for(&self, p: &TtmTrendParams) -> Option<&[bool]> {
659 self.row_for_params(p).and_then(|row| {
660 let start = row.checked_mul(self.cols)?;
661 self.values.get(start..start + self.cols)
662 })
663 }
664}
665
666#[inline(always)]
667fn expand_grid(r: &TtmTrendBatchRange) -> Result<Vec<TtmTrendParams>, TtmTrendError> {
668 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TtmTrendError> {
669 if step == 0 || start == end {
670 return Ok(vec![start]);
671 }
672 if start < end {
673 let st = step.max(1);
674 let v: Vec<usize> = (start..=end).step_by(st).collect();
675 if v.is_empty() {
676 return Err(TtmTrendError::InvalidRange { start, end, step });
677 }
678 return Ok(v);
679 }
680
681 let mut v = Vec::new();
682 let mut x = start as isize;
683 let end_i = end as isize;
684 let st = (step as isize).max(1);
685 while x >= end_i {
686 v.push(x as usize);
687 x -= st;
688 }
689 if v.is_empty() {
690 return Err(TtmTrendError::InvalidRange { start, end, step });
691 }
692 Ok(v)
693 }
694 let periods = axis_usize(r.period)?;
695 if periods.is_empty() {
696 return Err(TtmTrendError::InvalidRange {
697 start: r.period.0,
698 end: r.period.1,
699 step: r.period.2,
700 });
701 }
702 let mut out = Vec::with_capacity(periods.len());
703 for &p in &periods {
704 out.push(TtmTrendParams { period: Some(p) });
705 }
706 Ok(out)
707}
708
709#[inline(always)]
710pub fn ttm_trend_batch_slice(
711 source: &[f64],
712 close: &[f64],
713 sweep: &TtmTrendBatchRange,
714 kern: Kernel,
715) -> Result<TtmTrendBatchOutput, TtmTrendError> {
716 ttm_trend_batch_inner(source, close, sweep, kern, false)
717}
718
719#[inline(always)]
720pub fn ttm_trend_batch_par_slice(
721 source: &[f64],
722 close: &[f64],
723 sweep: &TtmTrendBatchRange,
724 kern: Kernel,
725) -> Result<TtmTrendBatchOutput, TtmTrendError> {
726 ttm_trend_batch_inner(source, close, sweep, kern, true)
727}
728
729#[inline(always)]
730fn ttm_batch_inner_f64(
731 source: &[f64],
732 close: &[f64],
733 sweep: &TtmTrendBatchRange,
734 kern: Kernel,
735 parallel: bool,
736) -> Result<(Vec<f64>, Vec<TtmTrendParams>, usize, usize), TtmTrendError> {
737 let combos = expand_grid(sweep)?;
738
739 let len = source.len().min(close.len());
740 if len == 0 {
741 return Err(TtmTrendError::EmptyInputData);
742 }
743 let first = source
744 .iter()
745 .zip(close)
746 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
747 .ok_or(TtmTrendError::AllValuesNaN)?;
748 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
749 if len - first < max_p {
750 return Err(TtmTrendError::NotEnoughValidData {
751 needed: max_p,
752 valid: len - first,
753 });
754 }
755
756 let rows = combos.len();
757 let cols = len;
758
759 let mut buf_mu = make_uninit_matrix(rows, cols);
760 let warm: Vec<usize> = combos
761 .iter()
762 .map(|c| first + c.period.unwrap() - 1)
763 .collect();
764 init_matrix_prefixes(&mut buf_mu, cols, &warm);
765
766 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
767 let values: &mut [f64] =
768 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
769
770 let mut psum = vec![0.0f64; len];
771 if first < len {
772 psum[first] = source[first];
773 for i in (first + 1)..len {
774 psum[i] = psum[i - 1] + source[i];
775 }
776 }
777
778 let do_row = |row: usize, row_dst: &mut [f64]| {
779 let p = combos[row].period.unwrap();
780 let warm_i = warm[row];
781 if warm_i >= len {
782 return;
783 }
784
785 let inv_p = 1.0 / (p as f64);
786 let mut i = warm_i;
787 row_dst[i] = if close[i] > psum[i] * inv_p { 1.0 } else { 0.0 };
788 i += 1;
789
790 while i < len {
791 let sum = psum[i] - psum[i - p];
792 row_dst[i] = if close[i] > sum * inv_p { 1.0 } else { 0.0 };
793 i += 1;
794 }
795 };
796
797 if parallel {
798 #[cfg(not(target_arch = "wasm32"))]
799 values
800 .par_chunks_mut(cols)
801 .enumerate()
802 .for_each(|(r, ch)| do_row(r, ch));
803 #[cfg(target_arch = "wasm32")]
804 for (r, ch) in values.chunks_mut(cols).enumerate() {
805 do_row(r, ch);
806 }
807 } else {
808 for (r, ch) in values.chunks_mut(cols).enumerate() {
809 do_row(r, ch);
810 }
811 }
812
813 let values_vec = unsafe {
814 Vec::from_raw_parts(
815 guard.as_mut_ptr() as *mut f64,
816 guard.len(),
817 guard.capacity(),
818 )
819 };
820 core::mem::forget(guard);
821 Ok((values_vec, combos, rows, cols))
822}
823
824#[inline(always)]
825fn ttm_trend_batch_inner(
826 source: &[f64],
827 close: &[f64],
828 sweep: &TtmTrendBatchRange,
829 kern: Kernel,
830 parallel: bool,
831) -> Result<TtmTrendBatchOutput, TtmTrendError> {
832 let (vals_f64, combos, rows, cols) = ttm_batch_inner_f64(source, close, sweep, kern, parallel)?;
833
834 let values: Vec<bool> = vals_f64.into_iter().map(|v| v == 1.0).collect();
835 Ok(TtmTrendBatchOutput {
836 values,
837 combos,
838 rows,
839 cols,
840 })
841}
842
843#[inline(always)]
844fn ttm_trend_batch_inner_into_f64(
845 source: &[f64],
846 close: &[f64],
847 sweep: &TtmTrendBatchRange,
848 kern: Kernel,
849 parallel: bool,
850 out: &mut [f64],
851) -> Result<Vec<TtmTrendParams>, TtmTrendError> {
852 let combos = expand_grid(sweep)?;
853 let len = source.len().min(close.len());
854 if len == 0 {
855 return Err(TtmTrendError::EmptyInputData);
856 }
857 let first = source
858 .iter()
859 .zip(close.iter())
860 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
861 .ok_or(TtmTrendError::AllValuesNaN)?;
862 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
863 if len - first < max_p {
864 return Err(TtmTrendError::NotEnoughValidData {
865 needed: max_p,
866 valid: len - first,
867 });
868 }
869 let rows = combos.len();
870 let cols = len;
871 let expected = rows.checked_mul(cols).ok_or(TtmTrendError::InvalidRange {
872 start: sweep.period.0,
873 end: sweep.period.1,
874 step: sweep.period.2,
875 })?;
876 if out.len() != expected {
877 return Err(TtmTrendError::OutputLengthMismatch {
878 expected,
879 got: out.len(),
880 });
881 }
882
883 for (row, combo) in combos.iter().enumerate() {
884 let warmup = first + combo.period.unwrap() - 1;
885 let row_start = row * cols;
886 for i in 0..warmup.min(cols) {
887 out[row_start + i] = f64::NAN;
888 }
889 }
890
891 let do_row = |row: usize, out_row: &mut [f64]| {
892 let period = combos[row].period.unwrap();
893 ttm_numeric_compute_into(source, close, period, first, out_row);
894 };
895
896 if parallel {
897 #[cfg(not(target_arch = "wasm32"))]
898 {
899 out.par_chunks_mut(cols)
900 .enumerate()
901 .for_each(|(row, slice)| do_row(row, slice));
902 }
903
904 #[cfg(target_arch = "wasm32")]
905 {
906 for (row, slice) in out.chunks_mut(cols).enumerate() {
907 do_row(row, slice);
908 }
909 }
910 } else {
911 for (row, slice) in out.chunks_mut(cols).enumerate() {
912 do_row(row, slice);
913 }
914 }
915
916 Ok(combos)
917}
918
919#[inline(always)]
920fn ttm_trend_batch_inner_into(
921 source: &[f64],
922 close: &[f64],
923 sweep: &TtmTrendBatchRange,
924 kern: Kernel,
925 parallel: bool,
926 out: &mut [bool],
927) -> Result<Vec<TtmTrendParams>, TtmTrendError> {
928 let len = source.len().min(close.len());
929 let combos = expand_grid(sweep)?;
930 let rows = combos.len();
931 let cols = len;
932 let expected = rows.checked_mul(cols).ok_or(TtmTrendError::InvalidRange {
933 start: sweep.period.0,
934 end: sweep.period.1,
935 step: sweep.period.2,
936 })?;
937 if out.len() != expected {
938 return Err(TtmTrendError::OutputLengthMismatch {
939 expected,
940 got: out.len(),
941 });
942 }
943
944 let mut tmp = vec![f64::NAN; expected];
945 let result = ttm_trend_batch_inner_into_f64(source, close, sweep, kern, parallel, &mut tmp)?;
946
947 for (i, &v) in tmp.iter().enumerate() {
948 out[i] = v == 1.0;
949 }
950
951 Ok(result)
952}
953
954#[inline(always)]
955unsafe fn ttm_trend_row_scalar(
956 source: &[f64],
957 close: &[f64],
958 first: usize,
959 period: usize,
960 out: &mut [bool],
961) {
962 out.fill(false);
963 let mut sum = 0.0;
964 for &v in &source[first..first + period] {
965 sum += v;
966 }
967 let inv_p = 1.0 / (period as f64);
968 let mut idx = first + period - 1;
969 out[idx] = close[idx] > sum * inv_p;
970 idx += 1;
971 while idx < source.len().min(close.len()) {
972 sum += source[idx] - source[idx - period];
973 out[idx] = close[idx] > sum * inv_p;
974 idx += 1;
975 }
976}
977
978#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
979#[inline(always)]
980unsafe fn ttm_trend_row_avx2(
981 source: &[f64],
982 close: &[f64],
983 first: usize,
984 period: usize,
985 out: &mut [bool],
986) {
987 ttm_trend_row_scalar(source, close, first, period, out)
988}
989
990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
991#[inline(always)]
992pub unsafe fn ttm_trend_row_avx512(
993 source: &[f64],
994 close: &[f64],
995 first: usize,
996 period: usize,
997 out: &mut [bool],
998) {
999 if period <= 32 {
1000 ttm_trend_row_avx512_short(source, close, first, period, out);
1001 } else {
1002 ttm_trend_row_avx512_long(source, close, first, period, out);
1003 }
1004}
1005
1006#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1007#[inline(always)]
1008unsafe fn ttm_trend_row_avx512_short(
1009 source: &[f64],
1010 close: &[f64],
1011 first: usize,
1012 period: usize,
1013 out: &mut [bool],
1014) {
1015 ttm_trend_row_scalar(source, close, first, period, out)
1016}
1017
1018#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1019#[inline(always)]
1020unsafe fn ttm_trend_row_avx512_long(
1021 source: &[f64],
1022 close: &[f64],
1023 first: usize,
1024 period: usize,
1025 out: &mut [bool],
1026) {
1027 ttm_trend_row_scalar(source, close, first, period, out)
1028}
1029
1030#[cfg(feature = "python")]
1031pub fn register_ttm_trend_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1032 m.add_function(wrap_pyfunction!(ttm_trend_py, m)?)?;
1033 m.add_function(wrap_pyfunction!(ttm_trend_batch_py, m)?)?;
1034 m.add_class::<TtmTrendStreamPy>()?;
1035 #[cfg(all(feature = "python", feature = "cuda"))]
1036 {
1037 m.add_class::<TtmTrendDeviceArrayF32Py>()?;
1038 m.add_function(wrap_pyfunction!(ttm_trend_cuda_batch_dev_py, m)?)?;
1039 m.add_function(wrap_pyfunction!(
1040 ttm_trend_cuda_many_series_one_param_dev_py,
1041 m
1042 )?)?;
1043 }
1044 Ok(())
1045}
1046
1047#[cfg(all(feature = "python", feature = "cuda"))]
1048use crate::cuda::cuda_available;
1049#[cfg(all(feature = "python", feature = "cuda"))]
1050use crate::cuda::moving_averages::DeviceArrayF32;
1051#[cfg(all(feature = "python", feature = "cuda"))]
1052use crate::cuda::ttm_trend_wrapper::CudaTtmTrend;
1053#[cfg(all(feature = "python", feature = "cuda"))]
1054use cust::context::Context;
1055#[cfg(all(feature = "python", feature = "cuda"))]
1056use std::sync::Arc;
1057
1058#[cfg(all(feature = "python", feature = "cuda"))]
1059#[pyclass(
1060 module = "ta_indicators.cuda",
1061 name = "TtmTrendDeviceArrayF32",
1062 unsendable
1063)]
1064pub struct TtmTrendDeviceArrayF32Py {
1065 pub(crate) inner: DeviceArrayF32,
1066 pub(crate) _ctx: Arc<Context>,
1067 pub(crate) device_id: u32,
1068}
1069
1070#[cfg(all(feature = "python", feature = "cuda"))]
1071#[pymethods]
1072impl TtmTrendDeviceArrayF32Py {
1073 #[getter]
1074 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1075 let d = PyDict::new(py);
1076 let itemsize = std::mem::size_of::<f32>();
1077 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1078 d.set_item("typestr", "<f4")?;
1079 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1080 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1081
1082 d.set_item("version", 3)?;
1083 Ok(d)
1084 }
1085
1086 fn __dlpack_device__(&self) -> (i32, i32) {
1087 (2, self.device_id as i32)
1088 }
1089
1090 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1091 fn __dlpack__<'py>(
1092 &mut self,
1093 py: Python<'py>,
1094 stream: Option<PyObject>,
1095 max_version: Option<PyObject>,
1096 dl_device: Option<PyObject>,
1097 copy: Option<PyObject>,
1098 ) -> PyResult<PyObject> {
1099 if let Some(ref s_obj) = stream {
1100 if let Ok(s) = s_obj.extract::<usize>(py) {
1101 if s == 0 {
1102 return Err(PyValueError::new_err(
1103 "__dlpack__ stream=0 is invalid for CUDA",
1104 ));
1105 }
1106 }
1107 }
1108
1109 let (kdl, alloc_dev) = self.__dlpack_device__();
1110 if let Some(dev_obj) = dl_device.as_ref() {
1111 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1112 if dev_ty != kdl || dev_id != alloc_dev {
1113 let wants_copy = copy
1114 .as_ref()
1115 .and_then(|c| c.extract::<bool>(py).ok())
1116 .unwrap_or(false);
1117 if wants_copy {
1118 return Err(PyValueError::new_err(
1119 "device copy not implemented for __dlpack__",
1120 ));
1121 } else {
1122 return Err(PyValueError::new_err(
1123 "dl_device mismatch; cross-device copy not supported for TtmTrendDeviceArrayF32",
1124 ));
1125 }
1126 }
1127 }
1128 }
1129
1130 if let Some(ref c_obj) = copy {
1131 if let Ok(true) = c_obj.extract::<bool>(py) {
1132 return Err(PyValueError::new_err(
1133 "copy=True not supported for TtmTrendDeviceArrayF32",
1134 ));
1135 }
1136 }
1137
1138 let dummy =
1139 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1140 let inner = std::mem::replace(
1141 &mut self.inner,
1142 DeviceArrayF32 {
1143 buf: dummy,
1144 rows: 0,
1145 cols: 0,
1146 },
1147 );
1148
1149 let rows = inner.rows;
1150 let cols = inner.cols;
1151 let buf = inner.buf;
1152
1153 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1154
1155 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1156 }
1157}
1158
1159#[cfg(all(feature = "python", feature = "cuda"))]
1160impl TtmTrendDeviceArrayF32Py {
1161 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1162 Self {
1163 inner,
1164 _ctx: ctx_guard,
1165 device_id,
1166 }
1167 }
1168}
1169
1170#[cfg(all(feature = "python", feature = "cuda"))]
1171#[pyfunction(name = "ttm_trend_cuda_batch_dev")]
1172#[pyo3(signature = (source_f32, close_f32, period_range, device_id=0))]
1173pub fn ttm_trend_cuda_batch_dev_py(
1174 py: Python<'_>,
1175 source_f32: PyReadonlyArray1<'_, f32>,
1176 close_f32: PyReadonlyArray1<'_, f32>,
1177 period_range: (usize, usize, usize),
1178 device_id: usize,
1179) -> PyResult<TtmTrendDeviceArrayF32Py> {
1180 if !cuda_available() {
1181 return Err(PyValueError::new_err("CUDA not available"));
1182 }
1183 let src = source_f32.as_slice()?;
1184 let cls = close_f32.as_slice()?;
1185 let sweep = TtmTrendBatchRange {
1186 period: period_range,
1187 };
1188 let (inner, ctx, dev_id) = py.allow_threads(|| {
1189 let cuda =
1190 CudaTtmTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1191 let ctx = cuda.context_arc();
1192 let dev_id = cuda.device_id();
1193 let arr = cuda
1194 .ttm_trend_batch_dev(src, cls, &sweep)
1195 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1196 Ok::<_, PyErr>((arr, ctx, dev_id))
1197 })?;
1198 Ok(TtmTrendDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1199}
1200
1201#[cfg(all(feature = "python", feature = "cuda"))]
1202#[pyfunction(name = "ttm_trend_cuda_many_series_one_param_dev")]
1203#[pyo3(signature = (source_tm_f32, close_tm_f32, cols, rows, period, device_id=0))]
1204pub fn ttm_trend_cuda_many_series_one_param_dev_py(
1205 py: Python<'_>,
1206 source_tm_f32: PyReadonlyArray1<'_, f32>,
1207 close_tm_f32: PyReadonlyArray1<'_, f32>,
1208 cols: usize,
1209 rows: usize,
1210 period: usize,
1211 device_id: usize,
1212) -> PyResult<TtmTrendDeviceArrayF32Py> {
1213 if !cuda_available() {
1214 return Err(PyValueError::new_err("CUDA not available"));
1215 }
1216 let src_tm = source_tm_f32.as_slice()?;
1217 let cls_tm = close_tm_f32.as_slice()?;
1218 let (inner, ctx, dev_id) = py.allow_threads(|| {
1219 let cuda =
1220 CudaTtmTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1221 let ctx = cuda.context_arc();
1222 let dev_id = cuda.device_id();
1223 let arr = cuda
1224 .ttm_trend_many_series_one_param_time_major_dev(src_tm, cls_tm, cols, rows, period)
1225 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1226 Ok::<_, PyErr>((arr, ctx, dev_id))
1227 })?;
1228 Ok(TtmTrendDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233 use super::*;
1234 use crate::skip_if_unsupported;
1235 use crate::utilities::data_loader::read_candles_from_csv;
1236 use paste::paste;
1237
1238 #[test]
1239 fn test_ttm_trend_into_matches_api() -> Result<(), Box<dyn Error>> {
1240 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1241 let candles = read_candles_from_csv(file)?;
1242 let source = source_type(&candles, "hl2");
1243 let close = source_type(&candles, "close");
1244 let input = TtmTrendInput::from_slices(source, close, TtmTrendParams::default());
1245
1246 let baseline = ttm_trend(&input)?.values;
1247
1248 let mut out = vec![0.0f64; close.len()];
1249 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1250 ttm_trend_into(&input, &mut out)?;
1251 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252 {
1253 ttm_trend_into_slice_f64(&mut out, &input, Kernel::Auto)?;
1254 }
1255
1256 let (_s, _c, p, first, _k) = ttm_prepare(&input, Kernel::Auto)?;
1257 let warmup_end = first + p - 1;
1258
1259 let mut expected = vec![0.0f64; out.len()];
1260 for i in 0..out.len() {
1261 if i < warmup_end {
1262 expected[i] = f64::NAN;
1263 } else {
1264 expected[i] = if baseline[i] { 1.0 } else { 0.0 };
1265 }
1266 }
1267
1268 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1269 (a.is_nan() && b.is_nan()) || (a == b)
1270 }
1271
1272 assert_eq!(out.len(), expected.len());
1273 for i in 0..out.len() {
1274 assert!(
1275 eq_or_both_nan(out[i], expected[i]),
1276 "Mismatch at index {}: out={:?}, expected={:?}",
1277 i,
1278 out[i],
1279 expected[i]
1280 );
1281 }
1282
1283 Ok(())
1284 }
1285
1286 fn check_ttm_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1287 skip_if_unsupported!(kernel, test_name);
1288 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1289 let candles = read_candles_from_csv(file)?;
1290 let default_params = TtmTrendParams { period: None };
1291 let input = TtmTrendInput::from_candles(&candles, "hl2", default_params);
1292 let output = ttm_trend_with_kernel(&input, kernel)?;
1293 assert_eq!(output.values.len(), candles.close.len());
1294 Ok(())
1295 }
1296
1297 fn check_ttm_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1298 skip_if_unsupported!(kernel, test_name);
1299 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1300 let candles = read_candles_from_csv(file)?;
1301 let close = source_type(&candles, "close");
1302 let params = TtmTrendParams { period: Some(5) };
1303 let input = TtmTrendInput::from_candles(&candles, "hl2", params);
1304 let result = ttm_trend_with_kernel(&input, kernel)?;
1305 assert_eq!(result.values.len(), close.len());
1306 let expected_last_five = [true, false, false, false, false];
1307 let start = result.values.len().saturating_sub(5);
1308 for (i, &val) in result.values[start..].iter().enumerate() {
1309 assert_eq!(val, expected_last_five[i]);
1310 }
1311 Ok(())
1312 }
1313
1314 fn check_ttm_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1315 skip_if_unsupported!(kernel, test_name);
1316 let src = [10.0, 20.0, 30.0];
1317 let close = [12.0, 22.0, 32.0];
1318 let params = TtmTrendParams { period: Some(0) };
1319 let input = TtmTrendInput::from_slices(&src, &close, params);
1320 let res = ttm_trend_with_kernel(&input, kernel);
1321 assert!(res.is_err());
1322 Ok(())
1323 }
1324
1325 fn check_ttm_period_exceeds_length(
1326 test_name: &str,
1327 kernel: Kernel,
1328 ) -> Result<(), Box<dyn Error>> {
1329 skip_if_unsupported!(kernel, test_name);
1330 let src = [1.0, 2.0, 3.0];
1331 let close = [1.0, 2.0, 3.0];
1332 let params = TtmTrendParams { period: Some(10) };
1333 let input = TtmTrendInput::from_slices(&src, &close, params);
1334 let res = ttm_trend_with_kernel(&input, kernel);
1335 assert!(res.is_err());
1336 Ok(())
1337 }
1338
1339 fn check_ttm_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1340 skip_if_unsupported!(kernel, test_name);
1341 let src = [42.0];
1342 let close = [43.0];
1343 let params = TtmTrendParams { period: Some(5) };
1344 let input = TtmTrendInput::from_slices(&src, &close, params);
1345 let res = ttm_trend_with_kernel(&input, kernel);
1346 assert!(res.is_err());
1347 Ok(())
1348 }
1349
1350 fn check_ttm_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1351 skip_if_unsupported!(kernel, test_name);
1352 let src = [f64::NAN, f64::NAN, f64::NAN];
1353 let close = [f64::NAN, f64::NAN, f64::NAN];
1354 let params = TtmTrendParams { period: Some(5) };
1355 let input = TtmTrendInput::from_slices(&src, &close, params);
1356 let res = ttm_trend_with_kernel(&input, kernel);
1357 assert!(res.is_err());
1358 Ok(())
1359 }
1360
1361 fn check_ttm_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1362 skip_if_unsupported!(kernel, test_name);
1363 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1364 let candles = read_candles_from_csv(file)?;
1365 let src = source_type(&candles, "hl2");
1366 let close = source_type(&candles, "close");
1367 let period = 5;
1368 let input = TtmTrendInput::from_slices(
1369 src,
1370 close,
1371 TtmTrendParams {
1372 period: Some(period),
1373 },
1374 );
1375 let batch_output = ttm_trend_with_kernel(&input, kernel)?.values;
1376 let mut stream = TtmTrendStream::try_new(TtmTrendParams {
1377 period: Some(period),
1378 })?;
1379 let mut stream_values = Vec::with_capacity(close.len());
1380 for (&s, &c) in src.iter().zip(close.iter()) {
1381 match stream.update(s, c) {
1382 Some(v) => stream_values.push(v),
1383 None => stream_values.push(false),
1384 }
1385 }
1386 assert_eq!(batch_output, stream_values);
1387 Ok(())
1388 }
1389
1390 #[cfg(debug_assertions)]
1391 fn check_ttm_trend_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1392 skip_if_unsupported!(kernel, test_name);
1393
1394 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1395 let candles = read_candles_from_csv(file_path)?;
1396 let src = source_type(&candles, "hl2");
1397 let close = source_type(&candles, "close");
1398
1399 let test_params = vec![
1400 TtmTrendParams::default(),
1401 TtmTrendParams { period: Some(1) },
1402 TtmTrendParams { period: Some(2) },
1403 TtmTrendParams { period: Some(3) },
1404 TtmTrendParams { period: Some(7) },
1405 TtmTrendParams { period: Some(10) },
1406 TtmTrendParams { period: Some(14) },
1407 TtmTrendParams { period: Some(20) },
1408 TtmTrendParams { period: Some(50) },
1409 TtmTrendParams { period: Some(100) },
1410 TtmTrendParams { period: Some(200) },
1411 ];
1412
1413 for (param_idx, params) in test_params.iter().enumerate() {
1414 let input = TtmTrendInput::from_slices(src, close, params.clone());
1415 let output = ttm_trend_with_kernel(&input, kernel)?;
1416
1417 let first_valid = src
1418 .iter()
1419 .zip(close.iter())
1420 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
1421 .unwrap_or(0);
1422 let period = params.period.unwrap_or(5);
1423 let warmup_end = first_valid + period - 1;
1424
1425 for i in 0..warmup_end.min(output.values.len()) {
1426 if output.values[i] {
1427 panic!(
1428 "[{}] Found unexpected true value (poison) at index {} in warmup period \
1429 with params: period={} (param set {}). \
1430 Expected false during warmup (indices 0-{})",
1431 test_name,
1432 i,
1433 period,
1434 param_idx,
1435 warmup_end - 1
1436 );
1437 }
1438 }
1439
1440 if warmup_end < output.values.len() {
1441 let after_warmup = &output.values[warmup_end..];
1442 let all_true = after_warmup.iter().all(|&v| v);
1443 let all_false = after_warmup.iter().all(|&v| !v);
1444
1445 if all_true && after_warmup.len() > 10 {
1446 panic!(
1447 "[{}] All values after warmup are true, possible poison pattern \
1448 with params: period={} (param set {}). This is highly unlikely \
1449 for real TTM trend calculations.",
1450 test_name, period, param_idx
1451 );
1452 }
1453 }
1454 }
1455
1456 Ok(())
1457 }
1458
1459 #[cfg(not(debug_assertions))]
1460 fn check_ttm_trend_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461 Ok(())
1462 }
1463
1464 macro_rules! generate_all_ttm_tests {
1465 ($($test_fn:ident),*) => {
1466 paste! {
1467 $(
1468 #[test]
1469 fn [<$test_fn _scalar_f64>]() {
1470 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1471 }
1472 )*
1473 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1474 $(
1475 #[test]
1476 fn [<$test_fn _avx2_f64>]() {
1477 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1478 }
1479 #[test]
1480 fn [<$test_fn _avx512_f64>]() {
1481 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1482 }
1483 )*
1484 }
1485 }
1486 }
1487
1488 generate_all_ttm_tests!(
1489 check_ttm_partial_params,
1490 check_ttm_accuracy,
1491 check_ttm_zero_period,
1492 check_ttm_period_exceeds_length,
1493 check_ttm_very_small_dataset,
1494 check_ttm_all_nan,
1495 check_ttm_streaming,
1496 check_ttm_trend_no_poison
1497 );
1498
1499 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1500 skip_if_unsupported!(kernel, test);
1501 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1502 let c = read_candles_from_csv(file)?;
1503 let src = source_type(&c, "hl2");
1504 let close = source_type(&c, "close");
1505 let output = TtmTrendBatchBuilder::new()
1506 .kernel(kernel)
1507 .apply_slices(src, close)?;
1508 let def = TtmTrendParams::default();
1509 let row = output.values_for(&def).expect("default row missing");
1510 assert_eq!(row.len(), close.len());
1511 let expected = [true, false, false, false, false];
1512 let start = row.len() - 5;
1513 for (i, &v) in row[start..].iter().enumerate() {
1514 assert_eq!(v, expected[i]);
1515 }
1516 Ok(())
1517 }
1518
1519 #[cfg(debug_assertions)]
1520 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1521 skip_if_unsupported!(kernel, test);
1522
1523 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1524 let c = read_candles_from_csv(file)?;
1525 let src = source_type(&c, "hl2");
1526 let close = source_type(&c, "close");
1527
1528 let test_configs = vec![
1529 (1, 10, 1),
1530 (2, 10, 2),
1531 (5, 25, 5),
1532 (10, 50, 10),
1533 (1, 5, 1),
1534 (20, 100, 20),
1535 (7, 21, 7),
1536 (3, 30, 3),
1537 (15, 15, 0),
1538 ];
1539
1540 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1541 let output = TtmTrendBatchBuilder::new()
1542 .kernel(kernel)
1543 .period_range(p_start, p_end, p_step)
1544 .apply_slices(src, close)?;
1545
1546 let first_valid = src
1547 .iter()
1548 .zip(close.iter())
1549 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
1550 .unwrap_or(0);
1551
1552 for (idx, &val) in output.values.iter().enumerate() {
1553 let row = idx / output.cols;
1554 let col = idx % output.cols;
1555 let combo = &output.combos[row];
1556 let period = combo.period.unwrap_or(5);
1557 let warmup_end = first_valid + period - 1;
1558
1559 if col < warmup_end {
1560 if val {
1561 panic!(
1562 "[{}] Config {}: Found unexpected true value (poison) \
1563 at row {} col {} (flat index {}) in warmup period \
1564 with params: period={} (warmup ends at col {})",
1565 test,
1566 cfg_idx,
1567 row,
1568 col,
1569 idx,
1570 period,
1571 warmup_end - 1
1572 );
1573 }
1574 }
1575 }
1576
1577 for row in 0..output.rows {
1578 let start_idx = row * output.cols;
1579 let row_values = &output.values[start_idx..start_idx + output.cols];
1580 let period = output.combos[row].period.unwrap_or(5);
1581 let warmup_end = first_valid + period - 1;
1582
1583 if warmup_end < row_values.len() {
1584 let after_warmup = &row_values[warmup_end..];
1585 if after_warmup.len() > 10 && after_warmup.iter().all(|&v| v) {
1586 panic!(
1587 "[{}] Config {}: Row {} has all true values after warmup, \
1588 possible poison pattern with period={}. This is highly unlikely \
1589 for real TTM trend calculations.",
1590 test, cfg_idx, row, period
1591 );
1592 }
1593 }
1594 }
1595 }
1596
1597 Ok(())
1598 }
1599
1600 #[cfg(not(debug_assertions))]
1601 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1602 Ok(())
1603 }
1604
1605 macro_rules! gen_batch_tests {
1606 ($fn_name:ident) => {
1607 paste! {
1608 #[test] fn [<$fn_name _scalar>]() {
1609 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1610 }
1611 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1612 #[test] fn [<$fn_name _avx2>]() {
1613 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1614 }
1615 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1616 #[test] fn [<$fn_name _avx512>]() {
1617 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1618 }
1619 #[test] fn [<$fn_name _auto_detect>]() {
1620 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1621 }
1622 }
1623 };
1624 }
1625 gen_batch_tests!(check_batch_default_row);
1626 gen_batch_tests!(check_batch_no_poison);
1627
1628 #[cfg(feature = "proptest")]
1629 #[allow(clippy::float_cmp)]
1630 fn check_ttm_trend_property(
1631 test_name: &str,
1632 kernel: Kernel,
1633 ) -> Result<(), Box<dyn std::error::Error>> {
1634 use proptest::prelude::*;
1635 skip_if_unsupported!(kernel, test_name);
1636
1637 {
1638 let source = vec![100.0, 200.0, 300.0, 400.0, 500.0];
1639 let close = vec![150.0, 250.0, 350.0, 450.0, 550.0];
1640 let period = 2;
1641 let params = TtmTrendParams {
1642 period: Some(period),
1643 };
1644 let input = TtmTrendInput::from_slices(&source, &close, params);
1645 let result = ttm_trend_with_kernel(&input, kernel)?;
1646
1647 assert!(
1648 result.values[1],
1649 "Manual test failed at index 1 for {}",
1650 test_name
1651 );
1652
1653 assert!(
1654 result.values[2],
1655 "Manual test failed at index 2 for {}",
1656 test_name
1657 );
1658 }
1659
1660 let strat = (2usize..=50).prop_flat_map(|period| {
1661 let data_len = period * 2 + 50;
1662 (
1663 (100f64..10000f64),
1664 prop::collection::vec((-0.02f64..0.02f64), data_len - 1),
1665 Just(period),
1666 (0.005f64..0.02f64),
1667 )
1668 });
1669
1670 proptest::test_runner::TestRunner::default()
1671 .run(
1672 &strat,
1673 |(start_price, price_changes, period, spread_factor)| {
1674 let mut base_prices = Vec::with_capacity(price_changes.len() + 1);
1675 base_prices.push(start_price);
1676
1677 let mut current_price = start_price;
1678 for &change_pct in &price_changes {
1679 current_price *= 1.0 + change_pct;
1680 current_price = current_price.max(10.0);
1681 base_prices.push(current_price);
1682 }
1683
1684 let mut source = Vec::with_capacity(base_prices.len());
1685 let mut close = Vec::with_capacity(base_prices.len());
1686
1687 for (i, &base) in base_prices.iter().enumerate() {
1688 let spread = base * spread_factor;
1689 let high = base + spread;
1690 let low = base - spread;
1691
1692 source.push((high + low) / 2.0);
1693
1694 let close_ratio = ((i as f64 * 0.3).sin() + 1.0) / 2.0;
1695 close.push(low + (high - low) * close_ratio);
1696 }
1697 let params = TtmTrendParams {
1698 period: Some(period),
1699 };
1700 let input = TtmTrendInput::from_slices(&source, &close, params);
1701
1702 let result = ttm_trend_with_kernel(&input, kernel)?;
1703 let values = result.values;
1704
1705 let ref_result = ttm_trend_with_kernel(&input, Kernel::Scalar)?;
1706 let ref_values = ref_result.values;
1707
1708 let first_valid = source
1709 .iter()
1710 .zip(close.iter())
1711 .position(|(&s, &c)| !s.is_nan() && !c.is_nan())
1712 .unwrap_or(0);
1713 let warmup_end = first_valid + period - 1;
1714
1715 prop_assert_eq!(values.len(), source.len());
1716 prop_assert_eq!(values.len(), close.len());
1717
1718 for i in 0..warmup_end.min(values.len()) {
1719 prop_assert!(
1720 !values[i],
1721 "Expected false during warmup at index {} (warmup ends at {})",
1722 i,
1723 warmup_end - 1
1724 );
1725 }
1726
1727 if warmup_end + 1 < values.len() {
1728 let mut sum = 0.0;
1729 for j in (first_valid + 1)..(first_valid + period + 1) {
1730 sum += source[j];
1731 }
1732
1733 for i in (warmup_end + 1)..values.len() {
1734 let avg = sum / (period as f64);
1735 let expected = close[i] > avg;
1736
1737 prop_assert_eq!(
1738 values[i], expected,
1739 "Calculation mismatch at index {}: close={:.4}, avg={:.4}, expected={}, got={}",
1740 i, close[i], avg, expected, values[i]
1741 );
1742
1743 if i + 1 < source.len() {
1744 sum += source[i + 1] - source[i + 1 - period];
1745 }
1746 }
1747 }
1748
1749 for i in 0..values.len() {
1750 prop_assert_eq!(
1751 values[i],
1752 ref_values[i],
1753 "Kernel mismatch at index {}: {} kernel={}, scalar={}",
1754 i,
1755 test_name,
1756 values[i],
1757 ref_values[i]
1758 );
1759 }
1760
1761 if period == 1 {
1762 for i in first_valid..values.len() {
1763 let expected = close[i] > source[i];
1764 prop_assert_eq!(
1765 values[i], expected,
1766 "Period=1 mismatch at index {}: close={}, source={}, expected={}, got={}",
1767 i, close[i], source[i], expected, values[i]
1768 );
1769 }
1770 }
1771
1772 let all_source_same = source.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1773 let all_close_same = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1774 if all_source_same && all_close_same && !source.is_empty() && !close.is_empty()
1775 {
1776 let expected_const = close[0] > source[0];
1777 for i in warmup_end..values.len() {
1778 prop_assert_eq!(
1779 values[i],
1780 expected_const,
1781 "Constant input mismatch at index {}: expected={}, got={}",
1782 i,
1783 expected_const,
1784 values[i]
1785 );
1786 }
1787 }
1788
1789 let mut transitions = 0;
1790 for i in (warmup_end + 1)..values.len() {
1791 if values[i] != values[i - 1] {
1792 transitions += 1;
1793
1794 let mut sum = 0.0;
1795 for j in (i + 1 - period)..=i {
1796 sum += source[j];
1797 }
1798 let avg = sum / (period as f64);
1799
1800 prop_assert!(
1801 (close[i] - avg).abs() < source[i] * 0.1
1802 || (values[i] && close[i] > avg)
1803 || (!values[i] && close[i] <= avg),
1804 "Invalid transition at index {}: close={:.4}, avg={:.4}, value={}",
1805 i,
1806 close[i],
1807 avg,
1808 values[i]
1809 );
1810 }
1811 }
1812
1813 if period == source.len() - 1 && source.len() > 2 {
1814 for i in 0..(source.len() - 1) {
1815 prop_assert!(
1816 !values[i],
1817 "Expected false for extreme period at index {} (period={}, len={})",
1818 i,
1819 period,
1820 source.len()
1821 );
1822 }
1823 }
1824
1825 let result2 = ttm_trend_with_kernel(&input, kernel)?;
1826 for i in 0..values.len() {
1827 prop_assert_eq!(
1828 values[i],
1829 result2.values[i],
1830 "Non-deterministic result at index {}",
1831 i
1832 );
1833 }
1834
1835 Ok(())
1836 },
1837 )
1838 .unwrap();
1839
1840 Ok(())
1841 }
1842
1843 #[cfg(feature = "proptest")]
1844 generate_all_ttm_tests!(check_ttm_trend_property);
1845}
1846
1847#[cfg(feature = "python")]
1848use crate::utilities::kernel_validation::validate_kernel;
1849#[cfg(feature = "python")]
1850use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1851#[cfg(feature = "python")]
1852use pyo3::exceptions::PyValueError;
1853#[cfg(feature = "python")]
1854use pyo3::prelude::*;
1855#[cfg(feature = "python")]
1856use pyo3::types::PyDict;
1857
1858#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1859use serde::{Deserialize, Serialize};
1860#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1861use wasm_bindgen::prelude::*;
1862
1863#[cfg(feature = "python")]
1864#[pyfunction(name = "ttm_trend")]
1865#[pyo3(signature = (source, close, period, kernel=None))]
1866pub fn ttm_trend_py<'py>(
1867 py: Python<'py>,
1868 source: PyReadonlyArray1<'py, f64>,
1869 close: PyReadonlyArray1<'py, f64>,
1870 period: usize,
1871 kernel: Option<&str>,
1872) -> PyResult<Bound<'py, PyArray1<f64>>> {
1873 let s = source.as_slice()?;
1874 let c = close.as_slice()?;
1875 let kern = validate_kernel(kernel, false)?;
1876 let params = TtmTrendParams {
1877 period: Some(period),
1878 };
1879 let input = TtmTrendInput::from_slices(s, c, params);
1880
1881 let len = s.len().min(c.len());
1882 let out = unsafe { PyArray1::<f64>::new(py, [len], false) };
1883 let dst = unsafe { out.as_slice_mut()? };
1884
1885 py.allow_threads(|| {
1886 let (ss, cc, p, first, chosen) = ttm_prepare(&input, kern).map_err(|e| e.to_string())?;
1887 for v in &mut dst[..first + p - 1] {
1888 *v = f64::NAN;
1889 }
1890 ttm_numeric_compute_into(ss, cc, p, first, dst);
1891 Ok::<(), String>(())
1892 })
1893 .map_err(|e: String| PyValueError::new_err(e))?;
1894
1895 Ok(out)
1896}
1897
1898#[cfg(feature = "python")]
1899#[pyclass(name = "TtmTrendStream")]
1900pub struct TtmTrendStreamPy {
1901 stream: TtmTrendStream,
1902}
1903
1904#[cfg(feature = "python")]
1905#[pymethods]
1906impl TtmTrendStreamPy {
1907 #[new]
1908 fn new(period: usize) -> PyResult<Self> {
1909 let params = TtmTrendParams {
1910 period: Some(period),
1911 };
1912 let stream =
1913 TtmTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1914 Ok(TtmTrendStreamPy { stream })
1915 }
1916
1917 fn update(&mut self, source_val: f64, close_val: f64) -> Option<bool> {
1918 self.stream.update(source_val, close_val)
1919 }
1920}
1921
1922#[cfg(feature = "python")]
1923#[pyfunction(name = "ttm_trend_batch")]
1924#[pyo3(signature = (source, close, period_range, kernel=None))]
1925pub fn ttm_trend_batch_py<'py>(
1926 py: Python<'py>,
1927 source: PyReadonlyArray1<'py, f64>,
1928 close: PyReadonlyArray1<'py, f64>,
1929 period_range: (usize, usize, usize),
1930 kernel: Option<&str>,
1931) -> PyResult<Bound<'py, PyDict>> {
1932 let s = source.as_slice()?;
1933 let c = close.as_slice()?;
1934 let kern = validate_kernel(kernel, true)?;
1935 let sweep = TtmTrendBatchRange {
1936 period: period_range,
1937 };
1938
1939 let (vals_f64, combos, rows, cols) = py
1940 .allow_threads(|| {
1941 let k = match kern {
1942 Kernel::Auto => detect_best_batch_kernel(),
1943 k => k,
1944 };
1945 let simd = match k {
1946 Kernel::Avx512Batch => Kernel::Avx512,
1947 Kernel::Avx2Batch => Kernel::Avx2,
1948 Kernel::ScalarBatch => Kernel::Scalar,
1949 _ => k,
1950 };
1951 ttm_batch_inner_f64(s, c, &sweep, simd, true).map_err(|e| e.to_string())
1952 })
1953 .map_err(|e: String| PyValueError::new_err(e))?;
1954
1955 let dict = PyDict::new(py);
1956
1957 let arr = unsafe { PyArray1::<f64>::from_vec(py, vals_f64) }.reshape((rows, cols))?;
1958 dict.set_item("values", arr)?;
1959 dict.set_item(
1960 "periods",
1961 combos
1962 .iter()
1963 .map(|p| p.period.unwrap() as u64)
1964 .collect::<Vec<_>>()
1965 .into_pyarray(py),
1966 )?;
1967 Ok(dict)
1968}
1969
1970#[inline]
1971pub fn ttm_trend_into_slice_f64(
1972 dst: &mut [f64],
1973 input: &TtmTrendInput,
1974 kern: Kernel,
1975) -> Result<(), TtmTrendError> {
1976 ttm_numeric_into_slice(dst, input, kern)
1977}
1978
1979#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1980#[inline]
1981pub fn ttm_trend_into(input: &TtmTrendInput, out: &mut [f64]) -> Result<(), TtmTrendError> {
1982 ttm_trend_into_slice_f64(out, input, Kernel::Auto)
1983}
1984
1985#[inline]
1986pub fn ttm_trend_into_slice(
1987 dst: &mut [bool],
1988 input: &TtmTrendInput,
1989 kern: Kernel,
1990) -> Result<(), TtmTrendError> {
1991 let (source, close, period, first, chosen) = ttm_prepare(input, kern)?;
1992 let len = source.len().min(close.len());
1993 if dst.len() != len {
1994 return Err(TtmTrendError::OutputLengthMismatch {
1995 expected: len,
1996 got: dst.len(),
1997 });
1998 }
1999 match chosen {
2000 Kernel::Scalar | Kernel::ScalarBatch => ttm_trend_scalar(source, close, period, first, dst),
2001 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2002 Kernel::Avx2 | Kernel::Avx2Batch => ttm_trend_avx2(source, close, period, first, dst),
2003 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2004 Kernel::Avx512 | Kernel::Avx512Batch => ttm_trend_avx512(source, close, period, first, dst),
2005 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2006 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
2007 ttm_trend_scalar(source, close, period, first, dst)
2008 }
2009 _ => unreachable!(),
2010 }
2011 Ok(())
2012}
2013
2014#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2015#[wasm_bindgen]
2016pub fn ttm_trend_js(source: &[f64], close: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2017 let input = TtmTrendInput::from_slices(
2018 source,
2019 close,
2020 TtmTrendParams {
2021 period: Some(period),
2022 },
2023 );
2024 let (s, c, p, first, _) =
2025 ttm_prepare(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
2026 let mut out = alloc_with_nan_prefix(s.len().min(c.len()), first + p - 1);
2027 ttm_numeric_compute_into(s, c, p, first, &mut out);
2028 Ok(out)
2029}
2030
2031#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2032#[wasm_bindgen]
2033pub fn ttm_trend_into(
2034 source_ptr: *const f64,
2035 close_ptr: *const f64,
2036 out_ptr: *mut f64,
2037 len: usize,
2038 period: usize,
2039) -> Result<(), JsValue> {
2040 if source_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2041 return Err(JsValue::from_str("null pointer"));
2042 }
2043 unsafe {
2044 let s = std::slice::from_raw_parts(source_ptr, len);
2045 let c = std::slice::from_raw_parts(close_ptr, len);
2046 let input = TtmTrendInput::from_slices(
2047 s,
2048 c,
2049 TtmTrendParams {
2050 period: Some(period),
2051 },
2052 );
2053 let (ss, cc, p, first, _) =
2054 ttm_prepare(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
2055 let out = std::slice::from_raw_parts_mut(out_ptr, ss.len().min(cc.len()));
2056 for v in &mut out[..first + p - 1] {
2057 *v = f64::NAN;
2058 }
2059 ttm_numeric_compute_into(ss, cc, p, first, out);
2060 Ok(())
2061 }
2062}
2063
2064#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2065#[wasm_bindgen]
2066pub fn ttm_trend_alloc(len: usize) -> *mut f64 {
2067 let mut vec = Vec::<f64>::with_capacity(len);
2068 let ptr = vec.as_mut_ptr();
2069 std::mem::forget(vec);
2070 ptr
2071}
2072
2073#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2074#[wasm_bindgen]
2075pub fn ttm_trend_free(ptr: *mut f64, len: usize) {
2076 if !ptr.is_null() {
2077 unsafe {
2078 let _ = Vec::from_raw_parts(ptr, len, len);
2079 }
2080 }
2081}
2082
2083#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2084#[derive(Serialize, Deserialize)]
2085pub struct TtmTrendBatchConfig {
2086 pub period_range: (usize, usize, usize),
2087}
2088
2089#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2090#[derive(Serialize, Deserialize)]
2091pub struct TtmTrendBatchJsOutput {
2092 pub values: Vec<f64>,
2093 pub periods: Vec<usize>,
2094 pub rows: usize,
2095 pub cols: usize,
2096}
2097
2098#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2099#[wasm_bindgen(js_name = ttm_trend_batch)]
2100pub fn ttm_trend_batch_js(
2101 source: &[f64],
2102 close: &[f64],
2103 config: JsValue,
2104) -> Result<JsValue, JsValue> {
2105 let config: TtmTrendBatchConfig = serde_wasm_bindgen::from_value(config)
2106 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2107
2108 if config.period_range.0 == 0 {
2109 return Err(JsValue::from_str("Invalid period: period must be > 0"));
2110 }
2111 if config.period_range.1 < config.period_range.0 {
2112 return Err(JsValue::from_str(
2113 "Invalid period range: end must be >= start",
2114 ));
2115 }
2116
2117 let sweep = TtmTrendBatchRange {
2118 period: config.period_range,
2119 };
2120 let kernel = detect_best_batch_kernel();
2121 let simd = match kernel {
2122 Kernel::Avx512Batch => Kernel::Avx512,
2123 Kernel::Avx2Batch => Kernel::Avx2,
2124 Kernel::ScalarBatch => Kernel::Scalar,
2125 _ => kernel,
2126 };
2127
2128 let (values, combos, rows, cols) = ttm_batch_inner_f64(source, close, &sweep, simd, false)
2129 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2130
2131 let output = TtmTrendBatchJsOutput {
2132 values,
2133 periods: combos.iter().map(|p| p.period.unwrap()).collect(),
2134 rows,
2135 cols,
2136 };
2137
2138 serde_wasm_bindgen::to_value(&output)
2139 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2140}