1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::collections::VecDeque;
25use std::error::Error;
26use thiserror::Error;
27
28#[derive(Debug, Clone)]
29pub enum ZigZagChannelsData<'a> {
30 Candles {
31 candles: &'a Candles,
32 },
33 Slices {
34 open: &'a [f64],
35 high: &'a [f64],
36 low: &'a [f64],
37 close: &'a [f64],
38 },
39}
40
41#[derive(Debug, Clone)]
42pub struct ZigZagChannelsOutput {
43 pub middle: Vec<f64>,
44 pub upper: Vec<f64>,
45 pub lower: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50 all(target_arch = "wasm32", feature = "wasm"),
51 derive(Serialize, Deserialize)
52)]
53pub struct ZigZagChannelsParams {
54 pub length: Option<usize>,
55 pub extend: Option<bool>,
56}
57
58impl Default for ZigZagChannelsParams {
59 fn default() -> Self {
60 Self {
61 length: Some(100),
62 extend: Some(true),
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct ZigZagChannelsInput<'a> {
69 pub data: ZigZagChannelsData<'a>,
70 pub params: ZigZagChannelsParams,
71}
72
73impl<'a> ZigZagChannelsInput<'a> {
74 #[inline]
75 pub fn from_candles(candles: &'a Candles, params: ZigZagChannelsParams) -> Self {
76 Self {
77 data: ZigZagChannelsData::Candles { candles },
78 params,
79 }
80 }
81
82 #[inline]
83 pub fn from_slices(
84 open: &'a [f64],
85 high: &'a [f64],
86 low: &'a [f64],
87 close: &'a [f64],
88 params: ZigZagChannelsParams,
89 ) -> Self {
90 Self {
91 data: ZigZagChannelsData::Slices {
92 open,
93 high,
94 low,
95 close,
96 },
97 params,
98 }
99 }
100
101 #[inline]
102 pub fn with_default_candles(candles: &'a Candles) -> Self {
103 Self::from_candles(candles, ZigZagChannelsParams::default())
104 }
105
106 #[inline]
107 pub fn get_length(&self) -> usize {
108 self.params.length.unwrap_or(100)
109 }
110
111 #[inline]
112 pub fn get_extend(&self) -> bool {
113 self.params.extend.unwrap_or(true)
114 }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct ZigZagChannelsBuilder {
119 length: Option<usize>,
120 extend: Option<bool>,
121 kernel: Kernel,
122}
123
124impl Default for ZigZagChannelsBuilder {
125 fn default() -> Self {
126 Self {
127 length: None,
128 extend: None,
129 kernel: Kernel::Auto,
130 }
131 }
132}
133
134impl ZigZagChannelsBuilder {
135 #[inline(always)]
136 pub fn new() -> Self {
137 Self::default()
138 }
139
140 #[inline(always)]
141 pub fn length(mut self, value: usize) -> Self {
142 self.length = Some(value);
143 self
144 }
145
146 #[inline(always)]
147 pub fn extend(mut self, value: bool) -> Self {
148 self.extend = Some(value);
149 self
150 }
151
152 #[inline(always)]
153 pub fn kernel(mut self, value: Kernel) -> Self {
154 self.kernel = value;
155 self
156 }
157
158 #[inline(always)]
159 pub fn apply(self, candles: &Candles) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
160 zig_zag_channels_with_kernel(
161 &ZigZagChannelsInput::from_candles(
162 candles,
163 ZigZagChannelsParams {
164 length: self.length,
165 extend: self.extend,
166 },
167 ),
168 self.kernel,
169 )
170 }
171
172 #[inline(always)]
173 pub fn apply_slices(
174 self,
175 open: &[f64],
176 high: &[f64],
177 low: &[f64],
178 close: &[f64],
179 ) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
180 zig_zag_channels_with_kernel(
181 &ZigZagChannelsInput::from_slices(
182 open,
183 high,
184 low,
185 close,
186 ZigZagChannelsParams {
187 length: self.length,
188 extend: self.extend,
189 },
190 ),
191 self.kernel,
192 )
193 }
194}
195
196#[derive(Debug, Error)]
197pub enum ZigZagChannelsError {
198 #[error("zig_zag_channels: Input data slice is empty.")]
199 EmptyInputData,
200 #[error(
201 "zig_zag_channels: Input length mismatch: open = {open_len}, high = {high_len}, low = {low_len}, close = {close_len}"
202 )]
203 InputLengthMismatch {
204 open_len: usize,
205 high_len: usize,
206 low_len: usize,
207 close_len: usize,
208 },
209 #[error("zig_zag_channels: All values are NaN.")]
210 AllValuesNaN,
211 #[error("zig_zag_channels: Invalid length: {length}")]
212 InvalidLength { length: usize },
213 #[error("zig_zag_channels: Not enough valid data: needed = {needed}, valid = {valid}")]
214 NotEnoughValidData { needed: usize, valid: usize },
215 #[error("zig_zag_channels: Output length mismatch: expected = {expected}, got = {got}")]
216 OutputLengthMismatch { expected: usize, got: usize },
217 #[error("zig_zag_channels: Invalid range: start={start}, end={end}, step={step}")]
218 InvalidRange {
219 start: usize,
220 end: usize,
221 step: usize,
222 },
223 #[error("zig_zag_channels: Invalid kernel for batch: {0:?}")]
224 InvalidKernelForBatch(Kernel),
225 #[error(
226 "zig_zag_channels: Output length mismatch: dst = {dst_len}, expected = {expected_len}"
227 )]
228 MismatchedOutputLen { dst_len: usize, expected_len: usize },
229 #[error("zig_zag_channels: Invalid input: {msg}")]
230 InvalidInput { msg: String },
231}
232
233#[derive(Debug, Clone, Copy)]
234struct PivotState {
235 confirm_idx: usize,
236 value: f64,
237}
238
239#[inline(always)]
240fn is_valid_ohlc(open: f64, high: f64, low: f64, close: f64) -> bool {
241 open.is_finite() && high.is_finite() && low.is_finite() && close.is_finite()
242}
243
244#[inline(always)]
245fn longest_valid_run(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
246 let mut best = 0usize;
247 let mut cur = 0usize;
248 for (((&o, &h), &l), &c) in open
249 .iter()
250 .zip(high.iter())
251 .zip(low.iter())
252 .zip(close.iter())
253 {
254 if is_valid_ohlc(o, h, l, c) {
255 cur += 1;
256 best = best.max(cur);
257 } else {
258 cur = 0;
259 }
260 }
261 best
262}
263
264#[inline(always)]
265fn input_slices<'a>(
266 input: &'a ZigZagChannelsInput<'a>,
267) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), ZigZagChannelsError> {
268 match &input.data {
269 ZigZagChannelsData::Candles { candles } => Ok((
270 candles.open.as_slice(),
271 candles.high.as_slice(),
272 candles.low.as_slice(),
273 candles.close.as_slice(),
274 )),
275 ZigZagChannelsData::Slices {
276 open,
277 high,
278 low,
279 close,
280 } => Ok((open, high, low, close)),
281 }
282}
283
284#[inline(always)]
285fn validate_common(
286 open: &[f64],
287 high: &[f64],
288 low: &[f64],
289 close: &[f64],
290 length: usize,
291) -> Result<(), ZigZagChannelsError> {
292 if open.is_empty() || high.is_empty() || low.is_empty() || close.is_empty() {
293 return Err(ZigZagChannelsError::EmptyInputData);
294 }
295 if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
296 return Err(ZigZagChannelsError::InputLengthMismatch {
297 open_len: open.len(),
298 high_len: high.len(),
299 low_len: low.len(),
300 close_len: close.len(),
301 });
302 }
303 if length == 0 {
304 return Err(ZigZagChannelsError::InvalidLength { length });
305 }
306
307 let longest = longest_valid_run(open, high, low, close);
308 if longest == 0 {
309 return Err(ZigZagChannelsError::AllValuesNaN);
310 }
311
312 let needed = length
313 .checked_add(1)
314 .ok_or_else(|| ZigZagChannelsError::InvalidInput {
315 msg: "zig_zag_channels: length overflow".to_string(),
316 })?;
317 if longest < needed {
318 return Err(ZigZagChannelsError::NotEnoughValidData {
319 needed,
320 valid: longest,
321 });
322 }
323 Ok(())
324}
325
326#[inline(always)]
327fn compute_segment_offsets(
328 open: &[f64],
329 close: &[f64],
330 start_idx: usize,
331 end_idx: usize,
332 start_value: f64,
333 end_value: f64,
334) -> (f64, f64) {
335 if end_idx <= start_idx {
336 return (0.0, 0.0);
337 }
338
339 if end_idx == start_idx + 1 {
340 let top = open[end_idx].max(close[end_idx]);
341 let bottom = open[end_idx].min(close[end_idx]);
342 return ((top - end_value).max(0.0), (end_value - bottom).max(0.0));
343 }
344
345 let mut max_diff_up = 0.0f64;
346 let mut max_diff_dn = 0.0f64;
347 let denom = (end_idx - start_idx - 1) as f64;
348 let span = end_value - start_value;
349
350 for idx in (start_idx + 1)..=end_idx {
351 let j = (idx - start_idx - 1) as f64;
352 let point = start_value + (j / denom) * span;
353 let top = open[idx].max(close[idx]);
354 let bottom = open[idx].min(close[idx]);
355 max_diff_up = max_diff_up.max(top - point);
356 max_diff_dn = max_diff_dn.max(point - bottom);
357 }
358
359 (max_diff_up.max(0.0), max_diff_dn.max(0.0))
360}
361
362#[inline(always)]
363fn fill_segment(
364 middle: &mut [f64],
365 upper: &mut [f64],
366 lower: &mut [f64],
367 start_idx: usize,
368 end_idx: usize,
369 start_value: f64,
370 end_value: f64,
371 up_offset: f64,
372 dn_offset: f64,
373) {
374 if end_idx < start_idx {
375 return;
376 }
377
378 if start_idx == end_idx {
379 middle[start_idx] = start_value;
380 upper[start_idx] = start_value + up_offset;
381 lower[start_idx] = start_value - dn_offset;
382 return;
383 }
384
385 let denom = (end_idx - start_idx) as f64;
386 let span = end_value - start_value;
387 for idx in start_idx..=end_idx {
388 let t = (idx - start_idx) as f64 / denom;
389 let value = start_value + t * span;
390 middle[idx] = value;
391 upper[idx] = value + up_offset;
392 lower[idx] = value - dn_offset;
393 }
394}
395
396fn compute_run(
397 open: &[f64],
398 high: &[f64],
399 low: &[f64],
400 close: &[f64],
401 length: usize,
402 extend: bool,
403 middle: &mut [f64],
404 upper: &mut [f64],
405 lower: &mut [f64],
406) {
407 let n = close.len();
408 if n <= length {
409 return;
410 }
411
412 let mut max_deque: VecDeque<usize> = VecDeque::with_capacity(length);
413 let mut min_deque: VecDeque<usize> = VecDeque::with_capacity(length);
414 let mut os = 0usize;
415 let mut last_top: Option<PivotState> = None;
416 let mut last_bottom: Option<PivotState> = None;
417
418 for idx in 0..n {
419 let current_close = close[idx];
420 while let Some(&back) = max_deque.back() {
421 if close[back] <= current_close {
422 max_deque.pop_back();
423 } else {
424 break;
425 }
426 }
427 max_deque.push_back(idx);
428
429 while let Some(&back) = min_deque.back() {
430 if close[back] >= current_close {
431 min_deque.pop_back();
432 } else {
433 break;
434 }
435 }
436 min_deque.push_back(idx);
437
438 if idx < length {
439 continue;
440 }
441
442 let window_start = idx + 1 - length;
443 while let Some(&front) = max_deque.front() {
444 if front < window_start {
445 max_deque.pop_front();
446 } else {
447 break;
448 }
449 }
450 while let Some(&front) = min_deque.front() {
451 if front < window_start {
452 min_deque.pop_front();
453 } else {
454 break;
455 }
456 }
457
458 let candidate = idx - length;
459 let upper_close = close[*max_deque.front().expect("window max present")];
460 let lower_close = close[*min_deque.front().expect("window min present")];
461 let prev_os = os;
462 let candidate_close = close[candidate];
463
464 if candidate_close > upper_close {
465 os = 0;
466 } else if candidate_close < lower_close {
467 os = 1;
468 }
469
470 if os == 1 && prev_os != 1 {
471 let end_idx = candidate;
472 let end_value = low[end_idx];
473 if let Some(prev_top) = last_top {
474 let start_idx = prev_top.confirm_idx - length;
475 let start_value = prev_top.value;
476 let (up_offset, dn_offset) = compute_segment_offsets(
477 open,
478 close,
479 start_idx,
480 end_idx,
481 start_value,
482 end_value,
483 );
484 fill_segment(
485 middle,
486 upper,
487 lower,
488 start_idx,
489 end_idx,
490 start_value,
491 end_value,
492 up_offset,
493 dn_offset,
494 );
495 }
496 last_bottom = Some(PivotState {
497 confirm_idx: idx,
498 value: end_value,
499 });
500 }
501
502 if os == 0 && prev_os != 0 {
503 let end_idx = candidate;
504 let end_value = high[end_idx];
505 if let Some(prev_bottom) = last_bottom {
506 let start_idx = prev_bottom.confirm_idx - length;
507 let start_value = prev_bottom.value;
508 let (up_offset, dn_offset) = compute_segment_offsets(
509 open,
510 close,
511 start_idx,
512 end_idx,
513 start_value,
514 end_value,
515 );
516 fill_segment(
517 middle,
518 upper,
519 lower,
520 start_idx,
521 end_idx,
522 start_value,
523 end_value,
524 up_offset,
525 dn_offset,
526 );
527 }
528 last_top = Some(PivotState {
529 confirm_idx: idx,
530 value: end_value,
531 });
532 }
533 }
534
535 if !extend {
536 return;
537 }
538
539 let end_idx = n - 1;
540 let end_value = close[end_idx];
541 if os == 1 {
542 if let Some(prev_bottom) = last_bottom {
543 let start_idx = prev_bottom.confirm_idx - length;
544 let start_value = prev_bottom.value;
545 let (up_offset, dn_offset) =
546 compute_segment_offsets(open, close, start_idx, end_idx, start_value, end_value);
547 fill_segment(
548 middle,
549 upper,
550 lower,
551 start_idx,
552 end_idx,
553 start_value,
554 end_value,
555 up_offset,
556 dn_offset,
557 );
558 }
559 } else if let Some(prev_top) = last_top {
560 let start_idx = prev_top.confirm_idx - length;
561 let start_value = prev_top.value;
562 let (up_offset, dn_offset) =
563 compute_segment_offsets(open, close, start_idx, end_idx, start_value, end_value);
564 fill_segment(
565 middle,
566 upper,
567 lower,
568 start_idx,
569 end_idx,
570 start_value,
571 end_value,
572 up_offset,
573 dn_offset,
574 );
575 }
576}
577
578fn compute_row(
579 open: &[f64],
580 high: &[f64],
581 low: &[f64],
582 close: &[f64],
583 length: usize,
584 extend: bool,
585 middle: &mut [f64],
586 upper: &mut [f64],
587 lower: &mut [f64],
588) {
589 let mut idx = 0usize;
590 while idx < close.len() {
591 while idx < close.len() && !is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
592 idx += 1;
593 }
594 if idx >= close.len() {
595 break;
596 }
597 let seg_start = idx;
598 idx += 1;
599 while idx < close.len() && is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
600 idx += 1;
601 }
602 let seg_end = idx;
603 if seg_end - seg_start >= length + 1 {
604 compute_run(
605 &open[seg_start..seg_end],
606 &high[seg_start..seg_end],
607 &low[seg_start..seg_end],
608 &close[seg_start..seg_end],
609 length,
610 extend,
611 &mut middle[seg_start..seg_end],
612 &mut upper[seg_start..seg_end],
613 &mut lower[seg_start..seg_end],
614 );
615 }
616 }
617}
618
619#[inline]
620pub fn zig_zag_channels(
621 input: &ZigZagChannelsInput,
622) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
623 zig_zag_channels_with_kernel(input, Kernel::Auto)
624}
625
626pub fn zig_zag_channels_with_kernel(
627 input: &ZigZagChannelsInput,
628 kernel: Kernel,
629) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
630 let (open, high, low, close) = input_slices(input)?;
631 let length = input.get_length();
632 let extend = input.get_extend();
633 validate_common(open, high, low, close, length)?;
634
635 let _chosen = match kernel {
636 Kernel::Auto => detect_best_kernel(),
637 other => other,
638 };
639
640 let mut middle = alloc_with_nan_prefix(close.len(), 0);
641 let mut upper = alloc_with_nan_prefix(close.len(), 0);
642 let mut lower = alloc_with_nan_prefix(close.len(), 0);
643 middle.fill(f64::NAN);
644 upper.fill(f64::NAN);
645 lower.fill(f64::NAN);
646
647 compute_row(
648 open,
649 high,
650 low,
651 close,
652 length,
653 extend,
654 &mut middle,
655 &mut upper,
656 &mut lower,
657 );
658
659 Ok(ZigZagChannelsOutput {
660 middle,
661 upper,
662 lower,
663 })
664}
665
666pub fn zig_zag_channels_into_slice(
667 out_middle: &mut [f64],
668 out_upper: &mut [f64],
669 out_lower: &mut [f64],
670 input: &ZigZagChannelsInput,
671 kernel: Kernel,
672) -> Result<(), ZigZagChannelsError> {
673 let (open, high, low, close) = input_slices(input)?;
674 let length = input.get_length();
675 let extend = input.get_extend();
676 validate_common(open, high, low, close, length)?;
677
678 if out_middle.len() != close.len() {
679 return Err(ZigZagChannelsError::OutputLengthMismatch {
680 expected: close.len(),
681 got: out_middle.len(),
682 });
683 }
684 if out_upper.len() != close.len() {
685 return Err(ZigZagChannelsError::OutputLengthMismatch {
686 expected: close.len(),
687 got: out_upper.len(),
688 });
689 }
690 if out_lower.len() != close.len() {
691 return Err(ZigZagChannelsError::OutputLengthMismatch {
692 expected: close.len(),
693 got: out_lower.len(),
694 });
695 }
696
697 let _chosen = match kernel {
698 Kernel::Auto => detect_best_kernel(),
699 other => other,
700 };
701
702 out_middle.fill(f64::NAN);
703 out_upper.fill(f64::NAN);
704 out_lower.fill(f64::NAN);
705 compute_row(
706 open, high, low, close, length, extend, out_middle, out_upper, out_lower,
707 );
708 Ok(())
709}
710
711#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
712pub fn zig_zag_channels_into(
713 input: &ZigZagChannelsInput,
714 out_middle: &mut [f64],
715 out_upper: &mut [f64],
716 out_lower: &mut [f64],
717) -> Result<(), ZigZagChannelsError> {
718 zig_zag_channels_into_slice(out_middle, out_upper, out_lower, input, Kernel::Auto)
719}
720
721#[derive(Debug, Clone, Copy)]
722pub struct ZigZagChannelsBatchRange {
723 pub length: (usize, usize, usize),
724 pub extend: bool,
725}
726
727impl Default for ZigZagChannelsBatchRange {
728 fn default() -> Self {
729 Self {
730 length: (100, 100, 0),
731 extend: true,
732 }
733 }
734}
735
736#[derive(Debug, Clone)]
737pub struct ZigZagChannelsBatchOutput {
738 pub middle: Vec<f64>,
739 pub upper: Vec<f64>,
740 pub lower: Vec<f64>,
741 pub combos: Vec<ZigZagChannelsParams>,
742 pub rows: usize,
743 pub cols: usize,
744}
745
746#[derive(Debug, Clone, Copy)]
747pub struct ZigZagChannelsBatchBuilder {
748 range: ZigZagChannelsBatchRange,
749 kernel: Kernel,
750}
751
752impl Default for ZigZagChannelsBatchBuilder {
753 fn default() -> Self {
754 Self {
755 range: ZigZagChannelsBatchRange::default(),
756 kernel: Kernel::Auto,
757 }
758 }
759}
760
761impl ZigZagChannelsBatchBuilder {
762 #[inline(always)]
763 pub fn new() -> Self {
764 Self::default()
765 }
766
767 #[inline(always)]
768 pub fn kernel(mut self, value: Kernel) -> Self {
769 self.kernel = value;
770 self
771 }
772
773 #[inline(always)]
774 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
775 self.range.length = (start, end, step);
776 self
777 }
778
779 #[inline(always)]
780 pub fn length_static(mut self, value: usize) -> Self {
781 self.range.length = (value, value, 0);
782 self
783 }
784
785 #[inline(always)]
786 pub fn extend(mut self, value: bool) -> Self {
787 self.range.extend = value;
788 self
789 }
790
791 #[inline(always)]
792 pub fn apply_slices(
793 self,
794 open: &[f64],
795 high: &[f64],
796 low: &[f64],
797 close: &[f64],
798 ) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
799 zig_zag_channels_batch_with_kernel(open, high, low, close, &self.range, self.kernel)
800 }
801
802 #[inline(always)]
803 pub fn apply_candles(
804 self,
805 candles: &Candles,
806 ) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
807 zig_zag_channels_batch_with_kernel(
808 candles.open.as_slice(),
809 candles.high.as_slice(),
810 candles.low.as_slice(),
811 candles.close.as_slice(),
812 &self.range,
813 self.kernel,
814 )
815 }
816}
817
818#[inline(always)]
819fn expand_grid_checked(
820 range: &ZigZagChannelsBatchRange,
821) -> Result<Vec<ZigZagChannelsParams>, ZigZagChannelsError> {
822 let (start, end, step) = range.length;
823 if start == 0 || end == 0 {
824 return Err(ZigZagChannelsError::InvalidRange { start, end, step });
825 }
826 if step == 0 {
827 return Ok(vec![ZigZagChannelsParams {
828 length: Some(start),
829 extend: Some(range.extend),
830 }]);
831 }
832 if start > end {
833 return Err(ZigZagChannelsError::InvalidRange { start, end, step });
834 }
835
836 let mut out = Vec::new();
837 let mut current = start;
838 loop {
839 out.push(ZigZagChannelsParams {
840 length: Some(current),
841 extend: Some(range.extend),
842 });
843 if current >= end {
844 break;
845 }
846 let next = current.saturating_add(step);
847 if next <= current {
848 return Err(ZigZagChannelsError::InvalidRange { start, end, step });
849 }
850 current = next.min(end);
851 if current == out.last().and_then(|item| item.length).unwrap_or(0) {
852 break;
853 }
854 }
855 Ok(out)
856}
857
858#[inline(always)]
859pub fn expand_grid_zig_zag_channels(range: &ZigZagChannelsBatchRange) -> Vec<ZigZagChannelsParams> {
860 expand_grid_checked(range).unwrap_or_default()
861}
862
863pub fn zig_zag_channels_batch_with_kernel(
864 open: &[f64],
865 high: &[f64],
866 low: &[f64],
867 close: &[f64],
868 sweep: &ZigZagChannelsBatchRange,
869 kernel: Kernel,
870) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
871 zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, true)
872}
873
874pub fn zig_zag_channels_batch_slice(
875 open: &[f64],
876 high: &[f64],
877 low: &[f64],
878 close: &[f64],
879 sweep: &ZigZagChannelsBatchRange,
880 kernel: Kernel,
881) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
882 zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, false)
883}
884
885pub fn zig_zag_channels_batch_par_slice(
886 open: &[f64],
887 high: &[f64],
888 low: &[f64],
889 close: &[f64],
890 sweep: &ZigZagChannelsBatchRange,
891 kernel: Kernel,
892) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
893 zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, true)
894}
895
896fn zig_zag_channels_batch_inner(
897 open: &[f64],
898 high: &[f64],
899 low: &[f64],
900 close: &[f64],
901 sweep: &ZigZagChannelsBatchRange,
902 kernel: Kernel,
903 parallel: bool,
904) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
905 match kernel {
906 Kernel::Auto
907 | Kernel::Scalar
908 | Kernel::ScalarBatch
909 | Kernel::Avx2
910 | Kernel::Avx2Batch
911 | Kernel::Avx512
912 | Kernel::Avx512Batch => {}
913 other => return Err(ZigZagChannelsError::InvalidKernelForBatch(other)),
914 }
915
916 let combos = expand_grid_checked(sweep)?;
917 let max_length = combos
918 .iter()
919 .map(|params| params.length.unwrap_or(100))
920 .max()
921 .unwrap_or(0);
922 validate_common(open, high, low, close, max_length)?;
923
924 let rows = combos.len();
925 let cols = close.len();
926 let total = rows
927 .checked_mul(cols)
928 .ok_or_else(|| ZigZagChannelsError::InvalidInput {
929 msg: "zig_zag_channels: rows*cols overflow in batch".to_string(),
930 })?;
931
932 let mut middle = vec![f64::NAN; total];
933 let mut upper = vec![f64::NAN; total];
934 let mut lower = vec![f64::NAN; total];
935 zig_zag_channels_batch_inner_into(
936 open,
937 high,
938 low,
939 close,
940 sweep,
941 kernel,
942 parallel,
943 &mut middle,
944 &mut upper,
945 &mut lower,
946 )?;
947
948 Ok(ZigZagChannelsBatchOutput {
949 middle,
950 upper,
951 lower,
952 combos,
953 rows,
954 cols,
955 })
956}
957
958fn zig_zag_channels_batch_inner_into(
959 open: &[f64],
960 high: &[f64],
961 low: &[f64],
962 close: &[f64],
963 sweep: &ZigZagChannelsBatchRange,
964 kernel: Kernel,
965 parallel: bool,
966 out_middle: &mut [f64],
967 out_upper: &mut [f64],
968 out_lower: &mut [f64],
969) -> Result<Vec<ZigZagChannelsParams>, ZigZagChannelsError> {
970 match kernel {
971 Kernel::Auto
972 | Kernel::Scalar
973 | Kernel::ScalarBatch
974 | Kernel::Avx2
975 | Kernel::Avx2Batch
976 | Kernel::Avx512
977 | Kernel::Avx512Batch => {}
978 other => return Err(ZigZagChannelsError::InvalidKernelForBatch(other)),
979 }
980
981 let combos = expand_grid_checked(sweep)?;
982 let max_length = combos
983 .iter()
984 .map(|params| params.length.unwrap_or(100))
985 .max()
986 .unwrap_or(0);
987 validate_common(open, high, low, close, max_length)?;
988
989 let cols = close.len();
990 let total =
991 combos
992 .len()
993 .checked_mul(cols)
994 .ok_or_else(|| ZigZagChannelsError::InvalidInput {
995 msg: "zig_zag_channels: rows*cols overflow in batch_into".to_string(),
996 })?;
997 if out_middle.len() != total {
998 return Err(ZigZagChannelsError::MismatchedOutputLen {
999 dst_len: out_middle.len(),
1000 expected_len: total,
1001 });
1002 }
1003 if out_upper.len() != total {
1004 return Err(ZigZagChannelsError::MismatchedOutputLen {
1005 dst_len: out_upper.len(),
1006 expected_len: total,
1007 });
1008 }
1009 if out_lower.len() != total {
1010 return Err(ZigZagChannelsError::MismatchedOutputLen {
1011 dst_len: out_lower.len(),
1012 expected_len: total,
1013 });
1014 }
1015
1016 let _chosen = match kernel {
1017 Kernel::Auto => detect_best_batch_kernel(),
1018 other => other,
1019 };
1020
1021 let worker =
1022 |row: usize, middle_row: &mut [f64], upper_row: &mut [f64], lower_row: &mut [f64]| {
1023 middle_row.fill(f64::NAN);
1024 upper_row.fill(f64::NAN);
1025 lower_row.fill(f64::NAN);
1026 let params = &combos[row];
1027 compute_row(
1028 open,
1029 high,
1030 low,
1031 close,
1032 params.length.unwrap_or(100),
1033 params.extend.unwrap_or(true),
1034 middle_row,
1035 upper_row,
1036 lower_row,
1037 );
1038 };
1039
1040 #[cfg(not(target_arch = "wasm32"))]
1041 if parallel && combos.len() > 1 {
1042 out_middle
1043 .par_chunks_mut(cols)
1044 .zip(out_upper.par_chunks_mut(cols))
1045 .zip(out_lower.par_chunks_mut(cols))
1046 .enumerate()
1047 .for_each(|(row, ((middle_row, upper_row), lower_row))| {
1048 worker(row, middle_row, upper_row, lower_row);
1049 });
1050 } else {
1051 for (row, ((middle_row, upper_row), lower_row)) in out_middle
1052 .chunks_mut(cols)
1053 .zip(out_upper.chunks_mut(cols))
1054 .zip(out_lower.chunks_mut(cols))
1055 .enumerate()
1056 {
1057 worker(row, middle_row, upper_row, lower_row);
1058 }
1059 }
1060
1061 #[cfg(target_arch = "wasm32")]
1062 {
1063 let _ = parallel;
1064 for (row, ((middle_row, upper_row), lower_row)) in out_middle
1065 .chunks_mut(cols)
1066 .zip(out_upper.chunks_mut(cols))
1067 .zip(out_lower.chunks_mut(cols))
1068 .enumerate()
1069 {
1070 worker(row, middle_row, upper_row, lower_row);
1071 }
1072 }
1073
1074 Ok(combos)
1075}
1076
1077#[cfg(feature = "python")]
1078#[pyfunction(name = "zig_zag_channels", signature = (open, high, low, close, length=100, extend=true, kernel=None))]
1079pub fn zig_zag_channels_py<'py>(
1080 py: Python<'py>,
1081 open: PyReadonlyArray1<'py, f64>,
1082 high: PyReadonlyArray1<'py, f64>,
1083 low: PyReadonlyArray1<'py, f64>,
1084 close: PyReadonlyArray1<'py, f64>,
1085 length: usize,
1086 extend: bool,
1087 kernel: Option<&str>,
1088) -> PyResult<(
1089 Bound<'py, PyArray1<f64>>,
1090 Bound<'py, PyArray1<f64>>,
1091 Bound<'py, PyArray1<f64>>,
1092)> {
1093 let open = open.as_slice()?;
1094 let high = high.as_slice()?;
1095 let low = low.as_slice()?;
1096 let close = close.as_slice()?;
1097 let kern = validate_kernel(kernel, false)?;
1098 let input = ZigZagChannelsInput::from_slices(
1099 open,
1100 high,
1101 low,
1102 close,
1103 ZigZagChannelsParams {
1104 length: Some(length),
1105 extend: Some(extend),
1106 },
1107 );
1108 let out = py
1109 .allow_threads(|| zig_zag_channels_with_kernel(&input, kern))
1110 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1111 Ok((
1112 out.middle.into_pyarray(py),
1113 out.upper.into_pyarray(py),
1114 out.lower.into_pyarray(py),
1115 ))
1116}
1117
1118#[cfg(feature = "python")]
1119#[pyfunction(name = "zig_zag_channels_batch", signature = (open, high, low, close, length_range=(100, 100, 0), extend=true, kernel=None))]
1120pub fn zig_zag_channels_batch_py<'py>(
1121 py: Python<'py>,
1122 open: PyReadonlyArray1<'py, f64>,
1123 high: PyReadonlyArray1<'py, f64>,
1124 low: PyReadonlyArray1<'py, f64>,
1125 close: PyReadonlyArray1<'py, f64>,
1126 length_range: (usize, usize, usize),
1127 extend: bool,
1128 kernel: Option<&str>,
1129) -> PyResult<Bound<'py, PyDict>> {
1130 let open = open.as_slice()?;
1131 let high = high.as_slice()?;
1132 let low = low.as_slice()?;
1133 let close = close.as_slice()?;
1134 let kern = validate_kernel(kernel, true)?;
1135
1136 let output = py
1137 .allow_threads(|| {
1138 zig_zag_channels_batch_with_kernel(
1139 open,
1140 high,
1141 low,
1142 close,
1143 &ZigZagChannelsBatchRange {
1144 length: length_range,
1145 extend,
1146 },
1147 kern,
1148 )
1149 })
1150 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1151
1152 let dict = PyDict::new(py);
1153 dict.set_item(
1154 "middle",
1155 output
1156 .middle
1157 .into_pyarray(py)
1158 .reshape((output.rows, output.cols))?,
1159 )?;
1160 dict.set_item(
1161 "upper",
1162 output
1163 .upper
1164 .into_pyarray(py)
1165 .reshape((output.rows, output.cols))?,
1166 )?;
1167 dict.set_item(
1168 "lower",
1169 output
1170 .lower
1171 .into_pyarray(py)
1172 .reshape((output.rows, output.cols))?,
1173 )?;
1174 dict.set_item(
1175 "lengths",
1176 output
1177 .combos
1178 .iter()
1179 .map(|params| params.length.unwrap_or(100) as u64)
1180 .collect::<Vec<_>>()
1181 .into_pyarray(py),
1182 )?;
1183 dict.set_item(
1184 "extends",
1185 output
1186 .combos
1187 .iter()
1188 .map(|params| params.extend.unwrap_or(true))
1189 .collect::<Vec<_>>(),
1190 )?;
1191 dict.set_item("rows", output.rows)?;
1192 dict.set_item("cols", output.cols)?;
1193 Ok(dict)
1194}
1195
1196#[cfg(feature = "python")]
1197pub fn register_zig_zag_channels_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1198 m.add_function(wrap_pyfunction!(zig_zag_channels_py, m)?)?;
1199 m.add_function(wrap_pyfunction!(zig_zag_channels_batch_py, m)?)?;
1200 Ok(())
1201}
1202
1203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1204#[derive(Debug, Clone, Serialize, Deserialize)]
1205pub struct ZigZagChannelsBatchConfig {
1206 pub length_range: Vec<usize>,
1207 pub extend: Option<bool>,
1208}
1209
1210#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1211#[wasm_bindgen(js_name = zig_zag_channels_js)]
1212pub fn zig_zag_channels_js(
1213 open: &[f64],
1214 high: &[f64],
1215 low: &[f64],
1216 close: &[f64],
1217 length: usize,
1218 extend: bool,
1219) -> Result<JsValue, JsValue> {
1220 let input = ZigZagChannelsInput::from_slices(
1221 open,
1222 high,
1223 low,
1224 close,
1225 ZigZagChannelsParams {
1226 length: Some(length),
1227 extend: Some(extend),
1228 },
1229 );
1230 let out = zig_zag_channels_with_kernel(&input, Kernel::Auto)
1231 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1232 let obj = js_sys::Object::new();
1233 js_sys::Reflect::set(
1234 &obj,
1235 &JsValue::from_str("middle"),
1236 &serde_wasm_bindgen::to_value(&out.middle).unwrap(),
1237 )?;
1238 js_sys::Reflect::set(
1239 &obj,
1240 &JsValue::from_str("upper"),
1241 &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1242 )?;
1243 js_sys::Reflect::set(
1244 &obj,
1245 &JsValue::from_str("lower"),
1246 &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1247 )?;
1248 Ok(obj.into())
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen(js_name = zig_zag_channels_batch_js)]
1253pub fn zig_zag_channels_batch_js(
1254 open: &[f64],
1255 high: &[f64],
1256 low: &[f64],
1257 close: &[f64],
1258 config: JsValue,
1259) -> Result<JsValue, JsValue> {
1260 let config: ZigZagChannelsBatchConfig = serde_wasm_bindgen::from_value(config)
1261 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1262 if config.length_range.len() != 3 {
1263 return Err(JsValue::from_str(
1264 "Invalid config: length_range must have exactly 3 elements [start, end, step]",
1265 ));
1266 }
1267
1268 let out = zig_zag_channels_batch_with_kernel(
1269 open,
1270 high,
1271 low,
1272 close,
1273 &ZigZagChannelsBatchRange {
1274 length: (
1275 config.length_range[0],
1276 config.length_range[1],
1277 config.length_range[2],
1278 ),
1279 extend: config.extend.unwrap_or(true),
1280 },
1281 Kernel::Auto,
1282 )
1283 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1284
1285 let obj = js_sys::Object::new();
1286 js_sys::Reflect::set(
1287 &obj,
1288 &JsValue::from_str("middle"),
1289 &serde_wasm_bindgen::to_value(&out.middle).unwrap(),
1290 )?;
1291 js_sys::Reflect::set(
1292 &obj,
1293 &JsValue::from_str("upper"),
1294 &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1295 )?;
1296 js_sys::Reflect::set(
1297 &obj,
1298 &JsValue::from_str("lower"),
1299 &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1300 )?;
1301 js_sys::Reflect::set(
1302 &obj,
1303 &JsValue::from_str("rows"),
1304 &JsValue::from_f64(out.rows as f64),
1305 )?;
1306 js_sys::Reflect::set(
1307 &obj,
1308 &JsValue::from_str("cols"),
1309 &JsValue::from_f64(out.cols as f64),
1310 )?;
1311 js_sys::Reflect::set(
1312 &obj,
1313 &JsValue::from_str("combos"),
1314 &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1315 )?;
1316 Ok(obj.into())
1317}
1318
1319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1320#[wasm_bindgen]
1321pub fn zig_zag_channels_alloc(len: usize) -> *mut f64 {
1322 let mut vec = Vec::<f64>::with_capacity(3 * len);
1323 let ptr = vec.as_mut_ptr();
1324 std::mem::forget(vec);
1325 ptr
1326}
1327
1328#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1329#[wasm_bindgen]
1330pub fn zig_zag_channels_free(ptr: *mut f64, len: usize) {
1331 if !ptr.is_null() {
1332 unsafe {
1333 let _ = Vec::from_raw_parts(ptr, 3 * len, 3 * len);
1334 }
1335 }
1336}
1337
1338#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1339#[wasm_bindgen]
1340pub fn zig_zag_channels_into(
1341 open_ptr: *const f64,
1342 high_ptr: *const f64,
1343 low_ptr: *const f64,
1344 close_ptr: *const f64,
1345 out_ptr: *mut f64,
1346 len: usize,
1347 length: usize,
1348 extend: bool,
1349) -> Result<(), JsValue> {
1350 if open_ptr.is_null()
1351 || high_ptr.is_null()
1352 || low_ptr.is_null()
1353 || close_ptr.is_null()
1354 || out_ptr.is_null()
1355 {
1356 return Err(JsValue::from_str(
1357 "null pointer passed to zig_zag_channels_into",
1358 ));
1359 }
1360
1361 unsafe {
1362 let open = std::slice::from_raw_parts(open_ptr, len);
1363 let high = std::slice::from_raw_parts(high_ptr, len);
1364 let low = std::slice::from_raw_parts(low_ptr, len);
1365 let close = std::slice::from_raw_parts(close_ptr, len);
1366 let out = std::slice::from_raw_parts_mut(out_ptr, 3 * len);
1367 let (middle, tail) = out.split_at_mut(len);
1368 let (upper, lower) = tail.split_at_mut(len);
1369 let input = ZigZagChannelsInput::from_slices(
1370 open,
1371 high,
1372 low,
1373 close,
1374 ZigZagChannelsParams {
1375 length: Some(length),
1376 extend: Some(extend),
1377 },
1378 );
1379 zig_zag_channels_into_slice(middle, upper, lower, &input, Kernel::Auto)
1380 .map_err(|e| JsValue::from_str(&e.to_string()))
1381 }
1382}
1383
1384#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1385#[wasm_bindgen]
1386pub fn zig_zag_channels_batch_into(
1387 open_ptr: *const f64,
1388 high_ptr: *const f64,
1389 low_ptr: *const f64,
1390 close_ptr: *const f64,
1391 out_ptr: *mut f64,
1392 len: usize,
1393 length_start: usize,
1394 length_end: usize,
1395 length_step: usize,
1396 extend: bool,
1397) -> Result<usize, JsValue> {
1398 if open_ptr.is_null()
1399 || high_ptr.is_null()
1400 || low_ptr.is_null()
1401 || close_ptr.is_null()
1402 || out_ptr.is_null()
1403 {
1404 return Err(JsValue::from_str(
1405 "null pointer passed to zig_zag_channels_batch_into",
1406 ));
1407 }
1408
1409 let sweep = ZigZagChannelsBatchRange {
1410 length: (length_start, length_end, length_step),
1411 extend,
1412 };
1413 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1414 let rows = combos.len();
1415 let split = rows
1416 .checked_mul(len)
1417 .ok_or_else(|| JsValue::from_str("rows*cols overflow in zig_zag_channels_batch_into"))?;
1418 let total = split
1419 .checked_mul(3)
1420 .ok_or_else(|| JsValue::from_str("3*rows*cols overflow in zig_zag_channels_batch_into"))?;
1421
1422 unsafe {
1423 let open = std::slice::from_raw_parts(open_ptr, len);
1424 let high = std::slice::from_raw_parts(high_ptr, len);
1425 let low = std::slice::from_raw_parts(low_ptr, len);
1426 let close = std::slice::from_raw_parts(close_ptr, len);
1427 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1428 let (middle, tail) = out.split_at_mut(split);
1429 let (upper, lower) = tail.split_at_mut(split);
1430 zig_zag_channels_batch_inner_into(
1431 open,
1432 high,
1433 low,
1434 close,
1435 &sweep,
1436 Kernel::Auto,
1437 false,
1438 middle,
1439 upper,
1440 lower,
1441 )
1442 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1443 }
1444
1445 Ok(rows)
1446}
1447
1448#[cfg(test)]
1449mod tests {
1450 use super::*;
1451 use crate::indicators::dispatch::{
1452 compute_cpu, IndicatorComputeRequest, IndicatorDataRef, ParamKV, ParamValue,
1453 };
1454
1455 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1456 let close: Vec<f64> = (0..len)
1457 .map(|i| {
1458 let x = i as f64;
1459 100.0 + (x * 0.21).sin() * 7.0 + (x * 0.037).cos() * 2.5 + (x * 0.02)
1460 })
1461 .collect();
1462 let open: Vec<f64> = close
1463 .iter()
1464 .enumerate()
1465 .map(|(i, &c)| c + ((i as f64) * 0.41).sin() * 0.9)
1466 .collect();
1467 let high: Vec<f64> = open
1468 .iter()
1469 .zip(close.iter())
1470 .map(|(&o, &c)| o.max(c) + 0.75)
1471 .collect();
1472 let low: Vec<f64> = open
1473 .iter()
1474 .zip(close.iter())
1475 .map(|(&o, &c)| o.min(c) - 0.75)
1476 .collect();
1477 (open, high, low, close)
1478 }
1479
1480 fn naive_offsets(
1481 open: &[f64],
1482 close: &[f64],
1483 start_idx: usize,
1484 end_idx: usize,
1485 start_value: f64,
1486 end_value: f64,
1487 ) -> (f64, f64) {
1488 if end_idx <= start_idx {
1489 return (0.0, 0.0);
1490 }
1491 if end_idx == start_idx + 1 {
1492 let top = open[end_idx].max(close[end_idx]);
1493 let bottom = open[end_idx].min(close[end_idx]);
1494 return ((top - end_value).max(0.0), (end_value - bottom).max(0.0));
1495 }
1496 let mut up = 0.0f64;
1497 let mut dn = 0.0f64;
1498 let denom = (end_idx - start_idx - 1) as f64;
1499 for idx in (start_idx + 1)..=end_idx {
1500 let j = (idx - start_idx - 1) as f64;
1501 let point = start_value + (j / denom) * (end_value - start_value);
1502 up = up.max(open[idx].max(close[idx]) - point);
1503 dn = dn.max(point - open[idx].min(close[idx]));
1504 }
1505 (up.max(0.0), dn.max(0.0))
1506 }
1507
1508 fn naive_fill(
1509 middle: &mut [f64],
1510 upper: &mut [f64],
1511 lower: &mut [f64],
1512 start_idx: usize,
1513 end_idx: usize,
1514 start_value: f64,
1515 end_value: f64,
1516 up: f64,
1517 dn: f64,
1518 ) {
1519 if end_idx < start_idx {
1520 return;
1521 }
1522 if start_idx == end_idx {
1523 middle[start_idx] = start_value;
1524 upper[start_idx] = start_value + up;
1525 lower[start_idx] = start_value - dn;
1526 return;
1527 }
1528 let denom = (end_idx - start_idx) as f64;
1529 for idx in start_idx..=end_idx {
1530 let t = (idx - start_idx) as f64 / denom;
1531 let value = start_value + t * (end_value - start_value);
1532 middle[idx] = value;
1533 upper[idx] = value + up;
1534 lower[idx] = value - dn;
1535 }
1536 }
1537
1538 fn naive_run(
1539 open: &[f64],
1540 high: &[f64],
1541 low: &[f64],
1542 close: &[f64],
1543 length: usize,
1544 extend: bool,
1545 middle: &mut [f64],
1546 upper: &mut [f64],
1547 lower: &mut [f64],
1548 ) {
1549 let mut os = 0usize;
1550 let mut last_top: Option<(usize, f64)> = None;
1551 let mut last_bottom: Option<(usize, f64)> = None;
1552
1553 for current in length..close.len() {
1554 let candidate = current - length;
1555 let mut hi = f64::NEG_INFINITY;
1556 let mut lo = f64::INFINITY;
1557 for idx in (candidate + 1)..=current {
1558 hi = hi.max(close[idx]);
1559 lo = lo.min(close[idx]);
1560 }
1561
1562 let prev_os = os;
1563 if close[candidate] > hi {
1564 os = 0;
1565 } else if close[candidate] < lo {
1566 os = 1;
1567 }
1568
1569 if os == 1 && prev_os != 1 {
1570 let end_idx = candidate;
1571 let end_value = low[end_idx];
1572 if let Some((confirm_idx, start_value)) = last_top {
1573 let start_idx = confirm_idx - length;
1574 let (up, dn) =
1575 naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1576 naive_fill(
1577 middle,
1578 upper,
1579 lower,
1580 start_idx,
1581 end_idx,
1582 start_value,
1583 end_value,
1584 up,
1585 dn,
1586 );
1587 }
1588 last_bottom = Some((current, end_value));
1589 }
1590
1591 if os == 0 && prev_os != 0 {
1592 let end_idx = candidate;
1593 let end_value = high[end_idx];
1594 if let Some((confirm_idx, start_value)) = last_bottom {
1595 let start_idx = confirm_idx - length;
1596 let (up, dn) =
1597 naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1598 naive_fill(
1599 middle,
1600 upper,
1601 lower,
1602 start_idx,
1603 end_idx,
1604 start_value,
1605 end_value,
1606 up,
1607 dn,
1608 );
1609 }
1610 last_top = Some((current, end_value));
1611 }
1612 }
1613
1614 if !extend || close.is_empty() {
1615 return;
1616 }
1617 let end_idx = close.len() - 1;
1618 let end_value = close[end_idx];
1619 if os == 1 {
1620 if let Some((confirm_idx, start_value)) = last_bottom {
1621 let start_idx = confirm_idx - length;
1622 let (up, dn) =
1623 naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1624 naive_fill(
1625 middle,
1626 upper,
1627 lower,
1628 start_idx,
1629 end_idx,
1630 start_value,
1631 end_value,
1632 up,
1633 dn,
1634 );
1635 }
1636 } else if let Some((confirm_idx, start_value)) = last_top {
1637 let start_idx = confirm_idx - length;
1638 let (up, dn) = naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1639 naive_fill(
1640 middle,
1641 upper,
1642 lower,
1643 start_idx,
1644 end_idx,
1645 start_value,
1646 end_value,
1647 up,
1648 dn,
1649 );
1650 }
1651 }
1652
1653 fn naive_zig_zag_channels(
1654 open: &[f64],
1655 high: &[f64],
1656 low: &[f64],
1657 close: &[f64],
1658 length: usize,
1659 extend: bool,
1660 ) -> ZigZagChannelsOutput {
1661 let mut middle = vec![f64::NAN; close.len()];
1662 let mut upper = vec![f64::NAN; close.len()];
1663 let mut lower = vec![f64::NAN; close.len()];
1664 let mut idx = 0usize;
1665 while idx < close.len() {
1666 while idx < close.len() && !is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
1667 idx += 1;
1668 }
1669 if idx >= close.len() {
1670 break;
1671 }
1672 let start = idx;
1673 idx += 1;
1674 while idx < close.len() && is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
1675 idx += 1;
1676 }
1677 let end = idx;
1678 if end - start >= length + 1 {
1679 naive_run(
1680 &open[start..end],
1681 &high[start..end],
1682 &low[start..end],
1683 &close[start..end],
1684 length,
1685 extend,
1686 &mut middle[start..end],
1687 &mut upper[start..end],
1688 &mut lower[start..end],
1689 );
1690 }
1691 }
1692 ZigZagChannelsOutput {
1693 middle,
1694 upper,
1695 lower,
1696 }
1697 }
1698
1699 fn assert_series_close(left: &[f64], right: &[f64], tol: f64) {
1700 assert_eq!(left.len(), right.len());
1701 for (a, b) in left.iter().zip(right.iter()) {
1702 if a.is_nan() || b.is_nan() {
1703 assert!(a.is_nan() && b.is_nan());
1704 } else {
1705 assert!((a - b).abs() <= tol, "left={a} right={b}");
1706 }
1707 }
1708 }
1709
1710 #[test]
1711 fn zig_zag_channels_matches_naive_reference() -> Result<(), Box<dyn Error>> {
1712 let (open, high, low, close) = sample_ohlc(256);
1713 let input = ZigZagChannelsInput::from_slices(
1714 &open,
1715 &high,
1716 &low,
1717 &close,
1718 ZigZagChannelsParams {
1719 length: Some(7),
1720 extend: Some(true),
1721 },
1722 );
1723 let out = zig_zag_channels_with_kernel(&input, Kernel::Scalar)?;
1724 let expected = naive_zig_zag_channels(&open, &high, &low, &close, 7, true);
1725 assert_series_close(&out.middle, &expected.middle, 1e-12);
1726 assert_series_close(&out.upper, &expected.upper, 1e-12);
1727 assert_series_close(&out.lower, &expected.lower, 1e-12);
1728 Ok(())
1729 }
1730
1731 #[test]
1732 fn zig_zag_channels_into_matches_api() -> Result<(), Box<dyn Error>> {
1733 let (open, high, low, close) = sample_ohlc(220);
1734 let input = ZigZagChannelsInput::from_slices(
1735 &open,
1736 &high,
1737 &low,
1738 &close,
1739 ZigZagChannelsParams {
1740 length: Some(6),
1741 extend: Some(true),
1742 },
1743 );
1744 let base = zig_zag_channels(&input)?;
1745 let mut middle = vec![0.0; close.len()];
1746 let mut upper = vec![0.0; close.len()];
1747 let mut lower = vec![0.0; close.len()];
1748 zig_zag_channels_into_slice(&mut middle, &mut upper, &mut lower, &input, Kernel::Auto)?;
1749 assert_series_close(&base.middle, &middle, 1e-12);
1750 assert_series_close(&base.upper, &upper, 1e-12);
1751 assert_series_close(&base.lower, &lower, 1e-12);
1752 Ok(())
1753 }
1754
1755 #[test]
1756 fn zig_zag_channels_extend_changes_tail_only() -> Result<(), Box<dyn Error>> {
1757 let (open, high, low, close) = sample_ohlc(180);
1758 let extend_true = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1759 &open,
1760 &high,
1761 &low,
1762 &close,
1763 ZigZagChannelsParams {
1764 length: Some(8),
1765 extend: Some(true),
1766 },
1767 ))?;
1768 let extend_false = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1769 &open,
1770 &high,
1771 &low,
1772 &close,
1773 ZigZagChannelsParams {
1774 length: Some(8),
1775 extend: Some(false),
1776 },
1777 ))?;
1778
1779 let finite_true = extend_true.middle.iter().filter(|v| v.is_finite()).count();
1780 let finite_false = extend_false.middle.iter().filter(|v| v.is_finite()).count();
1781 assert!(finite_true >= finite_false);
1782
1783 for i in 0..close.len() {
1784 if extend_false.middle[i].is_finite() {
1785 assert!(extend_true.middle[i].is_finite());
1786 }
1787 }
1788 Ok(())
1789 }
1790
1791 #[test]
1792 fn zig_zag_channels_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1793 let (open, high, low, close) = sample_ohlc(192);
1794 let single = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1795 &open,
1796 &high,
1797 &low,
1798 &close,
1799 ZigZagChannelsParams {
1800 length: Some(9),
1801 extend: Some(true),
1802 },
1803 ))?;
1804 let batch = zig_zag_channels_batch_with_kernel(
1805 &open,
1806 &high,
1807 &low,
1808 &close,
1809 &ZigZagChannelsBatchRange {
1810 length: (9, 9, 0),
1811 extend: true,
1812 },
1813 Kernel::Auto,
1814 )?;
1815 assert_eq!(batch.rows, 1);
1816 assert_eq!(batch.cols, close.len());
1817 assert_series_close(&batch.middle, &single.middle, 1e-12);
1818 assert_series_close(&batch.upper, &single.upper, 1e-12);
1819 assert_series_close(&batch.lower, &single.lower, 1e-12);
1820 Ok(())
1821 }
1822
1823 #[test]
1824 fn zig_zag_channels_rejects_invalid_params() {
1825 let (open, high, low, close) = sample_ohlc(32);
1826 let err = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1827 &open,
1828 &high,
1829 &low,
1830 &close,
1831 ZigZagChannelsParams {
1832 length: Some(0),
1833 extend: Some(true),
1834 },
1835 ))
1836 .unwrap_err();
1837 assert!(matches!(err, ZigZagChannelsError::InvalidLength { .. }));
1838 }
1839
1840 #[test]
1841 fn zig_zag_channels_dispatch_compute_returns_middle() -> Result<(), Box<dyn Error>> {
1842 let (open, high, low, close) = sample_ohlc(160);
1843 let out = compute_cpu(IndicatorComputeRequest {
1844 indicator_id: "zig_zag_channels",
1845 output_id: Some("middle"),
1846 data: IndicatorDataRef::Ohlc {
1847 open: &open,
1848 high: &high,
1849 low: &low,
1850 close: &close,
1851 },
1852 params: &[
1853 ParamKV {
1854 key: "length",
1855 value: ParamValue::Int(7),
1856 },
1857 ParamKV {
1858 key: "extend",
1859 value: ParamValue::Bool(true),
1860 },
1861 ],
1862 kernel: Kernel::Auto,
1863 })?;
1864 assert_eq!(out.output_id, "middle");
1865 assert_eq!(out.cols, close.len());
1866 Ok(())
1867 }
1868}