1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::{cuda_available, CudaHalftrend};
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
5#[cfg(feature = "python")]
6use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
7#[cfg(feature = "python")]
8use pyo3::exceptions::PyValueError;
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11#[cfg(feature = "python")]
12use pyo3::types::PyDict;
13
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use serde::{Deserialize, Serialize};
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19use crate::utilities::data_loader::{source_type, CandleFieldFlags, Candles};
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
23 make_uninit_matrix,
24};
25#[cfg(feature = "python")]
26use crate::utilities::kernel_validation::validate_kernel;
27
28use crate::indicators::atr::{atr, AtrInput, AtrOutput, AtrParams};
29use crate::indicators::moving_averages::sma::{sma, SmaInput, SmaOutput, SmaParams};
30
31#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
32use core::arch::x86_64::*;
33
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36
37use std::collections::{BTreeSet, HashMap};
38use std::convert::AsRef;
39use std::error::Error;
40use std::mem::MaybeUninit;
41use thiserror::Error;
42
43#[derive(Debug, Clone)]
44pub struct HalfTrendOutput {
45 pub halftrend: Vec<f64>,
46 pub trend: Vec<f64>,
47 pub atr_high: Vec<f64>,
48 pub atr_low: Vec<f64>,
49 pub buy_signal: Vec<f64>,
50 pub sell_signal: Vec<f64>,
51}
52
53#[derive(Debug, Clone)]
54#[cfg_attr(
55 all(target_arch = "wasm32", feature = "wasm"),
56 derive(Serialize, Deserialize)
57)]
58pub struct HalfTrendParams {
59 pub amplitude: Option<usize>,
60 pub channel_deviation: Option<f64>,
61 pub atr_period: Option<usize>,
62}
63
64impl Default for HalfTrendParams {
65 fn default() -> Self {
66 Self {
67 amplitude: Some(2),
68 channel_deviation: Some(2.0),
69 atr_period: Some(100),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub enum HalfTrendData<'a> {
76 Candles(&'a Candles),
77 Slices {
78 high: &'a [f64],
79 low: &'a [f64],
80 close: &'a [f64],
81 },
82}
83
84#[derive(Debug, Clone)]
85pub struct HalfTrendInput<'a> {
86 pub data: HalfTrendData<'a>,
87 pub params: HalfTrendParams,
88}
89
90impl<'a> HalfTrendInput<'a> {
91 #[inline]
92 pub fn from_candles(c: &'a Candles, p: HalfTrendParams) -> Self {
93 Self {
94 data: HalfTrendData::Candles(c),
95 params: p,
96 }
97 }
98
99 #[inline]
100 pub fn from_slices(h: &'a [f64], l: &'a [f64], c: &'a [f64], p: HalfTrendParams) -> Self {
101 Self {
102 data: HalfTrendData::Slices {
103 high: h,
104 low: l,
105 close: c,
106 },
107 params: p,
108 }
109 }
110
111 #[inline]
112 pub fn with_default_candles(c: &'a Candles) -> Self {
113 Self::from_candles(c, HalfTrendParams::default())
114 }
115
116 #[inline]
117 pub fn as_slices(&self) -> (&[f64], &[f64], &[f64]) {
118 match &self.data {
119 HalfTrendData::Candles(c) => (&c.high, &c.low, &c.close),
120 HalfTrendData::Slices { high, low, close } => (*high, *low, *close),
121 }
122 }
123
124 #[inline]
125 pub fn get_amplitude(&self) -> usize {
126 self.params.amplitude.unwrap_or(2)
127 }
128
129 #[inline]
130 pub fn get_channel_deviation(&self) -> f64 {
131 self.params.channel_deviation.unwrap_or(2.0)
132 }
133
134 #[inline]
135 pub fn get_atr_period(&self) -> usize {
136 self.params.atr_period.unwrap_or(100)
137 }
138}
139
140#[derive(Copy, Clone, Debug)]
141pub struct HalfTrendBuilder {
142 amplitude: Option<usize>,
143 channel_deviation: Option<f64>,
144 atr_period: Option<usize>,
145 kernel: Kernel,
146}
147
148impl Default for HalfTrendBuilder {
149 fn default() -> Self {
150 Self {
151 amplitude: None,
152 channel_deviation: None,
153 atr_period: None,
154 kernel: Kernel::Auto,
155 }
156 }
157}
158
159impl HalfTrendBuilder {
160 #[inline(always)]
161 pub fn new() -> Self {
162 Self::default()
163 }
164
165 #[inline(always)]
166 pub fn amplitude(mut self, val: usize) -> Self {
167 self.amplitude = Some(val);
168 self
169 }
170
171 #[inline(always)]
172 pub fn channel_deviation(mut self, val: f64) -> Self {
173 self.channel_deviation = Some(val);
174 self
175 }
176
177 #[inline(always)]
178 pub fn atr_period(mut self, val: usize) -> Self {
179 self.atr_period = Some(val);
180 self
181 }
182
183 #[inline(always)]
184 pub fn kernel(mut self, k: Kernel) -> Self {
185 self.kernel = k;
186 self
187 }
188
189 #[inline(always)]
190 pub fn apply(self, c: &Candles) -> Result<HalfTrendOutput, HalfTrendError> {
191 let p = HalfTrendParams {
192 amplitude: self.amplitude,
193 channel_deviation: self.channel_deviation,
194 atr_period: self.atr_period,
195 };
196 let i = HalfTrendInput::from_candles(c, p);
197 halftrend_with_kernel(&i, self.kernel)
198 }
199
200 #[inline(always)]
201 pub fn apply_slices(
202 self,
203 h: &[f64],
204 l: &[f64],
205 c: &[f64],
206 ) -> Result<HalfTrendOutput, HalfTrendError> {
207 let p = HalfTrendParams {
208 amplitude: self.amplitude,
209 channel_deviation: self.channel_deviation,
210 atr_period: self.atr_period,
211 };
212 let i = HalfTrendInput::from_slices(h, l, c, p);
213 halftrend_with_kernel(&i, self.kernel)
214 }
215
216 #[inline(always)]
217 pub fn into_stream(self) -> Result<HalfTrendStream, HalfTrendError> {
218 let p = HalfTrendParams {
219 amplitude: self.amplitude,
220 channel_deviation: self.channel_deviation,
221 atr_period: self.atr_period,
222 };
223 HalfTrendStream::try_new(p)
224 }
225
226 pub fn with_default_candles(c: &Candles) -> Result<HalfTrendOutput, HalfTrendError> {
227 Self::new().apply(c)
228 }
229
230 pub fn with_default_slices(
231 h: &[f64],
232 l: &[f64],
233 c: &[f64],
234 ) -> Result<HalfTrendOutput, HalfTrendError> {
235 Self::new().apply_slices(h, l, c)
236 }
237
238 #[inline(always)]
239 pub fn apply_candles(self, c: &Candles) -> Result<HalfTrendOutput, HalfTrendError> {
240 self.apply(c)
241 }
242
243 #[inline(always)]
244 pub fn apply_slice_triplet(
245 self,
246 h: &[f64],
247 l: &[f64],
248 c: &[f64],
249 ) -> Result<HalfTrendOutput, HalfTrendError> {
250 self.apply_slices(h, l, c)
251 }
252}
253
254#[derive(Debug, Error)]
255pub enum HalfTrendError {
256 #[error("halftrend: Input data slice is empty.")]
257 EmptyInputData,
258
259 #[error("halftrend: All values are NaN.")]
260 AllValuesNaN,
261
262 #[error("halftrend: Invalid period: period = {period}, data length = {data_len}")]
263 InvalidPeriod { period: usize, data_len: usize },
264
265 #[error("halftrend: Not enough valid data: needed = {needed}, valid = {valid}")]
266 NotEnoughValidData { needed: usize, valid: usize },
267
268 #[error("halftrend: ATR calculation failed: {0}")]
269 AtrError(String),
270
271 #[error("halftrend: SMA calculation failed: {0}")]
272 SmaError(String),
273
274 #[error("halftrend: Invalid channel_deviation: {channel_deviation}")]
275 InvalidChannelDeviation { channel_deviation: f64 },
276
277 #[error("halftrend: Output length mismatch: expected {expected}, got {got}")]
278 OutputLengthMismatch { expected: usize, got: usize },
279
280 #[error("halftrend: Invalid range: start={start}, end={end}, step={step}")]
281 InvalidRange {
282 start: String,
283 end: String,
284 step: String,
285 },
286
287 #[error("halftrend: Invalid kernel for batch: {0:?}")]
288 InvalidKernelForBatch(crate::utilities::enums::Kernel),
289}
290
291#[inline(always)]
292fn first_valid_ohlc(high: &[f64], low: &[f64], close: &[f64]) -> usize {
293 let fh = high.iter().position(|x| !x.is_nan()).unwrap_or(usize::MAX);
294 let fl = low.iter().position(|x| !x.is_nan()).unwrap_or(usize::MAX);
295 let fc = close.iter().position(|x| !x.is_nan()).unwrap_or(usize::MAX);
296 fh.min(fl).min(fc)
297}
298
299pub fn halftrend(input: &HalfTrendInput) -> Result<HalfTrendOutput, HalfTrendError> {
300 halftrend_with_kernel(input, Kernel::Auto)
301}
302
303pub fn halftrend_with_kernel(
304 input: &HalfTrendInput,
305 kernel: Kernel,
306) -> Result<HalfTrendOutput, HalfTrendError> {
307 let mut chosen = match kernel {
308 Kernel::Auto => Kernel::Scalar,
309 k => k,
310 };
311
312 let (high, low, close) = input.as_slices();
313
314 if high.is_empty() || low.is_empty() || close.is_empty() {
315 return Err(HalfTrendError::EmptyInputData);
316 }
317
318 let len = high.len();
319 if len != low.len() || len != close.len() {
320 return Err(HalfTrendError::InvalidPeriod {
321 period: len,
322 data_len: high.len().max(low.len()).max(close.len()),
323 });
324 }
325
326 let amplitude = input.get_amplitude();
327 let channel_deviation = input.get_channel_deviation();
328 let atr_period = input.get_atr_period();
329
330 if amplitude == 0 || amplitude > len {
331 return Err(HalfTrendError::InvalidPeriod {
332 period: amplitude,
333 data_len: len,
334 });
335 }
336
337 if !(channel_deviation.is_finite()) || channel_deviation <= 0.0 {
338 return Err(HalfTrendError::InvalidChannelDeviation { channel_deviation });
339 }
340
341 if atr_period == 0 || atr_period > len {
342 return Err(HalfTrendError::InvalidPeriod {
343 period: atr_period,
344 data_len: len,
345 });
346 }
347
348 if matches!(kernel, Kernel::Auto)
349 && amplitude == 2
350 && channel_deviation == 2.0
351 && atr_period == 100
352 {
353 chosen = Kernel::Scalar;
354 }
355
356 let first = first_valid_ohlc(high, low, close);
357 if first == usize::MAX {
358 return Err(HalfTrendError::AllValuesNaN);
359 }
360
361 let warmup_span = amplitude.max(atr_period);
362 if len - first < warmup_span {
363 return Err(HalfTrendError::NotEnoughValidData {
364 needed: warmup_span,
365 valid: len - first,
366 });
367 }
368 let warm = first + warmup_span - 1;
369
370 if chosen == Kernel::Scalar && amplitude == 2 && channel_deviation == 2.0 && atr_period == 100 {
371 let mut halftrend = alloc_with_nan_prefix(len, warm);
372 let mut trend = alloc_with_nan_prefix(len, warm);
373 let mut atr_high = alloc_with_nan_prefix(len, warm);
374 let mut atr_low = alloc_with_nan_prefix(len, warm);
375 let mut buy_signal = alloc_with_nan_prefix(len, warm);
376 let mut sell_signal = alloc_with_nan_prefix(len, warm);
377
378 unsafe {
379 halftrend_scalar_classic(
380 high,
381 low,
382 close,
383 amplitude,
384 channel_deviation,
385 atr_period,
386 first,
387 warm,
388 &mut halftrend,
389 &mut trend,
390 &mut atr_high,
391 &mut atr_low,
392 &mut buy_signal,
393 &mut sell_signal,
394 )?;
395 }
396
397 return Ok(HalfTrendOutput {
398 halftrend,
399 trend,
400 atr_high,
401 atr_low,
402 buy_signal,
403 sell_signal,
404 });
405 }
406
407 let mut halftrend = alloc_with_nan_prefix(len, warm);
408 let mut trend = alloc_with_nan_prefix(len, warm);
409 let mut atr_high = alloc_with_nan_prefix(len, warm);
410 let mut atr_low = alloc_with_nan_prefix(len, warm);
411 let mut buy_signal = alloc_with_nan_prefix(len, warm);
412 let mut sell_signal = alloc_with_nan_prefix(len, warm);
413
414 let atr_input = AtrInput::from_slices(
415 high,
416 low,
417 close,
418 AtrParams {
419 length: Some(atr_period),
420 },
421 );
422 let AtrOutput { values: atr_values } =
423 atr(&atr_input).map_err(|e| HalfTrendError::AtrError(e.to_string()))?;
424
425 let sma_high_input = SmaInput::from_slice(
426 high,
427 SmaParams {
428 period: Some(amplitude),
429 },
430 );
431 let SmaOutput { values: highma } =
432 sma(&sma_high_input).map_err(|e| HalfTrendError::SmaError(e.to_string()))?;
433
434 let sma_low_input = SmaInput::from_slice(
435 low,
436 SmaParams {
437 period: Some(amplitude),
438 },
439 );
440 let SmaOutput { values: lowma } =
441 sma(&sma_low_input).map_err(|e| HalfTrendError::SmaError(e.to_string()))?;
442
443 halftrend_compute_into(
444 high,
445 low,
446 close,
447 amplitude,
448 channel_deviation,
449 &atr_values,
450 &highma,
451 &lowma,
452 warm,
453 chosen,
454 &mut halftrend,
455 &mut trend,
456 &mut atr_high,
457 &mut atr_low,
458 &mut buy_signal,
459 &mut sell_signal,
460 );
461
462 Ok(HalfTrendOutput {
463 halftrend,
464 trend,
465 atr_high,
466 atr_low,
467 buy_signal,
468 sell_signal,
469 })
470}
471
472#[inline(always)]
473fn halftrend_compute_into(
474 high: &[f64],
475 low: &[f64],
476 close: &[f64],
477 amplitude: usize,
478 channel_deviation: f64,
479 atr_values: &[f64],
480 highma: &[f64],
481 lowma: &[f64],
482 start_idx: usize,
483 kernel: Kernel,
484 halftrend: &mut [f64],
485 trend: &mut [f64],
486 atr_high: &mut [f64],
487 atr_low: &mut [f64],
488 buy_signal: &mut [f64],
489 sell_signal: &mut [f64],
490) {
491 match kernel {
492 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
493 Kernel::Avx512 => unsafe {
494 halftrend_avx512(
495 high,
496 low,
497 close,
498 amplitude,
499 channel_deviation,
500 atr_values,
501 highma,
502 lowma,
503 start_idx,
504 halftrend,
505 trend,
506 atr_high,
507 atr_low,
508 buy_signal,
509 sell_signal,
510 )
511 },
512 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
513 Kernel::Avx2 => unsafe {
514 halftrend_avx2(
515 high,
516 low,
517 close,
518 amplitude,
519 channel_deviation,
520 atr_values,
521 highma,
522 lowma,
523 start_idx,
524 halftrend,
525 trend,
526 atr_high,
527 atr_low,
528 buy_signal,
529 sell_signal,
530 )
531 },
532 _ => halftrend_scalar(
533 high,
534 low,
535 close,
536 amplitude,
537 channel_deviation,
538 atr_values,
539 highma,
540 lowma,
541 start_idx,
542 halftrend,
543 trend,
544 atr_high,
545 atr_low,
546 buy_signal,
547 sell_signal,
548 ),
549 }
550}
551
552#[inline]
553#[inline(always)]
554pub unsafe fn halftrend_scalar_classic(
555 high: &[f64],
556 low: &[f64],
557 close: &[f64],
558 amplitude: usize,
559 channel_deviation: f64,
560 atr_period: usize,
561 first: usize,
562 warm: usize,
563 halftrend: &mut [f64],
564 trend: &mut [f64],
565 atr_high: &mut [f64],
566 atr_low: &mut [f64],
567 buy_signal: &mut [f64],
568 sell_signal: &mut [f64],
569) -> Result<(), HalfTrendError> {
570 let len = high.len();
571 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
572
573 let alpha = 1.0 / atr_period as f64;
574 let atr_warm = first + atr_period - 1;
575 let sma_warm = first + amplitude - 1;
576
577 let mut sum_tr = 0.0;
578 for i in first..=atr_warm.min(len - 1) {
579 let tr = if i == first {
580 high[i] - low[i]
581 } else {
582 let hl = high[i] - low[i];
583 let hc = (high[i] - close[i - 1]).abs();
584 let lc = (low[i] - close[i - 1]).abs();
585 hl.max(hc).max(lc)
586 };
587 sum_tr += tr;
588 }
589 let mut rma = sum_tr / atr_period as f64;
590
591 let mut sum_high = 0.0;
592 let mut sum_low = 0.0;
593 for i in first..=sma_warm.min(len - 1) {
594 sum_high += high[i];
595 sum_low += low[i];
596 }
597
598 for i in (sma_warm + 1)..=warm.min(len - 1) {
599 sum_high = sum_high - high[i - amplitude] + high[i];
600 sum_low = sum_low - low[i - amplitude] + low[i];
601 }
602 let inv_amp = 1.0 / amplitude as f64;
603
604 let mut current_trend = 0i32;
605 let mut next_trend = 0i32;
606 let mut up = 0.0f64;
607 let mut down = 0.0f64;
608 let mut max_low_price = if warm > 0 { low[warm - 1] } else { low[0] };
609 let mut min_high_price = if warm > 0 { high[warm - 1] } else { high[0] };
610
611 let ch_half = channel_deviation * 0.5;
612 for i in warm..len {
613 buy_signal[i] = qnan;
614 sell_signal[i] = qnan;
615
616 let highma_i = sum_high * inv_amp;
617 let lowma_i = sum_low * inv_amp;
618
619 let high_price = if high[i] > high[i - 1] {
620 high[i]
621 } else {
622 high[i - 1]
623 };
624 let low_price = if low[i] < low[i - 1] {
625 low[i]
626 } else {
627 low[i - 1]
628 };
629
630 let prev_low = low[i - 1];
631 let prev_high = high[i - 1];
632
633 if next_trend == 1 {
634 if low_price > max_low_price {
635 max_low_price = low_price;
636 }
637 if highma_i < max_low_price && close[i] < prev_low {
638 current_trend = 1;
639 next_trend = 0;
640 min_high_price = high_price;
641 }
642 } else {
643 if high_price < min_high_price {
644 min_high_price = high_price;
645 }
646 if lowma_i > min_high_price && close[i] > prev_high {
647 current_trend = 0;
648 next_trend = 1;
649 max_low_price = low_price;
650 }
651 }
652
653 let a = rma;
654 let atr2 = 0.5 * a;
655 let dev = a.mul_add(ch_half, 0.0);
656
657 if current_trend == 0 {
658 if i > warm && trend[i - 1] != 0.0 {
659 up = down;
660 buy_signal[i] = up - atr2;
661 } else {
662 up = if i == warm || up == 0.0 {
663 max_low_price
664 } else if max_low_price > up {
665 max_low_price
666 } else {
667 up
668 };
669 }
670 halftrend[i] = up;
671 atr_high[i] = up + dev;
672 atr_low[i] = up - dev;
673 trend[i] = 0.0;
674 } else {
675 if i > warm && trend[i - 1] != 1.0 {
676 down = up;
677 sell_signal[i] = down + atr2;
678 } else {
679 down = if i == warm || down == 0.0 {
680 min_high_price
681 } else if min_high_price < down {
682 min_high_price
683 } else {
684 down
685 };
686 }
687 halftrend[i] = down;
688 atr_high[i] = down + dev;
689 atr_low[i] = down - dev;
690 trend[i] = 1.0;
691 }
692
693 let ni = i + 1;
694 if ni < len {
695 sum_high = sum_high - high[ni - amplitude] + high[ni];
696 sum_low = sum_low - low[ni - amplitude] + low[ni];
697
698 let hl = high[ni] - low[ni];
699 let hc = (high[ni] - close[ni - 1]).abs();
700 let lc = (low[ni] - close[ni - 1]).abs();
701 let tr = hl.max(hc).max(lc);
702 rma += alpha * (tr - rma);
703 }
704 }
705
706 Ok(())
707}
708
709pub fn halftrend_scalar(
710 high: &[f64],
711 low: &[f64],
712 close: &[f64],
713 amplitude: usize,
714 channel_deviation: f64,
715 atr_values: &[f64],
716 highma: &[f64],
717 lowma: &[f64],
718 start_idx: usize,
719 halftrend: &mut [f64],
720 trend: &mut [f64],
721 atr_high: &mut [f64],
722 atr_low: &mut [f64],
723 buy_signal: &mut [f64],
724 sell_signal: &mut [f64],
725) {
726 let len = high.len();
727 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
728
729 let cap = amplitude.max(1);
730 let mut max_idx = vec![0usize; cap];
731 let mut max_val = vec![0.0f64; cap];
732 let mut min_idx = vec![0usize; cap];
733 let mut min_val = vec![0.0f64; cap];
734 let (mut max_head, mut max_tail, mut max_cnt) = (0usize, 0usize, 0usize);
735 let (mut min_head, mut min_tail, mut min_cnt) = (0usize, 0usize, 0usize);
736
737 #[inline(always)]
738 fn inc(i: usize, cap: usize) -> usize {
739 let j = i + 1;
740 if j == cap {
741 0
742 } else {
743 j
744 }
745 }
746 #[inline(always)]
747 fn dec(i: usize, cap: usize) -> usize {
748 if i == 0 {
749 cap - 1
750 } else {
751 i - 1
752 }
753 }
754
755 if start_idx < len {
756 debug_assert!(start_idx + 1 >= cap);
757 let wstart0 = start_idx + 1 - cap;
758 for k in wstart0..=start_idx {
759 let hv = high[k];
760 while max_cnt > 0 {
761 let back = dec(max_tail, cap);
762 if max_val[back] <= hv {
763 max_tail = back;
764 max_cnt -= 1;
765 } else {
766 break;
767 }
768 }
769 max_val[max_tail] = hv;
770 max_idx[max_tail] = k;
771 max_tail = inc(max_tail, cap);
772 max_cnt += 1;
773
774 let lv = low[k];
775 while min_cnt > 0 {
776 let back = dec(min_tail, cap);
777 if min_val[back] >= lv {
778 min_tail = back;
779 min_cnt -= 1;
780 } else {
781 break;
782 }
783 }
784 min_val[min_tail] = lv;
785 min_idx[min_tail] = k;
786 min_tail = inc(min_tail, cap);
787 min_cnt += 1;
788 }
789 }
790
791 let mut current_trend = 0i32;
792 let mut next_trend = 0i32;
793 let mut up = 0.0f64;
794 let mut down = 0.0f64;
795 let mut max_low_price = if start_idx > 0 {
796 low[start_idx - 1]
797 } else {
798 low[0]
799 };
800 let mut min_high_price = if start_idx > 0 {
801 high[start_idx - 1]
802 } else {
803 high[0]
804 };
805
806 let ch_half = channel_deviation * 0.5;
807
808 for i in start_idx..len {
809 buy_signal[i] = qnan;
810 sell_signal[i] = qnan;
811
812 if i > start_idx {
813 let wstart = i + 1 - cap;
814 while max_cnt > 0 && max_idx[max_head] < wstart {
815 max_head = inc(max_head, cap);
816 max_cnt -= 1;
817 }
818 while min_cnt > 0 && min_idx[min_head] < wstart {
819 min_head = inc(min_head, cap);
820 min_cnt -= 1;
821 }
822
823 let hv = high[i];
824 while max_cnt > 0 {
825 let back = dec(max_tail, cap);
826 if max_val[back] <= hv {
827 max_tail = back;
828 max_cnt -= 1;
829 } else {
830 break;
831 }
832 }
833 max_val[max_tail] = hv;
834 max_idx[max_tail] = i;
835 max_tail = inc(max_tail, cap);
836 max_cnt += 1;
837
838 let lv = low[i];
839 while min_cnt > 0 {
840 let back = dec(min_tail, cap);
841 if min_val[back] >= lv {
842 min_tail = back;
843 min_cnt -= 1;
844 } else {
845 break;
846 }
847 }
848 min_val[min_tail] = lv;
849 min_idx[min_tail] = i;
850 min_tail = inc(min_tail, cap);
851 min_cnt += 1;
852 }
853
854 debug_assert!(max_cnt > 0 && min_cnt > 0);
855 let high_price = max_val[max_head];
856 let low_price = min_val[min_head];
857
858 let prev_low = if i > 0 { low[i - 1] } else { low[0] };
859 let prev_high = if i > 0 { high[i - 1] } else { high[0] };
860
861 if next_trend == 1 {
862 if low_price > max_low_price {
863 max_low_price = low_price;
864 }
865 if highma[i] < max_low_price && close[i] < prev_low {
866 current_trend = 1;
867 next_trend = 0;
868 min_high_price = high_price;
869 }
870 } else {
871 if high_price < min_high_price {
872 min_high_price = high_price;
873 }
874 if lowma[i] > min_high_price && close[i] > prev_high {
875 current_trend = 0;
876 next_trend = 1;
877 max_low_price = low_price;
878 }
879 }
880
881 let a = atr_values[i];
882 let atr2 = 0.5 * a;
883 let dev = a * ch_half;
884
885 if current_trend == 0 {
886 if i > start_idx && trend[i - 1] != 0.0 {
887 up = down;
888 buy_signal[i] = up - atr2;
889 } else {
890 up = if i == start_idx || up == 0.0 {
891 max_low_price
892 } else if max_low_price > up {
893 max_low_price
894 } else {
895 up
896 };
897 }
898 halftrend[i] = up;
899 atr_high[i] = up + dev;
900 atr_low[i] = up - dev;
901 trend[i] = 0.0;
902 } else {
903 if i > start_idx && trend[i - 1] != 1.0 {
904 down = up;
905 sell_signal[i] = down + atr2;
906 } else {
907 down = if i == start_idx || down == 0.0 {
908 min_high_price
909 } else if min_high_price < down {
910 min_high_price
911 } else {
912 down
913 };
914 }
915 halftrend[i] = down;
916 atr_high[i] = down + dev;
917 atr_low[i] = down - dev;
918 trend[i] = 1.0;
919 }
920 }
921}
922
923#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
924#[target_feature(enable = "avx2,fma")]
925unsafe fn halftrend_avx2(
926 high: &[f64],
927 low: &[f64],
928 close: &[f64],
929 amplitude: usize,
930 channel_deviation: f64,
931 atr_values: &[f64],
932 highma: &[f64],
933 lowma: &[f64],
934 start_idx: usize,
935 halftrend: &mut [f64],
936 trend: &mut [f64],
937 atr_high: &mut [f64],
938 atr_low: &mut [f64],
939 buy_signal: &mut [f64],
940 sell_signal: &mut [f64],
941) {
942 halftrend_scalar(
943 high,
944 low,
945 close,
946 amplitude,
947 channel_deviation,
948 atr_values,
949 highma,
950 lowma,
951 start_idx,
952 halftrend,
953 trend,
954 atr_high,
955 atr_low,
956 buy_signal,
957 sell_signal,
958 )
959}
960
961#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
962#[target_feature(enable = "avx512f,fma")]
963unsafe fn halftrend_avx512(
964 high: &[f64],
965 low: &[f64],
966 close: &[f64],
967 amplitude: usize,
968 channel_deviation: f64,
969 atr_values: &[f64],
970 highma: &[f64],
971 lowma: &[f64],
972 start_idx: usize,
973 halftrend: &mut [f64],
974 trend: &mut [f64],
975 atr_high: &mut [f64],
976 atr_low: &mut [f64],
977 buy_signal: &mut [f64],
978 sell_signal: &mut [f64],
979) {
980 halftrend_scalar(
981 high,
982 low,
983 close,
984 amplitude,
985 channel_deviation,
986 atr_values,
987 highma,
988 lowma,
989 start_idx,
990 halftrend,
991 trend,
992 atr_high,
993 atr_low,
994 buy_signal,
995 sell_signal,
996 )
997}
998
999#[inline]
1000pub fn halftrend_into_slices(
1001 out_halftrend: &mut [f64],
1002 out_trend: &mut [f64],
1003 out_atr_high: &mut [f64],
1004 out_atr_low: &mut [f64],
1005 out_buy_signal: &mut [f64],
1006 out_sell_signal: &mut [f64],
1007 input: &HalfTrendInput,
1008) -> Result<(), HalfTrendError> {
1009 let (high, low, close) = input.as_slices();
1010
1011 if high.is_empty() || low.is_empty() || close.is_empty() {
1012 return Err(HalfTrendError::EmptyInputData);
1013 }
1014
1015 let len = high.len();
1016 if out_halftrend.len() != len
1017 || out_trend.len() != len
1018 || out_atr_high.len() != len
1019 || out_atr_low.len() != len
1020 || out_buy_signal.len() != len
1021 || out_sell_signal.len() != len
1022 {
1023 return Err(HalfTrendError::OutputLengthMismatch {
1024 expected: len,
1025 got: out_halftrend.len(),
1026 });
1027 }
1028
1029 let amplitude = input.get_amplitude();
1030 let atr_period = input.get_atr_period();
1031 let channel_deviation = input.get_channel_deviation();
1032
1033 let first = first_valid_ohlc(high, low, close);
1034 if first == usize::MAX {
1035 return Err(HalfTrendError::AllValuesNaN);
1036 }
1037
1038 let warmup_span = amplitude.max(atr_period);
1039 if len - first < warmup_span {
1040 return Err(HalfTrendError::NotEnoughValidData {
1041 needed: warmup_span,
1042 valid: len - first,
1043 });
1044 }
1045 let warm = first + warmup_span - 1;
1046
1047 for v in [
1048 &mut *out_halftrend,
1049 out_trend,
1050 out_atr_high,
1051 out_atr_low,
1052 out_buy_signal,
1053 out_sell_signal,
1054 ] {
1055 for x in &mut v[..warm] {
1056 *x = f64::NAN;
1057 }
1058 }
1059
1060 let atr_out = atr(&AtrInput::from_slices(
1061 high,
1062 low,
1063 close,
1064 AtrParams {
1065 length: Some(atr_period),
1066 },
1067 ))
1068 .map_err(|e| HalfTrendError::AtrError(e.to_string()))?
1069 .values;
1070 let highma = sma(&SmaInput::from_slice(
1071 high,
1072 SmaParams {
1073 period: Some(amplitude),
1074 },
1075 ))
1076 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?
1077 .values;
1078 let lowma = sma(&SmaInput::from_slice(
1079 low,
1080 SmaParams {
1081 period: Some(amplitude),
1082 },
1083 ))
1084 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?
1085 .values;
1086
1087 halftrend_scalar(
1088 high,
1089 low,
1090 close,
1091 amplitude,
1092 channel_deviation,
1093 &atr_out,
1094 &highma,
1095 &lowma,
1096 warm,
1097 out_halftrend,
1098 out_trend,
1099 out_atr_high,
1100 out_atr_low,
1101 out_buy_signal,
1102 out_sell_signal,
1103 );
1104
1105 Ok(())
1106}
1107
1108#[inline]
1109pub fn halftrend_into_slices_kernel(
1110 out_halftrend: &mut [f64],
1111 out_trend: &mut [f64],
1112 out_atr_high: &mut [f64],
1113 out_atr_low: &mut [f64],
1114 out_buy_signal: &mut [f64],
1115 out_sell_signal: &mut [f64],
1116 input: &HalfTrendInput,
1117 kern: Kernel,
1118) -> Result<(), HalfTrendError> {
1119 let (high, low, close) = input.as_slices();
1120 if high.is_empty() || low.is_empty() || close.is_empty() {
1121 return Err(HalfTrendError::EmptyInputData);
1122 }
1123 let len = high.len();
1124 if out_halftrend.len() != len
1125 || out_trend.len() != len
1126 || out_atr_high.len() != len
1127 || out_atr_low.len() != len
1128 || out_buy_signal.len() != len
1129 || out_sell_signal.len() != len
1130 {
1131 return Err(HalfTrendError::OutputLengthMismatch {
1132 expected: len,
1133 got: out_halftrend.len(),
1134 });
1135 }
1136
1137 let amplitude = input.get_amplitude();
1138 let atr_period = input.get_atr_period();
1139 let ch = input.get_channel_deviation();
1140
1141 if amplitude == 0 || amplitude > len {
1142 return Err(HalfTrendError::InvalidPeriod {
1143 period: amplitude,
1144 data_len: len,
1145 });
1146 }
1147 if atr_period == 0 || atr_period > len {
1148 return Err(HalfTrendError::InvalidPeriod {
1149 period: atr_period,
1150 data_len: len,
1151 });
1152 }
1153 if !(ch.is_finite()) || ch <= 0.0 {
1154 return Err(HalfTrendError::InvalidChannelDeviation {
1155 channel_deviation: ch,
1156 });
1157 }
1158
1159 let first = first_valid_ohlc(high, low, close);
1160 if first == usize::MAX {
1161 return Err(HalfTrendError::AllValuesNaN);
1162 }
1163 let warmup_span = amplitude.max(atr_period);
1164 if len - first < warmup_span {
1165 return Err(HalfTrendError::NotEnoughValidData {
1166 needed: warmup_span,
1167 valid: len - first,
1168 });
1169 }
1170 let warm = first + warmup_span - 1;
1171
1172 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
1173 for x in &mut out_halftrend[..warm] {
1174 *x = qnan;
1175 }
1176 for x in &mut out_trend[..warm] {
1177 *x = qnan;
1178 }
1179 for x in &mut out_atr_high[..warm] {
1180 *x = qnan;
1181 }
1182 for x in &mut out_atr_low[..warm] {
1183 *x = qnan;
1184 }
1185 for x in &mut out_buy_signal[..warm] {
1186 *x = qnan;
1187 }
1188 for x in &mut out_sell_signal[..warm] {
1189 *x = qnan;
1190 }
1191
1192 let mut chosen = match kern {
1193 Kernel::Auto => Kernel::Scalar,
1194 k => k,
1195 };
1196
1197 if matches!(kern, Kernel::Auto) && amplitude == 2 && ch == 2.0 && atr_period == 100 {
1198 chosen = Kernel::Scalar;
1199 }
1200
1201 if chosen == Kernel::Scalar && amplitude == 2 && ch == 2.0 && atr_period == 100 {
1202 unsafe {
1203 halftrend_scalar_classic(
1204 high,
1205 low,
1206 close,
1207 amplitude,
1208 ch,
1209 atr_period,
1210 first,
1211 warm,
1212 out_halftrend,
1213 out_trend,
1214 out_atr_high,
1215 out_atr_low,
1216 out_buy_signal,
1217 out_sell_signal,
1218 )?;
1219 }
1220 return Ok(());
1221 }
1222
1223 let AtrOutput { values: av } = atr(&AtrInput::from_slices(
1224 high,
1225 low,
1226 close,
1227 AtrParams {
1228 length: Some(atr_period),
1229 },
1230 ))
1231 .map_err(|e| HalfTrendError::AtrError(e.to_string()))?;
1232 let SmaOutput { values: hma } = sma(&SmaInput::from_slice(
1233 high,
1234 SmaParams {
1235 period: Some(amplitude),
1236 },
1237 ))
1238 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?;
1239 let SmaOutput { values: lma } = sma(&SmaInput::from_slice(
1240 low,
1241 SmaParams {
1242 period: Some(amplitude),
1243 },
1244 ))
1245 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?;
1246
1247 halftrend_compute_into(
1248 high,
1249 low,
1250 close,
1251 amplitude,
1252 ch,
1253 &av,
1254 &hma,
1255 &lma,
1256 warm,
1257 chosen,
1258 out_halftrend,
1259 out_trend,
1260 out_atr_high,
1261 out_atr_low,
1262 out_buy_signal,
1263 out_sell_signal,
1264 );
1265
1266 Ok(())
1267}
1268
1269#[inline]
1270pub fn halftrend_into_slice(
1271 out_halftrend: &mut [f64],
1272 out_trend: &mut [f64],
1273 out_atr_high: &mut [f64],
1274 out_atr_low: &mut [f64],
1275 out_buy_signal: &mut [f64],
1276 out_sell_signal: &mut [f64],
1277 input: &HalfTrendInput,
1278) -> Result<(), HalfTrendError> {
1279 halftrend_into_slices_kernel(
1280 out_halftrend,
1281 out_trend,
1282 out_atr_high,
1283 out_atr_low,
1284 out_buy_signal,
1285 out_sell_signal,
1286 input,
1287 Kernel::Auto,
1288 )
1289}
1290
1291#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1292#[inline]
1293pub fn halftrend_into(
1294 input: &HalfTrendInput,
1295 out_halftrend: &mut [f64],
1296 out_trend: &mut [f64],
1297 out_atr_high: &mut [f64],
1298 out_atr_low: &mut [f64],
1299 out_buy_signal: &mut [f64],
1300 out_sell_signal: &mut [f64],
1301) -> Result<(), HalfTrendError> {
1302 halftrend_into_slices_kernel(
1303 out_halftrend,
1304 out_trend,
1305 out_atr_high,
1306 out_atr_low,
1307 out_buy_signal,
1308 out_sell_signal,
1309 input,
1310 Kernel::Auto,
1311 )
1312}
1313
1314#[cfg(test)]
1315mod tests {
1316 use super::*;
1317 use crate::skip_if_unsupported;
1318 use crate::utilities::data_loader::read_candles_from_csv;
1319 use std::error::Error;
1320
1321 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1322 #[test]
1323 fn test_halftrend_into_matches_api() -> Result<(), Box<dyn Error>> {
1324 let len = 256usize;
1325 let mut high = Vec::with_capacity(len);
1326 let mut low = Vec::with_capacity(len);
1327 let mut close = Vec::with_capacity(len);
1328 for i in 0..len {
1329 let t = i as f64;
1330 let c = 100.0 + 0.1 * t + (t * 0.03).sin();
1331 close.push(c);
1332 high.push(c + 1.0 + (t * 0.01).cos() * 0.1);
1333 low.push(c - 1.0 - (t * 0.02).sin() * 0.1);
1334 }
1335
1336 let input = HalfTrendInput::from_slices(&high, &low, &close, HalfTrendParams::default());
1337
1338 let base = halftrend(&input)?;
1339
1340 let mut ht = vec![0.0; len];
1341 let mut tr = vec![0.0; len];
1342 let mut ah = vec![0.0; len];
1343 let mut al = vec![0.0; len];
1344 let mut bs = vec![0.0; len];
1345 let mut ss = vec![0.0; len];
1346
1347 halftrend_into(&input, &mut ht, &mut tr, &mut ah, &mut al, &mut bs, &mut ss)?;
1348
1349 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1350 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1351 }
1352
1353 assert_eq!(base.halftrend.len(), ht.len());
1354 assert_eq!(base.trend.len(), tr.len());
1355 assert_eq!(base.atr_high.len(), ah.len());
1356 assert_eq!(base.atr_low.len(), al.len());
1357 assert_eq!(base.buy_signal.len(), bs.len());
1358 assert_eq!(base.sell_signal.len(), ss.len());
1359
1360 for i in 0..len {
1361 assert!(
1362 eq_or_both_nan(base.halftrend[i], ht[i]),
1363 "halftrend mismatch at {}",
1364 i
1365 );
1366 assert!(
1367 eq_or_both_nan(base.trend[i], tr[i]),
1368 "trend mismatch at {}",
1369 i
1370 );
1371 assert!(
1372 eq_or_both_nan(base.atr_high[i], ah[i]),
1373 "atr_high mismatch at {}",
1374 i
1375 );
1376 assert!(
1377 eq_or_both_nan(base.atr_low[i], al[i]),
1378 "atr_low mismatch at {}",
1379 i
1380 );
1381 assert!(
1382 eq_or_both_nan(base.buy_signal[i], bs[i]),
1383 "buy_signal mismatch at {}",
1384 i
1385 );
1386 assert!(
1387 eq_or_both_nan(base.sell_signal[i], ss[i]),
1388 "sell_signal mismatch at {}",
1389 i
1390 );
1391 }
1392
1393 Ok(())
1394 }
1395
1396 fn check_halftrend_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1397 skip_if_unsupported!(kernel, test_name);
1398
1399 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1400 let candles = read_candles_from_csv(file_path)?;
1401
1402 let input = HalfTrendInput::from_candles(&candles, HalfTrendParams::default());
1403 let output = halftrend_with_kernel(&input, kernel)?;
1404
1405 let test_indices = vec![15570, 15571, 15574, 15575, 15576];
1406 let expected_halftrend = vec![59763.0, 59763.0, 59763.0, 59310.0, 59310.0];
1407 let expected_trend = vec![0.0, 0.0, 1.0, 1.0, 1.0];
1408
1409 for (i, &idx) in test_indices.iter().enumerate() {
1410 assert!(
1411 (output.halftrend[idx] - expected_halftrend[i]).abs() < 1.0,
1412 "[{}] HalfTrend mismatch at index {}: expected {}, got {}",
1413 test_name,
1414 idx,
1415 expected_halftrend[i],
1416 output.halftrend[idx]
1417 );
1418 assert!(
1419 (output.trend[idx] - expected_trend[i]).abs() < 0.01,
1420 "[{}] Trend mismatch at index {}: expected {}, got {}",
1421 test_name,
1422 idx,
1423 expected_trend[i],
1424 output.trend[idx]
1425 );
1426 }
1427 Ok(())
1428 }
1429
1430 fn check_halftrend_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1431 skip_if_unsupported!(kernel, test_name);
1432
1433 let candles = Candles {
1434 timestamp: vec![],
1435 high: vec![],
1436 low: vec![],
1437 close: vec![],
1438 open: vec![],
1439 volume: vec![],
1440 fields: CandleFieldFlags {
1441 open: true,
1442 high: true,
1443 low: true,
1444 close: true,
1445 volume: true,
1446 },
1447 hl2: vec![],
1448 hlc3: vec![],
1449 ohlc4: vec![],
1450 hlcc4: vec![],
1451 };
1452
1453 let input = HalfTrendInput::from_candles(&candles, HalfTrendParams::default());
1454 let result = halftrend_with_kernel(&input, kernel);
1455
1456 assert!(
1457 matches!(result, Err(HalfTrendError::EmptyInputData)),
1458 "[{}] Expected EmptyInputData error",
1459 test_name
1460 );
1461 Ok(())
1462 }
1463
1464 fn check_halftrend_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test_name);
1466
1467 let candles = Candles {
1468 timestamp: vec![0; 100],
1469 high: vec![f64::NAN; 100],
1470 low: vec![f64::NAN; 100],
1471 close: vec![f64::NAN; 100],
1472 open: vec![f64::NAN; 100],
1473 volume: vec![f64::NAN; 100],
1474 fields: CandleFieldFlags {
1475 open: true,
1476 high: true,
1477 low: true,
1478 close: true,
1479 volume: true,
1480 },
1481 hl2: vec![f64::NAN; 100],
1482 hlc3: vec![f64::NAN; 100],
1483 ohlc4: vec![f64::NAN; 100],
1484 hlcc4: vec![f64::NAN; 100],
1485 };
1486
1487 let input = HalfTrendInput::from_candles(&candles, HalfTrendParams::default());
1488 let result = halftrend_with_kernel(&input, kernel);
1489
1490 assert!(
1491 matches!(result, Err(HalfTrendError::AllValuesNaN)),
1492 "[{}] Expected AllValuesNaN error",
1493 test_name
1494 );
1495 Ok(())
1496 }
1497
1498 fn check_halftrend_invalid_period(
1499 test_name: &str,
1500 kernel: Kernel,
1501 ) -> Result<(), Box<dyn Error>> {
1502 skip_if_unsupported!(kernel, test_name);
1503
1504 let high = vec![1.0; 10];
1505 let low = vec![1.0; 10];
1506 let close = vec![1.0; 10];
1507 let candles = Candles {
1508 timestamp: vec![0; 10],
1509 high: high.clone(),
1510 low: low.clone(),
1511 close: close.clone(),
1512 open: vec![1.0; 10],
1513 volume: vec![1.0; 10],
1514 fields: CandleFieldFlags {
1515 open: true,
1516 high: true,
1517 low: true,
1518 close: true,
1519 volume: true,
1520 },
1521 hl2: high
1522 .iter()
1523 .zip(low.iter())
1524 .map(|(h, l)| (h + l) / 2.0)
1525 .collect(),
1526 hlc3: high
1527 .iter()
1528 .zip(low.iter())
1529 .zip(close.iter())
1530 .map(|((h, l), c)| (h + l + c) / 3.0)
1531 .collect(),
1532 ohlc4: vec![1.0; 10],
1533 hlcc4: high
1534 .iter()
1535 .zip(low.iter())
1536 .zip(close.iter())
1537 .map(|((h, l), c)| (h + l + c + c) / 4.0)
1538 .collect(),
1539 };
1540
1541 let params = HalfTrendParams {
1542 amplitude: Some(20),
1543 channel_deviation: Some(2.0),
1544 atr_period: Some(100),
1545 };
1546
1547 let input = HalfTrendInput::from_candles(&candles, params);
1548 let result = halftrend_with_kernel(&input, kernel);
1549
1550 assert!(
1551 matches!(result, Err(HalfTrendError::InvalidPeriod { .. })),
1552 "[{}] Expected InvalidPeriod error",
1553 test_name
1554 );
1555 Ok(())
1556 }
1557
1558 macro_rules! generate_all_halftrend_tests {
1559 ($($test_fn:ident),*) => {
1560 paste::paste! {
1561 $(
1562 #[test]
1563 fn [<$test_fn _scalar_f64>]() {
1564 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1565 }
1566 )*
1567 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1568 $(
1569 #[test]
1570 fn [<$test_fn _avx2_f64>]() {
1571 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1572 }
1573 #[test]
1574 fn [<$test_fn _avx512_f64>]() {
1575 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1576 }
1577 )*
1578 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1579 $(
1580 #[test]
1581 fn [<$test_fn _simd128_f64>]() {
1582 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1583 }
1584 )*
1585 }
1586 }
1587 }
1588
1589 fn check_halftrend_default_candles(
1590 test_name: &str,
1591 kernel: Kernel,
1592 ) -> Result<(), Box<dyn Error>> {
1593 skip_if_unsupported!(kernel, test_name);
1594 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1595 let c = read_candles_from_csv(file)?;
1596 let input = HalfTrendInput::with_default_candles(&c);
1597 let out = halftrend_with_kernel(&input, kernel)?;
1598 assert_eq!(out.halftrend.len(), c.close.len());
1599 Ok(())
1600 }
1601
1602 fn check_halftrend_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1603 skip_if_unsupported!(kernel, test_name);
1604 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1605 let c = read_candles_from_csv(file)?;
1606 let input = HalfTrendInput::from_candles(&c, HalfTrendParams::default());
1607 let out = halftrend_with_kernel(&input, kernel)?;
1608 let a = HalfTrendParams::default().amplitude.unwrap_or(2);
1609 let p = HalfTrendParams::default().atr_period.unwrap_or(100);
1610 let warm = a.max(p) - 1;
1611 for &v in &out.halftrend[warm.min(out.halftrend.len())..] {
1612 assert!(!v.is_nan(), "[{}] Found NaN after warmup", test_name);
1613 }
1614 Ok(())
1615 }
1616
1617 fn check_halftrend_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1618 skip_if_unsupported!(kernel, test_name);
1619 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1620 let c = read_candles_from_csv(file)?;
1621 let p = HalfTrendParams::default();
1622
1623 let batch = halftrend_with_kernel(&HalfTrendInput::from_candles(&c, p.clone()), kernel)?;
1624 let mut s = HalfTrendStream::try_new(p)?;
1625 let mut ht = Vec::with_capacity(c.close.len());
1626 let mut tr = Vec::with_capacity(c.close.len());
1627 let mut ah = Vec::with_capacity(c.close.len());
1628 let mut al = Vec::with_capacity(c.close.len());
1629 let mut bs = Vec::with_capacity(c.close.len());
1630 let mut ss = Vec::with_capacity(c.close.len());
1631 for i in 0..c.close.len() {
1632 match s.update(c.high[i], c.low[i], c.close[i]) {
1633 Some(o) => {
1634 ht.push(o.halftrend);
1635 tr.push(o.trend);
1636 ah.push(o.atr_high);
1637 al.push(o.atr_low);
1638 bs.push(o.buy_signal.unwrap_or(f64::NAN));
1639 ss.push(o.sell_signal.unwrap_or(f64::NAN));
1640 }
1641 None => {
1642 ht.push(f64::NAN);
1643 tr.push(f64::NAN);
1644 ah.push(f64::NAN);
1645 al.push(f64::NAN);
1646 bs.push(f64::NAN);
1647 ss.push(f64::NAN);
1648 }
1649 }
1650 }
1651 assert_eq!(batch.halftrend.len(), ht.len());
1652
1653 Ok(())
1654 }
1655
1656 #[cfg(debug_assertions)]
1657 fn check_halftrend_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1658 skip_if_unsupported!(kernel, test_name);
1659 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1660 let c = read_candles_from_csv(file)?;
1661 let out = halftrend_with_kernel(&HalfTrendInput::with_default_candles(&c), kernel)?;
1662 let poison = [
1663 0x1111_1111_1111_1111u64,
1664 0x2222_2222_2222_2222u64,
1665 0x3333_3333_3333_3333u64,
1666 ];
1667 for (name, vec) in [
1668 ("halftrend", &out.halftrend),
1669 ("trend", &out.trend),
1670 ("atr_high", &out.atr_high),
1671 ("atr_low", &out.atr_low),
1672 ("buy", &out.buy_signal),
1673 ("sell", &out.sell_signal),
1674 ] {
1675 for (i, &v) in vec.iter().enumerate() {
1676 if v.is_nan() {
1677 continue;
1678 }
1679 let b = v.to_bits();
1680 for p in poison {
1681 assert_ne!(b, p, "[{}] poison in {} at {}", test_name, name, i);
1682 }
1683 }
1684 }
1685 Ok(())
1686 }
1687
1688 #[cfg(not(debug_assertions))]
1689 fn check_halftrend_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1690 Ok(())
1691 }
1692
1693 fn check_halftrend_partial_params(
1694 test_name: &str,
1695 kernel: Kernel,
1696 ) -> Result<(), Box<dyn Error>> {
1697 skip_if_unsupported!(kernel, test_name);
1698 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1699 let c = read_candles_from_csv(file)?;
1700 let input = HalfTrendInput::from_candles(
1701 &c,
1702 HalfTrendParams {
1703 amplitude: None,
1704 channel_deviation: None,
1705 atr_period: None,
1706 },
1707 );
1708 let out = halftrend_with_kernel(&input, kernel)?;
1709 assert_eq!(out.halftrend.len(), c.close.len());
1710 Ok(())
1711 }
1712
1713 fn check_halftrend_not_enough_valid(
1714 test_name: &str,
1715 kernel: Kernel,
1716 ) -> Result<(), Box<dyn Error>> {
1717 skip_if_unsupported!(kernel, test_name);
1718 let n = 10;
1719 let mut c = Candles {
1720 timestamp: vec![0; n],
1721 high: vec![f64::NAN; n],
1722 low: vec![f64::NAN; n],
1723 close: vec![f64::NAN; n],
1724 open: vec![f64::NAN; n],
1725 volume: vec![f64::NAN; n],
1726 fields: CandleFieldFlags {
1727 open: true,
1728 high: true,
1729 low: true,
1730 close: true,
1731 volume: true,
1732 },
1733 hl2: vec![f64::NAN; n],
1734 hlc3: vec![f64::NAN; n],
1735 ohlc4: vec![f64::NAN; n],
1736 hlcc4: vec![f64::NAN; n],
1737 };
1738 c.high[5] = 1.0;
1739 c.low[5] = 1.0;
1740 c.close[5] = 1.0;
1741 let p = HalfTrendParams {
1742 amplitude: Some(9),
1743 channel_deviation: Some(2.0),
1744 atr_period: Some(9),
1745 };
1746 let r = halftrend_with_kernel(&HalfTrendInput::from_candles(&c, p), kernel);
1747 assert!(matches!(r, Err(HalfTrendError::NotEnoughValidData { .. })));
1748 Ok(())
1749 }
1750
1751 fn check_halftrend_invalid_chdev(
1752 test_name: &str,
1753 kernel: Kernel,
1754 ) -> Result<(), Box<dyn Error>> {
1755 skip_if_unsupported!(kernel, test_name);
1756 let h = [1.0, 1.0, 1.0];
1757 let l = [1.0, 1.0, 1.0];
1758 let c = [1.0, 1.0, 1.0];
1759 let inp = HalfTrendInput::from_slices(
1760 &h,
1761 &l,
1762 &c,
1763 HalfTrendParams {
1764 amplitude: Some(2),
1765 channel_deviation: Some(0.0),
1766 atr_period: Some(2),
1767 },
1768 );
1769 let r = halftrend_with_kernel(&inp, kernel);
1770 assert!(matches!(
1771 r,
1772 Err(HalfTrendError::InvalidChannelDeviation { .. })
1773 ));
1774 Ok(())
1775 }
1776
1777 generate_all_halftrend_tests!(
1778 check_halftrend_accuracy,
1779 check_halftrend_empty_data,
1780 check_halftrend_all_nan,
1781 check_halftrend_invalid_period,
1782 check_halftrend_default_candles,
1783 check_halftrend_nan_handling,
1784 check_halftrend_streaming,
1785 check_halftrend_no_poison,
1786 check_halftrend_partial_params,
1787 check_halftrend_not_enough_valid,
1788 check_halftrend_invalid_chdev
1789 );
1790
1791 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1792 skip_if_unsupported!(kernel, test_name);
1793
1794 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1795 let c = read_candles_from_csv(file)?;
1796
1797 let output = HalfTrendBatchBuilder::new()
1798 .kernel(kernel)
1799 .apply_candles(&c)?;
1800
1801 let def = HalfTrendParams::default();
1802 let row = output.halftrend_for(&def).expect("default row missing");
1803
1804 assert_eq!(row.len(), c.close.len());
1805 Ok(())
1806 }
1807
1808 fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1809 skip_if_unsupported!(kernel, test_name);
1810
1811 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1812 let c = read_candles_from_csv(file)?;
1813
1814 let output = HalfTrendBatchBuilder::new()
1815 .kernel(kernel)
1816 .amplitude_range(2, 4, 1)
1817 .channel_deviation_range(1.5, 2.5, 0.5)
1818 .atr_period_range(50, 150, 50)
1819 .apply_candles(&c)?;
1820
1821 let expected_combos = 3 * 3 * 3;
1822 assert_eq!(output.combos.len(), expected_combos);
1823 assert_eq!(output.rows, expected_combos);
1824 assert_eq!(output.cols, c.close.len());
1825
1826 Ok(())
1827 }
1828
1829 macro_rules! gen_batch_tests {
1830 ($fn_name:ident) => {
1831 paste::paste! {
1832 #[test] fn [<$fn_name _scalar>]() {
1833 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1834 }
1835 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1836 #[test] fn [<$fn_name _avx2>]() {
1837 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1838 }
1839 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1840 #[test] fn [<$fn_name _avx512>]() {
1841 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1842 }
1843 #[test] fn [<$fn_name _auto_detect>]() {
1844 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1845 }
1846 }
1847 };
1848 }
1849
1850 gen_batch_tests!(check_batch_default_row);
1851 gen_batch_tests!(check_batch_sweep);
1852}
1853
1854#[derive(Debug, Clone)]
1855pub struct HalfTrendStream {
1856 amplitude: usize,
1857 atr_period: usize,
1858 ch_half: f64,
1859 inv_amp: f64,
1860
1861 atr_stream: crate::indicators::atr::AtrStream,
1862
1863 max_idx: Vec<usize>,
1864 max_val: Vec<f64>,
1865 min_idx: Vec<usize>,
1866 min_val: Vec<f64>,
1867 max_head: usize,
1868 max_tail: usize,
1869 max_cnt: usize,
1870 min_head: usize,
1871 min_tail: usize,
1872 min_cnt: usize,
1873
1874 ring_high: Vec<f64>,
1875 ring_low: Vec<f64>,
1876 ring_pos: usize,
1877 filled: usize,
1878 high_sum: f64,
1879 low_sum: f64,
1880
1881 i: usize,
1882 warmup_need: usize,
1883
1884 current_trend: i32,
1885 next_trend: i32,
1886 last_trend: i8,
1887 max_low_price: f64,
1888 min_high_price: f64,
1889 up: f64,
1890 down: f64,
1891
1892 prev_high: f64,
1893 prev_low: f64,
1894 have_prev: bool,
1895}
1896
1897impl HalfTrendStream {
1898 #[inline]
1899 pub fn try_new(params: HalfTrendParams) -> Result<Self, HalfTrendError> {
1900 let amplitude = params.amplitude.unwrap_or(2);
1901 let channel_deviation = params.channel_deviation.unwrap_or(2.0);
1902 let atr_period = params.atr_period.unwrap_or(100);
1903
1904 if amplitude == 0 {
1905 return Err(HalfTrendError::InvalidPeriod {
1906 period: amplitude,
1907 data_len: 0,
1908 });
1909 }
1910 if atr_period == 0 {
1911 return Err(HalfTrendError::InvalidPeriod {
1912 period: atr_period,
1913 data_len: 0,
1914 });
1915 }
1916 if !(channel_deviation.is_finite()) || channel_deviation <= 0.0 {
1917 return Err(HalfTrendError::InvalidChannelDeviation { channel_deviation });
1918 }
1919
1920 let atr_stream =
1921 crate::indicators::atr::AtrStream::try_new(crate::indicators::atr::AtrParams {
1922 length: Some(atr_period),
1923 })
1924 .map_err(|e| HalfTrendError::AtrError(e.to_string()))?;
1925
1926 let cap = amplitude.max(1);
1927
1928 Ok(Self {
1929 amplitude,
1930 atr_period,
1931 ch_half: channel_deviation * 0.5,
1932 inv_amp: 1.0 / (amplitude as f64),
1933
1934 atr_stream,
1935
1936 max_idx: vec![0; cap],
1937 max_val: vec![0.0; cap],
1938 min_idx: vec![0; cap],
1939 min_val: vec![0.0; cap],
1940 max_head: 0,
1941 max_tail: 0,
1942 max_cnt: 0,
1943 min_head: 0,
1944 min_tail: 0,
1945 min_cnt: 0,
1946
1947 ring_high: vec![0.0; cap],
1948 ring_low: vec![0.0; cap],
1949 ring_pos: 0,
1950 filled: 0,
1951 high_sum: 0.0,
1952 low_sum: 0.0,
1953
1954 i: 0,
1955 warmup_need: amplitude.max(atr_period),
1956
1957 current_trend: 0,
1958 next_trend: 0,
1959 last_trend: -1,
1960 max_low_price: f64::NAN,
1961 min_high_price: f64::NAN,
1962 up: 0.0,
1963 down: 0.0,
1964
1965 prev_high: f64::NAN,
1966 prev_low: f64::NAN,
1967 have_prev: false,
1968 })
1969 }
1970
1971 #[inline(always)]
1972 fn inc(i: usize, cap: usize) -> usize {
1973 let j = i + 1;
1974 if j == cap {
1975 0
1976 } else {
1977 j
1978 }
1979 }
1980 #[inline(always)]
1981 fn dec(i: usize, cap: usize) -> usize {
1982 if i == 0 {
1983 cap - 1
1984 } else {
1985 i - 1
1986 }
1987 }
1988
1989 #[inline(always)]
1990 fn q_push_max(&mut self, idx: usize, v: f64) {
1991 let cap = self.amplitude;
1992 while self.max_cnt > 0 {
1993 let back = Self::dec(self.max_tail, cap);
1994 if self.max_val[back] <= v {
1995 self.max_tail = back;
1996 self.max_cnt -= 1;
1997 } else {
1998 break;
1999 }
2000 }
2001 self.max_val[self.max_tail] = v;
2002 self.max_idx[self.max_tail] = idx;
2003 self.max_tail = Self::inc(self.max_tail, cap);
2004 self.max_cnt += 1;
2005 }
2006
2007 #[inline(always)]
2008 fn q_push_min(&mut self, idx: usize, v: f64) {
2009 let cap = self.amplitude;
2010 while self.min_cnt > 0 {
2011 let back = Self::dec(self.min_tail, cap);
2012 if self.min_val[back] >= v {
2013 self.min_tail = back;
2014 self.min_cnt -= 1;
2015 } else {
2016 break;
2017 }
2018 }
2019 self.min_val[self.min_tail] = v;
2020 self.min_idx[self.min_tail] = idx;
2021 self.min_tail = Self::inc(self.min_tail, cap);
2022 self.min_cnt += 1;
2023 }
2024
2025 #[inline(always)]
2026 fn q_evict(&mut self, idx: usize) {
2027 let cap = self.amplitude;
2028 let limit = idx.saturating_sub(self.amplitude - 1);
2029 while self.max_cnt > 0 && self.max_idx[self.max_head] < limit {
2030 self.max_head = Self::inc(self.max_head, cap);
2031 self.max_cnt -= 1;
2032 }
2033 while self.min_cnt > 0 && self.min_idx[self.min_head] < limit {
2034 self.min_head = Self::inc(self.min_head, cap);
2035 self.min_cnt -= 1;
2036 }
2037 }
2038
2039 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<HalfTrendStreamOutput> {
2040 if !(high.is_finite() && low.is_finite() && close.is_finite()) {
2041 return None;
2042 }
2043
2044 let idx = self.i;
2045
2046 self.q_evict(idx);
2047 self.q_push_max(idx, high);
2048 self.q_push_min(idx, low);
2049
2050 if self.filled == self.amplitude {
2051 let old_h = self.ring_high[self.ring_pos];
2052 let old_l = self.ring_low[self.ring_pos];
2053 self.high_sum -= old_h;
2054 self.low_sum -= old_l;
2055 } else {
2056 self.filled += 1;
2057 }
2058 self.ring_high[self.ring_pos] = high;
2059 self.ring_low[self.ring_pos] = low;
2060 self.high_sum += high;
2061 self.low_sum += low;
2062 self.ring_pos = Self::inc(self.ring_pos, self.amplitude);
2063
2064 let atr_opt = self.atr_stream.update(high, low, close);
2065
2066 let warmed =
2067 self.filled == self.amplitude && (idx + 1) >= self.warmup_need && atr_opt.is_some();
2068
2069 if !warmed {
2070 self.prev_high = high;
2071 self.prev_low = low;
2072 self.have_prev = true;
2073 self.i = idx + 1;
2074 return None;
2075 }
2076
2077 debug_assert!(self.max_cnt > 0 && self.min_cnt > 0);
2078 let high_price = self.max_val[self.max_head];
2079 let low_price = self.min_val[self.min_head];
2080 let atr = atr_opt.unwrap();
2081 let atr2 = 0.5 * atr;
2082 let dev = atr * self.ch_half;
2083
2084 let prev_low = if self.have_prev { self.prev_low } else { low };
2085 let prev_high = if self.have_prev { self.prev_high } else { high };
2086 if self.max_low_price.is_nan() {
2087 self.max_low_price = prev_low;
2088 }
2089 if self.min_high_price.is_nan() {
2090 self.min_high_price = prev_high;
2091 }
2092
2093 let highma = self.high_sum * self.inv_amp;
2094 let lowma = self.low_sum * self.inv_amp;
2095
2096 if self.next_trend == 1 {
2097 if low_price > self.max_low_price {
2098 self.max_low_price = low_price;
2099 }
2100 if highma < self.max_low_price && close < prev_low {
2101 self.current_trend = 1;
2102 self.next_trend = 0;
2103 self.min_high_price = high_price;
2104 }
2105 } else {
2106 if high_price < self.min_high_price {
2107 self.min_high_price = high_price;
2108 }
2109 if lowma > self.min_high_price && close > prev_high {
2110 self.current_trend = 0;
2111 self.next_trend = 1;
2112 self.max_low_price = low_price;
2113 }
2114 }
2115
2116 let prev_trend = self.last_trend;
2117 let mut buy_sig: Option<f64> = None;
2118 let mut sell_sig: Option<f64> = None;
2119
2120 let (ht, atr_hi, atr_lo, tr_val) = if self.current_trend == 0 {
2121 if prev_trend == 1 {
2122 self.up = self.down;
2123 buy_sig = Some(self.up - atr2);
2124 } else {
2125 self.up = if self.up == 0.0 {
2126 self.max_low_price
2127 } else if self.max_low_price > self.up {
2128 self.max_low_price
2129 } else {
2130 self.up
2131 };
2132 }
2133 let h = self.up;
2134 (h, h + dev, h - dev, 0.0)
2135 } else {
2136 if prev_trend == 0 {
2137 self.down = self.up;
2138 sell_sig = Some(self.down + atr2);
2139 } else {
2140 self.down = if self.down == 0.0 {
2141 self.min_high_price
2142 } else if self.min_high_price < self.down {
2143 self.min_high_price
2144 } else {
2145 self.down
2146 };
2147 }
2148 let d = self.down;
2149 (d, d + dev, d - dev, 1.0)
2150 };
2151
2152 self.last_trend = self.current_trend as i8;
2153 self.prev_high = high;
2154 self.prev_low = low;
2155 self.have_prev = true;
2156 self.i = idx + 1;
2157
2158 Some(HalfTrendStreamOutput {
2159 halftrend: ht,
2160 trend: tr_val,
2161 atr_high: atr_hi,
2162 atr_low: atr_lo,
2163 buy_signal: buy_sig,
2164 sell_signal: sell_sig,
2165 })
2166 }
2167}
2168
2169#[derive(Debug, Clone)]
2170pub struct HalfTrendStreamOutput {
2171 pub halftrend: f64,
2172 pub trend: f64,
2173 pub atr_high: f64,
2174 pub atr_low: f64,
2175 pub buy_signal: Option<f64>,
2176 pub sell_signal: Option<f64>,
2177}
2178
2179#[derive(Clone, Debug)]
2180pub struct HalfTrendBatchRange {
2181 pub amplitude: (usize, usize, usize),
2182 pub channel_deviation: (f64, f64, f64),
2183 pub atr_period: (usize, usize, usize),
2184}
2185
2186impl Default for HalfTrendBatchRange {
2187 fn default() -> Self {
2188 Self {
2189 amplitude: (2, 2, 0),
2190 channel_deviation: (2.0, 2.0, 0.0),
2191 atr_period: (100, 349, 1),
2192 }
2193 }
2194}
2195
2196#[derive(Clone, Debug, Default)]
2197pub struct HalfTrendBatchBuilder {
2198 range: HalfTrendBatchRange,
2199 kernel: Kernel,
2200}
2201
2202impl HalfTrendBatchBuilder {
2203 pub fn new() -> Self {
2204 Self::default()
2205 }
2206
2207 pub fn kernel(mut self, k: Kernel) -> Self {
2208 self.kernel = k;
2209 self
2210 }
2211
2212 pub fn amplitude_range(mut self, start: usize, end: usize, step: usize) -> Self {
2213 self.range.amplitude = (start, end, step);
2214 self
2215 }
2216
2217 pub fn amplitude_static(mut self, a: usize) -> Self {
2218 self.range.amplitude = (a, a, 0);
2219 self
2220 }
2221
2222 pub fn channel_deviation_range(mut self, start: f64, end: f64, step: f64) -> Self {
2223 self.range.channel_deviation = (start, end, step);
2224 self
2225 }
2226
2227 pub fn channel_deviation_static(mut self, c: f64) -> Self {
2228 self.range.channel_deviation = (c, c, 0.0);
2229 self
2230 }
2231
2232 pub fn atr_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
2233 self.range.atr_period = (start, end, step);
2234 self
2235 }
2236
2237 pub fn atr_period_static(mut self, p: usize) -> Self {
2238 self.range.atr_period = (p, p, 0);
2239 self
2240 }
2241
2242 pub fn apply_candles(self, c: &Candles) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2243 halftrend_batch_with_kernel(c, &self.range, self.kernel)
2244 }
2245
2246 pub fn apply_slices(
2247 self,
2248 h: &[f64],
2249 l: &[f64],
2250 c: &[f64],
2251 ) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2252 halftrend_batch_with_kernel_slices(h, l, c, &self.range, self.kernel)
2253 }
2254
2255 pub fn with_default_candles(c: &Candles) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2256 HalfTrendBatchBuilder::new()
2257 .kernel(Kernel::Auto)
2258 .apply_candles(c)
2259 }
2260
2261 pub fn with_default_slices(
2262 h: &[f64],
2263 l: &[f64],
2264 c: &[f64],
2265 ) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2266 HalfTrendBatchBuilder::new()
2267 .kernel(Kernel::Auto)
2268 .apply_slices(h, l, c)
2269 }
2270}
2271
2272pub struct HalfTrendBatchOutput {
2273 pub halftrend: Vec<f64>,
2274 pub trend: Vec<f64>,
2275 pub atr_high: Vec<f64>,
2276 pub atr_low: Vec<f64>,
2277 pub buy_signal: Vec<f64>,
2278 pub sell_signal: Vec<f64>,
2279 pub combos: Vec<HalfTrendParams>,
2280 pub rows: usize,
2281 pub cols: usize,
2282}
2283
2284impl HalfTrendBatchOutput {
2285 pub fn row_for_params(&self, p: &HalfTrendParams) -> Option<usize> {
2286 self.combos.iter().position(|c| {
2287 c.amplitude.unwrap_or(2) == p.amplitude.unwrap_or(2)
2288 && (c.channel_deviation.unwrap_or(2.0) - p.channel_deviation.unwrap_or(2.0)).abs()
2289 < 1e-12
2290 && c.atr_period.unwrap_or(100) == p.atr_period.unwrap_or(100)
2291 })
2292 }
2293
2294 pub fn halftrend_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2295 self.row_for_params(p).map(|row| {
2296 let start = row * self.cols;
2297 &self.halftrend[start..start + self.cols]
2298 })
2299 }
2300
2301 pub fn trend_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2302 self.row_for_params(p).map(|row| {
2303 let start = row * self.cols;
2304 &self.trend[start..start + self.cols]
2305 })
2306 }
2307
2308 pub fn atr_high_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2309 self.row_for_params(p).map(|row| {
2310 let start = row * self.cols;
2311 &self.atr_high[start..start + self.cols]
2312 })
2313 }
2314
2315 pub fn atr_low_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2316 self.row_for_params(p).map(|row| {
2317 let start = row * self.cols;
2318 &self.atr_low[start..start + self.cols]
2319 })
2320 }
2321
2322 pub fn buy_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2323 self.row_for_params(p).map(|row| {
2324 let start = row * self.cols;
2325 &self.buy_signal[start..start + self.cols]
2326 })
2327 }
2328
2329 pub fn sell_for(&self, p: &HalfTrendParams) -> Option<&[f64]> {
2330 self.row_for_params(p).map(|row| {
2331 let start = row * self.cols;
2332 &self.sell_signal[start..start + self.cols]
2333 })
2334 }
2335}
2336
2337fn halftrend_batch_with_kernel(
2338 candles: &Candles,
2339 sweep: &HalfTrendBatchRange,
2340 k: Kernel,
2341) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2342 halftrend_batch_with_kernel_slices(&candles.high, &candles.low, &candles.close, sweep, k)
2343}
2344
2345fn halftrend_batch_with_kernel_slices(
2346 high: &[f64],
2347 low: &[f64],
2348 close: &[f64],
2349 sweep: &HalfTrendBatchRange,
2350 k: Kernel,
2351) -> Result<HalfTrendBatchOutput, HalfTrendError> {
2352 let combos = expand_grid_halftrend(sweep)?;
2353 let rows = combos.len();
2354 let cols = close.len();
2355
2356 if cols == 0 {
2357 return Err(HalfTrendError::EmptyInputData);
2358 }
2359
2360 let batch = match k {
2361 Kernel::Auto => detect_best_batch_kernel(),
2362 other if other.is_batch() => other,
2363 _ => return Err(HalfTrendError::InvalidKernelForBatch(k)),
2364 };
2365 let simd = match batch {
2366 Kernel::Avx512Batch => Kernel::Avx512,
2367 Kernel::Avx2Batch => Kernel::Avx2,
2368 Kernel::ScalarBatch => Kernel::Scalar,
2369 _ => Kernel::Scalar,
2370 };
2371
2372 let _cap = rows
2373 .checked_mul(cols)
2374 .ok_or_else(|| HalfTrendError::InvalidRange {
2375 start: "rows".into(),
2376 end: "cols".into(),
2377 step: "mul".into(),
2378 })?;
2379
2380 let mut mu_ht = make_uninit_matrix(rows, cols);
2381 let mut mu_tr = make_uninit_matrix(rows, cols);
2382 let mut mu_ah = make_uninit_matrix(rows, cols);
2383 let mut mu_al = make_uninit_matrix(rows, cols);
2384 let mut mu_bs = make_uninit_matrix(rows, cols);
2385 let mut mu_ss = make_uninit_matrix(rows, cols);
2386
2387 let first = first_valid_ohlc(high, low, close);
2388 if first == usize::MAX {
2389 return Err(HalfTrendError::AllValuesNaN);
2390 }
2391 let warms: Vec<usize> = combos
2392 .iter()
2393 .map(|p| warmup_from(first, p.amplitude.unwrap(), p.atr_period.unwrap()))
2394 .collect();
2395
2396 init_matrix_prefixes(&mut mu_ht, cols, &warms);
2397 init_matrix_prefixes(&mut mu_tr, cols, &warms);
2398 init_matrix_prefixes(&mut mu_ah, cols, &warms);
2399 init_matrix_prefixes(&mut mu_al, cols, &warms);
2400 init_matrix_prefixes(&mut mu_bs, cols, &warms);
2401 init_matrix_prefixes(&mut mu_ss, cols, &warms);
2402
2403 let dst_ht =
2404 unsafe { core::slice::from_raw_parts_mut(mu_ht.as_mut_ptr() as *mut f64, mu_ht.len()) };
2405 let dst_tr =
2406 unsafe { core::slice::from_raw_parts_mut(mu_tr.as_mut_ptr() as *mut f64, mu_tr.len()) };
2407 let dst_ah =
2408 unsafe { core::slice::from_raw_parts_mut(mu_ah.as_mut_ptr() as *mut f64, mu_ah.len()) };
2409 let dst_al =
2410 unsafe { core::slice::from_raw_parts_mut(mu_al.as_mut_ptr() as *mut f64, mu_al.len()) };
2411 let dst_bs =
2412 unsafe { core::slice::from_raw_parts_mut(mu_bs.as_mut_ptr() as *mut f64, mu_bs.len()) };
2413 let dst_ss =
2414 unsafe { core::slice::from_raw_parts_mut(mu_ss.as_mut_ptr() as *mut f64, mu_ss.len()) };
2415
2416 halftrend_batch_rows_into(
2417 high, low, close, sweep, simd, dst_ht, dst_tr, dst_ah, dst_al, dst_bs, dst_ss,
2418 )?;
2419
2420 let take = |v: Vec<MaybeUninit<f64>>| unsafe {
2421 let ptr = v.as_ptr() as *mut f64;
2422 let len = v.len();
2423 let cap = v.capacity();
2424 core::mem::forget(v);
2425 Vec::from_raw_parts(ptr, len, cap)
2426 };
2427
2428 Ok(HalfTrendBatchOutput {
2429 halftrend: take(mu_ht),
2430 trend: take(mu_tr),
2431 atr_high: take(mu_ah),
2432 atr_low: take(mu_al),
2433 buy_signal: take(mu_bs),
2434 sell_signal: take(mu_ss),
2435 combos,
2436 rows,
2437 cols,
2438 })
2439}
2440
2441fn expand_grid_halftrend(r: &HalfTrendBatchRange) -> Result<Vec<HalfTrendParams>, HalfTrendError> {
2442 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, HalfTrendError> {
2443 if step == 0 || start == end {
2444 return Ok(vec![start]);
2445 }
2446 if start < end {
2447 return Ok((start..=end).step_by(step.max(1)).collect());
2448 }
2449
2450 let mut v = Vec::new();
2451 let mut x = start as isize;
2452 let end_i = end as isize;
2453 let st = (step as isize).max(1);
2454 while x >= end_i {
2455 v.push(x as usize);
2456 x -= st;
2457 }
2458 if v.is_empty() {
2459 return Err(HalfTrendError::InvalidRange {
2460 start: start.to_string(),
2461 end: end.to_string(),
2462 step: step.to_string(),
2463 });
2464 }
2465 Ok(v)
2466 }
2467 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, HalfTrendError> {
2468 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
2469 return Ok(vec![start]);
2470 }
2471 if start < end {
2472 let mut v = Vec::new();
2473 let mut x = start;
2474 let st = step.abs();
2475 while x <= end + 1e-12 {
2476 v.push(x);
2477 x += st;
2478 }
2479 if v.is_empty() {
2480 return Err(HalfTrendError::InvalidRange {
2481 start: start.to_string(),
2482 end: end.to_string(),
2483 step: step.to_string(),
2484 });
2485 }
2486 return Ok(v);
2487 }
2488 let mut v = Vec::new();
2489 let mut x = start;
2490 let st = step.abs();
2491 while x + 1e-12 >= end {
2492 v.push(x);
2493 x -= st;
2494 }
2495 if v.is_empty() {
2496 return Err(HalfTrendError::InvalidRange {
2497 start: start.to_string(),
2498 end: end.to_string(),
2499 step: step.to_string(),
2500 });
2501 }
2502 Ok(v)
2503 }
2504
2505 let amplitudes = axis_usize(r.amplitude)?;
2506 let channel_deviations = axis_f64(r.channel_deviation)?;
2507 let atr_periods = axis_usize(r.atr_period)?;
2508
2509 let cap = amplitudes
2510 .len()
2511 .checked_mul(channel_deviations.len())
2512 .and_then(|x| x.checked_mul(atr_periods.len()))
2513 .ok_or_else(|| HalfTrendError::InvalidRange {
2514 start: "cap".into(),
2515 end: "overflow".into(),
2516 step: "mul".into(),
2517 })?;
2518
2519 let mut out = Vec::with_capacity(cap);
2520 for &a in &litudes {
2521 for &c in &channel_deviations {
2522 for &p in &atr_periods {
2523 out.push(HalfTrendParams {
2524 amplitude: Some(a),
2525 channel_deviation: Some(c),
2526 atr_period: Some(p),
2527 });
2528 }
2529 }
2530 }
2531 Ok(out)
2532}
2533
2534#[inline(always)]
2535fn warmup_from(first: usize, amplitude: usize, atr_period: usize) -> usize {
2536 first + amplitude.max(atr_period) - 1
2537}
2538
2539#[inline(always)]
2540fn halftrend_row_into(
2541 high: &[f64],
2542 low: &[f64],
2543 close: &[f64],
2544 amplitude: usize,
2545 ch_dev: f64,
2546 atr: &[f64],
2547 highma: &[f64],
2548 lowma: &[f64],
2549 warm: usize,
2550 out_halftrend: &mut [f64],
2551 out_trend: &mut [f64],
2552 out_atr_high: &mut [f64],
2553 out_atr_low: &mut [f64],
2554 out_buy: &mut [f64],
2555 out_sell: &mut [f64],
2556) {
2557 halftrend_scalar(
2558 high,
2559 low,
2560 close,
2561 amplitude,
2562 ch_dev,
2563 atr,
2564 highma,
2565 lowma,
2566 warm,
2567 out_halftrend,
2568 out_trend,
2569 out_atr_high,
2570 out_atr_low,
2571 out_buy,
2572 out_sell,
2573 );
2574}
2575
2576#[inline(always)]
2577fn rolling_max_series(src: &[f64], win: usize) -> Vec<f64> {
2578 let n = src.len();
2579 if n == 0 {
2580 return Vec::new();
2581 }
2582 let cap = win.max(1);
2583 let mut idx = vec![0usize; cap];
2584 let mut val = vec![0.0f64; cap];
2585 let (mut head, mut tail, mut cnt) = (0usize, 0usize, 0usize);
2586 #[inline(always)]
2587 fn inc(i: usize, cap: usize) -> usize {
2588 let j = i + 1;
2589 if j == cap {
2590 0
2591 } else {
2592 j
2593 }
2594 }
2595 #[inline(always)]
2596 fn dec(i: usize, cap: usize) -> usize {
2597 if i == 0 {
2598 cap - 1
2599 } else {
2600 i - 1
2601 }
2602 }
2603 let mut out = vec![f64::NAN; n];
2604 for i in 0..n {
2605 let wstart = i.saturating_add(1).saturating_sub(cap);
2606 while cnt > 0 && idx[head] < wstart {
2607 head = inc(head, cap);
2608 cnt -= 1;
2609 }
2610 let x = src[i];
2611 while cnt > 0 {
2612 let back = dec(tail, cap);
2613 if val[back] <= x {
2614 tail = back;
2615 cnt -= 1;
2616 } else {
2617 break;
2618 }
2619 }
2620 val[tail] = x;
2621 idx[tail] = i;
2622 tail = inc(tail, cap);
2623 cnt += 1;
2624 out[i] = val[head];
2625 }
2626 out
2627}
2628
2629#[inline(always)]
2630fn rolling_min_series(src: &[f64], win: usize) -> Vec<f64> {
2631 let n = src.len();
2632 if n == 0 {
2633 return Vec::new();
2634 }
2635 let cap = win.max(1);
2636 let mut idx = vec![0usize; cap];
2637 let mut val = vec![0.0f64; cap];
2638 let (mut head, mut tail, mut cnt) = (0usize, 0usize, 0usize);
2639 #[inline(always)]
2640 fn inc(i: usize, cap: usize) -> usize {
2641 let j = i + 1;
2642 if j == cap {
2643 0
2644 } else {
2645 j
2646 }
2647 }
2648 #[inline(always)]
2649 fn dec(i: usize, cap: usize) -> usize {
2650 if i == 0 {
2651 cap - 1
2652 } else {
2653 i - 1
2654 }
2655 }
2656 let mut out = vec![f64::NAN; n];
2657 for i in 0..n {
2658 let wstart = i.saturating_add(1).saturating_sub(cap);
2659 while cnt > 0 && idx[head] < wstart {
2660 head = inc(head, cap);
2661 cnt -= 1;
2662 }
2663 let x = src[i];
2664 while cnt > 0 {
2665 let back = dec(tail, cap);
2666 if val[back] >= x {
2667 tail = back;
2668 cnt -= 1;
2669 } else {
2670 break;
2671 }
2672 }
2673 val[tail] = x;
2674 idx[tail] = i;
2675 tail = inc(tail, cap);
2676 cnt += 1;
2677 out[i] = val[head];
2678 }
2679 out
2680}
2681
2682#[inline(always)]
2683fn halftrend_row_into_precomputed(
2684 high: &[f64],
2685 low: &[f64],
2686 close: &[f64],
2687 ch_dev: f64,
2688 atr: &[f64],
2689 highma: &[f64],
2690 lowma: &[f64],
2691 warm: usize,
2692 roll_high: &[f64],
2693 roll_low: &[f64],
2694 out_halftrend: &mut [f64],
2695 out_trend: &mut [f64],
2696 out_atr_high: &mut [f64],
2697 out_atr_low: &mut [f64],
2698 out_buy: &mut [f64],
2699 out_sell: &mut [f64],
2700) {
2701 let len = close.len();
2702 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
2703 let ch_half = ch_dev * 0.5;
2704
2705 let mut current_trend = 0i32;
2706 let mut next_trend = 0i32;
2707 let mut up = 0.0f64;
2708 let mut down = 0.0f64;
2709 let mut max_low_price = if warm > 0 { low[warm - 1] } else { low[0] };
2710 let mut min_high_price = if warm > 0 { high[warm - 1] } else { high[0] };
2711
2712 for i in warm..len {
2713 out_buy[i] = qnan;
2714 out_sell[i] = qnan;
2715
2716 let high_price = roll_high[i];
2717 let low_price = roll_low[i];
2718 let prev_low = if i > 0 { low[i - 1] } else { low[0] };
2719 let prev_high = if i > 0 { high[i - 1] } else { high[0] };
2720
2721 if next_trend == 1 {
2722 if low_price > max_low_price {
2723 max_low_price = low_price;
2724 }
2725 if highma[i] < max_low_price && close[i] < prev_low {
2726 current_trend = 1;
2727 next_trend = 0;
2728 min_high_price = high_price;
2729 }
2730 } else {
2731 if high_price < min_high_price {
2732 min_high_price = high_price;
2733 }
2734 if lowma[i] > min_high_price && close[i] > prev_high {
2735 current_trend = 0;
2736 next_trend = 1;
2737 max_low_price = low_price;
2738 }
2739 }
2740
2741 let a = atr[i];
2742 let atr2 = 0.5 * a;
2743 let dev = a.mul_add(ch_half, 0.0);
2744
2745 if current_trend == 0 {
2746 if i > warm && out_trend[i - 1] != 0.0 {
2747 up = down;
2748 out_buy[i] = up - atr2;
2749 } else {
2750 up = if i == warm || up == 0.0 {
2751 max_low_price
2752 } else if max_low_price > up {
2753 max_low_price
2754 } else {
2755 up
2756 };
2757 }
2758 out_halftrend[i] = up;
2759 out_atr_high[i] = up + dev;
2760 out_atr_low[i] = up - dev;
2761 out_trend[i] = 0.0;
2762 } else {
2763 if i > warm && out_trend[i - 1] != 1.0 {
2764 down = up;
2765 out_sell[i] = down + atr2;
2766 } else {
2767 down = if i == warm || down == 0.0 {
2768 min_high_price
2769 } else if min_high_price < down {
2770 min_high_price
2771 } else {
2772 down
2773 };
2774 }
2775 out_halftrend[i] = down;
2776 out_atr_high[i] = down + dev;
2777 out_atr_low[i] = down - dev;
2778 out_trend[i] = 1.0;
2779 }
2780 }
2781}
2782
2783#[inline(always)]
2784pub fn halftrend_batch_rows_into(
2785 high: &[f64],
2786 low: &[f64],
2787 close: &[f64],
2788 sweep: &HalfTrendBatchRange,
2789 kern: Kernel,
2790 dst_halftrend: &mut [f64],
2791 dst_trend: &mut [f64],
2792 dst_atr_high: &mut [f64],
2793 dst_atr_low: &mut [f64],
2794 dst_buy: &mut [f64],
2795 dst_sell: &mut [f64],
2796) -> Result<Vec<HalfTrendParams>, HalfTrendError> {
2797 let combos = expand_grid_halftrend(sweep)?;
2798
2799 let len = high.len();
2800 let first = first_valid_ohlc(high, low, close);
2801 if first == usize::MAX {
2802 return Err(HalfTrendError::AllValuesNaN);
2803 }
2804
2805 let rows = combos.len();
2806 let cols = len;
2807
2808 use std::collections::BTreeSet;
2809 let uniq_amp: BTreeSet<usize> = combos.iter().map(|p| p.amplitude.unwrap()).collect();
2810 let uniq_atr: BTreeSet<usize> = combos.iter().map(|p| p.atr_period.unwrap()).collect();
2811
2812 use std::collections::HashMap;
2813 let mut hi_map: HashMap<usize, Vec<f64>> = HashMap::new();
2814 let mut lo_map: HashMap<usize, Vec<f64>> = HashMap::new();
2815 for &a in &uniq_amp {
2816 hi_map.insert(
2817 a,
2818 sma(&SmaInput::from_slice(high, SmaParams { period: Some(a) }))
2819 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?
2820 .values,
2821 );
2822 lo_map.insert(
2823 a,
2824 sma(&SmaInput::from_slice(low, SmaParams { period: Some(a) }))
2825 .map_err(|e| HalfTrendError::SmaError(e.to_string()))?
2826 .values,
2827 );
2828 }
2829
2830 let mut roll_high_map: HashMap<usize, Vec<f64>> = HashMap::new();
2831 let mut roll_low_map: HashMap<usize, Vec<f64>> = HashMap::new();
2832 for &a in &uniq_amp {
2833 roll_high_map.insert(a, rolling_max_series(high, a));
2834 roll_low_map.insert(a, rolling_min_series(low, a));
2835 }
2836 let mut atr_map: HashMap<usize, Vec<f64>> = HashMap::new();
2837 for &p in &uniq_atr {
2838 atr_map.insert(
2839 p,
2840 atr(&AtrInput::from_slices(
2841 high,
2842 low,
2843 close,
2844 AtrParams { length: Some(p) },
2845 ))
2846 .map_err(|e| HalfTrendError::AtrError(e.to_string()))?
2847 .values,
2848 );
2849 }
2850
2851 for row in 0..rows {
2852 let prm = &combos[row];
2853 let amp = prm.amplitude.unwrap();
2854 let ap = prm.atr_period.unwrap();
2855 let ch = prm.channel_deviation.unwrap_or(2.0);
2856 let warm = warmup_from(first, amp, ap);
2857
2858 let base = row * cols;
2859 let (ht, tr, ah, al, bs, ss) = (
2860 &mut dst_halftrend[base..base + cols],
2861 &mut dst_trend[base..base + cols],
2862 &mut dst_atr_high[base..base + cols],
2863 &mut dst_atr_low[base..base + cols],
2864 &mut dst_buy[base..base + cols],
2865 &mut dst_sell[base..base + cols],
2866 );
2867
2868 let hma = hi_map.get(&).unwrap().as_slice();
2869 let lma = lo_map.get(&).unwrap().as_slice();
2870 let av = atr_map.get(&ap).unwrap().as_slice();
2871 let rhi = roll_high_map.get(&).unwrap().as_slice();
2872 let rlo = roll_low_map.get(&).unwrap().as_slice();
2873
2874 halftrend_row_into_precomputed(
2875 high, low, close, ch, av, hma, lma, warm, rhi, rlo, ht, tr, ah, al, bs, ss,
2876 );
2877 }
2878
2879 Ok(combos)
2880}
2881
2882#[cfg(feature = "python")]
2883#[pyfunction]
2884#[pyo3(name = "halftrend", signature = (
2885 high,
2886 low,
2887 close,
2888 amplitude,
2889 channel_deviation,
2890 atr_period,
2891 kernel = None
2892))]
2893pub fn halftrend_py<'py>(
2894 py: Python<'py>,
2895 high: PyReadonlyArray1<'py, f64>,
2896 low: PyReadonlyArray1<'py, f64>,
2897 close: PyReadonlyArray1<'py, f64>,
2898 amplitude: usize,
2899 channel_deviation: f64,
2900 atr_period: usize,
2901 kernel: Option<&str>,
2902) -> PyResult<Bound<'py, PyDict>> {
2903 let h = high.as_slice()?;
2904 let l = low.as_slice()?;
2905 let c = close.as_slice()?;
2906 let kern = validate_kernel(kernel, false)?;
2907 let params = HalfTrendParams {
2908 amplitude: Some(amplitude),
2909 channel_deviation: Some(channel_deviation),
2910 atr_period: Some(atr_period),
2911 };
2912 let input = HalfTrendInput::from_slices(h, l, c, params);
2913
2914 let out = py
2915 .allow_threads(|| halftrend_with_kernel(&input, kern))
2916 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2917
2918 let dict = PyDict::new(py);
2919 dict.set_item("halftrend", out.halftrend.into_pyarray(py))?;
2920 dict.set_item("trend", out.trend.into_pyarray(py))?;
2921 dict.set_item("atr_high", out.atr_high.into_pyarray(py))?;
2922 dict.set_item("atr_low", out.atr_low.into_pyarray(py))?;
2923 dict.set_item("buy_signal", out.buy_signal.into_pyarray(py))?;
2924 dict.set_item("sell_signal", out.sell_signal.into_pyarray(py))?;
2925
2926 Ok(dict)
2927}
2928
2929#[cfg(feature = "python")]
2930#[pyfunction]
2931#[pyo3(name = "halftrend_tuple", signature = (
2932 high,
2933 low,
2934 close,
2935 amplitude,
2936 channel_deviation,
2937 atr_period,
2938 kernel = None
2939))]
2940pub fn halftrend_tuple_py<'py>(
2941 py: Python<'py>,
2942 high: PyReadonlyArray1<'py, f64>,
2943 low: PyReadonlyArray1<'py, f64>,
2944 close: PyReadonlyArray1<'py, f64>,
2945 amplitude: usize,
2946 channel_deviation: f64,
2947 atr_period: usize,
2948 kernel: Option<&str>,
2949) -> PyResult<(
2950 Bound<'py, PyArray1<f64>>,
2951 Bound<'py, PyArray1<f64>>,
2952 Bound<'py, PyArray1<f64>>,
2953 Bound<'py, PyArray1<f64>>,
2954 Bound<'py, PyArray1<f64>>,
2955 Bound<'py, PyArray1<f64>>,
2956)> {
2957 let h = high.as_slice()?;
2958 let l = low.as_slice()?;
2959 let c = close.as_slice()?;
2960 let kern = validate_kernel(kernel, false)?;
2961 let params = HalfTrendParams {
2962 amplitude: Some(amplitude),
2963 channel_deviation: Some(channel_deviation),
2964 atr_period: Some(atr_period),
2965 };
2966 let input = HalfTrendInput::from_slices(h, l, c, params);
2967
2968 let out = py
2969 .allow_threads(|| halftrend_with_kernel(&input, kern))
2970 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2971 Ok((
2972 out.halftrend.into_pyarray(py),
2973 out.trend.into_pyarray(py),
2974 out.atr_high.into_pyarray(py),
2975 out.atr_low.into_pyarray(py),
2976 out.buy_signal.into_pyarray(py),
2977 out.sell_signal.into_pyarray(py),
2978 ))
2979}
2980
2981#[cfg(feature = "python")]
2982#[pyfunction]
2983#[pyo3(name = "halftrend_batch", signature = (
2984 high,
2985 low,
2986 close,
2987 amplitude_start = None,
2988 amplitude_end = None,
2989 amplitude_step = None,
2990 channel_deviation_start = None,
2991 channel_deviation_end = None,
2992 channel_deviation_step = None,
2993 atr_period_start = None,
2994 atr_period_end = None,
2995 atr_period_step = None,
2996 kernel = None
2997))]
2998pub fn halftrend_batch_py<'py>(
2999 py: Python<'py>,
3000 high: PyReadonlyArray1<'py, f64>,
3001 low: PyReadonlyArray1<'py, f64>,
3002 close: PyReadonlyArray1<'py, f64>,
3003 amplitude_start: Option<usize>,
3004 amplitude_end: Option<usize>,
3005 amplitude_step: Option<usize>,
3006 channel_deviation_start: Option<f64>,
3007 channel_deviation_end: Option<f64>,
3008 channel_deviation_step: Option<f64>,
3009 atr_period_start: Option<usize>,
3010 atr_period_end: Option<usize>,
3011 atr_period_step: Option<usize>,
3012 kernel: Option<&str>,
3013) -> PyResult<Bound<'py, PyDict>> {
3014 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
3015
3016 let h = high.as_slice()?;
3017 let l = low.as_slice()?;
3018 let c = close.as_slice()?;
3019 let kern = validate_kernel(kernel, true)?;
3020
3021 let mut range = HalfTrendBatchRange::default();
3022 if let (Some(s), Some(e), Some(st)) = (amplitude_start, amplitude_end, amplitude_step) {
3023 range.amplitude = (s, e, st);
3024 } else if let Some(v) = amplitude_start {
3025 range.amplitude = (v, v, 0);
3026 }
3027 if let (Some(s), Some(e), Some(st)) = (
3028 channel_deviation_start,
3029 channel_deviation_end,
3030 channel_deviation_step,
3031 ) {
3032 range.channel_deviation = (s, e, st);
3033 } else if let Some(v) = channel_deviation_start {
3034 range.channel_deviation = (v, v, 0.0);
3035 }
3036 if let (Some(s), Some(e), Some(st)) = (atr_period_start, atr_period_end, atr_period_step) {
3037 range.atr_period = (s, e, st);
3038 } else if let Some(v) = atr_period_start {
3039 range.atr_period = (v, v, 0);
3040 }
3041
3042 let combos = expand_grid_halftrend(&range).map_err(|e| PyValueError::new_err(e.to_string()))?;
3043 if combos.is_empty() {
3044 return Err(PyValueError::new_err("empty sweep"));
3045 }
3046 let rows = combos.len();
3047 let cols = h.len();
3048
3049 let total = rows
3050 .checked_mul(cols)
3051 .ok_or_else(|| PyValueError::new_err("halftrend_batch: rows*cols overflow"))?;
3052 let arr_ht = unsafe { PyArray1::<f64>::new(py, [total], false) };
3053 let arr_tr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3054 let arr_ah = unsafe { PyArray1::<f64>::new(py, [total], false) };
3055 let arr_al = unsafe { PyArray1::<f64>::new(py, [total], false) };
3056 let arr_bs = unsafe { PyArray1::<f64>::new(py, [total], false) };
3057 let arr_ss = unsafe { PyArray1::<f64>::new(py, [total], false) };
3058
3059 let dst_ht = unsafe { arr_ht.as_slice_mut()? };
3060 let dst_tr = unsafe { arr_tr.as_slice_mut()? };
3061 let dst_ah = unsafe { arr_ah.as_slice_mut()? };
3062 let dst_al = unsafe { arr_al.as_slice_mut()? };
3063 let dst_bs = unsafe { arr_bs.as_slice_mut()? };
3064 let dst_ss = unsafe { arr_ss.as_slice_mut()? };
3065
3066 let first = first_valid_ohlc(h, l, c);
3067 let warms: Vec<usize> = combos
3068 .iter()
3069 .map(|p| warmup_from(first, p.amplitude.unwrap(), p.atr_period.unwrap()))
3070 .collect();
3071 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
3072 for (row, &w) in warms.iter().enumerate() {
3073 let base = row * cols;
3074 for x in &mut dst_ht[base..base + w] {
3075 *x = qnan;
3076 }
3077 for x in &mut dst_tr[base..base + w] {
3078 *x = qnan;
3079 }
3080 for x in &mut dst_ah[base..base + w] {
3081 *x = qnan;
3082 }
3083 for x in &mut dst_al[base..base + w] {
3084 *x = qnan;
3085 }
3086 for x in &mut dst_bs[base..base + w] {
3087 *x = qnan;
3088 }
3089 for x in &mut dst_ss[base..base + w] {
3090 *x = qnan;
3091 }
3092 }
3093
3094 let simd = match kern {
3095 Kernel::Auto => detect_best_batch_kernel(),
3096 k => k,
3097 };
3098 let simd = match simd {
3099 Kernel::Avx512Batch => Kernel::Avx512,
3100 Kernel::Avx2Batch => Kernel::Avx2,
3101 Kernel::ScalarBatch => Kernel::Scalar,
3102 _ => Kernel::Scalar,
3103 };
3104
3105 py.allow_threads(|| {
3106 halftrend_batch_rows_into(
3107 h, l, c, &range, simd, dst_ht, dst_tr, dst_ah, dst_al, dst_bs, dst_ss,
3108 )
3109 })
3110 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3111
3112 let total_rows = rows
3113 .checked_mul(6)
3114 .ok_or_else(|| PyValueError::new_err("halftrend_batch: rows*6 overflow"))?;
3115 let total_stacked = total_rows
3116 .checked_mul(cols)
3117 .ok_or_else(|| PyValueError::new_err("halftrend_batch: stacked size overflow"))?;
3118 let stacked = unsafe { PyArray1::<f64>::new(py, [total_stacked], false) };
3119 let dst_stacked = unsafe { stacked.as_slice_mut()? };
3120
3121 let block = rows
3122 .checked_mul(cols)
3123 .ok_or_else(|| PyValueError::new_err("halftrend_batch: block size overflow"))?;
3124 dst_stacked[0..block].copy_from_slice(dst_ht);
3125 dst_stacked[block..2 * block].copy_from_slice(dst_tr);
3126 dst_stacked[2 * block..3 * block].copy_from_slice(dst_ah);
3127 dst_stacked[3 * block..4 * block].copy_from_slice(dst_al);
3128 dst_stacked[4 * block..5 * block].copy_from_slice(dst_bs);
3129 dst_stacked[5 * block..6 * block].copy_from_slice(dst_ss);
3130
3131 let dict = PyDict::new(py);
3132 dict.set_item("values", stacked.reshape((total_rows, cols))?)?;
3133
3134 use pyo3::types::PyList;
3135 let series = PyList::new(
3136 py,
3137 vec!["halftrend", "trend", "atr_high", "atr_low", "buy", "sell"],
3138 )?;
3139 dict.set_item("series", series)?;
3140 dict.set_item("rows", rows)?;
3141 dict.set_item("cols", cols)?;
3142 dict.set_item(
3143 "amplitudes",
3144 combos
3145 .iter()
3146 .map(|p| p.amplitude.unwrap() as u64)
3147 .collect::<Vec<_>>()
3148 .into_pyarray(py),
3149 )?;
3150 dict.set_item(
3151 "channel_deviations",
3152 combos
3153 .iter()
3154 .map(|p| p.channel_deviation.unwrap())
3155 .collect::<Vec<_>>()
3156 .into_pyarray(py),
3157 )?;
3158 dict.set_item(
3159 "atr_periods",
3160 combos
3161 .iter()
3162 .map(|p| p.atr_period.unwrap() as u64)
3163 .collect::<Vec<_>>()
3164 .into_pyarray(py),
3165 )?;
3166 Ok(dict.into())
3167}
3168
3169#[cfg(all(feature = "python", feature = "cuda"))]
3170#[pyfunction(name = "halftrend_cuda_batch_dev")]
3171#[pyo3(signature = (high_f32, low_f32, close_f32, amplitude_range, channel_deviation_range=(2.0,2.0,0.0), atr_period_range=(14,14,0), device_id=0))]
3172pub fn halftrend_cuda_batch_dev_py<'py>(
3173 py: Python<'py>,
3174 high_f32: numpy::PyReadonlyArray1<'py, f32>,
3175 low_f32: numpy::PyReadonlyArray1<'py, f32>,
3176 close_f32: numpy::PyReadonlyArray1<'py, f32>,
3177 amplitude_range: (usize, usize, usize),
3178 channel_deviation_range: (f64, f64, f64),
3179 atr_period_range: (usize, usize, usize),
3180 device_id: usize,
3181) -> PyResult<Bound<'py, PyDict>> {
3182 if !cuda_available() {
3183 return Err(PyValueError::new_err("CUDA not available"));
3184 }
3185 let h = high_f32.as_slice()?;
3186 let l = low_f32.as_slice()?;
3187 let c = close_f32.as_slice()?;
3188 let sweep = HalfTrendBatchRange {
3189 amplitude: amplitude_range,
3190 channel_deviation: channel_deviation_range,
3191 atr_period: atr_period_range,
3192 };
3193 let (batch, ctx_arc, dev_id) = py.allow_threads(|| {
3194 let cuda =
3195 CudaHalftrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3196 let out = cuda
3197 .halftrend_batch_dev(h, l, c, &sweep)
3198 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3199 Ok::<_, PyErr>((out, cuda.context_arc(), cuda.device_id()))
3200 })?;
3201 let mut halftrend_dev = make_device_array_py(dev_id as usize, batch.halftrend)?;
3202 halftrend_dev._ctx = Some(ctx_arc.clone());
3203 let mut trend_dev = make_device_array_py(dev_id as usize, batch.trend)?;
3204 trend_dev._ctx = Some(ctx_arc.clone());
3205 let mut atr_high_dev = make_device_array_py(dev_id as usize, batch.atr_high)?;
3206 atr_high_dev._ctx = Some(ctx_arc.clone());
3207 let mut atr_low_dev = make_device_array_py(dev_id as usize, batch.atr_low)?;
3208 atr_low_dev._ctx = Some(ctx_arc.clone());
3209 let mut buy_dev = make_device_array_py(dev_id as usize, batch.buy)?;
3210 buy_dev._ctx = Some(ctx_arc.clone());
3211 let mut sell_dev = make_device_array_py(dev_id as usize, batch.sell)?;
3212 sell_dev._ctx = Some(ctx_arc.clone());
3213 let dict = PyDict::new(py);
3214 dict.set_item("halftrend", Py::new(py, halftrend_dev)?)?;
3215 dict.set_item("trend", Py::new(py, trend_dev)?)?;
3216 dict.set_item("atr_high", Py::new(py, atr_high_dev)?)?;
3217 dict.set_item("atr_low", Py::new(py, atr_low_dev)?)?;
3218 dict.set_item("buy_signal", Py::new(py, buy_dev)?)?;
3219 dict.set_item("sell_signal", Py::new(py, sell_dev)?)?;
3220 use numpy::IntoPyArray;
3221 dict.set_item(
3222 "amplitudes",
3223 batch
3224 .combos
3225 .iter()
3226 .map(|p| p.amplitude.unwrap() as u64)
3227 .collect::<Vec<_>>()
3228 .into_pyarray(py),
3229 )?;
3230 dict.set_item(
3231 "channel_deviations",
3232 batch
3233 .combos
3234 .iter()
3235 .map(|p| p.channel_deviation.unwrap())
3236 .collect::<Vec<_>>()
3237 .into_pyarray(py),
3238 )?;
3239 dict.set_item(
3240 "atr_periods",
3241 batch
3242 .combos
3243 .iter()
3244 .map(|p| p.atr_period.unwrap() as u64)
3245 .collect::<Vec<_>>()
3246 .into_pyarray(py),
3247 )?;
3248 Ok(dict)
3249}
3250
3251#[cfg(all(feature = "python", feature = "cuda"))]
3252#[pyfunction(name = "halftrend_cuda_many_series_one_param_dev")]
3253#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, amplitude, channel_deviation, atr_period, device_id=0))]
3254pub fn halftrend_cuda_many_series_one_param_dev_py<'py>(
3255 py: Python<'py>,
3256 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
3257 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
3258 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
3259 cols: usize,
3260 rows: usize,
3261 amplitude: usize,
3262 channel_deviation: f64,
3263 atr_period: usize,
3264 device_id: usize,
3265) -> PyResult<Bound<'py, PyDict>> {
3266 if !cuda_available() {
3267 return Err(PyValueError::new_err("CUDA not available"));
3268 }
3269 let h = high_tm_f32.as_slice()?;
3270 let l = low_tm_f32.as_slice()?;
3271 let c = close_tm_f32.as_slice()?;
3272 let (out, ctx_arc, dev_id) = py.allow_threads(|| {
3273 let cuda =
3274 CudaHalftrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3275 let out = cuda
3276 .halftrend_many_series_one_param_time_major_dev(
3277 h,
3278 l,
3279 c,
3280 cols,
3281 rows,
3282 amplitude,
3283 channel_deviation,
3284 atr_period,
3285 )
3286 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3287 Ok::<_, PyErr>((out, cuda.context_arc(), cuda.device_id()))
3288 })?;
3289 let mut halftrend_dev = make_device_array_py(dev_id as usize, out.halftrend)?;
3290 halftrend_dev._ctx = Some(ctx_arc.clone());
3291 let mut trend_dev = make_device_array_py(dev_id as usize, out.trend)?;
3292 trend_dev._ctx = Some(ctx_arc.clone());
3293 let mut atr_high_dev = make_device_array_py(dev_id as usize, out.atr_high)?;
3294 atr_high_dev._ctx = Some(ctx_arc.clone());
3295 let mut atr_low_dev = make_device_array_py(dev_id as usize, out.atr_low)?;
3296 atr_low_dev._ctx = Some(ctx_arc.clone());
3297 let mut buy_dev = make_device_array_py(dev_id as usize, out.buy)?;
3298 buy_dev._ctx = Some(ctx_arc.clone());
3299 let mut sell_dev = make_device_array_py(dev_id as usize, out.sell)?;
3300 sell_dev._ctx = Some(ctx_arc.clone());
3301 let dict = PyDict::new(py);
3302 dict.set_item("halftrend", Py::new(py, halftrend_dev)?)?;
3303 dict.set_item("trend", Py::new(py, trend_dev)?)?;
3304 dict.set_item("atr_high", Py::new(py, atr_high_dev)?)?;
3305 dict.set_item("atr_low", Py::new(py, atr_low_dev)?)?;
3306 dict.set_item("buy_signal", Py::new(py, buy_dev)?)?;
3307 dict.set_item("sell_signal", Py::new(py, sell_dev)?)?;
3308 Ok(dict)
3309}
3310
3311#[cfg(feature = "python")]
3312#[pyclass(name = "HalfTrendStream")]
3313pub struct HalfTrendStreamPy {
3314 stream: HalfTrendStream,
3315}
3316
3317#[cfg(feature = "python")]
3318#[pymethods]
3319impl HalfTrendStreamPy {
3320 #[new]
3321 fn new(amplitude: usize, channel_deviation: f64, atr_period: usize) -> PyResult<Self> {
3322 let params = HalfTrendParams {
3323 amplitude: Some(amplitude),
3324 channel_deviation: Some(channel_deviation),
3325 atr_period: Some(atr_period),
3326 };
3327
3328 let stream =
3329 HalfTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
3330
3331 Ok(Self { stream })
3332 }
3333
3334 pub fn update(
3335 &mut self,
3336 high: f64,
3337 low: f64,
3338 close: f64,
3339 ) -> Option<(f64, f64, f64, f64, Option<f64>, Option<f64>)> {
3340 self.stream.update(high, low, close).map(|output| {
3341 (
3342 output.halftrend,
3343 output.trend,
3344 output.atr_high,
3345 output.atr_low,
3346 output.buy_signal,
3347 output.sell_signal,
3348 )
3349 })
3350 }
3351}
3352
3353#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3354#[derive(Serialize, Deserialize)]
3355pub struct HalfTrendJsResult {
3356 pub values: Vec<f64>,
3357 pub rows: usize,
3358 pub cols: usize,
3359}
3360
3361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3362#[wasm_bindgen(js_name = "halftrend")]
3363pub fn halftrend_js(
3364 high: &[f64],
3365 low: &[f64],
3366 close: &[f64],
3367 amplitude: usize,
3368 channel_deviation: f64,
3369 atr_period: usize,
3370) -> Result<JsValue, JsValue> {
3371 if high.is_empty() || low.is_empty() || close.is_empty() {
3372 return Err(JsValue::from_str("halftrend: Input data slice is empty."));
3373 }
3374
3375 let len = high.len();
3376 if len != low.len() || len != close.len() {
3377 return Err(JsValue::from_str(&format!(
3378 "halftrend: Mismatched array lengths: high={}, low={}, close={}",
3379 high.len(),
3380 low.len(),
3381 close.len()
3382 )));
3383 }
3384
3385 if channel_deviation <= 0.0 {
3386 return Err(JsValue::from_str(&format!(
3387 "halftrend: Invalid channel_deviation: {}",
3388 channel_deviation
3389 )));
3390 }
3391
3392 if amplitude == 0 {
3393 return Err(JsValue::from_str(&format!(
3394 "halftrend: Invalid period: period = {}, data length = {}",
3395 amplitude, len
3396 )));
3397 }
3398
3399 if atr_period == 0 {
3400 return Err(JsValue::from_str(&format!(
3401 "halftrend: Invalid period: period = {}, data length = {}",
3402 atr_period, len
3403 )));
3404 }
3405
3406 if atr_period > len {
3407 return Err(JsValue::from_str(&format!(
3408 "halftrend: Invalid period: period = {}, data length = {}",
3409 atr_period, len
3410 )));
3411 }
3412
3413 let mut valid_count = 0;
3414 for i in 0..len {
3415 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
3416 valid_count += 1;
3417 }
3418 }
3419
3420 if valid_count == 0 {
3421 return Err(JsValue::from_str("halftrend: All values are NaN."));
3422 }
3423
3424 let required = amplitude.max(atr_period);
3425 if valid_count < required {
3426 return Err(JsValue::from_str(&format!(
3427 "halftrend: Not enough valid data: needed = {}, valid = {}",
3428 required, valid_count
3429 )));
3430 }
3431
3432 let params = HalfTrendParams {
3433 amplitude: Some(amplitude),
3434 channel_deviation: Some(channel_deviation),
3435 atr_period: Some(atr_period),
3436 };
3437 let input = HalfTrendInput::from_slices(high, low, close, params);
3438
3439 let cols = high.len();
3440 let rows: usize = 6;
3441 let total = rows
3442 .checked_mul(cols)
3443 .ok_or_else(|| JsValue::from_str("halftrend: rows*cols overflow"))?;
3444 let mut values = vec![0.0; total];
3445
3446 let (ht, rest) = values.split_at_mut(cols);
3447 let (tr, rest) = rest.split_at_mut(cols);
3448 let (ah, rest) = rest.split_at_mut(cols);
3449 let (al, rest) = rest.split_at_mut(cols);
3450 let (bs, ss) = rest.split_at_mut(cols);
3451
3452 halftrend_into_slices_kernel(ht, tr, ah, al, bs, ss, &input, Kernel::Auto)
3453 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3454
3455 serde_wasm_bindgen::to_value(&HalfTrendJsResult { values, rows, cols })
3456 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
3457}
3458
3459#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3460#[wasm_bindgen]
3461pub fn halftrend_wasm(
3462 high: &[f64],
3463 low: &[f64],
3464 close: &[f64],
3465 amplitude: Option<usize>,
3466 channel_deviation: Option<f64>,
3467 atr_period: Option<usize>,
3468) -> Result<JsValue, JsValue> {
3469 let len = high.len();
3470 let candles = Candles {
3471 timestamp: vec![0; len],
3472 high: high.to_vec(),
3473 low: low.to_vec(),
3474 close: close.to_vec(),
3475 open: vec![f64::NAN; len],
3476 volume: vec![f64::NAN; len],
3477 fields: CandleFieldFlags {
3478 open: false,
3479 high: true,
3480 low: true,
3481 close: true,
3482 volume: false,
3483 },
3484 hl2: high
3485 .iter()
3486 .zip(low.iter())
3487 .map(|(h, l)| (h + l) / 2.0)
3488 .collect(),
3489 hlc3: high
3490 .iter()
3491 .zip(low.iter())
3492 .zip(close.iter())
3493 .map(|((h, l), c)| (h + l + c) / 3.0)
3494 .collect(),
3495 ohlc4: vec![f64::NAN; len],
3496 hlcc4: high
3497 .iter()
3498 .zip(low.iter())
3499 .zip(close.iter())
3500 .map(|((h, l), c)| (h + l + c + c) / 4.0)
3501 .collect(),
3502 };
3503
3504 let params = HalfTrendParams {
3505 amplitude,
3506 channel_deviation,
3507 atr_period,
3508 };
3509
3510 let input = HalfTrendInput::from_candles(&candles, params);
3511 let output = halftrend(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
3512
3513 let result = js_sys::Object::new();
3514
3515 let halftrend_array = js_sys::Float64Array::from(&output.halftrend[..]);
3516 let trend_array = js_sys::Float64Array::from(&output.trend[..]);
3517 let atr_high_array = js_sys::Float64Array::from(&output.atr_high[..]);
3518 let atr_low_array = js_sys::Float64Array::from(&output.atr_low[..]);
3519 let buy_signal_array = js_sys::Float64Array::from(&output.buy_signal[..]);
3520 let sell_signal_array = js_sys::Float64Array::from(&output.sell_signal[..]);
3521
3522 js_sys::Reflect::set(&result, &"halftrend".into(), &halftrend_array)?;
3523 js_sys::Reflect::set(&result, &"trend".into(), &trend_array)?;
3524 js_sys::Reflect::set(&result, &"atr_high".into(), &atr_high_array)?;
3525 js_sys::Reflect::set(&result, &"atr_low".into(), &atr_low_array)?;
3526 js_sys::Reflect::set(&result, &"buy_signal".into(), &buy_signal_array)?;
3527 js_sys::Reflect::set(&result, &"sell_signal".into(), &sell_signal_array)?;
3528
3529 Ok(result.into())
3530}
3531
3532#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3533#[wasm_bindgen]
3534pub fn halftrend_alloc(len: usize) -> *mut f64 {
3535 let mut v = Vec::<f64>::with_capacity(len);
3536 let ptr = v.as_mut_ptr();
3537 core::mem::forget(v);
3538 ptr
3539}
3540
3541#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3542#[wasm_bindgen]
3543pub fn halftrend_free(ptr: *mut f64, len: usize) {
3544 unsafe {
3545 let _ = Vec::from_raw_parts(ptr, len, len);
3546 }
3547}
3548
3549#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3550#[wasm_bindgen(js_name = "halftrend_into")]
3551pub fn halftrend_into(
3552 high_ptr: *const f64,
3553 low_ptr: *const f64,
3554 close_ptr: *const f64,
3555 out_ptr: *mut f64,
3556 len: usize,
3557 amplitude: usize,
3558 channel_deviation: f64,
3559 atr_period: usize,
3560) -> Result<(), JsValue> {
3561 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
3562 return Err(JsValue::from_str("null pointer"));
3563 }
3564
3565 if len == 0 {
3566 return Err(JsValue::from_str("halftrend: Input data slice is empty."));
3567 }
3568
3569 if channel_deviation <= 0.0 {
3570 return Err(JsValue::from_str(&format!(
3571 "halftrend: Invalid channel_deviation: {}",
3572 channel_deviation
3573 )));
3574 }
3575
3576 if amplitude == 0 {
3577 return Err(JsValue::from_str(&format!(
3578 "halftrend: Invalid period: period = {}, data length = {}",
3579 amplitude, len
3580 )));
3581 }
3582
3583 if atr_period == 0 {
3584 return Err(JsValue::from_str(&format!(
3585 "halftrend: Invalid period: period = {}, data length = {}",
3586 atr_period, len
3587 )));
3588 }
3589
3590 if atr_period > len {
3591 return Err(JsValue::from_str(&format!(
3592 "halftrend: Invalid period: period = {}, data length = {}",
3593 atr_period, len
3594 )));
3595 }
3596
3597 unsafe {
3598 let h = core::slice::from_raw_parts(high_ptr, len);
3599 let l = core::slice::from_raw_parts(low_ptr, len);
3600 let c = core::slice::from_raw_parts(close_ptr, len);
3601 let out = core::slice::from_raw_parts_mut(out_ptr, 6 * len);
3602
3603 let (ht, rest) = out.split_at_mut(len);
3604 let (tr, rest) = rest.split_at_mut(len);
3605 let (ah, rest) = rest.split_at_mut(len);
3606 let (al, rest) = rest.split_at_mut(len);
3607 let (bs, ss) = rest.split_at_mut(len);
3608
3609 let input = HalfTrendInput::from_slices(
3610 h,
3611 l,
3612 c,
3613 HalfTrendParams {
3614 amplitude: Some(amplitude),
3615 channel_deviation: Some(channel_deviation),
3616 atr_period: Some(atr_period),
3617 },
3618 );
3619
3620 halftrend_into_slices_kernel(ht, tr, ah, al, bs, ss, &input, Kernel::Auto)
3621 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3622 }
3623 Ok(())
3624}
3625
3626#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3627#[derive(Serialize, Deserialize)]
3628pub struct HalfTrendBatchConfig {
3629 pub amplitude_range: (usize, usize, usize),
3630 pub channel_deviation_range: (f64, f64, f64),
3631 pub atr_period_range: (usize, usize, usize),
3632}
3633
3634#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3635#[derive(Serialize, Deserialize)]
3636pub struct HalfTrendBatchJsOutput {
3637 pub values: Vec<f64>,
3638 pub combos: Vec<HalfTrendParams>,
3639 pub rows: usize,
3640 pub cols: usize,
3641}
3642
3643#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3644#[wasm_bindgen(js_name = "halftrend_batch")]
3645pub fn halftrend_batch_unified_js(
3646 high: &[f64],
3647 low: &[f64],
3648 close: &[f64],
3649 config: JsValue,
3650) -> Result<JsValue, JsValue> {
3651 let cfg: HalfTrendBatchConfig =
3652 serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
3653
3654 let sweep = HalfTrendBatchRange {
3655 amplitude: cfg.amplitude_range,
3656 channel_deviation: cfg.channel_deviation_range,
3657 atr_period: cfg.atr_period_range,
3658 };
3659
3660 let combos = expand_grid_halftrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3661 let rows_ind = combos.len();
3662 let cols = high.len();
3663
3664 let mut ht = vec![0.0; rows_ind * cols];
3665 let mut tr = vec![0.0; rows_ind * cols];
3666 let mut ah = vec![0.0; rows_ind * cols];
3667 let mut al = vec![0.0; rows_ind * cols];
3668 let mut bs = vec![0.0; rows_ind * cols];
3669 let mut ss = vec![0.0; rows_ind * cols];
3670
3671 halftrend_batch_rows_into(
3672 high,
3673 low,
3674 close,
3675 &sweep,
3676 Kernel::Auto,
3677 &mut ht,
3678 &mut tr,
3679 &mut ah,
3680 &mut al,
3681 &mut bs,
3682 &mut ss,
3683 )
3684 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3685
3686 let mut values = Vec::with_capacity(6 * rows_ind * cols);
3687 values.extend_from_slice(&ht);
3688 values.extend_from_slice(&tr);
3689 values.extend_from_slice(&ah);
3690 values.extend_from_slice(&al);
3691 values.extend_from_slice(&bs);
3692 values.extend_from_slice(&ss);
3693
3694 let out = HalfTrendBatchJsOutput {
3695 values,
3696 combos,
3697 rows: 6 * rows_ind,
3698 cols,
3699 };
3700 serde_wasm_bindgen::to_value(&out).map_err(|e| JsValue::from_str(&e.to_string()))
3701}
3702
3703#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3704#[wasm_bindgen]
3705pub fn halftrend_batch_into(
3706 high_ptr: *const f64,
3707 low_ptr: *const f64,
3708 close_ptr: *const f64,
3709 out_ptr: *mut f64,
3710 len: usize,
3711 amp_start: usize,
3712 amp_end: usize,
3713 amp_step: usize,
3714 ch_start: f64,
3715 ch_end: f64,
3716 ch_step: f64,
3717 atr_start: usize,
3718 atr_end: usize,
3719 atr_step: usize,
3720) -> Result<usize, JsValue> {
3721 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
3722 return Err(JsValue::from_str(
3723 "null pointer passed to halftrend_batch_into",
3724 ));
3725 }
3726 unsafe {
3727 let h = std::slice::from_raw_parts(high_ptr, len);
3728 let l = std::slice::from_raw_parts(low_ptr, len);
3729 let c = std::slice::from_raw_parts(close_ptr, len);
3730
3731 let sweep = HalfTrendBatchRange {
3732 amplitude: (amp_start, amp_end, amp_step),
3733 channel_deviation: (ch_start, ch_end, ch_step),
3734 atr_period: (atr_start, atr_end, atr_step),
3735 };
3736 let combos =
3737 expand_grid_halftrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3738 let rows_ind = combos.len();
3739 let cols = len;
3740
3741 let out = std::slice::from_raw_parts_mut(out_ptr, 6 * rows_ind * cols);
3742
3743 let block = rows_ind * cols;
3744 let (ht, rest) = out.split_at_mut(block);
3745 let (tr, rest) = rest.split_at_mut(block);
3746 let (ah, rest) = rest.split_at_mut(block);
3747 let (al, rest) = rest.split_at_mut(block);
3748 let (bs, ss) = rest.split_at_mut(block);
3749
3750 halftrend_batch_rows_into(h, l, c, &sweep, Kernel::Auto, ht, tr, ah, al, bs, ss)
3751 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3752
3753 Ok(rows_ind)
3754 }
3755}