Skip to main content

vector_ta/indicators/
zig_zag_channels.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::collections::VecDeque;
25use std::error::Error;
26use thiserror::Error;
27
28#[derive(Debug, Clone)]
29pub enum ZigZagChannelsData<'a> {
30    Candles {
31        candles: &'a Candles,
32    },
33    Slices {
34        open: &'a [f64],
35        high: &'a [f64],
36        low: &'a [f64],
37        close: &'a [f64],
38    },
39}
40
41#[derive(Debug, Clone)]
42pub struct ZigZagChannelsOutput {
43    pub middle: Vec<f64>,
44    pub upper: Vec<f64>,
45    pub lower: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50    all(target_arch = "wasm32", feature = "wasm"),
51    derive(Serialize, Deserialize)
52)]
53pub struct ZigZagChannelsParams {
54    pub length: Option<usize>,
55    pub extend: Option<bool>,
56}
57
58impl Default for ZigZagChannelsParams {
59    fn default() -> Self {
60        Self {
61            length: Some(100),
62            extend: Some(true),
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct ZigZagChannelsInput<'a> {
69    pub data: ZigZagChannelsData<'a>,
70    pub params: ZigZagChannelsParams,
71}
72
73impl<'a> ZigZagChannelsInput<'a> {
74    #[inline]
75    pub fn from_candles(candles: &'a Candles, params: ZigZagChannelsParams) -> Self {
76        Self {
77            data: ZigZagChannelsData::Candles { candles },
78            params,
79        }
80    }
81
82    #[inline]
83    pub fn from_slices(
84        open: &'a [f64],
85        high: &'a [f64],
86        low: &'a [f64],
87        close: &'a [f64],
88        params: ZigZagChannelsParams,
89    ) -> Self {
90        Self {
91            data: ZigZagChannelsData::Slices {
92                open,
93                high,
94                low,
95                close,
96            },
97            params,
98        }
99    }
100
101    #[inline]
102    pub fn with_default_candles(candles: &'a Candles) -> Self {
103        Self::from_candles(candles, ZigZagChannelsParams::default())
104    }
105
106    #[inline]
107    pub fn get_length(&self) -> usize {
108        self.params.length.unwrap_or(100)
109    }
110
111    #[inline]
112    pub fn get_extend(&self) -> bool {
113        self.params.extend.unwrap_or(true)
114    }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct ZigZagChannelsBuilder {
119    length: Option<usize>,
120    extend: Option<bool>,
121    kernel: Kernel,
122}
123
124impl Default for ZigZagChannelsBuilder {
125    fn default() -> Self {
126        Self {
127            length: None,
128            extend: None,
129            kernel: Kernel::Auto,
130        }
131    }
132}
133
134impl ZigZagChannelsBuilder {
135    #[inline(always)]
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    #[inline(always)]
141    pub fn length(mut self, value: usize) -> Self {
142        self.length = Some(value);
143        self
144    }
145
146    #[inline(always)]
147    pub fn extend(mut self, value: bool) -> Self {
148        self.extend = Some(value);
149        self
150    }
151
152    #[inline(always)]
153    pub fn kernel(mut self, value: Kernel) -> Self {
154        self.kernel = value;
155        self
156    }
157
158    #[inline(always)]
159    pub fn apply(self, candles: &Candles) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
160        zig_zag_channels_with_kernel(
161            &ZigZagChannelsInput::from_candles(
162                candles,
163                ZigZagChannelsParams {
164                    length: self.length,
165                    extend: self.extend,
166                },
167            ),
168            self.kernel,
169        )
170    }
171
172    #[inline(always)]
173    pub fn apply_slices(
174        self,
175        open: &[f64],
176        high: &[f64],
177        low: &[f64],
178        close: &[f64],
179    ) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
180        zig_zag_channels_with_kernel(
181            &ZigZagChannelsInput::from_slices(
182                open,
183                high,
184                low,
185                close,
186                ZigZagChannelsParams {
187                    length: self.length,
188                    extend: self.extend,
189                },
190            ),
191            self.kernel,
192        )
193    }
194}
195
196#[derive(Debug, Error)]
197pub enum ZigZagChannelsError {
198    #[error("zig_zag_channels: Input data slice is empty.")]
199    EmptyInputData,
200    #[error(
201        "zig_zag_channels: Input length mismatch: open = {open_len}, high = {high_len}, low = {low_len}, close = {close_len}"
202    )]
203    InputLengthMismatch {
204        open_len: usize,
205        high_len: usize,
206        low_len: usize,
207        close_len: usize,
208    },
209    #[error("zig_zag_channels: All values are NaN.")]
210    AllValuesNaN,
211    #[error("zig_zag_channels: Invalid length: {length}")]
212    InvalidLength { length: usize },
213    #[error("zig_zag_channels: Not enough valid data: needed = {needed}, valid = {valid}")]
214    NotEnoughValidData { needed: usize, valid: usize },
215    #[error("zig_zag_channels: Output length mismatch: expected = {expected}, got = {got}")]
216    OutputLengthMismatch { expected: usize, got: usize },
217    #[error("zig_zag_channels: Invalid range: start={start}, end={end}, step={step}")]
218    InvalidRange {
219        start: usize,
220        end: usize,
221        step: usize,
222    },
223    #[error("zig_zag_channels: Invalid kernel for batch: {0:?}")]
224    InvalidKernelForBatch(Kernel),
225    #[error(
226        "zig_zag_channels: Output length mismatch: dst = {dst_len}, expected = {expected_len}"
227    )]
228    MismatchedOutputLen { dst_len: usize, expected_len: usize },
229    #[error("zig_zag_channels: Invalid input: {msg}")]
230    InvalidInput { msg: String },
231}
232
233#[derive(Debug, Clone, Copy)]
234struct PivotState {
235    confirm_idx: usize,
236    value: f64,
237}
238
239#[inline(always)]
240fn is_valid_ohlc(open: f64, high: f64, low: f64, close: f64) -> bool {
241    open.is_finite() && high.is_finite() && low.is_finite() && close.is_finite()
242}
243
244#[inline(always)]
245fn longest_valid_run(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
246    let mut best = 0usize;
247    let mut cur = 0usize;
248    for (((&o, &h), &l), &c) in open
249        .iter()
250        .zip(high.iter())
251        .zip(low.iter())
252        .zip(close.iter())
253    {
254        if is_valid_ohlc(o, h, l, c) {
255            cur += 1;
256            best = best.max(cur);
257        } else {
258            cur = 0;
259        }
260    }
261    best
262}
263
264#[inline(always)]
265fn input_slices<'a>(
266    input: &'a ZigZagChannelsInput<'a>,
267) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), ZigZagChannelsError> {
268    match &input.data {
269        ZigZagChannelsData::Candles { candles } => Ok((
270            candles.open.as_slice(),
271            candles.high.as_slice(),
272            candles.low.as_slice(),
273            candles.close.as_slice(),
274        )),
275        ZigZagChannelsData::Slices {
276            open,
277            high,
278            low,
279            close,
280        } => Ok((open, high, low, close)),
281    }
282}
283
284#[inline(always)]
285fn validate_common(
286    open: &[f64],
287    high: &[f64],
288    low: &[f64],
289    close: &[f64],
290    length: usize,
291) -> Result<(), ZigZagChannelsError> {
292    if open.is_empty() || high.is_empty() || low.is_empty() || close.is_empty() {
293        return Err(ZigZagChannelsError::EmptyInputData);
294    }
295    if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
296        return Err(ZigZagChannelsError::InputLengthMismatch {
297            open_len: open.len(),
298            high_len: high.len(),
299            low_len: low.len(),
300            close_len: close.len(),
301        });
302    }
303    if length == 0 {
304        return Err(ZigZagChannelsError::InvalidLength { length });
305    }
306
307    let longest = longest_valid_run(open, high, low, close);
308    if longest == 0 {
309        return Err(ZigZagChannelsError::AllValuesNaN);
310    }
311
312    let needed = length
313        .checked_add(1)
314        .ok_or_else(|| ZigZagChannelsError::InvalidInput {
315            msg: "zig_zag_channels: length overflow".to_string(),
316        })?;
317    if longest < needed {
318        return Err(ZigZagChannelsError::NotEnoughValidData {
319            needed,
320            valid: longest,
321        });
322    }
323    Ok(())
324}
325
326#[inline(always)]
327fn compute_segment_offsets(
328    open: &[f64],
329    close: &[f64],
330    start_idx: usize,
331    end_idx: usize,
332    start_value: f64,
333    end_value: f64,
334) -> (f64, f64) {
335    if end_idx <= start_idx {
336        return (0.0, 0.0);
337    }
338
339    if end_idx == start_idx + 1 {
340        let top = open[end_idx].max(close[end_idx]);
341        let bottom = open[end_idx].min(close[end_idx]);
342        return ((top - end_value).max(0.0), (end_value - bottom).max(0.0));
343    }
344
345    let mut max_diff_up = 0.0f64;
346    let mut max_diff_dn = 0.0f64;
347    let denom = (end_idx - start_idx - 1) as f64;
348    let span = end_value - start_value;
349
350    for idx in (start_idx + 1)..=end_idx {
351        let j = (idx - start_idx - 1) as f64;
352        let point = start_value + (j / denom) * span;
353        let top = open[idx].max(close[idx]);
354        let bottom = open[idx].min(close[idx]);
355        max_diff_up = max_diff_up.max(top - point);
356        max_diff_dn = max_diff_dn.max(point - bottom);
357    }
358
359    (max_diff_up.max(0.0), max_diff_dn.max(0.0))
360}
361
362#[inline(always)]
363fn fill_segment(
364    middle: &mut [f64],
365    upper: &mut [f64],
366    lower: &mut [f64],
367    start_idx: usize,
368    end_idx: usize,
369    start_value: f64,
370    end_value: f64,
371    up_offset: f64,
372    dn_offset: f64,
373) {
374    if end_idx < start_idx {
375        return;
376    }
377
378    if start_idx == end_idx {
379        middle[start_idx] = start_value;
380        upper[start_idx] = start_value + up_offset;
381        lower[start_idx] = start_value - dn_offset;
382        return;
383    }
384
385    let denom = (end_idx - start_idx) as f64;
386    let span = end_value - start_value;
387    for idx in start_idx..=end_idx {
388        let t = (idx - start_idx) as f64 / denom;
389        let value = start_value + t * span;
390        middle[idx] = value;
391        upper[idx] = value + up_offset;
392        lower[idx] = value - dn_offset;
393    }
394}
395
396fn compute_run(
397    open: &[f64],
398    high: &[f64],
399    low: &[f64],
400    close: &[f64],
401    length: usize,
402    extend: bool,
403    middle: &mut [f64],
404    upper: &mut [f64],
405    lower: &mut [f64],
406) {
407    let n = close.len();
408    if n <= length {
409        return;
410    }
411
412    let mut max_deque: VecDeque<usize> = VecDeque::with_capacity(length);
413    let mut min_deque: VecDeque<usize> = VecDeque::with_capacity(length);
414    let mut os = 0usize;
415    let mut last_top: Option<PivotState> = None;
416    let mut last_bottom: Option<PivotState> = None;
417
418    for idx in 0..n {
419        let current_close = close[idx];
420        while let Some(&back) = max_deque.back() {
421            if close[back] <= current_close {
422                max_deque.pop_back();
423            } else {
424                break;
425            }
426        }
427        max_deque.push_back(idx);
428
429        while let Some(&back) = min_deque.back() {
430            if close[back] >= current_close {
431                min_deque.pop_back();
432            } else {
433                break;
434            }
435        }
436        min_deque.push_back(idx);
437
438        if idx < length {
439            continue;
440        }
441
442        let window_start = idx + 1 - length;
443        while let Some(&front) = max_deque.front() {
444            if front < window_start {
445                max_deque.pop_front();
446            } else {
447                break;
448            }
449        }
450        while let Some(&front) = min_deque.front() {
451            if front < window_start {
452                min_deque.pop_front();
453            } else {
454                break;
455            }
456        }
457
458        let candidate = idx - length;
459        let upper_close = close[*max_deque.front().expect("window max present")];
460        let lower_close = close[*min_deque.front().expect("window min present")];
461        let prev_os = os;
462        let candidate_close = close[candidate];
463
464        if candidate_close > upper_close {
465            os = 0;
466        } else if candidate_close < lower_close {
467            os = 1;
468        }
469
470        if os == 1 && prev_os != 1 {
471            let end_idx = candidate;
472            let end_value = low[end_idx];
473            if let Some(prev_top) = last_top {
474                let start_idx = prev_top.confirm_idx - length;
475                let start_value = prev_top.value;
476                let (up_offset, dn_offset) = compute_segment_offsets(
477                    open,
478                    close,
479                    start_idx,
480                    end_idx,
481                    start_value,
482                    end_value,
483                );
484                fill_segment(
485                    middle,
486                    upper,
487                    lower,
488                    start_idx,
489                    end_idx,
490                    start_value,
491                    end_value,
492                    up_offset,
493                    dn_offset,
494                );
495            }
496            last_bottom = Some(PivotState {
497                confirm_idx: idx,
498                value: end_value,
499            });
500        }
501
502        if os == 0 && prev_os != 0 {
503            let end_idx = candidate;
504            let end_value = high[end_idx];
505            if let Some(prev_bottom) = last_bottom {
506                let start_idx = prev_bottom.confirm_idx - length;
507                let start_value = prev_bottom.value;
508                let (up_offset, dn_offset) = compute_segment_offsets(
509                    open,
510                    close,
511                    start_idx,
512                    end_idx,
513                    start_value,
514                    end_value,
515                );
516                fill_segment(
517                    middle,
518                    upper,
519                    lower,
520                    start_idx,
521                    end_idx,
522                    start_value,
523                    end_value,
524                    up_offset,
525                    dn_offset,
526                );
527            }
528            last_top = Some(PivotState {
529                confirm_idx: idx,
530                value: end_value,
531            });
532        }
533    }
534
535    if !extend {
536        return;
537    }
538
539    let end_idx = n - 1;
540    let end_value = close[end_idx];
541    if os == 1 {
542        if let Some(prev_bottom) = last_bottom {
543            let start_idx = prev_bottom.confirm_idx - length;
544            let start_value = prev_bottom.value;
545            let (up_offset, dn_offset) =
546                compute_segment_offsets(open, close, start_idx, end_idx, start_value, end_value);
547            fill_segment(
548                middle,
549                upper,
550                lower,
551                start_idx,
552                end_idx,
553                start_value,
554                end_value,
555                up_offset,
556                dn_offset,
557            );
558        }
559    } else if let Some(prev_top) = last_top {
560        let start_idx = prev_top.confirm_idx - length;
561        let start_value = prev_top.value;
562        let (up_offset, dn_offset) =
563            compute_segment_offsets(open, close, start_idx, end_idx, start_value, end_value);
564        fill_segment(
565            middle,
566            upper,
567            lower,
568            start_idx,
569            end_idx,
570            start_value,
571            end_value,
572            up_offset,
573            dn_offset,
574        );
575    }
576}
577
578fn compute_row(
579    open: &[f64],
580    high: &[f64],
581    low: &[f64],
582    close: &[f64],
583    length: usize,
584    extend: bool,
585    middle: &mut [f64],
586    upper: &mut [f64],
587    lower: &mut [f64],
588) {
589    let mut idx = 0usize;
590    while idx < close.len() {
591        while idx < close.len() && !is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
592            idx += 1;
593        }
594        if idx >= close.len() {
595            break;
596        }
597        let seg_start = idx;
598        idx += 1;
599        while idx < close.len() && is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
600            idx += 1;
601        }
602        let seg_end = idx;
603        if seg_end - seg_start >= length + 1 {
604            compute_run(
605                &open[seg_start..seg_end],
606                &high[seg_start..seg_end],
607                &low[seg_start..seg_end],
608                &close[seg_start..seg_end],
609                length,
610                extend,
611                &mut middle[seg_start..seg_end],
612                &mut upper[seg_start..seg_end],
613                &mut lower[seg_start..seg_end],
614            );
615        }
616    }
617}
618
619#[inline]
620pub fn zig_zag_channels(
621    input: &ZigZagChannelsInput,
622) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
623    zig_zag_channels_with_kernel(input, Kernel::Auto)
624}
625
626pub fn zig_zag_channels_with_kernel(
627    input: &ZigZagChannelsInput,
628    kernel: Kernel,
629) -> Result<ZigZagChannelsOutput, ZigZagChannelsError> {
630    let (open, high, low, close) = input_slices(input)?;
631    let length = input.get_length();
632    let extend = input.get_extend();
633    validate_common(open, high, low, close, length)?;
634
635    let _chosen = match kernel {
636        Kernel::Auto => detect_best_kernel(),
637        other => other,
638    };
639
640    let mut middle = alloc_with_nan_prefix(close.len(), 0);
641    let mut upper = alloc_with_nan_prefix(close.len(), 0);
642    let mut lower = alloc_with_nan_prefix(close.len(), 0);
643    middle.fill(f64::NAN);
644    upper.fill(f64::NAN);
645    lower.fill(f64::NAN);
646
647    compute_row(
648        open,
649        high,
650        low,
651        close,
652        length,
653        extend,
654        &mut middle,
655        &mut upper,
656        &mut lower,
657    );
658
659    Ok(ZigZagChannelsOutput {
660        middle,
661        upper,
662        lower,
663    })
664}
665
666pub fn zig_zag_channels_into_slice(
667    out_middle: &mut [f64],
668    out_upper: &mut [f64],
669    out_lower: &mut [f64],
670    input: &ZigZagChannelsInput,
671    kernel: Kernel,
672) -> Result<(), ZigZagChannelsError> {
673    let (open, high, low, close) = input_slices(input)?;
674    let length = input.get_length();
675    let extend = input.get_extend();
676    validate_common(open, high, low, close, length)?;
677
678    if out_middle.len() != close.len() {
679        return Err(ZigZagChannelsError::OutputLengthMismatch {
680            expected: close.len(),
681            got: out_middle.len(),
682        });
683    }
684    if out_upper.len() != close.len() {
685        return Err(ZigZagChannelsError::OutputLengthMismatch {
686            expected: close.len(),
687            got: out_upper.len(),
688        });
689    }
690    if out_lower.len() != close.len() {
691        return Err(ZigZagChannelsError::OutputLengthMismatch {
692            expected: close.len(),
693            got: out_lower.len(),
694        });
695    }
696
697    let _chosen = match kernel {
698        Kernel::Auto => detect_best_kernel(),
699        other => other,
700    };
701
702    out_middle.fill(f64::NAN);
703    out_upper.fill(f64::NAN);
704    out_lower.fill(f64::NAN);
705    compute_row(
706        open, high, low, close, length, extend, out_middle, out_upper, out_lower,
707    );
708    Ok(())
709}
710
711#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
712pub fn zig_zag_channels_into(
713    input: &ZigZagChannelsInput,
714    out_middle: &mut [f64],
715    out_upper: &mut [f64],
716    out_lower: &mut [f64],
717) -> Result<(), ZigZagChannelsError> {
718    zig_zag_channels_into_slice(out_middle, out_upper, out_lower, input, Kernel::Auto)
719}
720
721#[derive(Debug, Clone, Copy)]
722pub struct ZigZagChannelsBatchRange {
723    pub length: (usize, usize, usize),
724    pub extend: bool,
725}
726
727impl Default for ZigZagChannelsBatchRange {
728    fn default() -> Self {
729        Self {
730            length: (100, 100, 0),
731            extend: true,
732        }
733    }
734}
735
736#[derive(Debug, Clone)]
737pub struct ZigZagChannelsBatchOutput {
738    pub middle: Vec<f64>,
739    pub upper: Vec<f64>,
740    pub lower: Vec<f64>,
741    pub combos: Vec<ZigZagChannelsParams>,
742    pub rows: usize,
743    pub cols: usize,
744}
745
746#[derive(Debug, Clone, Copy)]
747pub struct ZigZagChannelsBatchBuilder {
748    range: ZigZagChannelsBatchRange,
749    kernel: Kernel,
750}
751
752impl Default for ZigZagChannelsBatchBuilder {
753    fn default() -> Self {
754        Self {
755            range: ZigZagChannelsBatchRange::default(),
756            kernel: Kernel::Auto,
757        }
758    }
759}
760
761impl ZigZagChannelsBatchBuilder {
762    #[inline(always)]
763    pub fn new() -> Self {
764        Self::default()
765    }
766
767    #[inline(always)]
768    pub fn kernel(mut self, value: Kernel) -> Self {
769        self.kernel = value;
770        self
771    }
772
773    #[inline(always)]
774    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
775        self.range.length = (start, end, step);
776        self
777    }
778
779    #[inline(always)]
780    pub fn length_static(mut self, value: usize) -> Self {
781        self.range.length = (value, value, 0);
782        self
783    }
784
785    #[inline(always)]
786    pub fn extend(mut self, value: bool) -> Self {
787        self.range.extend = value;
788        self
789    }
790
791    #[inline(always)]
792    pub fn apply_slices(
793        self,
794        open: &[f64],
795        high: &[f64],
796        low: &[f64],
797        close: &[f64],
798    ) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
799        zig_zag_channels_batch_with_kernel(open, high, low, close, &self.range, self.kernel)
800    }
801
802    #[inline(always)]
803    pub fn apply_candles(
804        self,
805        candles: &Candles,
806    ) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
807        zig_zag_channels_batch_with_kernel(
808            candles.open.as_slice(),
809            candles.high.as_slice(),
810            candles.low.as_slice(),
811            candles.close.as_slice(),
812            &self.range,
813            self.kernel,
814        )
815    }
816}
817
818#[inline(always)]
819fn expand_grid_checked(
820    range: &ZigZagChannelsBatchRange,
821) -> Result<Vec<ZigZagChannelsParams>, ZigZagChannelsError> {
822    let (start, end, step) = range.length;
823    if start == 0 || end == 0 {
824        return Err(ZigZagChannelsError::InvalidRange { start, end, step });
825    }
826    if step == 0 {
827        return Ok(vec![ZigZagChannelsParams {
828            length: Some(start),
829            extend: Some(range.extend),
830        }]);
831    }
832    if start > end {
833        return Err(ZigZagChannelsError::InvalidRange { start, end, step });
834    }
835
836    let mut out = Vec::new();
837    let mut current = start;
838    loop {
839        out.push(ZigZagChannelsParams {
840            length: Some(current),
841            extend: Some(range.extend),
842        });
843        if current >= end {
844            break;
845        }
846        let next = current.saturating_add(step);
847        if next <= current {
848            return Err(ZigZagChannelsError::InvalidRange { start, end, step });
849        }
850        current = next.min(end);
851        if current == out.last().and_then(|item| item.length).unwrap_or(0) {
852            break;
853        }
854    }
855    Ok(out)
856}
857
858#[inline(always)]
859pub fn expand_grid_zig_zag_channels(range: &ZigZagChannelsBatchRange) -> Vec<ZigZagChannelsParams> {
860    expand_grid_checked(range).unwrap_or_default()
861}
862
863pub fn zig_zag_channels_batch_with_kernel(
864    open: &[f64],
865    high: &[f64],
866    low: &[f64],
867    close: &[f64],
868    sweep: &ZigZagChannelsBatchRange,
869    kernel: Kernel,
870) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
871    zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, true)
872}
873
874pub fn zig_zag_channels_batch_slice(
875    open: &[f64],
876    high: &[f64],
877    low: &[f64],
878    close: &[f64],
879    sweep: &ZigZagChannelsBatchRange,
880    kernel: Kernel,
881) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
882    zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, false)
883}
884
885pub fn zig_zag_channels_batch_par_slice(
886    open: &[f64],
887    high: &[f64],
888    low: &[f64],
889    close: &[f64],
890    sweep: &ZigZagChannelsBatchRange,
891    kernel: Kernel,
892) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
893    zig_zag_channels_batch_inner(open, high, low, close, sweep, kernel, true)
894}
895
896fn zig_zag_channels_batch_inner(
897    open: &[f64],
898    high: &[f64],
899    low: &[f64],
900    close: &[f64],
901    sweep: &ZigZagChannelsBatchRange,
902    kernel: Kernel,
903    parallel: bool,
904) -> Result<ZigZagChannelsBatchOutput, ZigZagChannelsError> {
905    match kernel {
906        Kernel::Auto
907        | Kernel::Scalar
908        | Kernel::ScalarBatch
909        | Kernel::Avx2
910        | Kernel::Avx2Batch
911        | Kernel::Avx512
912        | Kernel::Avx512Batch => {}
913        other => return Err(ZigZagChannelsError::InvalidKernelForBatch(other)),
914    }
915
916    let combos = expand_grid_checked(sweep)?;
917    let max_length = combos
918        .iter()
919        .map(|params| params.length.unwrap_or(100))
920        .max()
921        .unwrap_or(0);
922    validate_common(open, high, low, close, max_length)?;
923
924    let rows = combos.len();
925    let cols = close.len();
926    let total = rows
927        .checked_mul(cols)
928        .ok_or_else(|| ZigZagChannelsError::InvalidInput {
929            msg: "zig_zag_channels: rows*cols overflow in batch".to_string(),
930        })?;
931
932    let mut middle = vec![f64::NAN; total];
933    let mut upper = vec![f64::NAN; total];
934    let mut lower = vec![f64::NAN; total];
935    zig_zag_channels_batch_inner_into(
936        open,
937        high,
938        low,
939        close,
940        sweep,
941        kernel,
942        parallel,
943        &mut middle,
944        &mut upper,
945        &mut lower,
946    )?;
947
948    Ok(ZigZagChannelsBatchOutput {
949        middle,
950        upper,
951        lower,
952        combos,
953        rows,
954        cols,
955    })
956}
957
958fn zig_zag_channels_batch_inner_into(
959    open: &[f64],
960    high: &[f64],
961    low: &[f64],
962    close: &[f64],
963    sweep: &ZigZagChannelsBatchRange,
964    kernel: Kernel,
965    parallel: bool,
966    out_middle: &mut [f64],
967    out_upper: &mut [f64],
968    out_lower: &mut [f64],
969) -> Result<Vec<ZigZagChannelsParams>, ZigZagChannelsError> {
970    match kernel {
971        Kernel::Auto
972        | Kernel::Scalar
973        | Kernel::ScalarBatch
974        | Kernel::Avx2
975        | Kernel::Avx2Batch
976        | Kernel::Avx512
977        | Kernel::Avx512Batch => {}
978        other => return Err(ZigZagChannelsError::InvalidKernelForBatch(other)),
979    }
980
981    let combos = expand_grid_checked(sweep)?;
982    let max_length = combos
983        .iter()
984        .map(|params| params.length.unwrap_or(100))
985        .max()
986        .unwrap_or(0);
987    validate_common(open, high, low, close, max_length)?;
988
989    let cols = close.len();
990    let total =
991        combos
992            .len()
993            .checked_mul(cols)
994            .ok_or_else(|| ZigZagChannelsError::InvalidInput {
995                msg: "zig_zag_channels: rows*cols overflow in batch_into".to_string(),
996            })?;
997    if out_middle.len() != total {
998        return Err(ZigZagChannelsError::MismatchedOutputLen {
999            dst_len: out_middle.len(),
1000            expected_len: total,
1001        });
1002    }
1003    if out_upper.len() != total {
1004        return Err(ZigZagChannelsError::MismatchedOutputLen {
1005            dst_len: out_upper.len(),
1006            expected_len: total,
1007        });
1008    }
1009    if out_lower.len() != total {
1010        return Err(ZigZagChannelsError::MismatchedOutputLen {
1011            dst_len: out_lower.len(),
1012            expected_len: total,
1013        });
1014    }
1015
1016    let _chosen = match kernel {
1017        Kernel::Auto => detect_best_batch_kernel(),
1018        other => other,
1019    };
1020
1021    let worker =
1022        |row: usize, middle_row: &mut [f64], upper_row: &mut [f64], lower_row: &mut [f64]| {
1023            middle_row.fill(f64::NAN);
1024            upper_row.fill(f64::NAN);
1025            lower_row.fill(f64::NAN);
1026            let params = &combos[row];
1027            compute_row(
1028                open,
1029                high,
1030                low,
1031                close,
1032                params.length.unwrap_or(100),
1033                params.extend.unwrap_or(true),
1034                middle_row,
1035                upper_row,
1036                lower_row,
1037            );
1038        };
1039
1040    #[cfg(not(target_arch = "wasm32"))]
1041    if parallel && combos.len() > 1 {
1042        out_middle
1043            .par_chunks_mut(cols)
1044            .zip(out_upper.par_chunks_mut(cols))
1045            .zip(out_lower.par_chunks_mut(cols))
1046            .enumerate()
1047            .for_each(|(row, ((middle_row, upper_row), lower_row))| {
1048                worker(row, middle_row, upper_row, lower_row);
1049            });
1050    } else {
1051        for (row, ((middle_row, upper_row), lower_row)) in out_middle
1052            .chunks_mut(cols)
1053            .zip(out_upper.chunks_mut(cols))
1054            .zip(out_lower.chunks_mut(cols))
1055            .enumerate()
1056        {
1057            worker(row, middle_row, upper_row, lower_row);
1058        }
1059    }
1060
1061    #[cfg(target_arch = "wasm32")]
1062    {
1063        let _ = parallel;
1064        for (row, ((middle_row, upper_row), lower_row)) in out_middle
1065            .chunks_mut(cols)
1066            .zip(out_upper.chunks_mut(cols))
1067            .zip(out_lower.chunks_mut(cols))
1068            .enumerate()
1069        {
1070            worker(row, middle_row, upper_row, lower_row);
1071        }
1072    }
1073
1074    Ok(combos)
1075}
1076
1077#[cfg(feature = "python")]
1078#[pyfunction(name = "zig_zag_channels", signature = (open, high, low, close, length=100, extend=true, kernel=None))]
1079pub fn zig_zag_channels_py<'py>(
1080    py: Python<'py>,
1081    open: PyReadonlyArray1<'py, f64>,
1082    high: PyReadonlyArray1<'py, f64>,
1083    low: PyReadonlyArray1<'py, f64>,
1084    close: PyReadonlyArray1<'py, f64>,
1085    length: usize,
1086    extend: bool,
1087    kernel: Option<&str>,
1088) -> PyResult<(
1089    Bound<'py, PyArray1<f64>>,
1090    Bound<'py, PyArray1<f64>>,
1091    Bound<'py, PyArray1<f64>>,
1092)> {
1093    let open = open.as_slice()?;
1094    let high = high.as_slice()?;
1095    let low = low.as_slice()?;
1096    let close = close.as_slice()?;
1097    let kern = validate_kernel(kernel, false)?;
1098    let input = ZigZagChannelsInput::from_slices(
1099        open,
1100        high,
1101        low,
1102        close,
1103        ZigZagChannelsParams {
1104            length: Some(length),
1105            extend: Some(extend),
1106        },
1107    );
1108    let out = py
1109        .allow_threads(|| zig_zag_channels_with_kernel(&input, kern))
1110        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1111    Ok((
1112        out.middle.into_pyarray(py),
1113        out.upper.into_pyarray(py),
1114        out.lower.into_pyarray(py),
1115    ))
1116}
1117
1118#[cfg(feature = "python")]
1119#[pyfunction(name = "zig_zag_channels_batch", signature = (open, high, low, close, length_range=(100, 100, 0), extend=true, kernel=None))]
1120pub fn zig_zag_channels_batch_py<'py>(
1121    py: Python<'py>,
1122    open: PyReadonlyArray1<'py, f64>,
1123    high: PyReadonlyArray1<'py, f64>,
1124    low: PyReadonlyArray1<'py, f64>,
1125    close: PyReadonlyArray1<'py, f64>,
1126    length_range: (usize, usize, usize),
1127    extend: bool,
1128    kernel: Option<&str>,
1129) -> PyResult<Bound<'py, PyDict>> {
1130    let open = open.as_slice()?;
1131    let high = high.as_slice()?;
1132    let low = low.as_slice()?;
1133    let close = close.as_slice()?;
1134    let kern = validate_kernel(kernel, true)?;
1135
1136    let output = py
1137        .allow_threads(|| {
1138            zig_zag_channels_batch_with_kernel(
1139                open,
1140                high,
1141                low,
1142                close,
1143                &ZigZagChannelsBatchRange {
1144                    length: length_range,
1145                    extend,
1146                },
1147                kern,
1148            )
1149        })
1150        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1151
1152    let dict = PyDict::new(py);
1153    dict.set_item(
1154        "middle",
1155        output
1156            .middle
1157            .into_pyarray(py)
1158            .reshape((output.rows, output.cols))?,
1159    )?;
1160    dict.set_item(
1161        "upper",
1162        output
1163            .upper
1164            .into_pyarray(py)
1165            .reshape((output.rows, output.cols))?,
1166    )?;
1167    dict.set_item(
1168        "lower",
1169        output
1170            .lower
1171            .into_pyarray(py)
1172            .reshape((output.rows, output.cols))?,
1173    )?;
1174    dict.set_item(
1175        "lengths",
1176        output
1177            .combos
1178            .iter()
1179            .map(|params| params.length.unwrap_or(100) as u64)
1180            .collect::<Vec<_>>()
1181            .into_pyarray(py),
1182    )?;
1183    dict.set_item(
1184        "extends",
1185        output
1186            .combos
1187            .iter()
1188            .map(|params| params.extend.unwrap_or(true))
1189            .collect::<Vec<_>>(),
1190    )?;
1191    dict.set_item("rows", output.rows)?;
1192    dict.set_item("cols", output.cols)?;
1193    Ok(dict)
1194}
1195
1196#[cfg(feature = "python")]
1197pub fn register_zig_zag_channels_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1198    m.add_function(wrap_pyfunction!(zig_zag_channels_py, m)?)?;
1199    m.add_function(wrap_pyfunction!(zig_zag_channels_batch_py, m)?)?;
1200    Ok(())
1201}
1202
1203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1204#[derive(Debug, Clone, Serialize, Deserialize)]
1205pub struct ZigZagChannelsBatchConfig {
1206    pub length_range: Vec<usize>,
1207    pub extend: Option<bool>,
1208}
1209
1210#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1211#[wasm_bindgen(js_name = zig_zag_channels_js)]
1212pub fn zig_zag_channels_js(
1213    open: &[f64],
1214    high: &[f64],
1215    low: &[f64],
1216    close: &[f64],
1217    length: usize,
1218    extend: bool,
1219) -> Result<JsValue, JsValue> {
1220    let input = ZigZagChannelsInput::from_slices(
1221        open,
1222        high,
1223        low,
1224        close,
1225        ZigZagChannelsParams {
1226            length: Some(length),
1227            extend: Some(extend),
1228        },
1229    );
1230    let out = zig_zag_channels_with_kernel(&input, Kernel::Auto)
1231        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1232    let obj = js_sys::Object::new();
1233    js_sys::Reflect::set(
1234        &obj,
1235        &JsValue::from_str("middle"),
1236        &serde_wasm_bindgen::to_value(&out.middle).unwrap(),
1237    )?;
1238    js_sys::Reflect::set(
1239        &obj,
1240        &JsValue::from_str("upper"),
1241        &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1242    )?;
1243    js_sys::Reflect::set(
1244        &obj,
1245        &JsValue::from_str("lower"),
1246        &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1247    )?;
1248    Ok(obj.into())
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen(js_name = zig_zag_channels_batch_js)]
1253pub fn zig_zag_channels_batch_js(
1254    open: &[f64],
1255    high: &[f64],
1256    low: &[f64],
1257    close: &[f64],
1258    config: JsValue,
1259) -> Result<JsValue, JsValue> {
1260    let config: ZigZagChannelsBatchConfig = serde_wasm_bindgen::from_value(config)
1261        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1262    if config.length_range.len() != 3 {
1263        return Err(JsValue::from_str(
1264            "Invalid config: length_range must have exactly 3 elements [start, end, step]",
1265        ));
1266    }
1267
1268    let out = zig_zag_channels_batch_with_kernel(
1269        open,
1270        high,
1271        low,
1272        close,
1273        &ZigZagChannelsBatchRange {
1274            length: (
1275                config.length_range[0],
1276                config.length_range[1],
1277                config.length_range[2],
1278            ),
1279            extend: config.extend.unwrap_or(true),
1280        },
1281        Kernel::Auto,
1282    )
1283    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1284
1285    let obj = js_sys::Object::new();
1286    js_sys::Reflect::set(
1287        &obj,
1288        &JsValue::from_str("middle"),
1289        &serde_wasm_bindgen::to_value(&out.middle).unwrap(),
1290    )?;
1291    js_sys::Reflect::set(
1292        &obj,
1293        &JsValue::from_str("upper"),
1294        &serde_wasm_bindgen::to_value(&out.upper).unwrap(),
1295    )?;
1296    js_sys::Reflect::set(
1297        &obj,
1298        &JsValue::from_str("lower"),
1299        &serde_wasm_bindgen::to_value(&out.lower).unwrap(),
1300    )?;
1301    js_sys::Reflect::set(
1302        &obj,
1303        &JsValue::from_str("rows"),
1304        &JsValue::from_f64(out.rows as f64),
1305    )?;
1306    js_sys::Reflect::set(
1307        &obj,
1308        &JsValue::from_str("cols"),
1309        &JsValue::from_f64(out.cols as f64),
1310    )?;
1311    js_sys::Reflect::set(
1312        &obj,
1313        &JsValue::from_str("combos"),
1314        &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1315    )?;
1316    Ok(obj.into())
1317}
1318
1319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1320#[wasm_bindgen]
1321pub fn zig_zag_channels_alloc(len: usize) -> *mut f64 {
1322    let mut vec = Vec::<f64>::with_capacity(3 * len);
1323    let ptr = vec.as_mut_ptr();
1324    std::mem::forget(vec);
1325    ptr
1326}
1327
1328#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1329#[wasm_bindgen]
1330pub fn zig_zag_channels_free(ptr: *mut f64, len: usize) {
1331    if !ptr.is_null() {
1332        unsafe {
1333            let _ = Vec::from_raw_parts(ptr, 3 * len, 3 * len);
1334        }
1335    }
1336}
1337
1338#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1339#[wasm_bindgen]
1340pub fn zig_zag_channels_into(
1341    open_ptr: *const f64,
1342    high_ptr: *const f64,
1343    low_ptr: *const f64,
1344    close_ptr: *const f64,
1345    out_ptr: *mut f64,
1346    len: usize,
1347    length: usize,
1348    extend: bool,
1349) -> Result<(), JsValue> {
1350    if open_ptr.is_null()
1351        || high_ptr.is_null()
1352        || low_ptr.is_null()
1353        || close_ptr.is_null()
1354        || out_ptr.is_null()
1355    {
1356        return Err(JsValue::from_str(
1357            "null pointer passed to zig_zag_channels_into",
1358        ));
1359    }
1360
1361    unsafe {
1362        let open = std::slice::from_raw_parts(open_ptr, len);
1363        let high = std::slice::from_raw_parts(high_ptr, len);
1364        let low = std::slice::from_raw_parts(low_ptr, len);
1365        let close = std::slice::from_raw_parts(close_ptr, len);
1366        let out = std::slice::from_raw_parts_mut(out_ptr, 3 * len);
1367        let (middle, tail) = out.split_at_mut(len);
1368        let (upper, lower) = tail.split_at_mut(len);
1369        let input = ZigZagChannelsInput::from_slices(
1370            open,
1371            high,
1372            low,
1373            close,
1374            ZigZagChannelsParams {
1375                length: Some(length),
1376                extend: Some(extend),
1377            },
1378        );
1379        zig_zag_channels_into_slice(middle, upper, lower, &input, Kernel::Auto)
1380            .map_err(|e| JsValue::from_str(&e.to_string()))
1381    }
1382}
1383
1384#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1385#[wasm_bindgen]
1386pub fn zig_zag_channels_batch_into(
1387    open_ptr: *const f64,
1388    high_ptr: *const f64,
1389    low_ptr: *const f64,
1390    close_ptr: *const f64,
1391    out_ptr: *mut f64,
1392    len: usize,
1393    length_start: usize,
1394    length_end: usize,
1395    length_step: usize,
1396    extend: bool,
1397) -> Result<usize, JsValue> {
1398    if open_ptr.is_null()
1399        || high_ptr.is_null()
1400        || low_ptr.is_null()
1401        || close_ptr.is_null()
1402        || out_ptr.is_null()
1403    {
1404        return Err(JsValue::from_str(
1405            "null pointer passed to zig_zag_channels_batch_into",
1406        ));
1407    }
1408
1409    let sweep = ZigZagChannelsBatchRange {
1410        length: (length_start, length_end, length_step),
1411        extend,
1412    };
1413    let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1414    let rows = combos.len();
1415    let split = rows
1416        .checked_mul(len)
1417        .ok_or_else(|| JsValue::from_str("rows*cols overflow in zig_zag_channels_batch_into"))?;
1418    let total = split
1419        .checked_mul(3)
1420        .ok_or_else(|| JsValue::from_str("3*rows*cols overflow in zig_zag_channels_batch_into"))?;
1421
1422    unsafe {
1423        let open = std::slice::from_raw_parts(open_ptr, len);
1424        let high = std::slice::from_raw_parts(high_ptr, len);
1425        let low = std::slice::from_raw_parts(low_ptr, len);
1426        let close = std::slice::from_raw_parts(close_ptr, len);
1427        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1428        let (middle, tail) = out.split_at_mut(split);
1429        let (upper, lower) = tail.split_at_mut(split);
1430        zig_zag_channels_batch_inner_into(
1431            open,
1432            high,
1433            low,
1434            close,
1435            &sweep,
1436            Kernel::Auto,
1437            false,
1438            middle,
1439            upper,
1440            lower,
1441        )
1442        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1443    }
1444
1445    Ok(rows)
1446}
1447
1448#[cfg(test)]
1449mod tests {
1450    use super::*;
1451    use crate::indicators::dispatch::{
1452        compute_cpu, IndicatorComputeRequest, IndicatorDataRef, ParamKV, ParamValue,
1453    };
1454
1455    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1456        let close: Vec<f64> = (0..len)
1457            .map(|i| {
1458                let x = i as f64;
1459                100.0 + (x * 0.21).sin() * 7.0 + (x * 0.037).cos() * 2.5 + (x * 0.02)
1460            })
1461            .collect();
1462        let open: Vec<f64> = close
1463            .iter()
1464            .enumerate()
1465            .map(|(i, &c)| c + ((i as f64) * 0.41).sin() * 0.9)
1466            .collect();
1467        let high: Vec<f64> = open
1468            .iter()
1469            .zip(close.iter())
1470            .map(|(&o, &c)| o.max(c) + 0.75)
1471            .collect();
1472        let low: Vec<f64> = open
1473            .iter()
1474            .zip(close.iter())
1475            .map(|(&o, &c)| o.min(c) - 0.75)
1476            .collect();
1477        (open, high, low, close)
1478    }
1479
1480    fn naive_offsets(
1481        open: &[f64],
1482        close: &[f64],
1483        start_idx: usize,
1484        end_idx: usize,
1485        start_value: f64,
1486        end_value: f64,
1487    ) -> (f64, f64) {
1488        if end_idx <= start_idx {
1489            return (0.0, 0.0);
1490        }
1491        if end_idx == start_idx + 1 {
1492            let top = open[end_idx].max(close[end_idx]);
1493            let bottom = open[end_idx].min(close[end_idx]);
1494            return ((top - end_value).max(0.0), (end_value - bottom).max(0.0));
1495        }
1496        let mut up = 0.0f64;
1497        let mut dn = 0.0f64;
1498        let denom = (end_idx - start_idx - 1) as f64;
1499        for idx in (start_idx + 1)..=end_idx {
1500            let j = (idx - start_idx - 1) as f64;
1501            let point = start_value + (j / denom) * (end_value - start_value);
1502            up = up.max(open[idx].max(close[idx]) - point);
1503            dn = dn.max(point - open[idx].min(close[idx]));
1504        }
1505        (up.max(0.0), dn.max(0.0))
1506    }
1507
1508    fn naive_fill(
1509        middle: &mut [f64],
1510        upper: &mut [f64],
1511        lower: &mut [f64],
1512        start_idx: usize,
1513        end_idx: usize,
1514        start_value: f64,
1515        end_value: f64,
1516        up: f64,
1517        dn: f64,
1518    ) {
1519        if end_idx < start_idx {
1520            return;
1521        }
1522        if start_idx == end_idx {
1523            middle[start_idx] = start_value;
1524            upper[start_idx] = start_value + up;
1525            lower[start_idx] = start_value - dn;
1526            return;
1527        }
1528        let denom = (end_idx - start_idx) as f64;
1529        for idx in start_idx..=end_idx {
1530            let t = (idx - start_idx) as f64 / denom;
1531            let value = start_value + t * (end_value - start_value);
1532            middle[idx] = value;
1533            upper[idx] = value + up;
1534            lower[idx] = value - dn;
1535        }
1536    }
1537
1538    fn naive_run(
1539        open: &[f64],
1540        high: &[f64],
1541        low: &[f64],
1542        close: &[f64],
1543        length: usize,
1544        extend: bool,
1545        middle: &mut [f64],
1546        upper: &mut [f64],
1547        lower: &mut [f64],
1548    ) {
1549        let mut os = 0usize;
1550        let mut last_top: Option<(usize, f64)> = None;
1551        let mut last_bottom: Option<(usize, f64)> = None;
1552
1553        for current in length..close.len() {
1554            let candidate = current - length;
1555            let mut hi = f64::NEG_INFINITY;
1556            let mut lo = f64::INFINITY;
1557            for idx in (candidate + 1)..=current {
1558                hi = hi.max(close[idx]);
1559                lo = lo.min(close[idx]);
1560            }
1561
1562            let prev_os = os;
1563            if close[candidate] > hi {
1564                os = 0;
1565            } else if close[candidate] < lo {
1566                os = 1;
1567            }
1568
1569            if os == 1 && prev_os != 1 {
1570                let end_idx = candidate;
1571                let end_value = low[end_idx];
1572                if let Some((confirm_idx, start_value)) = last_top {
1573                    let start_idx = confirm_idx - length;
1574                    let (up, dn) =
1575                        naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1576                    naive_fill(
1577                        middle,
1578                        upper,
1579                        lower,
1580                        start_idx,
1581                        end_idx,
1582                        start_value,
1583                        end_value,
1584                        up,
1585                        dn,
1586                    );
1587                }
1588                last_bottom = Some((current, end_value));
1589            }
1590
1591            if os == 0 && prev_os != 0 {
1592                let end_idx = candidate;
1593                let end_value = high[end_idx];
1594                if let Some((confirm_idx, start_value)) = last_bottom {
1595                    let start_idx = confirm_idx - length;
1596                    let (up, dn) =
1597                        naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1598                    naive_fill(
1599                        middle,
1600                        upper,
1601                        lower,
1602                        start_idx,
1603                        end_idx,
1604                        start_value,
1605                        end_value,
1606                        up,
1607                        dn,
1608                    );
1609                }
1610                last_top = Some((current, end_value));
1611            }
1612        }
1613
1614        if !extend || close.is_empty() {
1615            return;
1616        }
1617        let end_idx = close.len() - 1;
1618        let end_value = close[end_idx];
1619        if os == 1 {
1620            if let Some((confirm_idx, start_value)) = last_bottom {
1621                let start_idx = confirm_idx - length;
1622                let (up, dn) =
1623                    naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1624                naive_fill(
1625                    middle,
1626                    upper,
1627                    lower,
1628                    start_idx,
1629                    end_idx,
1630                    start_value,
1631                    end_value,
1632                    up,
1633                    dn,
1634                );
1635            }
1636        } else if let Some((confirm_idx, start_value)) = last_top {
1637            let start_idx = confirm_idx - length;
1638            let (up, dn) = naive_offsets(open, close, start_idx, end_idx, start_value, end_value);
1639            naive_fill(
1640                middle,
1641                upper,
1642                lower,
1643                start_idx,
1644                end_idx,
1645                start_value,
1646                end_value,
1647                up,
1648                dn,
1649            );
1650        }
1651    }
1652
1653    fn naive_zig_zag_channels(
1654        open: &[f64],
1655        high: &[f64],
1656        low: &[f64],
1657        close: &[f64],
1658        length: usize,
1659        extend: bool,
1660    ) -> ZigZagChannelsOutput {
1661        let mut middle = vec![f64::NAN; close.len()];
1662        let mut upper = vec![f64::NAN; close.len()];
1663        let mut lower = vec![f64::NAN; close.len()];
1664        let mut idx = 0usize;
1665        while idx < close.len() {
1666            while idx < close.len() && !is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
1667                idx += 1;
1668            }
1669            if idx >= close.len() {
1670                break;
1671            }
1672            let start = idx;
1673            idx += 1;
1674            while idx < close.len() && is_valid_ohlc(open[idx], high[idx], low[idx], close[idx]) {
1675                idx += 1;
1676            }
1677            let end = idx;
1678            if end - start >= length + 1 {
1679                naive_run(
1680                    &open[start..end],
1681                    &high[start..end],
1682                    &low[start..end],
1683                    &close[start..end],
1684                    length,
1685                    extend,
1686                    &mut middle[start..end],
1687                    &mut upper[start..end],
1688                    &mut lower[start..end],
1689                );
1690            }
1691        }
1692        ZigZagChannelsOutput {
1693            middle,
1694            upper,
1695            lower,
1696        }
1697    }
1698
1699    fn assert_series_close(left: &[f64], right: &[f64], tol: f64) {
1700        assert_eq!(left.len(), right.len());
1701        for (a, b) in left.iter().zip(right.iter()) {
1702            if a.is_nan() || b.is_nan() {
1703                assert!(a.is_nan() && b.is_nan());
1704            } else {
1705                assert!((a - b).abs() <= tol, "left={a} right={b}");
1706            }
1707        }
1708    }
1709
1710    #[test]
1711    fn zig_zag_channels_matches_naive_reference() -> Result<(), Box<dyn Error>> {
1712        let (open, high, low, close) = sample_ohlc(256);
1713        let input = ZigZagChannelsInput::from_slices(
1714            &open,
1715            &high,
1716            &low,
1717            &close,
1718            ZigZagChannelsParams {
1719                length: Some(7),
1720                extend: Some(true),
1721            },
1722        );
1723        let out = zig_zag_channels_with_kernel(&input, Kernel::Scalar)?;
1724        let expected = naive_zig_zag_channels(&open, &high, &low, &close, 7, true);
1725        assert_series_close(&out.middle, &expected.middle, 1e-12);
1726        assert_series_close(&out.upper, &expected.upper, 1e-12);
1727        assert_series_close(&out.lower, &expected.lower, 1e-12);
1728        Ok(())
1729    }
1730
1731    #[test]
1732    fn zig_zag_channels_into_matches_api() -> Result<(), Box<dyn Error>> {
1733        let (open, high, low, close) = sample_ohlc(220);
1734        let input = ZigZagChannelsInput::from_slices(
1735            &open,
1736            &high,
1737            &low,
1738            &close,
1739            ZigZagChannelsParams {
1740                length: Some(6),
1741                extend: Some(true),
1742            },
1743        );
1744        let base = zig_zag_channels(&input)?;
1745        let mut middle = vec![0.0; close.len()];
1746        let mut upper = vec![0.0; close.len()];
1747        let mut lower = vec![0.0; close.len()];
1748        zig_zag_channels_into_slice(&mut middle, &mut upper, &mut lower, &input, Kernel::Auto)?;
1749        assert_series_close(&base.middle, &middle, 1e-12);
1750        assert_series_close(&base.upper, &upper, 1e-12);
1751        assert_series_close(&base.lower, &lower, 1e-12);
1752        Ok(())
1753    }
1754
1755    #[test]
1756    fn zig_zag_channels_extend_changes_tail_only() -> Result<(), Box<dyn Error>> {
1757        let (open, high, low, close) = sample_ohlc(180);
1758        let extend_true = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1759            &open,
1760            &high,
1761            &low,
1762            &close,
1763            ZigZagChannelsParams {
1764                length: Some(8),
1765                extend: Some(true),
1766            },
1767        ))?;
1768        let extend_false = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1769            &open,
1770            &high,
1771            &low,
1772            &close,
1773            ZigZagChannelsParams {
1774                length: Some(8),
1775                extend: Some(false),
1776            },
1777        ))?;
1778
1779        let finite_true = extend_true.middle.iter().filter(|v| v.is_finite()).count();
1780        let finite_false = extend_false.middle.iter().filter(|v| v.is_finite()).count();
1781        assert!(finite_true >= finite_false);
1782
1783        for i in 0..close.len() {
1784            if extend_false.middle[i].is_finite() {
1785                assert!(extend_true.middle[i].is_finite());
1786            }
1787        }
1788        Ok(())
1789    }
1790
1791    #[test]
1792    fn zig_zag_channels_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1793        let (open, high, low, close) = sample_ohlc(192);
1794        let single = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1795            &open,
1796            &high,
1797            &low,
1798            &close,
1799            ZigZagChannelsParams {
1800                length: Some(9),
1801                extend: Some(true),
1802            },
1803        ))?;
1804        let batch = zig_zag_channels_batch_with_kernel(
1805            &open,
1806            &high,
1807            &low,
1808            &close,
1809            &ZigZagChannelsBatchRange {
1810                length: (9, 9, 0),
1811                extend: true,
1812            },
1813            Kernel::Auto,
1814        )?;
1815        assert_eq!(batch.rows, 1);
1816        assert_eq!(batch.cols, close.len());
1817        assert_series_close(&batch.middle, &single.middle, 1e-12);
1818        assert_series_close(&batch.upper, &single.upper, 1e-12);
1819        assert_series_close(&batch.lower, &single.lower, 1e-12);
1820        Ok(())
1821    }
1822
1823    #[test]
1824    fn zig_zag_channels_rejects_invalid_params() {
1825        let (open, high, low, close) = sample_ohlc(32);
1826        let err = zig_zag_channels(&ZigZagChannelsInput::from_slices(
1827            &open,
1828            &high,
1829            &low,
1830            &close,
1831            ZigZagChannelsParams {
1832                length: Some(0),
1833                extend: Some(true),
1834            },
1835        ))
1836        .unwrap_err();
1837        assert!(matches!(err, ZigZagChannelsError::InvalidLength { .. }));
1838    }
1839
1840    #[test]
1841    fn zig_zag_channels_dispatch_compute_returns_middle() -> Result<(), Box<dyn Error>> {
1842        let (open, high, low, close) = sample_ohlc(160);
1843        let out = compute_cpu(IndicatorComputeRequest {
1844            indicator_id: "zig_zag_channels",
1845            output_id: Some("middle"),
1846            data: IndicatorDataRef::Ohlc {
1847                open: &open,
1848                high: &high,
1849                low: &low,
1850                close: &close,
1851            },
1852            params: &[
1853                ParamKV {
1854                    key: "length",
1855                    value: ParamValue::Int(7),
1856                },
1857                ParamKV {
1858                    key: "extend",
1859                    value: ParamValue::Bool(true),
1860                },
1861            ],
1862            kernel: Kernel::Auto,
1863        })?;
1864        assert_eq!(out.output_id, "middle");
1865        assert_eq!(out.cols, close.len());
1866        Ok(())
1867    }
1868}