Skip to main content

vector_ta/indicators/
keltner.rs

1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::{PyDict, PyList};
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::error::Error;
29use thiserror::Error;
30
31#[cfg(all(feature = "python", feature = "cuda"))]
32use crate::cuda::keltner_wrapper::CudaKeltner;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use crate::cuda::moving_averages::DeviceArrayF32;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::context::Context;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41
42#[derive(Debug, Clone)]
43pub enum KeltnerData<'a> {
44    Candles {
45        candles: &'a Candles,
46        source: &'a str,
47    },
48    Slice(&'a [f64], &'a [f64], &'a [f64], &'a [f64]),
49}
50
51#[derive(Debug, Clone)]
52pub struct KeltnerOutput {
53    pub upper_band: Vec<f64>,
54    pub middle_band: Vec<f64>,
55    pub lower_band: Vec<f64>,
56}
57
58#[derive(Debug, Clone)]
59#[cfg_attr(
60    all(target_arch = "wasm32", feature = "wasm"),
61    derive(Serialize, Deserialize)
62)]
63pub struct KeltnerParams {
64    pub period: Option<usize>,
65    pub multiplier: Option<f64>,
66    pub ma_type: Option<String>,
67}
68
69impl Default for KeltnerParams {
70    fn default() -> Self {
71        Self {
72            period: Some(20),
73            multiplier: Some(2.0),
74            ma_type: Some("ema".to_string()),
75        }
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct KeltnerInput<'a> {
81    pub data: KeltnerData<'a>,
82    pub params: KeltnerParams,
83}
84
85impl<'a> KeltnerInput<'a> {
86    #[inline]
87    pub fn from_candles(candles: &'a Candles, source: &'a str, params: KeltnerParams) -> Self {
88        Self {
89            data: KeltnerData::Candles { candles, source },
90            params,
91        }
92    }
93    #[inline]
94    pub fn from_slice(
95        high: &'a [f64],
96        low: &'a [f64],
97        close: &'a [f64],
98        source: &'a [f64],
99        params: KeltnerParams,
100    ) -> Self {
101        Self {
102            data: KeltnerData::Slice(high, low, close, source),
103            params,
104        }
105    }
106    #[inline]
107    pub fn with_default_candles(candles: &'a Candles) -> Self {
108        Self::from_candles(candles, "close", KeltnerParams::default())
109    }
110    #[inline]
111    pub fn get_period(&self) -> usize {
112        self.params.period.unwrap_or(20)
113    }
114    #[inline]
115    pub fn get_multiplier(&self) -> f64 {
116        self.params.multiplier.unwrap_or(2.0)
117    }
118    #[inline]
119    pub fn get_ma_type(&self) -> &str {
120        self.params.ma_type.as_deref().unwrap_or("ema")
121    }
122}
123
124impl<'a> AsRef<[f64]> for KeltnerInput<'a> {
125    #[inline(always)]
126    fn as_ref(&self) -> &[f64] {
127        match &self.data {
128            KeltnerData::Slice(_, _, _, source) => source,
129            KeltnerData::Candles { candles, source } => source_type(candles, source),
130        }
131    }
132}
133
134#[derive(Clone, Debug)]
135pub struct KeltnerBuilder {
136    period: Option<usize>,
137    multiplier: Option<f64>,
138    ma_type: Option<String>,
139    kernel: Kernel,
140}
141
142impl Default for KeltnerBuilder {
143    fn default() -> Self {
144        Self {
145            period: None,
146            multiplier: None,
147            ma_type: None,
148            kernel: Kernel::Auto,
149        }
150    }
151}
152
153impl KeltnerBuilder {
154    #[inline(always)]
155    pub fn new() -> Self {
156        Self::default()
157    }
158    #[inline(always)]
159    pub fn period(mut self, n: usize) -> Self {
160        self.period = Some(n);
161        self
162    }
163    #[inline(always)]
164    pub fn multiplier(mut self, x: f64) -> Self {
165        self.multiplier = Some(x);
166        self
167    }
168    #[inline(always)]
169    pub fn ma_type(mut self, mt: &str) -> Self {
170        self.ma_type = Some(mt.to_lowercase());
171        self
172    }
173    #[inline(always)]
174    pub fn kernel(mut self, k: Kernel) -> Self {
175        self.kernel = k;
176        self
177    }
178
179    #[inline(always)]
180    pub fn apply(self, c: &Candles) -> Result<KeltnerOutput, KeltnerError> {
181        let p = KeltnerParams {
182            period: self.period,
183            multiplier: self.multiplier,
184            ma_type: self.ma_type,
185        };
186        let i = KeltnerInput::from_candles(c, "close", p);
187        keltner_with_kernel(&i, self.kernel)
188    }
189
190    #[inline(always)]
191    pub fn apply_slice(
192        self,
193        high: &[f64],
194        low: &[f64],
195        close: &[f64],
196        source: &[f64],
197    ) -> Result<KeltnerOutput, KeltnerError> {
198        let p = KeltnerParams {
199            period: self.period,
200            multiplier: self.multiplier,
201            ma_type: self.ma_type,
202        };
203        let i = KeltnerInput::from_slice(high, low, close, source, p);
204        keltner_with_kernel(&i, self.kernel)
205    }
206
207    #[inline(always)]
208    pub fn into_stream(self) -> Result<KeltnerStream, KeltnerError> {
209        let p = KeltnerParams {
210            period: self.period,
211            multiplier: self.multiplier,
212            ma_type: self.ma_type,
213        };
214        KeltnerStream::try_new(p)
215    }
216}
217
218#[derive(Debug, Error)]
219pub enum KeltnerError {
220    #[error("keltner: empty data provided.")]
221    EmptyInputData,
222    #[error("keltner: invalid period: period = {period}, data length = {data_len}")]
223    InvalidPeriod { period: usize, data_len: usize },
224    #[error("keltner: not enough valid data: needed = {needed}, valid = {valid}")]
225    NotEnoughValidData { needed: usize, valid: usize },
226    #[error("keltner: all values are NaN.")]
227    AllValuesNaN,
228    #[error("keltner: output length mismatch: expected = {expected}, got = {got}")]
229    OutputLengthMismatch { expected: usize, got: usize },
230    #[error("keltner: invalid range: start={start}, end={end}, step={step}")]
231    InvalidRange {
232        start: String,
233        end: String,
234        step: String,
235    },
236    #[error("keltner: invalid kernel for batch: {0:?}")]
237    InvalidKernelForBatch(Kernel),
238    #[error("keltner: invalid input: {0}")]
239    InvalidInput(String),
240    #[error("keltner: MA error: {0}")]
241    MaError(String),
242}
243
244#[inline]
245pub fn keltner(input: &KeltnerInput) -> Result<KeltnerOutput, KeltnerError> {
246    keltner_with_kernel(input, Kernel::Auto)
247}
248
249pub fn keltner_with_kernel(
250    input: &KeltnerInput,
251    kernel: Kernel,
252) -> Result<KeltnerOutput, KeltnerError> {
253    let (high, low, close, source_slice): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
254        KeltnerData::Candles { candles, source } => (
255            candles
256                .select_candle_field("high")
257                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
258            candles
259                .select_candle_field("low")
260                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
261            candles
262                .select_candle_field("close")
263                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
264            source_type(candles, source),
265        ),
266        KeltnerData::Slice(h, l, c, s) => (*h, *l, *c, *s),
267    };
268    let period = input.get_period();
269    let len = close.len();
270    if len == 0 {
271        return Err(KeltnerError::EmptyInputData);
272    }
273    if period == 0 || period > len {
274        return Err(KeltnerError::InvalidPeriod {
275            period,
276            data_len: len,
277        });
278    }
279    let first = close
280        .iter()
281        .position(|x| !x.is_nan())
282        .ok_or(KeltnerError::AllValuesNaN)?;
283
284    if (len - first) < period {
285        return Err(KeltnerError::NotEnoughValidData {
286            needed: period,
287            valid: len - first,
288        });
289    }
290
291    let chosen = match kernel {
292        Kernel::Auto => Kernel::Scalar,
293        other => other,
294    };
295
296    let warm = first + period - 1;
297    let mut upper_band = alloc_with_nan_prefix(len, warm);
298    let mut middle_band = alloc_with_nan_prefix(len, warm);
299    let mut lower_band = alloc_with_nan_prefix(len, warm);
300
301    unsafe {
302        match chosen {
303            Kernel::Scalar | Kernel::ScalarBatch => keltner_scalar(
304                high,
305                low,
306                close,
307                source_slice,
308                period,
309                input.get_multiplier(),
310                input.get_ma_type(),
311                first,
312                &mut upper_band,
313                &mut middle_band,
314                &mut lower_band,
315            ),
316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317            Kernel::Avx2 | Kernel::Avx2Batch => keltner_avx2(
318                high,
319                low,
320                close,
321                source_slice,
322                period,
323                input.get_multiplier(),
324                input.get_ma_type(),
325                first,
326                &mut upper_band,
327                &mut middle_band,
328                &mut lower_band,
329            ),
330            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
331            Kernel::Avx512 | Kernel::Avx512Batch => keltner_avx512(
332                high,
333                low,
334                close,
335                source_slice,
336                period,
337                input.get_multiplier(),
338                input.get_ma_type(),
339                first,
340                &mut upper_band,
341                &mut middle_band,
342                &mut lower_band,
343            ),
344            _ => unreachable!(),
345        }
346    }
347    Ok(KeltnerOutput {
348        upper_band,
349        middle_band,
350        lower_band,
351    })
352}
353
354#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
355#[inline(always)]
356pub fn keltner_into(
357    input: &KeltnerInput,
358    upper_dst: &mut [f64],
359    middle_dst: &mut [f64],
360    lower_dst: &mut [f64],
361) -> Result<(), KeltnerError> {
362    keltner_into_slice(upper_dst, middle_dst, lower_dst, input, Kernel::Auto)
363}
364
365#[inline(always)]
366pub fn keltner_into_slice(
367    upper_dst: &mut [f64],
368    middle_dst: &mut [f64],
369    lower_dst: &mut [f64],
370    input: &KeltnerInput,
371    kernel: Kernel,
372) -> Result<(), KeltnerError> {
373    let (high, low, close, source_slice): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
374        KeltnerData::Candles { candles, source } => (
375            candles
376                .select_candle_field("high")
377                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
378            candles
379                .select_candle_field("low")
380                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
381            candles
382                .select_candle_field("close")
383                .map_err(|e| KeltnerError::MaError(e.to_string()))?,
384            source_type(candles, source),
385        ),
386        KeltnerData::Slice(h, l, c, s) => (*h, *l, *c, *s),
387    };
388
389    let period = input.get_period();
390    let len = close.len();
391
392    if len == 0 {
393        return Err(KeltnerError::EmptyInputData);
394    }
395
396    if upper_dst.len() != len || middle_dst.len() != len || lower_dst.len() != len {
397        return Err(KeltnerError::OutputLengthMismatch {
398            expected: len,
399            got: upper_dst.len().max(middle_dst.len()).max(lower_dst.len()),
400        });
401    }
402
403    if period == 0 || period > len {
404        return Err(KeltnerError::InvalidPeriod {
405            period,
406            data_len: len,
407        });
408    }
409
410    let first = close
411        .iter()
412        .position(|x| !x.is_nan())
413        .ok_or(KeltnerError::AllValuesNaN)?;
414
415    if (len - first) < period {
416        return Err(KeltnerError::NotEnoughValidData {
417            needed: period,
418            valid: len - first,
419        });
420    }
421
422    let chosen = match kernel {
423        Kernel::Auto => Kernel::Scalar,
424        other => other,
425    };
426
427    unsafe {
428        match chosen {
429            Kernel::Scalar | Kernel::ScalarBatch => keltner_scalar(
430                high,
431                low,
432                close,
433                source_slice,
434                period,
435                input.get_multiplier(),
436                input.get_ma_type(),
437                first,
438                upper_dst,
439                middle_dst,
440                lower_dst,
441            ),
442            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
443            Kernel::Avx2 | Kernel::Avx2Batch => keltner_avx2(
444                high,
445                low,
446                close,
447                source_slice,
448                period,
449                input.get_multiplier(),
450                input.get_ma_type(),
451                first,
452                upper_dst,
453                middle_dst,
454                lower_dst,
455            ),
456            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
457            Kernel::Avx512 | Kernel::Avx512Batch => keltner_avx512(
458                high,
459                low,
460                close,
461                source_slice,
462                period,
463                input.get_multiplier(),
464                input.get_ma_type(),
465                first,
466                upper_dst,
467                middle_dst,
468                lower_dst,
469            ),
470            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
471            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
472                keltner_scalar(
473                    high,
474                    low,
475                    close,
476                    source_slice,
477                    period,
478                    input.get_multiplier(),
479                    input.get_ma_type(),
480                    first,
481                    upper_dst,
482                    middle_dst,
483                    lower_dst,
484                )
485            }
486            _ => unreachable!(),
487        }
488    }
489
490    let warm = first + period - 1;
491    for i in 0..warm {
492        upper_dst[i] = f64::NAN;
493        middle_dst[i] = f64::NAN;
494        lower_dst[i] = f64::NAN;
495    }
496
497    Ok(())
498}
499
500#[inline]
501
502pub fn keltner_scalar_classic_sma(
503    high: &[f64],
504    low: &[f64],
505    close: &[f64],
506    source: &[f64],
507    period: usize,
508    multiplier: f64,
509    first: usize,
510    upper: &mut [f64],
511    middle: &mut [f64],
512    lower: &mut [f64],
513) {
514    let len = close.len();
515    let warm = first + period - 1;
516
517    if warm >= len {
518        return;
519    }
520
521    let alpha = 1.0 / (period as f64);
522    let mut sum_tr = 0.0;
523    let mut rma = f64::NAN;
524    let mut atr_values = vec![f64::NAN; len];
525
526    for i in 0..len {
527        let tr = if i == 0 {
528            high[0] - low[0]
529        } else {
530            let hl = high[i] - low[i];
531            let hc = (high[i] - close[i - 1]).abs();
532            let lc = (low[i] - close[i - 1]).abs();
533            hl.max(hc).max(lc)
534        };
535
536        if i < period {
537            sum_tr += tr;
538            if i == period - 1 {
539                rma = sum_tr / (period as f64);
540                atr_values[i] = rma;
541            }
542        } else {
543            rma += alpha * (tr - rma);
544            atr_values[i] = rma;
545        }
546    }
547
548    let mut sum = 0.0;
549
550    for j in 0..period {
551        sum += source[first + j];
552    }
553    let mut sma_val = sum / period as f64;
554
555    if warm < len {
556        middle[warm] = sma_val;
557        let atr_v = atr_values[warm];
558        if !atr_v.is_nan() {
559            upper[warm] = sma_val + multiplier * atr_v;
560            lower[warm] = sma_val - multiplier * atr_v;
561        }
562    }
563
564    for i in (warm + 1)..len {
565        sum += source[i] - source[i - period];
566        sma_val = sum / period as f64;
567
568        middle[i] = sma_val;
569        let atr_v = atr_values[i];
570        if !atr_v.is_nan() {
571            upper[i] = sma_val + multiplier * atr_v;
572            lower[i] = sma_val - multiplier * atr_v;
573        }
574    }
575}
576
577pub fn keltner_scalar_classic_ema(
578    high: &[f64],
579    low: &[f64],
580    close: &[f64],
581    source: &[f64],
582    period: usize,
583    multiplier: f64,
584    first: usize,
585    upper: &mut [f64],
586    middle: &mut [f64],
587    lower: &mut [f64],
588) {
589    let len = close.len();
590    let warm = first + period - 1;
591
592    if warm >= len {
593        return;
594    }
595
596    let alpha = 1.0 / (period as f64);
597    let mut sum_tr = 0.0;
598    let mut rma = f64::NAN;
599    let mut atr_values = vec![f64::NAN; len];
600
601    for i in 0..len {
602        let tr = if i == 0 {
603            high[0] - low[0]
604        } else {
605            let hl = high[i] - low[i];
606            let hc = (high[i] - close[i - 1]).abs();
607            let lc = (low[i] - close[i - 1]).abs();
608            hl.max(hc).max(lc)
609        };
610
611        if i < period {
612            sum_tr += tr;
613            if i == period - 1 {
614                rma = sum_tr / (period as f64);
615                atr_values[i] = rma;
616            }
617        } else {
618            rma += alpha * (tr - rma);
619            atr_values[i] = rma;
620        }
621    }
622
623    let ema_alpha = 2.0 / (period as f64 + 1.0);
624    let ema_alpha_1 = 1.0 - ema_alpha;
625
626    let mut sum = 0.0;
627    for j in 0..period {
628        sum += source[first + j];
629    }
630    let mut ema_val = sum / period as f64;
631
632    if warm < len {
633        middle[warm] = ema_val;
634        let atr_v = atr_values[warm];
635        if !atr_v.is_nan() {
636            upper[warm] = ema_val + multiplier * atr_v;
637            lower[warm] = ema_val - multiplier * atr_v;
638        }
639    }
640
641    for i in (warm + 1)..len {
642        ema_val = ema_alpha * source[i] + ema_alpha_1 * ema_val;
643
644        middle[i] = ema_val;
645        let atr_v = atr_values[i];
646        if !atr_v.is_nan() {
647            upper[i] = ema_val + multiplier * atr_v;
648            lower[i] = ema_val - multiplier * atr_v;
649        }
650    }
651}
652
653pub fn keltner_scalar(
654    high: &[f64],
655    low: &[f64],
656    close: &[f64],
657    source: &[f64],
658    period: usize,
659    multiplier: f64,
660    ma_type: &str,
661    first: usize,
662    upper: &mut [f64],
663    middle: &mut [f64],
664    lower: &mut [f64],
665) {
666    let len = close.len();
667    let warm = first + period - 1;
668    if warm >= len {
669        return;
670    }
671
672    let pf = period as f64;
673    let rma_alpha = 1.0 / pf;
674
675    let mut atr: f64;
676    unsafe {
677        atr = *high.get_unchecked(0) - *low.get_unchecked(0);
678
679        let mut i = 1usize;
680        while i < period {
681            let hi = *high.get_unchecked(i);
682            let lo = *low.get_unchecked(i);
683            let pc = *close.get_unchecked(i - 1);
684            let hl = hi - lo;
685            let hc = (hi - pc).abs();
686            let lc = (lo - pc).abs();
687            atr += hl.max(hc).max(lc);
688            i += 1;
689        }
690        atr /= pf;
691
692        let mut k = period;
693        while k <= warm {
694            let hi = *high.get_unchecked(k);
695            let lo = *low.get_unchecked(k);
696            let pc = *close.get_unchecked(k - 1);
697            let hl = hi - lo;
698            let hc = (hi - pc).abs();
699            let lc = (lo - pc).abs();
700            let tr = hl.max(hc).max(lc);
701            atr = (tr - atr).mul_add(rma_alpha, atr);
702            k += 1;
703        }
704    }
705
706    let m = multiplier;
707
708    if ma_type.eq_ignore_ascii_case("ema") {
709        let mut ema: f64 = 0.0;
710        unsafe {
711            let mut j = 0usize;
712            while j < period {
713                ema += *source.get_unchecked(first + j);
714                j += 1;
715            }
716        }
717        ema /= pf;
718
719        middle[warm] = ema;
720        upper[warm] = m.mul_add(atr, ema);
721        lower[warm] = (-m).mul_add(atr, ema);
722
723        let ema_alpha = 2.0 / (pf + 1.0);
724
725        unsafe {
726            let mut i = warm + 1;
727            while i < len {
728                let hi = *high.get_unchecked(i);
729                let lo = *low.get_unchecked(i);
730                let pc = *close.get_unchecked(i - 1);
731                let hl = hi - lo;
732                let hc = (hi - pc).abs();
733                let lc = (lo - pc).abs();
734                let tr = hl.max(hc).max(lc);
735                atr = (tr - atr).mul_add(rma_alpha, atr);
736
737                let xi = *source.get_unchecked(i);
738                ema = (xi - ema).mul_add(ema_alpha, ema);
739
740                middle[i] = ema;
741                upper[i] = m.mul_add(atr, ema);
742                lower[i] = (-m).mul_add(atr, ema);
743                i += 1;
744            }
745        }
746        return;
747    }
748
749    if ma_type.eq_ignore_ascii_case("sma") {
750        let mut sum: f64 = 0.0;
751        unsafe {
752            let mut j = 0usize;
753            while j < period {
754                sum += *source.get_unchecked(first + j);
755                j += 1;
756            }
757        }
758        let mut mid = sum / pf;
759        middle[warm] = mid;
760        upper[warm] = m.mul_add(atr, mid);
761        lower[warm] = (-m).mul_add(atr, mid);
762
763        unsafe {
764            let mut i = warm + 1;
765            while i < len {
766                let hi = *high.get_unchecked(i);
767                let lo = *low.get_unchecked(i);
768                let pc = *close.get_unchecked(i - 1);
769                let hl = hi - lo;
770                let hc = (hi - pc).abs();
771                let lc = (lo - pc).abs();
772                let tr = hl.max(hc).max(lc);
773                atr = (tr - atr).mul_add(rma_alpha, atr);
774
775                let new_x = *source.get_unchecked(i);
776                let old_x = *source.get_unchecked(i - period);
777                sum += new_x - old_x;
778                mid = sum / pf;
779
780                middle[i] = mid;
781                upper[i] = m.mul_add(atr, mid);
782                lower[i] = (-m).mul_add(atr, mid);
783                i += 1;
784            }
785        }
786        return;
787    }
788
789    let mut atr = crate::utilities::helpers::alloc_with_nan_prefix(len, warm);
790    let alpha = 1.0 / (period as f64);
791    let mut sum_tr = 0.0;
792    let mut rma = f64::NAN;
793
794    for i in 0..len {
795        let tr = if i == 0 {
796            high[0] - low[0]
797        } else {
798            let hl = high[i] - low[i];
799            let hc = (high[i] - close[i - 1]).abs();
800            let lc = (low[i] - close[i - 1]).abs();
801            hl.max(hc).max(lc)
802        };
803        if i < period {
804            sum_tr += tr;
805            if i == period - 1 {
806                rma = sum_tr / (period as f64);
807                atr[i] = rma;
808            }
809        } else {
810            rma += alpha * (tr - rma);
811            atr[i] = rma;
812        }
813    }
814
815    let mut ma_values = crate::utilities::helpers::alloc_with_nan_prefix(len, warm);
816
817    match ma_type {
818        "ema" => {
819            use crate::indicators::moving_averages::ema::{
820                ema_into_slice, EmaData, EmaInput, EmaParams,
821            };
822            let ema_input = EmaInput {
823                data: EmaData::Slice(source),
824                params: EmaParams {
825                    period: Some(period),
826                },
827            };
828            let _ = ema_into_slice(&mut ma_values, &ema_input, Kernel::Auto);
829        }
830        "sma" => {
831            use crate::indicators::moving_averages::sma::{
832                sma_into_slice, SmaData, SmaInput, SmaParams,
833            };
834            let sma_input = SmaInput {
835                data: SmaData::Slice(source),
836                params: SmaParams {
837                    period: Some(period),
838                },
839            };
840            let _ = sma_into_slice(&mut ma_values, &sma_input, Kernel::Auto);
841        }
842        _ => {
843            if let Ok(result) = crate::indicators::moving_averages::ma::ma(
844                ma_type,
845                crate::indicators::moving_averages::ma::MaData::Slice(source),
846                period,
847            ) {
848                ma_values.copy_from_slice(&result);
849            }
850        }
851    }
852
853    for i in warm..len {
854        let ma_v = ma_values[i];
855        let atr_v = atr[i];
856        if ma_v.is_nan() || atr_v.is_nan() {
857            continue;
858        }
859        middle[i] = ma_v;
860        upper[i] = multiplier.mul_add(atr_v, ma_v);
861        lower[i] = (-multiplier).mul_add(atr_v, ma_v);
862    }
863}
864
865#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
866#[inline]
867pub fn keltner_avx2(
868    high: &[f64],
869    low: &[f64],
870    close: &[f64],
871    source: &[f64],
872    period: usize,
873    multiplier: f64,
874    ma_type: &str,
875    first: usize,
876    upper: &mut [f64],
877    middle: &mut [f64],
878    lower: &mut [f64],
879) {
880    keltner_scalar(
881        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
882    )
883}
884
885#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
886#[inline]
887pub fn keltner_avx512(
888    high: &[f64],
889    low: &[f64],
890    close: &[f64],
891    source: &[f64],
892    period: usize,
893    multiplier: f64,
894    ma_type: &str,
895    first: usize,
896    upper: &mut [f64],
897    middle: &mut [f64],
898    lower: &mut [f64],
899) {
900    if period <= 32 {
901        unsafe {
902            keltner_avx512_short(
903                high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
904            )
905        }
906    } else {
907        unsafe {
908            keltner_avx512_long(
909                high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
910            )
911        }
912    }
913}
914
915#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
916#[inline]
917pub unsafe fn keltner_avx512_short(
918    high: &[f64],
919    low: &[f64],
920    close: &[f64],
921    source: &[f64],
922    period: usize,
923    multiplier: f64,
924    ma_type: &str,
925    first: usize,
926    upper: &mut [f64],
927    middle: &mut [f64],
928    lower: &mut [f64],
929) {
930    keltner_scalar(
931        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
932    )
933}
934
935#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
936#[inline]
937pub unsafe fn keltner_avx512_long(
938    high: &[f64],
939    low: &[f64],
940    close: &[f64],
941    source: &[f64],
942    period: usize,
943    multiplier: f64,
944    ma_type: &str,
945    first: usize,
946    upper: &mut [f64],
947    middle: &mut [f64],
948    lower: &mut [f64],
949) {
950    keltner_scalar(
951        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
952    )
953}
954
955#[inline(always)]
956pub fn keltner_row_scalar(
957    high: &[f64],
958    low: &[f64],
959    close: &[f64],
960    source: &[f64],
961    period: usize,
962    multiplier: f64,
963    ma_type: &str,
964    first: usize,
965    upper: &mut [f64],
966    middle: &mut [f64],
967    lower: &mut [f64],
968) {
969    keltner_scalar(
970        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
971    );
972}
973
974#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
975#[inline(always)]
976pub fn keltner_row_avx2(
977    high: &[f64],
978    low: &[f64],
979    close: &[f64],
980    source: &[f64],
981    period: usize,
982    multiplier: f64,
983    ma_type: &str,
984    first: usize,
985    upper: &mut [f64],
986    middle: &mut [f64],
987    lower: &mut [f64],
988) {
989    keltner_avx2(
990        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
991    )
992}
993
994#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
995#[inline(always)]
996pub fn keltner_row_avx512(
997    high: &[f64],
998    low: &[f64],
999    close: &[f64],
1000    source: &[f64],
1001    period: usize,
1002    multiplier: f64,
1003    ma_type: &str,
1004    first: usize,
1005    upper: &mut [f64],
1006    middle: &mut [f64],
1007    lower: &mut [f64],
1008) {
1009    keltner_avx512(
1010        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1011    )
1012}
1013
1014#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015#[inline(always)]
1016pub unsafe fn keltner_row_avx512_short(
1017    high: &[f64],
1018    low: &[f64],
1019    close: &[f64],
1020    source: &[f64],
1021    period: usize,
1022    multiplier: f64,
1023    ma_type: &str,
1024    first: usize,
1025    upper: &mut [f64],
1026    middle: &mut [f64],
1027    lower: &mut [f64],
1028) {
1029    keltner_avx512_short(
1030        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1031    )
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036pub unsafe fn keltner_row_avx512_long(
1037    high: &[f64],
1038    low: &[f64],
1039    close: &[f64],
1040    source: &[f64],
1041    period: usize,
1042    multiplier: f64,
1043    ma_type: &str,
1044    first: usize,
1045    upper: &mut [f64],
1046    middle: &mut [f64],
1047    lower: &mut [f64],
1048) {
1049    keltner_avx512_long(
1050        high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1051    )
1052}
1053
1054#[derive(Clone, Debug)]
1055pub struct KeltnerBatchRange {
1056    pub period: (usize, usize, usize),
1057    pub multiplier: (f64, f64, f64),
1058}
1059
1060impl Default for KeltnerBatchRange {
1061    fn default() -> Self {
1062        Self {
1063            period: (20, 269, 1),
1064            multiplier: (2.0, 2.0, 0.0),
1065        }
1066    }
1067}
1068
1069#[derive(Clone, Debug)]
1070pub struct KeltnerBatchBuilder {
1071    range: KeltnerBatchRange,
1072    kernel: Kernel,
1073}
1074
1075impl Default for KeltnerBatchBuilder {
1076    fn default() -> Self {
1077        Self {
1078            range: KeltnerBatchRange::default(),
1079            kernel: Kernel::Auto,
1080        }
1081    }
1082}
1083impl KeltnerBatchBuilder {
1084    pub fn new() -> Self {
1085        Self::default()
1086    }
1087    pub fn kernel(mut self, k: Kernel) -> Self {
1088        self.kernel = k;
1089        self
1090    }
1091    #[inline]
1092    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1093        self.range.period = (start, end, step);
1094        self
1095    }
1096    #[inline]
1097    pub fn period_static(mut self, p: usize) -> Self {
1098        self.range.period = (p, p, 0);
1099        self
1100    }
1101    #[inline]
1102    pub fn multiplier_range(mut self, start: f64, end: f64, step: f64) -> Self {
1103        self.range.multiplier = (start, end, step);
1104        self
1105    }
1106    #[inline]
1107    pub fn multiplier_static(mut self, m: f64) -> Self {
1108        self.range.multiplier = (m, m, 0.0);
1109        self
1110    }
1111    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<KeltnerBatchOutput, KeltnerError> {
1112        let h = c
1113            .select_candle_field("high")
1114            .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1115        let l = c
1116            .select_candle_field("low")
1117            .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1118        let cl = c
1119            .select_candle_field("close")
1120            .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1121        let src_v = source_type(c, src);
1122        self.apply_slice(&h, &l, &cl, src_v)
1123    }
1124    pub fn apply_slice(
1125        self,
1126        high: &[f64],
1127        low: &[f64],
1128        close: &[f64],
1129        source: &[f64],
1130    ) -> Result<KeltnerBatchOutput, KeltnerError> {
1131        keltner_batch_with_kernel(high, low, close, source, &self.range, self.kernel)
1132    }
1133
1134    pub fn with_default_slice(
1135        high: &[f64],
1136        low: &[f64],
1137        close: &[f64],
1138        source: &[f64],
1139        k: Kernel,
1140    ) -> Result<KeltnerBatchOutput, KeltnerError> {
1141        KeltnerBatchBuilder::new()
1142            .kernel(k)
1143            .apply_slice(high, low, close, source)
1144    }
1145}
1146
1147#[derive(Clone, Debug)]
1148pub struct KeltnerBatchOutput {
1149    pub upper_band: Vec<f64>,
1150    pub middle_band: Vec<f64>,
1151    pub lower_band: Vec<f64>,
1152    pub combos: Vec<KeltnerParams>,
1153    pub rows: usize,
1154    pub cols: usize,
1155}
1156impl KeltnerBatchOutput {
1157    pub fn row_for_params(&self, p: &KeltnerParams) -> Option<usize> {
1158        self.combos.iter().position(|c| {
1159            c.period.unwrap_or(20) == p.period.unwrap_or(20)
1160                && (c.multiplier.unwrap_or(2.0) - p.multiplier.unwrap_or(2.0)).abs() < 1e-12
1161        })
1162    }
1163    pub fn values_for(&self, p: &KeltnerParams) -> Option<(&[f64], &[f64], &[f64])> {
1164        self.row_for_params(p).map(|row| {
1165            let start = row * self.cols;
1166            (
1167                &self.upper_band[start..start + self.cols],
1168                &self.middle_band[start..start + self.cols],
1169                &self.lower_band[start..start + self.cols],
1170            )
1171        })
1172    }
1173}
1174
1175fn expand_grid(r: &KeltnerBatchRange) -> Result<Vec<KeltnerParams>, KeltnerError> {
1176    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, KeltnerError> {
1177        if step == 0 || start == end {
1178            return Ok(vec![start]);
1179        }
1180        if start < end {
1181            return Ok((start..=end).step_by(step.max(1)).collect());
1182        }
1183        let mut v = Vec::new();
1184        let mut x = start as isize;
1185        let end_i = end as isize;
1186        let st = (step as isize).max(1);
1187        while x >= end_i {
1188            v.push(x as usize);
1189            x -= st;
1190        }
1191        if v.is_empty() {
1192            return Err(KeltnerError::InvalidRange {
1193                start: start.to_string(),
1194                end: end.to_string(),
1195                step: step.to_string(),
1196            });
1197        }
1198        Ok(v)
1199    }
1200    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, KeltnerError> {
1201        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1202            return Ok(vec![start]);
1203        }
1204        if start < end {
1205            let mut v = Vec::new();
1206            let mut x = start;
1207            let st = step.abs();
1208            while x <= end + 1e-12 {
1209                v.push(x);
1210                x += st;
1211            }
1212            if v.is_empty() {
1213                return Err(KeltnerError::InvalidRange {
1214                    start: start.to_string(),
1215                    end: end.to_string(),
1216                    step: step.to_string(),
1217                });
1218            }
1219            return Ok(v);
1220        }
1221        let mut v = Vec::new();
1222        let mut x = start;
1223        let st = step.abs();
1224        while x + 1e-12 >= end {
1225            v.push(x);
1226            x -= st;
1227        }
1228        if v.is_empty() {
1229            return Err(KeltnerError::InvalidRange {
1230                start: start.to_string(),
1231                end: end.to_string(),
1232                step: step.to_string(),
1233            });
1234        }
1235        Ok(v)
1236    }
1237
1238    let periods = axis_usize(r.period)?;
1239    let mults = axis_f64(r.multiplier)?;
1240
1241    let cap = periods
1242        .len()
1243        .checked_mul(mults.len())
1244        .ok_or_else(|| KeltnerError::InvalidRange {
1245            start: "rows".into(),
1246            end: "cols".into(),
1247            step: "rows*cols".into(),
1248        })?;
1249
1250    let mut out = Vec::with_capacity(cap);
1251    for &p in &periods {
1252        for &m in &mults {
1253            out.push(KeltnerParams {
1254                period: Some(p),
1255                multiplier: Some(m),
1256                ma_type: None,
1257            });
1258        }
1259    }
1260    Ok(out)
1261}
1262
1263pub fn keltner_batch_with_kernel(
1264    high: &[f64],
1265    low: &[f64],
1266    close: &[f64],
1267    source: &[f64],
1268    sweep: &KeltnerBatchRange,
1269    k: Kernel,
1270) -> Result<KeltnerBatchOutput, KeltnerError> {
1271    let kernel = match k {
1272        Kernel::Auto => {
1273            let best = detect_best_batch_kernel();
1274            if best == Kernel::Avx512Batch {
1275                Kernel::Avx2Batch
1276            } else {
1277                best
1278            }
1279        }
1280        other if other.is_batch() => other,
1281        _ => {
1282            return Err(KeltnerError::InvalidKernelForBatch(k));
1283        }
1284    };
1285    let simd = match kernel {
1286        Kernel::Avx512Batch => Kernel::Avx512,
1287        Kernel::Avx2Batch => Kernel::Avx2,
1288        Kernel::ScalarBatch => Kernel::Scalar,
1289        _ => unreachable!(),
1290    };
1291    keltner_batch_par_slice(high, low, close, source, sweep, simd)
1292}
1293
1294pub fn keltner_batch_slice(
1295    high: &[f64],
1296    low: &[f64],
1297    close: &[f64],
1298    source: &[f64],
1299    sweep: &KeltnerBatchRange,
1300    kern: Kernel,
1301) -> Result<KeltnerBatchOutput, KeltnerError> {
1302    keltner_batch_inner(high, low, close, source, sweep, kern, false, None)
1303}
1304pub fn keltner_batch_par_slice(
1305    high: &[f64],
1306    low: &[f64],
1307    close: &[f64],
1308    source: &[f64],
1309    sweep: &KeltnerBatchRange,
1310    kern: Kernel,
1311) -> Result<KeltnerBatchOutput, KeltnerError> {
1312    keltner_batch_inner(high, low, close, source, sweep, kern, true, None)
1313}
1314fn keltner_batch_inner(
1315    high: &[f64],
1316    low: &[f64],
1317    close: &[f64],
1318    source: &[f64],
1319    sweep: &KeltnerBatchRange,
1320    kern: Kernel,
1321    parallel: bool,
1322    ma_type: Option<&str>,
1323) -> Result<KeltnerBatchOutput, KeltnerError> {
1324    let combos = expand_grid(sweep)?;
1325    if combos.is_empty() {
1326        return Err(KeltnerError::InvalidRange {
1327            start: "range".into(),
1328            end: "range".into(),
1329            step: "empty".into(),
1330        });
1331    }
1332    let len = close.len();
1333    let first = close
1334        .iter()
1335        .position(|x| !x.is_nan())
1336        .ok_or(KeltnerError::AllValuesNaN)?;
1337    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1338    if len - first < max_p {
1339        return Err(KeltnerError::NotEnoughValidData {
1340            needed: max_p,
1341            valid: len - first,
1342        });
1343    }
1344    let rows = combos.len();
1345    let cols = len;
1346
1347    rows.checked_mul(cols)
1348        .ok_or_else(|| KeltnerError::InvalidRange {
1349            start: rows.to_string(),
1350            end: cols.to_string(),
1351            step: "rows*cols".into(),
1352        })?;
1353
1354    let warm: Vec<usize> = combos
1355        .iter()
1356        .map(|c| {
1357            let p = c.period.unwrap();
1358            first
1359                .checked_add(p.saturating_sub(1))
1360                .ok_or_else(|| KeltnerError::InvalidRange {
1361                    start: first.to_string(),
1362                    end: p.to_string(),
1363                    step: "first+period-1".into(),
1364                })
1365        })
1366        .collect::<Result<Vec<_>, _>>()?;
1367
1368    let mut upper_mu = make_uninit_matrix(rows, cols);
1369    let mut middle_mu = make_uninit_matrix(rows, cols);
1370    let mut lower_mu = make_uninit_matrix(rows, cols);
1371
1372    init_matrix_prefixes(&mut upper_mu, cols, &warm);
1373    init_matrix_prefixes(&mut middle_mu, cols, &warm);
1374    init_matrix_prefixes(&mut lower_mu, cols, &warm);
1375
1376    let mut upper_guard = core::mem::ManuallyDrop::new(upper_mu);
1377    let mut middle_guard = core::mem::ManuallyDrop::new(middle_mu);
1378    let mut lower_guard = core::mem::ManuallyDrop::new(lower_mu);
1379
1380    let upper: &mut [f64] = unsafe {
1381        core::slice::from_raw_parts_mut(upper_guard.as_mut_ptr() as *mut f64, upper_guard.len())
1382    };
1383    let middle: &mut [f64] = unsafe {
1384        core::slice::from_raw_parts_mut(middle_guard.as_mut_ptr() as *mut f64, middle_guard.len())
1385    };
1386    let lower: &mut [f64] = unsafe {
1387        core::slice::from_raw_parts_mut(lower_guard.as_mut_ptr() as *mut f64, lower_guard.len())
1388    };
1389
1390    let mut tr: Vec<f64> = vec![0.0; cols];
1391    unsafe {
1392        let mut i = 0usize;
1393        while i < cols {
1394            let val = if i == 0 {
1395                *high.get_unchecked(0) - *low.get_unchecked(0)
1396            } else {
1397                let hi = *high.get_unchecked(i);
1398                let lo = *low.get_unchecked(i);
1399                let pc = *close.get_unchecked(i - 1);
1400                let hl = hi - lo;
1401                let hc = (hi - pc).abs();
1402                let lc = (lo - pc).abs();
1403                hl.max(hc).max(lc)
1404            };
1405            *tr.get_unchecked_mut(i) = val;
1406            i += 1;
1407        }
1408    }
1409
1410    let ps: Option<Vec<f64>> = if ma_type.unwrap_or("ema").eq_ignore_ascii_case("sma") {
1411        let mut buf = vec![0.0; cols + 1];
1412        unsafe {
1413            let mut i = 0usize;
1414            while i < cols {
1415                let prev = *buf.get_unchecked(i);
1416                let xi = *source.get_unchecked(i);
1417                *buf.get_unchecked_mut(i + 1) = prev + xi;
1418                i += 1;
1419            }
1420        }
1421        Some(buf)
1422    } else {
1423        None
1424    };
1425
1426    let ma = ma_type.unwrap_or("ema");
1427    let do_row = |row: usize, up: &mut [f64], mid: &mut [f64], low_out: &mut [f64]| {
1428        let period = combos[row].period.unwrap();
1429        let mult = combos[row].multiplier.unwrap();
1430        let row_warm = warm[row];
1431
1432        if row_warm >= cols {
1433            return;
1434        }
1435
1436        let pf = period as f64;
1437        let alpha_rma = 1.0 / pf;
1438
1439        let mut atr = 0.0f64;
1440        unsafe {
1441            let mut j = 0usize;
1442            while j < period {
1443                atr += *tr.get_unchecked(j);
1444                j += 1;
1445            }
1446        }
1447        atr /= pf;
1448        let mut k = period;
1449        unsafe {
1450            while k <= row_warm {
1451                let tri = *tr.get_unchecked(k);
1452                atr = (tri - atr).mul_add(alpha_rma, atr);
1453                k += 1;
1454            }
1455        }
1456
1457        if ma.eq_ignore_ascii_case("ema") {
1458            let mut acc = 0.0f64;
1459            unsafe {
1460                let mut j = 0usize;
1461                while j < period {
1462                    acc += *source.get_unchecked(first + j);
1463                    j += 1;
1464                }
1465            }
1466            let mut ema = acc / pf;
1467            unsafe {
1468                *mid.get_unchecked_mut(row_warm) = ema;
1469                *up.get_unchecked_mut(row_warm) = mult.mul_add(atr, ema);
1470                *low_out.get_unchecked_mut(row_warm) = (-mult).mul_add(atr, ema);
1471            }
1472
1473            let alpha_ema = 2.0 / (pf + 1.0);
1474            unsafe {
1475                let mut i = row_warm + 1;
1476                while i < cols {
1477                    let tri = *tr.get_unchecked(i);
1478                    atr = (tri - atr).mul_add(alpha_rma, atr);
1479
1480                    let xi = *source.get_unchecked(i);
1481                    ema = (xi - ema).mul_add(alpha_ema, ema);
1482
1483                    *mid.get_unchecked_mut(i) = ema;
1484                    *up.get_unchecked_mut(i) = mult.mul_add(atr, ema);
1485                    *low_out.get_unchecked_mut(i) = (-mult).mul_add(atr, ema);
1486                    i += 1;
1487                }
1488            }
1489        } else if ma.eq_ignore_ascii_case("sma") {
1490            let ps = ps.as_ref().expect("prefix sums computed for SMA");
1491
1492            let start = row_warm + 1 - period;
1493            let end = row_warm + 1;
1494            let mut sm = unsafe { (*ps.get_unchecked(end) - *ps.get_unchecked(start)) / pf };
1495            unsafe {
1496                *mid.get_unchecked_mut(row_warm) = sm;
1497                *up.get_unchecked_mut(row_warm) = mult.mul_add(atr, sm);
1498                *low_out.get_unchecked_mut(row_warm) = (-mult).mul_add(atr, sm);
1499            }
1500
1501            unsafe {
1502                let mut i = row_warm + 1;
1503                while i < cols {
1504                    let tri = *tr.get_unchecked(i);
1505                    atr = (tri - atr).mul_add(alpha_rma, atr);
1506                    let s = (*ps.get_unchecked(i + 1) - *ps.get_unchecked(i + 1 - period)) / pf;
1507                    sm = s;
1508                    *mid.get_unchecked_mut(i) = sm;
1509                    *up.get_unchecked_mut(i) = mult.mul_add(atr, sm);
1510                    *low_out.get_unchecked_mut(i) = (-mult).mul_add(atr, sm);
1511                    i += 1;
1512                }
1513            }
1514        } else {
1515            keltner_row_scalar(
1516                high, low, close, source, period, mult, ma, first, up, mid, low_out,
1517            );
1518        }
1519    };
1520    if parallel {
1521        #[cfg(not(target_arch = "wasm32"))]
1522        {
1523            upper
1524                .par_chunks_mut(cols)
1525                .zip(middle.par_chunks_mut(cols))
1526                .zip(lower.par_chunks_mut(cols))
1527                .enumerate()
1528                .for_each(|(row, ((u, m), l))| do_row(row, u, m, l));
1529        }
1530
1531        #[cfg(target_arch = "wasm32")]
1532        {
1533            for ((row, u), (m, l)) in upper
1534                .chunks_mut(cols)
1535                .enumerate()
1536                .zip(middle.chunks_mut(cols).zip(lower.chunks_mut(cols)))
1537            {
1538                do_row(row, u, m, l);
1539            }
1540        }
1541    } else {
1542        for ((row, u), (m, l)) in upper
1543            .chunks_mut(cols)
1544            .enumerate()
1545            .zip(middle.chunks_mut(cols).zip(lower.chunks_mut(cols)))
1546        {
1547            do_row(row, u, m, l);
1548        }
1549    }
1550
1551    let upper = unsafe {
1552        let ptr = upper_guard.as_mut_ptr() as *mut f64;
1553        let len = upper_guard.len();
1554        let cap = upper_guard.capacity();
1555        core::mem::forget(upper_guard);
1556        Vec::from_raw_parts(ptr, len, cap)
1557    };
1558
1559    let middle = unsafe {
1560        let ptr = middle_guard.as_mut_ptr() as *mut f64;
1561        let len = middle_guard.len();
1562        let cap = middle_guard.capacity();
1563        core::mem::forget(middle_guard);
1564        Vec::from_raw_parts(ptr, len, cap)
1565    };
1566
1567    let lower = unsafe {
1568        let ptr = lower_guard.as_mut_ptr() as *mut f64;
1569        let len = lower_guard.len();
1570        let cap = lower_guard.capacity();
1571        core::mem::forget(lower_guard);
1572        Vec::from_raw_parts(ptr, len, cap)
1573    };
1574
1575    Ok(KeltnerBatchOutput {
1576        upper_band: upper,
1577        middle_band: middle,
1578        lower_band: lower,
1579        combos,
1580        rows,
1581        cols,
1582    })
1583}
1584
1585#[derive(Debug, Clone)]
1586pub struct KeltnerStream {
1587    period: usize,
1588    rcp_period: f64,
1589    multiplier: f64,
1590
1591    ma_impl: MaImpl,
1592
1593    atr: f64,
1594    atr_sum: f64,
1595    rma_alpha: f64,
1596
1597    count: usize,
1598    prev_close: f64,
1599}
1600
1601#[derive(Debug, Clone)]
1602enum MaImpl {
1603    Ema {
1604        alpha: f64,
1605        value: f64,
1606        seed_sum: f64,
1607    },
1608
1609    Sma {
1610        buffer: Vec<f64>,
1611        sum: f64,
1612        idx: usize,
1613        filled: bool,
1614    },
1615}
1616
1617impl KeltnerStream {
1618    pub fn try_new(params: KeltnerParams) -> Result<Self, KeltnerError> {
1619        let period = params.period.unwrap_or(20);
1620        let multiplier = params.multiplier.unwrap_or(2.0);
1621        let ma_type = params
1622            .ma_type
1623            .unwrap_or_else(|| "ema".to_string())
1624            .to_lowercase();
1625
1626        if period == 0 {
1627            return Err(KeltnerError::InvalidPeriod {
1628                period,
1629                data_len: 0,
1630            });
1631        }
1632
1633        let pf = period as f64;
1634        let rcp = 1.0 / pf;
1635
1636        let ma_impl = if ma_type == "sma" {
1637            MaImpl::Sma {
1638                buffer: vec![0.0; period],
1639                sum: 0.0,
1640                idx: 0,
1641                filled: false,
1642            }
1643        } else {
1644            MaImpl::Ema {
1645                alpha: 2.0 / (pf + 1.0),
1646                value: 0.0,
1647                seed_sum: 0.0,
1648            }
1649        };
1650
1651        Ok(Self {
1652            period,
1653            rcp_period: rcp,
1654            multiplier,
1655            ma_impl,
1656            atr: 0.0,
1657            atr_sum: 0.0,
1658            rma_alpha: rcp,
1659            count: 0,
1660            prev_close: f64::NAN,
1661        })
1662    }
1663
1664    #[inline(always)]
1665    pub fn update(
1666        &mut self,
1667        high: f64,
1668        low: f64,
1669        close: f64,
1670        source: f64,
1671    ) -> Option<(f64, f64, f64)> {
1672        let tr = if self.count == 0 {
1673            high - low
1674        } else {
1675            let hl = high - low;
1676            let hc = (high - self.prev_close).abs();
1677            let lc = (low - self.prev_close).abs();
1678            hl.max(hc).max(lc)
1679        };
1680
1681        self.prev_close = close;
1682        self.count += 1;
1683
1684        if self.count < self.period {
1685            self.atr_sum += tr;
1686
1687            match &mut self.ma_impl {
1688                MaImpl::Ema { seed_sum, .. } => {
1689                    *seed_sum += source;
1690                }
1691                MaImpl::Sma {
1692                    buffer, sum, idx, ..
1693                } => {
1694                    *sum += source;
1695                    buffer[*idx] = source;
1696                    *idx = (*idx + 1) % self.period;
1697                }
1698            }
1699            return None;
1700        }
1701
1702        if self.count == self.period {
1703            self.atr = (self.atr_sum + tr) * self.rcp_period;
1704
1705            let mid = match &mut self.ma_impl {
1706                MaImpl::Ema {
1707                    value, seed_sum, ..
1708                } => {
1709                    *seed_sum += source;
1710                    *value = *seed_sum * self.rcp_period;
1711                    *value
1712                }
1713                MaImpl::Sma {
1714                    buffer,
1715                    sum,
1716                    idx,
1717                    filled,
1718                } => {
1719                    *sum += source;
1720                    buffer[*idx] = source;
1721                    *idx = (*idx + 1) % self.period;
1722                    *filled = true;
1723                    *sum * self.rcp_period
1724                }
1725            };
1726
1727            let up = self.multiplier.mul_add(self.atr, mid);
1728            let lo = (-self.multiplier).mul_add(self.atr, mid);
1729            return Some((up, mid, lo));
1730        }
1731
1732        self.atr = (tr - self.atr).mul_add(self.rma_alpha, self.atr);
1733
1734        let mid = match &mut self.ma_impl {
1735            MaImpl::Ema { alpha, value, .. } => {
1736                *value = (source - *value).mul_add(*alpha, *value);
1737                *value
1738            }
1739            MaImpl::Sma {
1740                buffer, sum, idx, ..
1741            } => {
1742                let old = buffer[*idx];
1743                buffer[*idx] = source;
1744                *sum += source - old;
1745                *idx = (*idx + 1) % self.period;
1746                *sum * self.rcp_period
1747            }
1748        };
1749
1750        let up = self.multiplier.mul_add(self.atr, mid);
1751        let lo = (-self.multiplier).mul_add(self.atr, mid);
1752        Some((up, mid, lo))
1753    }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758    use super::*;
1759    use crate::skip_if_unsupported;
1760    use crate::utilities::data_loader::read_candles_from_csv;
1761    use crate::utilities::enums::Kernel;
1762    #[cfg(feature = "proptest")]
1763    use proptest::prelude::*;
1764
1765    fn check_keltner_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1766        skip_if_unsupported!(kernel, test_name);
1767        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1768        let candles = read_candles_from_csv(file_path)?;
1769
1770        let params = KeltnerParams {
1771            period: Some(20),
1772            multiplier: Some(2.0),
1773            ma_type: Some("ema".to_string()),
1774        };
1775        let input = KeltnerInput::from_candles(&candles, "close", params);
1776        let result = keltner_with_kernel(&input, kernel)?;
1777
1778        assert_eq!(result.upper_band.len(), candles.close.len());
1779        assert_eq!(result.middle_band.len(), candles.close.len());
1780        assert_eq!(result.lower_band.len(), candles.close.len());
1781
1782        let last_five_index = candles.close.len().saturating_sub(5);
1783        let expected_upper = [
1784            61619.504155205745,
1785            61503.56119134791,
1786            61387.47897150178,
1787            61286.61078267451,
1788            61206.25688331261,
1789        ];
1790        let expected_middle = [
1791            59758.339871629956,
1792            59703.35512195091,
1793            59640.083205574636,
1794            59593.884805043715,
1795            59504.46720456336,
1796        ];
1797        let expected_lower = [
1798            57897.17558805417,
1799            57903.14905255391,
1800            57892.68743964749,
1801            57901.158827412924,
1802            57802.67752581411,
1803        ];
1804        let last_five_upper = &result.upper_band[last_five_index..];
1805        let last_five_middle = &result.middle_band[last_five_index..];
1806        let last_five_lower = &result.lower_band[last_five_index..];
1807        for i in 0..5 {
1808            let diff_u = (last_five_upper[i] - expected_upper[i]).abs();
1809            let diff_m = (last_five_middle[i] - expected_middle[i]).abs();
1810            let diff_l = (last_five_lower[i] - expected_lower[i]).abs();
1811            assert!(
1812                diff_u < 1e-1,
1813                "Upper band mismatch at index {}: expected {}, got {}",
1814                i,
1815                expected_upper[i],
1816                last_five_upper[i]
1817            );
1818            assert!(
1819                diff_m < 1e-1,
1820                "Middle band mismatch at index {}: expected {}, got {}",
1821                i,
1822                expected_middle[i],
1823                last_five_middle[i]
1824            );
1825            assert!(
1826                diff_l < 1e-1,
1827                "Lower band mismatch at index {}: expected {}, got {}",
1828                i,
1829                expected_lower[i],
1830                last_five_lower[i]
1831            );
1832        }
1833        Ok(())
1834    }
1835
1836    fn check_keltner_default_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1837        skip_if_unsupported!(kernel, test_name);
1838        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1839        let candles = read_candles_from_csv(file_path)?;
1840        let default_params = KeltnerParams::default();
1841        let input = KeltnerInput::from_candles(&candles, "close", default_params);
1842        let result = keltner_with_kernel(&input, kernel)?;
1843        assert_eq!(result.upper_band.len(), candles.close.len());
1844        assert_eq!(result.middle_band.len(), candles.close.len());
1845        assert_eq!(result.lower_band.len(), candles.close.len());
1846        Ok(())
1847    }
1848
1849    fn check_keltner_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850        skip_if_unsupported!(kernel, test_name);
1851        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1852        let candles = read_candles_from_csv(file_path)?;
1853        let params = KeltnerParams {
1854            period: Some(0),
1855            multiplier: Some(2.0),
1856            ma_type: Some("ema".to_string()),
1857        };
1858        let input = KeltnerInput::from_candles(&candles, "close", params);
1859        let result = keltner_with_kernel(&input, kernel);
1860        assert!(result.is_err());
1861        if let Err(e) = result {
1862            assert!(e.to_string().contains("invalid period"));
1863        }
1864        Ok(())
1865    }
1866
1867    fn check_keltner_large_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1868        skip_if_unsupported!(kernel, test_name);
1869        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1870        let candles = read_candles_from_csv(file_path)?;
1871        let params = KeltnerParams {
1872            period: Some(999999),
1873            multiplier: Some(2.0),
1874            ma_type: Some("ema".to_string()),
1875        };
1876        let input = KeltnerInput::from_candles(&candles, "close", params);
1877        let result = keltner_with_kernel(&input, kernel);
1878        assert!(result.is_err());
1879        if let Err(e) = result {
1880            assert!(e.to_string().contains("invalid period"));
1881        }
1882        Ok(())
1883    }
1884
1885    fn check_keltner_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1886        skip_if_unsupported!(kernel, test_name);
1887        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1888        let candles = read_candles_from_csv(file_path)?;
1889        let params = KeltnerParams::default();
1890        let input = KeltnerInput::from_candles(&candles, "close", params);
1891        let result = keltner_with_kernel(&input, kernel)?;
1892        assert_eq!(result.middle_band.len(), candles.close.len());
1893        if result.middle_band.len() > 240 {
1894            for (i, &val) in result.middle_band[240..].iter().enumerate() {
1895                assert!(
1896                    !val.is_nan(),
1897                    "[{}] Found unexpected NaN at out-index {}",
1898                    test_name,
1899                    240 + i
1900                );
1901            }
1902        }
1903        Ok(())
1904    }
1905
1906    fn check_keltner_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1907        skip_if_unsupported!(kernel, test_name);
1908
1909        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1910        let candles = read_candles_from_csv(file_path)?;
1911        let period = 20;
1912        let multiplier = 2.0;
1913
1914        let params = KeltnerParams {
1915            period: Some(period),
1916            multiplier: Some(multiplier),
1917            ma_type: Some("ema".to_string()),
1918        };
1919        let input = KeltnerInput::from_candles(&candles, "close", params.clone());
1920        let batch_output = keltner_with_kernel(&input, kernel)?;
1921
1922        let mut stream = KeltnerStream::try_new(params)?;
1923        let mut upper_stream = Vec::with_capacity(candles.close.len());
1924        let mut middle_stream = Vec::with_capacity(candles.close.len());
1925        let mut lower_stream = Vec::with_capacity(candles.close.len());
1926
1927        for i in 0..candles.close.len() {
1928            let hi = candles.high[i];
1929            let lo = candles.low[i];
1930            let cl = candles.close[i];
1931            let src = candles.close[i];
1932            match stream.update(hi, lo, cl, src) {
1933                Some((up, mid, low)) => {
1934                    upper_stream.push(up);
1935                    middle_stream.push(mid);
1936                    lower_stream.push(low);
1937                }
1938                None => {
1939                    upper_stream.push(f64::NAN);
1940                    middle_stream.push(f64::NAN);
1941                    lower_stream.push(f64::NAN);
1942                }
1943            }
1944        }
1945        assert_eq!(batch_output.upper_band.len(), upper_stream.len());
1946        assert_eq!(batch_output.middle_band.len(), middle_stream.len());
1947        assert_eq!(batch_output.lower_band.len(), lower_stream.len());
1948        for (i, (&b, &s)) in batch_output
1949            .middle_band
1950            .iter()
1951            .zip(middle_stream.iter())
1952            .enumerate()
1953        {
1954            if b.is_nan() && s.is_nan() {
1955                continue;
1956            }
1957            let diff = (b - s).abs();
1958            assert!(
1959                diff < 1e-8,
1960                "[{}] Keltner streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1961                test_name,
1962                i,
1963                b,
1964                s,
1965                diff
1966            );
1967        }
1968        Ok(())
1969    }
1970
1971    #[cfg(debug_assertions)]
1972    fn check_keltner_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1973        skip_if_unsupported!(kernel, test_name);
1974
1975        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1976        let candles = read_candles_from_csv(file_path)?;
1977
1978        let test_params = vec![
1979            KeltnerParams::default(),
1980            KeltnerParams {
1981                period: Some(2),
1982                multiplier: Some(1.0),
1983                ma_type: Some("ema".to_string()),
1984            },
1985            KeltnerParams {
1986                period: Some(5),
1987                multiplier: Some(0.5),
1988                ma_type: Some("ema".to_string()),
1989            },
1990            KeltnerParams {
1991                period: Some(10),
1992                multiplier: Some(1.5),
1993                ma_type: Some("sma".to_string()),
1994            },
1995            KeltnerParams {
1996                period: Some(20),
1997                multiplier: Some(3.0),
1998                ma_type: Some("ema".to_string()),
1999            },
2000            KeltnerParams {
2001                period: Some(50),
2002                multiplier: Some(2.5),
2003                ma_type: Some("sma".to_string()),
2004            },
2005            KeltnerParams {
2006                period: Some(100),
2007                multiplier: Some(1.0),
2008                ma_type: Some("ema".to_string()),
2009            },
2010            KeltnerParams {
2011                period: Some(14),
2012                multiplier: Some(2.0),
2013                ma_type: Some("sma".to_string()),
2014            },
2015            KeltnerParams {
2016                period: Some(7),
2017                multiplier: Some(1.0),
2018                ma_type: Some("ema".to_string()),
2019            },
2020            KeltnerParams {
2021                period: Some(21),
2022                multiplier: Some(1.5),
2023                ma_type: Some("ema".to_string()),
2024            },
2025            KeltnerParams {
2026                period: Some(30),
2027                multiplier: Some(2.0),
2028                ma_type: Some("sma".to_string()),
2029            },
2030            KeltnerParams {
2031                period: Some(3),
2032                multiplier: Some(0.75),
2033                ma_type: Some("ema".to_string()),
2034            },
2035        ];
2036
2037        for (param_idx, params) in test_params.iter().enumerate() {
2038            let input = KeltnerInput::from_candles(&candles, "close", params.clone());
2039            let output = keltner_with_kernel(&input, kernel)?;
2040
2041            for (band_name, band_values) in [
2042                ("upper", &output.upper_band),
2043                ("middle", &output.middle_band),
2044                ("lower", &output.lower_band),
2045            ] {
2046                for (i, &val) in band_values.iter().enumerate() {
2047                    if val.is_nan() {
2048                        continue;
2049                    }
2050
2051                    let bits = val.to_bits();
2052
2053                    if bits == 0x11111111_11111111 {
2054                        panic!(
2055							"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2056							 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2057							test_name, val, bits, i, band_name,
2058							params.period.unwrap_or(20),
2059							params.multiplier.unwrap_or(2.0),
2060							params.ma_type.as_deref().unwrap_or("ema"),
2061							param_idx
2062						);
2063                    }
2064
2065                    if bits == 0x22222222_22222222 {
2066                        panic!(
2067							"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2068							 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2069							test_name, val, bits, i, band_name,
2070							params.period.unwrap_or(20),
2071							params.multiplier.unwrap_or(2.0),
2072							params.ma_type.as_deref().unwrap_or("ema"),
2073							param_idx
2074						);
2075                    }
2076
2077                    if bits == 0x33333333_33333333 {
2078                        panic!(
2079							"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2080							 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2081							test_name, val, bits, i, band_name,
2082							params.period.unwrap_or(20),
2083							params.multiplier.unwrap_or(2.0),
2084							params.ma_type.as_deref().unwrap_or("ema"),
2085							param_idx
2086						);
2087                    }
2088                }
2089            }
2090        }
2091
2092        Ok(())
2093    }
2094
2095    #[cfg(not(debug_assertions))]
2096    fn check_keltner_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2097        Ok(())
2098    }
2099
2100    fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2101        skip_if_unsupported!(kernel, test_name);
2102
2103        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2104        let c = read_candles_from_csv(file)?;
2105        let output = KeltnerBatchBuilder::new()
2106            .kernel(kernel)
2107            .apply_candles(&c, "close")?;
2108
2109        let def = KeltnerParams::default();
2110        let (upper, middle, lower) = output.values_for(&def).expect("default row missing");
2111
2112        assert_eq!(upper.len(), c.close.len());
2113        assert_eq!(middle.len(), c.close.len());
2114        assert_eq!(lower.len(), c.close.len());
2115
2116        Ok(())
2117    }
2118
2119    #[cfg(debug_assertions)]
2120    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2121        skip_if_unsupported!(kernel, test);
2122
2123        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2124        let c = read_candles_from_csv(file)?;
2125
2126        let test_configs = vec![
2127            (2, 10, 2, 0.5, 2.5, 0.5),
2128            (5, 25, 5, 1.0, 3.0, 1.0),
2129            (30, 60, 15, 2.0, 2.0, 0.0),
2130            (2, 5, 1, 1.5, 2.5, 0.25),
2131            (10, 30, 10, 0.75, 2.25, 0.75),
2132            (14, 21, 7, 1.0, 2.0, 0.5),
2133            (20, 20, 0, 0.5, 3.0, 0.5),
2134        ];
2135
2136        for (cfg_idx, &(p_start, p_end, p_step, m_start, m_end, m_step)) in
2137            test_configs.iter().enumerate()
2138        {
2139            let output = KeltnerBatchBuilder::new()
2140                .kernel(kernel)
2141                .period_range(p_start, p_end, p_step)
2142                .multiplier_range(m_start, m_end, m_step)
2143                .apply_candles(&c, "close")?;
2144
2145            for (band_name, band_values) in [
2146                ("upper", &output.upper_band),
2147                ("middle", &output.middle_band),
2148                ("lower", &output.lower_band),
2149            ] {
2150                for (idx, &val) in band_values.iter().enumerate() {
2151                    if val.is_nan() {
2152                        continue;
2153                    }
2154
2155                    let bits = val.to_bits();
2156                    let row = idx / output.cols;
2157                    let col = idx % output.cols;
2158                    let combo = &output.combos[row];
2159
2160                    if bits == 0x11111111_11111111 {
2161                        panic!(
2162							"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2163							 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2164							test, cfg_idx, val, bits, row, col, idx, band_name,
2165							combo.period.unwrap_or(20),
2166							combo.multiplier.unwrap_or(2.0)
2167						);
2168                    }
2169
2170                    if bits == 0x22222222_22222222 {
2171                        panic!(
2172							"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2173							 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2174							test, cfg_idx, val, bits, row, col, idx, band_name,
2175							combo.period.unwrap_or(20),
2176							combo.multiplier.unwrap_or(2.0)
2177						);
2178                    }
2179
2180                    if bits == 0x33333333_33333333 {
2181                        panic!(
2182                            "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2183							 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2184                            test,
2185                            cfg_idx,
2186                            val,
2187                            bits,
2188                            row,
2189                            col,
2190                            idx,
2191                            band_name,
2192                            combo.period.unwrap_or(20),
2193                            combo.multiplier.unwrap_or(2.0)
2194                        );
2195                    }
2196                }
2197            }
2198        }
2199
2200        Ok(())
2201    }
2202
2203    #[cfg(not(debug_assertions))]
2204    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2205        Ok(())
2206    }
2207
2208    #[cfg(feature = "proptest")]
2209    #[allow(clippy::float_cmp)]
2210    fn check_keltner_property(
2211        test_name: &str,
2212        kernel: Kernel,
2213    ) -> Result<(), Box<dyn std::error::Error>> {
2214        use proptest::prelude::*;
2215        skip_if_unsupported!(kernel, test_name);
2216
2217        let strat = (
2218            2usize..=50,
2219            50usize..500,
2220            0.5f64..3.0f64,
2221            0usize..6,
2222            any::<u64>(),
2223        )
2224            .prop_map(|(period, len, multiplier, scenario, seed)| {
2225                let mut high = Vec::with_capacity(len);
2226                let mut low = Vec::with_capacity(len);
2227                let mut close = Vec::with_capacity(len);
2228
2229                let mut rng_state = seed;
2230                let mut next_random = || -> f64 {
2231                    rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2232                    (rng_state as f64) / (u64::MAX as f64)
2233                };
2234
2235                match scenario {
2236                    0 => {
2237                        let mut prev_close = 100.0;
2238                        for _ in 0..len {
2239                            let volatility = 0.01 + next_random() * 0.04;
2240                            let change = -volatility + next_random() * (2.0 * volatility);
2241                            let new_close = (prev_close * (1.0 + change)).max(0.1f64);
2242                            let high_val = new_close * (1.0 + next_random() * volatility);
2243                            let low_val = new_close * (1.0 - next_random() * volatility);
2244
2245                            high.push(high_val);
2246                            low.push(low_val.min(high_val));
2247                            close.push(new_close);
2248                            prev_close = new_close;
2249                        }
2250                    }
2251                    1 => {
2252                        let start = 100.0;
2253                        for i in 0..len {
2254                            let base = start * (1.0 + 0.01 * i as f64);
2255                            let spread = base * 0.02;
2256                            high.push(base + spread);
2257                            low.push(base - spread);
2258                            close.push(base);
2259                        }
2260                    }
2261                    2 => {
2262                        let start = 100.0;
2263                        for i in 0..len {
2264                            let base = start * (1.0 - 0.005 * i as f64).max(10.0);
2265                            let spread = base * 0.02;
2266                            high.push(base + spread);
2267                            low.push(base - spread);
2268                            close.push(base);
2269                        }
2270                    }
2271                    3 => {
2272                        let mut price = 100.0;
2273                        for i in 0..len {
2274                            let volatility = 0.1 * (1.0 + (i as f64 * 0.1).sin());
2275                            let change = if i % 2 == 0 {
2276                                volatility
2277                            } else {
2278                                -volatility * 0.8
2279                            };
2280                            price = (price * (1.0 + change)).max(1.0);
2281
2282                            let spread = price * volatility;
2283                            high.push(price + spread);
2284                            low.push(price - spread * 0.8);
2285                            close.push(price);
2286                        }
2287                    }
2288                    4 => {
2289                        let constant_price = 50.0;
2290                        high = vec![constant_price; len];
2291                        low = vec![constant_price; len];
2292                        close = vec![constant_price; len];
2293                    }
2294                    _ => {
2295                        let mut price = 1000.0;
2296                        let mut momentum = 0.0;
2297
2298                        for i in 0..len {
2299                            momentum = momentum * 0.9 + (if i % 20 < 10 { 0.001 } else { -0.001 });
2300                            let noise = ((i as f64 * 0.3).sin() * 0.005);
2301                            price = (price * (1.0 + momentum + noise)).max(100.0);
2302
2303                            let daily_range = price * 0.02;
2304                            let high_val = price + daily_range * 0.6;
2305                            let low_val = price - daily_range * 0.4;
2306
2307                            high.push(high_val);
2308                            low.push(low_val);
2309                            close.push(price);
2310                        }
2311                    }
2312                }
2313
2314                let ma_type = if next_random() > 0.5 { "ema" } else { "sma" };
2315                (high, low, close, period, multiplier, ma_type.to_string())
2316            });
2317
2318        proptest::test_runner::TestRunner::default()
2319			.run(&strat, |(high, low, close, period, multiplier, ma_type)| {
2320
2321				let source = close.clone();
2322
2323				let params = KeltnerParams {
2324					period: Some(period),
2325					multiplier: Some(multiplier),
2326					ma_type: Some(ma_type.clone()),
2327				};
2328				let input = KeltnerInput::from_slice(&high, &low, &close, &source, params.clone());
2329
2330				let result = keltner_with_kernel(&input, kernel).unwrap();
2331				let scalar_result = keltner_with_kernel(&input, Kernel::Scalar).unwrap();
2332
2333
2334				prop_assert_eq!(result.upper_band.len(), close.len());
2335				prop_assert_eq!(result.middle_band.len(), close.len());
2336				prop_assert_eq!(result.lower_band.len(), close.len());
2337
2338
2339				let warmup = period - 1;
2340				for i in 0..warmup.min(close.len()) {
2341					prop_assert!(
2342						result.upper_band[i].is_nan(),
2343						"Upper band[{}] should be NaN during warmup", i
2344					);
2345					prop_assert!(
2346						result.middle_band[i].is_nan(),
2347						"Middle band[{}] should be NaN during warmup", i
2348					);
2349					prop_assert!(
2350						result.lower_band[i].is_nan(),
2351						"Lower band[{}] should be NaN during warmup", i
2352					);
2353				}
2354
2355
2356				for i in warmup..close.len() {
2357					let upper = result.upper_band[i];
2358					let middle = result.middle_band[i];
2359					let lower = result.lower_band[i];
2360
2361					let scalar_upper = scalar_result.upper_band[i];
2362					let scalar_middle = scalar_result.middle_band[i];
2363					let scalar_lower = scalar_result.lower_band[i];
2364
2365
2366					if upper.is_nan() || middle.is_nan() || lower.is_nan() {
2367						continue;
2368					}
2369
2370
2371					prop_assert!(
2372						upper >= middle - 1e-10,
2373						"Upper band {} must be >= middle band {} at index {}", upper, middle, i
2374					);
2375					prop_assert!(
2376						middle >= lower - 1e-10,
2377						"Middle band {} must be >= lower band {} at index {}", middle, lower, i
2378					);
2379
2380
2381					let spread = upper - lower;
2382					prop_assert!(
2383						spread >= -1e-10,
2384						"Spread {} must be positive at index {}", spread, i
2385					);
2386
2387
2388					if i >= warmup + period {
2389
2390						let scalar_spread = scalar_upper - scalar_lower;
2391						if scalar_spread > 0.0 && spread > 0.0 {
2392							let ratio = spread / scalar_spread;
2393							prop_assert!(
2394								(ratio - 1.0).abs() < 0.01,
2395								"Spread ratio between kernels should be consistent at index {}: ratio={}", i, ratio
2396							);
2397						}
2398					}
2399
2400
2401					let window_start = i.saturating_sub(period - 1);
2402					let window_high = high[window_start..=i].iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2403					let window_low = low[window_start..=i].iter().fold(f64::INFINITY, |a, &b| a.min(b));
2404
2405
2406					prop_assert!(
2407						middle <= window_high * 1.05 + 1.0,
2408						"Middle band {} exceeds window high {} at index {}", middle, window_high, i
2409					);
2410					prop_assert!(
2411						middle >= window_low * 0.95 - 1.0,
2412						"Middle band {} below window low {} at index {}", middle, window_low, i
2413					);
2414
2415
2416					let tolerance = 1e-9;
2417					prop_assert!(
2418						(upper - scalar_upper).abs() <= tolerance,
2419						"Upper band kernel mismatch at {}: {} vs {} (diff: {})",
2420						i, upper, scalar_upper, (upper - scalar_upper).abs()
2421					);
2422					prop_assert!(
2423						(middle - scalar_middle).abs() <= tolerance,
2424						"Middle band kernel mismatch at {}: {} vs {} (diff: {})",
2425						i, middle, scalar_middle, (middle - scalar_middle).abs()
2426					);
2427					prop_assert!(
2428						(lower - scalar_lower).abs() <= tolerance,
2429						"Lower band kernel mismatch at {}: {} vs {} (diff: {})",
2430						i, lower, scalar_lower, (lower - scalar_lower).abs()
2431					);
2432
2433
2434					#[cfg(debug_assertions)]
2435					{
2436						let upper_bits = upper.to_bits();
2437						let middle_bits = middle.to_bits();
2438						let lower_bits = lower.to_bits();
2439
2440						prop_assert!(
2441							upper_bits != 0x11111111_11111111 &&
2442							upper_bits != 0x22222222_22222222 &&
2443							upper_bits != 0x33333333_33333333,
2444							"Found poison value in upper band at index {}", i
2445						);
2446						prop_assert!(
2447							middle_bits != 0x11111111_11111111 &&
2448							middle_bits != 0x22222222_22222222 &&
2449							middle_bits != 0x33333333_33333333,
2450							"Found poison value in middle band at index {}", i
2451						);
2452						prop_assert!(
2453							lower_bits != 0x11111111_11111111 &&
2454							lower_bits != 0x22222222_22222222 &&
2455							lower_bits != 0x33333333_33333333,
2456							"Found poison value in lower band at index {}", i
2457						);
2458					}
2459
2460
2461
2462					let all_same = high[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) &&
2463					               low[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) &&
2464					               close[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
2465
2466
2467					let no_spread = high[window_start..=i].iter()
2468						.zip(low[window_start..=i].iter())
2469						.zip(close[window_start..=i].iter())
2470						.all(|((h, l), c)| (h - l).abs() < 1e-10 && (h - c).abs() < 1e-10);
2471
2472					if all_same && no_spread && i >= warmup + period * 3 {
2473
2474
2475						let band_spread = upper - lower;
2476						prop_assert!(
2477							band_spread < 0.01 || band_spread < middle * 0.001,
2478							"Bands should converge for constant prices with no spread, but spread is {} at index {} (middle: {})",
2479							band_spread, i, middle
2480						);
2481					}
2482				}
2483
2484				Ok(())
2485			})
2486			.unwrap();
2487
2488        Ok(())
2489    }
2490
2491    macro_rules! generate_all_keltner_tests {
2492        ($($test_fn:ident),*) => {
2493            paste::paste! {
2494                $(
2495                    #[test]
2496                    fn [<$test_fn _scalar_f64>]() {
2497                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2498                    }
2499                )*
2500                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2501                $(
2502                    #[test]
2503                    fn [<$test_fn _avx2_f64>]() {
2504                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2505                    }
2506                    #[test]
2507                    fn [<$test_fn _avx512_f64>]() {
2508                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2509                    }
2510                )*
2511            }
2512        }
2513    }
2514
2515    generate_all_keltner_tests!(
2516        check_keltner_accuracy,
2517        check_keltner_default_params,
2518        check_keltner_zero_period,
2519        check_keltner_large_period,
2520        check_keltner_nan_handling,
2521        check_keltner_streaming,
2522        check_keltner_no_poison
2523    );
2524
2525    #[cfg(feature = "proptest")]
2526    generate_all_keltner_tests!(check_keltner_property);
2527
2528    macro_rules! gen_batch_tests {
2529        ($fn_name:ident) => {
2530            paste::paste! {
2531                #[test] fn [<$fn_name _scalar>]()      {
2532                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2533                }
2534                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2535                #[test] fn [<$fn_name _avx2>]()        {
2536                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2537                }
2538                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2539                #[test] fn [<$fn_name _avx512>]()      {
2540                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2541                }
2542                #[test] fn [<$fn_name _auto_detect>]() {
2543                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2544                }
2545            }
2546        };
2547    }
2548    gen_batch_tests!(check_batch_default_row);
2549    gen_batch_tests!(check_batch_no_poison);
2550
2551    #[test]
2552    fn test_keltner_into_matches_api() {
2553        use crate::utilities::data_loader::Candles;
2554
2555        let n = 256usize;
2556        let mut ts = Vec::with_capacity(n);
2557        let mut open = Vec::with_capacity(n);
2558        let mut high = Vec::with_capacity(n);
2559        let mut low = Vec::with_capacity(n);
2560        let mut close = Vec::with_capacity(n);
2561        let mut volume = Vec::with_capacity(n);
2562
2563        let mut price = 1000.0f64;
2564        for i in 0..n {
2565            let i_f = i as f64;
2566            let drift = (i_f * 0.001).sin() * 0.5;
2567            let noise = (i_f * 0.07).cos() * 0.8;
2568            price = (price + drift + noise).max(1.0);
2569            let spread = 2.0 + (i % 5) as f64;
2570            let h = price + spread;
2571            let l = price - spread * 0.8;
2572            let o = price - 0.25 * spread;
2573            let c = price + 0.25 * spread;
2574
2575            ts.push(i as i64);
2576            open.push(o);
2577            high.push(h);
2578            low.push(l);
2579            close.push(c);
2580            volume.push(1000.0 + i as f64);
2581        }
2582
2583        let candles = Candles::new(ts, open, high, low, close, volume);
2584        let input = KeltnerInput::from_candles(&candles, "close", KeltnerParams::default());
2585
2586        let base = keltner(&input).expect("keltner baseline failed");
2587
2588        let len = candles.close.len();
2589        let mut up = vec![0.0; len];
2590        let mut mid = vec![0.0; len];
2591        let mut lo = vec![0.0; len];
2592
2593        keltner_into(&input, &mut up, &mut mid, &mut lo).expect("keltner_into failed");
2594
2595        assert_eq!(base.upper_band.len(), up.len());
2596        assert_eq!(base.middle_band.len(), mid.len());
2597        assert_eq!(base.lower_band.len(), lo.len());
2598
2599        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2600            (a.is_nan() && b.is_nan()) || (a == b)
2601        }
2602
2603        for i in 0..len {
2604            assert!(
2605                eq_or_both_nan(base.upper_band[i], up[i]),
2606                "upper mismatch at {}: base={} into={}",
2607                i,
2608                base.upper_band[i],
2609                up[i]
2610            );
2611            assert!(
2612                eq_or_both_nan(base.middle_band[i], mid[i]),
2613                "middle mismatch at {}: base={} into={}",
2614                i,
2615                base.middle_band[i],
2616                mid[i]
2617            );
2618            assert!(
2619                eq_or_both_nan(base.lower_band[i], lo[i]),
2620                "lower mismatch at {}: base={} into={}",
2621                i,
2622                base.lower_band[i],
2623                lo[i]
2624            );
2625        }
2626    }
2627}
2628
2629#[cfg(all(feature = "python", feature = "cuda"))]
2630pub struct KeltnerDeviceArrayF32 {
2631    pub inner: DeviceArrayF32,
2632    pub context: Arc<Context>,
2633    pub device_id: u32,
2634}
2635
2636#[cfg(all(feature = "python", feature = "cuda"))]
2637#[pyclass(module = "ta_indicators.cuda", unsendable)]
2638pub struct KeltnerDeviceArrayF32Py {
2639    pub(crate) inner: KeltnerDeviceArrayF32,
2640}
2641
2642#[cfg(all(feature = "python", feature = "cuda"))]
2643#[pymethods]
2644impl KeltnerDeviceArrayF32Py {
2645    #[getter]
2646    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2647        let inner = &self.inner.inner;
2648        let d = PyDict::new(py);
2649        d.set_item("shape", (inner.rows, inner.cols))?;
2650        d.set_item("typestr", "<f4")?;
2651        d.set_item(
2652            "strides",
2653            (
2654                inner.cols * std::mem::size_of::<f32>(),
2655                std::mem::size_of::<f32>(),
2656            ),
2657        )?;
2658        d.set_item("data", (inner.device_ptr() as usize, false))?;
2659
2660        d.set_item("version", 3)?;
2661        Ok(d)
2662    }
2663
2664    fn __dlpack_device__(&self) -> (i32, i32) {
2665        (2, self.inner.device_id as i32)
2666    }
2667
2668    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2669    fn __dlpack__<'py>(
2670        &mut self,
2671        py: Python<'py>,
2672        stream: Option<pyo3::PyObject>,
2673        max_version: Option<pyo3::PyObject>,
2674        dl_device: Option<pyo3::PyObject>,
2675        copy: Option<pyo3::PyObject>,
2676    ) -> PyResult<PyObject> {
2677        use cust::memory::DeviceBuffer;
2678
2679        let (kdl, alloc_dev) = self.__dlpack_device__();
2680        if let Some(dev_obj) = dl_device.as_ref() {
2681            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2682                if dev_ty != kdl || dev_id != alloc_dev {
2683                    let wants_copy = copy
2684                        .as_ref()
2685                        .and_then(|c| c.extract::<bool>(py).ok())
2686                        .unwrap_or(false);
2687                    if wants_copy {
2688                        return Err(PyValueError::new_err(
2689                            "device copy not implemented for __dlpack__",
2690                        ));
2691                    } else {
2692                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2693                    }
2694                }
2695            }
2696        }
2697
2698        let _ = stream;
2699
2700        if let Some(copy_obj) = copy.as_ref() {
2701            let do_copy: bool = copy_obj.extract(py)?;
2702            if do_copy {
2703                return Err(PyValueError::new_err(
2704                    "__dlpack__(copy=True) not supported for keltner CUDA buffers",
2705                ));
2706            }
2707        }
2708
2709        let dummy =
2710            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2711        let context = self.inner.context.clone();
2712        let device_id = self.inner.device_id;
2713        let inner = std::mem::replace(
2714            &mut self.inner,
2715            KeltnerDeviceArrayF32 {
2716                inner: DeviceArrayF32 {
2717                    buf: dummy,
2718                    rows: 0,
2719                    cols: 0,
2720                },
2721                context,
2722                device_id,
2723            },
2724        );
2725
2726        let rows = inner.inner.rows;
2727        let cols = inner.inner.cols;
2728
2729        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2730
2731        export_f32_cuda_dlpack_2d(
2732            py,
2733            inner.inner.buf,
2734            rows,
2735            cols,
2736            alloc_dev,
2737            max_version_bound,
2738        )
2739    }
2740}
2741
2742#[cfg(feature = "python")]
2743#[pyfunction(name = "keltner")]
2744#[pyo3(signature = (high, low, close, source, period, multiplier, ma_type="ema", kernel=None))]
2745pub fn keltner_py<'py>(
2746    py: Python<'py>,
2747    high: numpy::PyReadonlyArray1<'py, f64>,
2748    low: numpy::PyReadonlyArray1<'py, f64>,
2749    close: numpy::PyReadonlyArray1<'py, f64>,
2750    source: numpy::PyReadonlyArray1<'py, f64>,
2751    period: usize,
2752    multiplier: f64,
2753    ma_type: &str,
2754    kernel: Option<&str>,
2755) -> PyResult<(
2756    Bound<'py, numpy::PyArray1<f64>>,
2757    Bound<'py, numpy::PyArray1<f64>>,
2758    Bound<'py, numpy::PyArray1<f64>>,
2759)> {
2760    use numpy::{PyArray1, PyArrayMethods};
2761
2762    let h = high.as_slice()?;
2763    let l = low.as_slice()?;
2764    let c = close.as_slice()?;
2765    let s = source.as_slice()?;
2766    let len = c.len();
2767
2768    let mut up_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2769    let mut mid_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2770    let mut low_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2771
2772    let up = unsafe { up_arr.as_slice_mut()? };
2773    let mid = unsafe { mid_arr.as_slice_mut()? };
2774    let lowo = unsafe { low_arr.as_slice_mut()? };
2775
2776    let params = KeltnerParams {
2777        period: Some(period),
2778        multiplier: Some(multiplier),
2779        ma_type: Some(ma_type.to_string()),
2780    };
2781    let input = KeltnerInput::from_slice(h, l, c, s, params);
2782    let kern = validate_kernel(kernel, false)?;
2783
2784    py.allow_threads(|| keltner_into_slice(up, mid, lowo, &input, kern))
2785        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2786
2787    Ok((up_arr, mid_arr, low_arr))
2788}
2789
2790#[cfg(feature = "python")]
2791#[pyclass(name = "KeltnerStream")]
2792pub struct KeltnerStreamPy {
2793    stream: KeltnerStream,
2794}
2795
2796#[cfg(feature = "python")]
2797#[pymethods]
2798impl KeltnerStreamPy {
2799    #[new]
2800    fn new(period: usize, multiplier: f64, ma_type: &str) -> PyResult<Self> {
2801        let params = KeltnerParams {
2802            period: Some(period),
2803            multiplier: Some(multiplier),
2804            ma_type: Some(ma_type.to_string()),
2805        };
2806        let stream =
2807            KeltnerStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2808        Ok(KeltnerStreamPy { stream })
2809    }
2810
2811    fn update(&mut self, high: f64, low: f64, close: f64, source: f64) -> Option<(f64, f64, f64)> {
2812        self.stream.update(high, low, close, source)
2813    }
2814}
2815
2816#[cfg(feature = "python")]
2817#[pyfunction(name = "keltner_batch")]
2818#[pyo3(signature = (high, low, close, source, period_range, multiplier_range, kernel=None))]
2819pub fn keltner_batch_py<'py>(
2820    py: Python<'py>,
2821    high: numpy::PyReadonlyArray1<'py, f64>,
2822    low: numpy::PyReadonlyArray1<'py, f64>,
2823    close: numpy::PyReadonlyArray1<'py, f64>,
2824    source: numpy::PyReadonlyArray1<'py, f64>,
2825    period_range: (usize, usize, usize),
2826    multiplier_range: (f64, f64, f64),
2827    kernel: Option<&str>,
2828) -> PyResult<Bound<'py, PyDict>> {
2829    use numpy::{PyArray1, PyArrayMethods};
2830    let h = high.as_slice()?;
2831    let l = low.as_slice()?;
2832    let c = close.as_slice()?;
2833    let s = source.as_slice()?;
2834
2835    let sweep = KeltnerBatchRange {
2836        period: period_range,
2837        multiplier: multiplier_range,
2838    };
2839    let kern = validate_kernel(kernel, true)?;
2840
2841    let out = py
2842        .allow_threads(|| {
2843            keltner_batch_par_slice(
2844                h,
2845                l,
2846                c,
2847                s,
2848                &sweep,
2849                match kern {
2850                    Kernel::Auto => detect_best_batch_kernel(),
2851                    k => k,
2852                },
2853            )
2854        })
2855        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2856
2857    let rows = out.rows;
2858    let cols = out.cols;
2859
2860    let total = rows
2861        .checked_mul(cols)
2862        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2863    let up_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2864    let mid_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2865    let low_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2866
2867    unsafe { up_arr.as_slice_mut()? }.copy_from_slice(&out.upper_band);
2868    unsafe { mid_arr.as_slice_mut()? }.copy_from_slice(&out.middle_band);
2869    unsafe { low_arr.as_slice_mut()? }.copy_from_slice(&out.lower_band);
2870
2871    let dict = PyDict::new(py);
2872    dict.set_item("upper", up_arr.reshape((rows, cols))?)?;
2873    dict.set_item("middle", mid_arr.reshape((rows, cols))?)?;
2874    dict.set_item("lower", low_arr.reshape((rows, cols))?)?;
2875    use numpy::IntoPyArray;
2876    dict.set_item(
2877        "periods",
2878        out.combos
2879            .iter()
2880            .map(|p| p.period.unwrap() as u64)
2881            .collect::<Vec<_>>()
2882            .into_pyarray(py),
2883    )?;
2884    dict.set_item(
2885        "multipliers",
2886        out.combos
2887            .iter()
2888            .map(|p| p.multiplier.unwrap())
2889            .collect::<Vec<_>>()
2890            .into_pyarray(py),
2891    )?;
2892    Ok(dict)
2893}
2894
2895#[cfg(all(feature = "python", feature = "cuda"))]
2896#[pyfunction(name = "keltner_cuda_batch_dev")]
2897#[pyo3(signature = (high_f32, low_f32, close_f32, source_f32, period_range, multiplier_range, ma_type="ema", device_id=0))]
2898pub fn keltner_cuda_batch_dev_py<'py>(
2899    py: Python<'py>,
2900    high_f32: numpy::PyReadonlyArray1<'py, f32>,
2901    low_f32: numpy::PyReadonlyArray1<'py, f32>,
2902    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2903    source_f32: numpy::PyReadonlyArray1<'py, f32>,
2904    period_range: (usize, usize, usize),
2905    multiplier_range: (f64, f64, f64),
2906    ma_type: &str,
2907    device_id: usize,
2908) -> PyResult<Bound<'py, PyDict>> {
2909    use crate::cuda::cuda_available;
2910    if !cuda_available() {
2911        return Err(PyValueError::new_err("CUDA not available"));
2912    }
2913    let h = high_f32.as_slice()?;
2914    let l = low_f32.as_slice()?;
2915    let c = close_f32.as_slice()?;
2916    let s = source_f32.as_slice()?;
2917    if !(h.len() == l.len() && l.len() == c.len() && c.len() == s.len()) {
2918        return Err(PyValueError::new_err("input length mismatch"));
2919    }
2920    let sweep = KeltnerBatchRange {
2921        period: period_range,
2922        multiplier: multiplier_range,
2923    };
2924    let (up, mid, low, rows, cols, ctx, dev_id) = py.allow_threads(|| {
2925        let cuda = CudaKeltner::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2926        let ctx = cuda.context_arc();
2927        let dev_id = cuda.device_id();
2928        let res = cuda
2929            .keltner_batch_dev(h, l, c, s, &sweep, ma_type)
2930            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2931        let rows = res.outputs.upper.rows;
2932        let cols = res.outputs.upper.cols;
2933        Ok::<_, PyErr>((
2934            res.outputs.upper,
2935            res.outputs.middle,
2936            res.outputs.lower,
2937            rows,
2938            cols,
2939            ctx,
2940            dev_id,
2941        ))
2942    })?;
2943    let dict = PyDict::new(py);
2944    dict.set_item(
2945        "upper",
2946        Py::new(
2947            py,
2948            KeltnerDeviceArrayF32Py {
2949                inner: KeltnerDeviceArrayF32 {
2950                    inner: up,
2951                    context: ctx.clone(),
2952                    device_id: dev_id,
2953                },
2954            },
2955        )?,
2956    )?;
2957    dict.set_item(
2958        "middle",
2959        Py::new(
2960            py,
2961            KeltnerDeviceArrayF32Py {
2962                inner: KeltnerDeviceArrayF32 {
2963                    inner: mid,
2964                    context: ctx.clone(),
2965                    device_id: dev_id,
2966                },
2967            },
2968        )?,
2969    )?;
2970    dict.set_item(
2971        "lower",
2972        Py::new(
2973            py,
2974            KeltnerDeviceArrayF32Py {
2975                inner: KeltnerDeviceArrayF32 {
2976                    inner: low,
2977                    context: ctx,
2978                    device_id: dev_id,
2979                },
2980            },
2981        )?,
2982    )?;
2983    dict.set_item("rows", rows)?;
2984    dict.set_item("cols", cols)?;
2985    Ok(dict)
2986}
2987
2988#[cfg(all(feature = "python", feature = "cuda"))]
2989#[pyfunction(name = "keltner_cuda_many_series_one_param_dev")]
2990#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, source_tm_f32, cols, rows, period, multiplier, ma_type="ema", device_id=0))]
2991pub fn keltner_cuda_many_series_one_param_dev_py(
2992    py: Python<'_>,
2993    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2994    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2995    close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2996    source_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2997    cols: usize,
2998    rows: usize,
2999    period: usize,
3000    multiplier: f32,
3001    ma_type: &str,
3002    device_id: usize,
3003) -> PyResult<(
3004    KeltnerDeviceArrayF32Py,
3005    KeltnerDeviceArrayF32Py,
3006    KeltnerDeviceArrayF32Py,
3007)> {
3008    use crate::cuda::cuda_available;
3009    if !cuda_available() {
3010        return Err(PyValueError::new_err("CUDA not available"));
3011    }
3012    let ht = high_tm_f32.as_slice()?;
3013    let lt = low_tm_f32.as_slice()?;
3014    let ct = close_tm_f32.as_slice()?;
3015    let st = source_tm_f32.as_slice()?;
3016    let expected = cols
3017        .checked_mul(rows)
3018        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
3019    if ht.len() != expected || lt.len() != expected || ct.len() != expected || st.len() != expected
3020    {
3021        return Err(PyValueError::new_err("time-major input length mismatch"));
3022    }
3023    let (up, mid, low, ctx, dev_id) = py.allow_threads(|| {
3024        let cuda = CudaKeltner::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3025        let ctx = cuda.context_arc();
3026        let dev_id = cuda.device_id();
3027        let trip = cuda
3028            .keltner_many_series_one_param_time_major_dev(
3029                ht, lt, ct, st, cols, rows, period, multiplier, ma_type,
3030            )
3031            .map_err(|e| PyValueError::new_err(e.to_string()))?;
3032        Ok::<_, PyErr>((trip.upper, trip.middle, trip.lower, ctx, dev_id))
3033    })?;
3034    Ok((
3035        KeltnerDeviceArrayF32Py {
3036            inner: KeltnerDeviceArrayF32 {
3037                inner: up,
3038                context: ctx.clone(),
3039                device_id: dev_id,
3040            },
3041        },
3042        KeltnerDeviceArrayF32Py {
3043            inner: KeltnerDeviceArrayF32 {
3044                inner: mid,
3045                context: ctx.clone(),
3046                device_id: dev_id,
3047            },
3048        },
3049        KeltnerDeviceArrayF32Py {
3050            inner: KeltnerDeviceArrayF32 {
3051                inner: low,
3052                context: ctx,
3053                device_id: dev_id,
3054            },
3055        },
3056    ))
3057}
3058
3059#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3060#[derive(Serialize, Deserialize)]
3061pub struct KeltnerResult {
3062    pub values: Vec<f64>,
3063    pub rows: usize,
3064    pub cols: usize,
3065}
3066
3067#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3068#[wasm_bindgen(js_name = "keltner")]
3069pub fn keltner_js(
3070    high: &[f64],
3071    low: &[f64],
3072    close: &[f64],
3073    source: &[f64],
3074    period: usize,
3075    multiplier: f64,
3076    ma_type: String,
3077) -> Result<JsValue, JsValue> {
3078    if !(high.len() == low.len() && low.len() == close.len() && close.len() == source.len()) {
3079        return Err(JsValue::from_str("Input arrays must have equal length"));
3080    }
3081    let len = close.len();
3082
3083    let mut values = vec![0.0f64; 3 * len];
3084    let (upper, rest) = values.split_at_mut(len);
3085    let (middle, lower) = rest.split_at_mut(len);
3086
3087    let params = KeltnerParams {
3088        period: Some(period),
3089        multiplier: Some(multiplier),
3090        ma_type: Some(ma_type),
3091    };
3092    let input = KeltnerInput::from_slice(high, low, close, source, params);
3093
3094    keltner_into_slice(upper, middle, lower, &input, detect_best_kernel())
3095        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3096
3097    let out = KeltnerResult {
3098        values,
3099        rows: 3,
3100        cols: len,
3101    };
3102    serde_wasm_bindgen::to_value(&out)
3103        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3104}
3105
3106#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3107#[wasm_bindgen]
3108pub fn keltner_into(
3109    high_ptr: *const f64,
3110    low_ptr: *const f64,
3111    close_ptr: *const f64,
3112    source_ptr: *const f64,
3113    upper_ptr: *mut f64,
3114    middle_ptr: *mut f64,
3115    lower_ptr: *mut f64,
3116    len: usize,
3117    period: usize,
3118    multiplier: f64,
3119    ma_type: &str,
3120) -> Result<(), JsValue> {
3121    if high_ptr.is_null()
3122        || low_ptr.is_null()
3123        || close_ptr.is_null()
3124        || source_ptr.is_null()
3125        || upper_ptr.is_null()
3126        || middle_ptr.is_null()
3127        || lower_ptr.is_null()
3128    {
3129        return Err(JsValue::from_str("Null pointer provided"));
3130    }
3131
3132    unsafe {
3133        let high = std::slice::from_raw_parts(high_ptr, len);
3134        let low = std::slice::from_raw_parts(low_ptr, len);
3135        let close = std::slice::from_raw_parts(close_ptr, len);
3136        let source = std::slice::from_raw_parts(source_ptr, len);
3137
3138        let params = KeltnerParams {
3139            period: Some(period),
3140            multiplier: Some(multiplier),
3141            ma_type: Some(ma_type.to_string()),
3142        };
3143        let input = KeltnerInput::from_slice(high, low, close, source, params);
3144
3145        let input_ptrs = [
3146            high_ptr as *const f64,
3147            low_ptr as *const f64,
3148            close_ptr as *const f64,
3149            source_ptr as *const f64,
3150        ];
3151        let output_ptrs = [
3152            upper_ptr as *const f64,
3153            middle_ptr as *const f64,
3154            lower_ptr as *const f64,
3155        ];
3156
3157        let has_aliasing = input_ptrs
3158            .iter()
3159            .any(|&in_ptr| output_ptrs.iter().any(|&out_ptr| in_ptr == out_ptr));
3160
3161        if has_aliasing {
3162            let mut temp_upper = vec![0.0; len];
3163            let mut temp_middle = vec![0.0; len];
3164            let mut temp_lower = vec![0.0; len];
3165
3166            keltner_into_slice(
3167                &mut temp_upper,
3168                &mut temp_middle,
3169                &mut temp_lower,
3170                &input,
3171                Kernel::Auto,
3172            )
3173            .map_err(|e| JsValue::from_str(&e.to_string()))?;
3174
3175            let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
3176            let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
3177            let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
3178
3179            upper_out.copy_from_slice(&temp_upper);
3180            middle_out.copy_from_slice(&temp_middle);
3181            lower_out.copy_from_slice(&temp_lower);
3182        } else {
3183            let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
3184            let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
3185            let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
3186
3187            keltner_into_slice(upper_out, middle_out, lower_out, &input, Kernel::Auto)
3188                .map_err(|e| JsValue::from_str(&e.to_string()))?;
3189        }
3190
3191        Ok(())
3192    }
3193}
3194
3195#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3196#[wasm_bindgen]
3197pub fn keltner_alloc(len: usize) -> *mut f64 {
3198    let mut vec = Vec::<f64>::with_capacity(len);
3199    let ptr = vec.as_mut_ptr();
3200    std::mem::forget(vec);
3201    ptr
3202}
3203
3204#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3205#[wasm_bindgen]
3206pub fn keltner_free(ptr: *mut f64, len: usize) {
3207    if !ptr.is_null() {
3208        unsafe {
3209            let _ = Vec::from_raw_parts(ptr, len, len);
3210        }
3211    }
3212}
3213
3214#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3215#[derive(Serialize, Deserialize)]
3216pub struct KeltnerBatchConfig {
3217    pub period_range: (usize, usize, usize),
3218    pub multiplier_range: (f64, f64, f64),
3219    pub ma_type: String,
3220}
3221
3222#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3223#[derive(Serialize, Deserialize)]
3224pub struct KeltnerBatchJsOutput {
3225    pub upper: Vec<f64>,
3226    pub middle: Vec<f64>,
3227    pub lower: Vec<f64>,
3228    pub combos: Vec<KeltnerParams>,
3229    pub rows: usize,
3230    pub cols: usize,
3231}
3232
3233#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3234#[wasm_bindgen(js_name = "keltner_batch")]
3235pub fn keltner_batch_unified_js(
3236    high: &[f64],
3237    low: &[f64],
3238    close: &[f64],
3239    source: &[f64],
3240    config: JsValue,
3241) -> Result<JsValue, JsValue> {
3242    let cfg: KeltnerBatchConfig = serde_wasm_bindgen::from_value(config)
3243        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3244    let sweep = KeltnerBatchRange {
3245        period: cfg.period_range,
3246        multiplier: cfg.multiplier_range,
3247    };
3248
3249    let out = keltner_batch_inner(
3250        high,
3251        low,
3252        close,
3253        source,
3254        &sweep,
3255        detect_best_batch_kernel(),
3256        false,
3257        Some(&cfg.ma_type),
3258    )
3259    .map_err(|e| JsValue::from_str(&e.to_string()))?;
3260
3261    let js_out = KeltnerBatchJsOutput {
3262        upper: out.upper_band,
3263        middle: out.middle_band,
3264        lower: out.lower_band,
3265        combos: out.combos,
3266        rows: out.rows,
3267        cols: out.cols,
3268    };
3269    serde_wasm_bindgen::to_value(&js_out)
3270        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3271}
3272
3273#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3274#[wasm_bindgen(js_name = "keltner_into_concat")]
3275pub fn keltner_into_concat(
3276    h_ptr: *const f64,
3277    l_ptr: *const f64,
3278    c_ptr: *const f64,
3279    s_ptr: *const f64,
3280    out_ptr: *mut f64,
3281    len: usize,
3282    period: usize,
3283    multiplier: f64,
3284    ma_type: String,
3285) -> Result<(), JsValue> {
3286    if [h_ptr, l_ptr, c_ptr, s_ptr, out_ptr as *const f64]
3287        .iter()
3288        .any(|p| p.is_null())
3289    {
3290        return Err(JsValue::from_str(
3291            "null pointer passed to keltner_into_concat",
3292        ));
3293    }
3294    unsafe {
3295        let h = std::slice::from_raw_parts(h_ptr, len);
3296        let l = std::slice::from_raw_parts(l_ptr, len);
3297        let c = std::slice::from_raw_parts(c_ptr, len);
3298        let s = std::slice::from_raw_parts(s_ptr, len);
3299
3300        let out = std::slice::from_raw_parts_mut(out_ptr, 3 * len);
3301        let (upper, rest) = out.split_at_mut(len);
3302        let (middle, lower) = rest.split_at_mut(len);
3303
3304        let params = KeltnerParams {
3305            period: Some(period),
3306            multiplier: Some(multiplier),
3307            ma_type: Some(ma_type),
3308        };
3309        let input = KeltnerInput::from_slice(h, l, c, s, params);
3310        keltner_into_slice(upper, middle, lower, &input, detect_best_kernel())
3311            .map_err(|e| JsValue::from_str(&e.to_string()))
3312    }
3313}