1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::{PyDict, PyList};
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::error::Error;
29use thiserror::Error;
30
31#[cfg(all(feature = "python", feature = "cuda"))]
32use crate::cuda::keltner_wrapper::CudaKeltner;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use crate::cuda::moving_averages::DeviceArrayF32;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::context::Context;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41
42#[derive(Debug, Clone)]
43pub enum KeltnerData<'a> {
44 Candles {
45 candles: &'a Candles,
46 source: &'a str,
47 },
48 Slice(&'a [f64], &'a [f64], &'a [f64], &'a [f64]),
49}
50
51#[derive(Debug, Clone)]
52pub struct KeltnerOutput {
53 pub upper_band: Vec<f64>,
54 pub middle_band: Vec<f64>,
55 pub lower_band: Vec<f64>,
56}
57
58#[derive(Debug, Clone)]
59#[cfg_attr(
60 all(target_arch = "wasm32", feature = "wasm"),
61 derive(Serialize, Deserialize)
62)]
63pub struct KeltnerParams {
64 pub period: Option<usize>,
65 pub multiplier: Option<f64>,
66 pub ma_type: Option<String>,
67}
68
69impl Default for KeltnerParams {
70 fn default() -> Self {
71 Self {
72 period: Some(20),
73 multiplier: Some(2.0),
74 ma_type: Some("ema".to_string()),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct KeltnerInput<'a> {
81 pub data: KeltnerData<'a>,
82 pub params: KeltnerParams,
83}
84
85impl<'a> KeltnerInput<'a> {
86 #[inline]
87 pub fn from_candles(candles: &'a Candles, source: &'a str, params: KeltnerParams) -> Self {
88 Self {
89 data: KeltnerData::Candles { candles, source },
90 params,
91 }
92 }
93 #[inline]
94 pub fn from_slice(
95 high: &'a [f64],
96 low: &'a [f64],
97 close: &'a [f64],
98 source: &'a [f64],
99 params: KeltnerParams,
100 ) -> Self {
101 Self {
102 data: KeltnerData::Slice(high, low, close, source),
103 params,
104 }
105 }
106 #[inline]
107 pub fn with_default_candles(candles: &'a Candles) -> Self {
108 Self::from_candles(candles, "close", KeltnerParams::default())
109 }
110 #[inline]
111 pub fn get_period(&self) -> usize {
112 self.params.period.unwrap_or(20)
113 }
114 #[inline]
115 pub fn get_multiplier(&self) -> f64 {
116 self.params.multiplier.unwrap_or(2.0)
117 }
118 #[inline]
119 pub fn get_ma_type(&self) -> &str {
120 self.params.ma_type.as_deref().unwrap_or("ema")
121 }
122}
123
124impl<'a> AsRef<[f64]> for KeltnerInput<'a> {
125 #[inline(always)]
126 fn as_ref(&self) -> &[f64] {
127 match &self.data {
128 KeltnerData::Slice(_, _, _, source) => source,
129 KeltnerData::Candles { candles, source } => source_type(candles, source),
130 }
131 }
132}
133
134#[derive(Clone, Debug)]
135pub struct KeltnerBuilder {
136 period: Option<usize>,
137 multiplier: Option<f64>,
138 ma_type: Option<String>,
139 kernel: Kernel,
140}
141
142impl Default for KeltnerBuilder {
143 fn default() -> Self {
144 Self {
145 period: None,
146 multiplier: None,
147 ma_type: None,
148 kernel: Kernel::Auto,
149 }
150 }
151}
152
153impl KeltnerBuilder {
154 #[inline(always)]
155 pub fn new() -> Self {
156 Self::default()
157 }
158 #[inline(always)]
159 pub fn period(mut self, n: usize) -> Self {
160 self.period = Some(n);
161 self
162 }
163 #[inline(always)]
164 pub fn multiplier(mut self, x: f64) -> Self {
165 self.multiplier = Some(x);
166 self
167 }
168 #[inline(always)]
169 pub fn ma_type(mut self, mt: &str) -> Self {
170 self.ma_type = Some(mt.to_lowercase());
171 self
172 }
173 #[inline(always)]
174 pub fn kernel(mut self, k: Kernel) -> Self {
175 self.kernel = k;
176 self
177 }
178
179 #[inline(always)]
180 pub fn apply(self, c: &Candles) -> Result<KeltnerOutput, KeltnerError> {
181 let p = KeltnerParams {
182 period: self.period,
183 multiplier: self.multiplier,
184 ma_type: self.ma_type,
185 };
186 let i = KeltnerInput::from_candles(c, "close", p);
187 keltner_with_kernel(&i, self.kernel)
188 }
189
190 #[inline(always)]
191 pub fn apply_slice(
192 self,
193 high: &[f64],
194 low: &[f64],
195 close: &[f64],
196 source: &[f64],
197 ) -> Result<KeltnerOutput, KeltnerError> {
198 let p = KeltnerParams {
199 period: self.period,
200 multiplier: self.multiplier,
201 ma_type: self.ma_type,
202 };
203 let i = KeltnerInput::from_slice(high, low, close, source, p);
204 keltner_with_kernel(&i, self.kernel)
205 }
206
207 #[inline(always)]
208 pub fn into_stream(self) -> Result<KeltnerStream, KeltnerError> {
209 let p = KeltnerParams {
210 period: self.period,
211 multiplier: self.multiplier,
212 ma_type: self.ma_type,
213 };
214 KeltnerStream::try_new(p)
215 }
216}
217
218#[derive(Debug, Error)]
219pub enum KeltnerError {
220 #[error("keltner: empty data provided.")]
221 EmptyInputData,
222 #[error("keltner: invalid period: period = {period}, data length = {data_len}")]
223 InvalidPeriod { period: usize, data_len: usize },
224 #[error("keltner: not enough valid data: needed = {needed}, valid = {valid}")]
225 NotEnoughValidData { needed: usize, valid: usize },
226 #[error("keltner: all values are NaN.")]
227 AllValuesNaN,
228 #[error("keltner: output length mismatch: expected = {expected}, got = {got}")]
229 OutputLengthMismatch { expected: usize, got: usize },
230 #[error("keltner: invalid range: start={start}, end={end}, step={step}")]
231 InvalidRange {
232 start: String,
233 end: String,
234 step: String,
235 },
236 #[error("keltner: invalid kernel for batch: {0:?}")]
237 InvalidKernelForBatch(Kernel),
238 #[error("keltner: invalid input: {0}")]
239 InvalidInput(String),
240 #[error("keltner: MA error: {0}")]
241 MaError(String),
242}
243
244#[inline]
245pub fn keltner(input: &KeltnerInput) -> Result<KeltnerOutput, KeltnerError> {
246 keltner_with_kernel(input, Kernel::Auto)
247}
248
249pub fn keltner_with_kernel(
250 input: &KeltnerInput,
251 kernel: Kernel,
252) -> Result<KeltnerOutput, KeltnerError> {
253 let (high, low, close, source_slice): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
254 KeltnerData::Candles { candles, source } => (
255 candles
256 .select_candle_field("high")
257 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
258 candles
259 .select_candle_field("low")
260 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
261 candles
262 .select_candle_field("close")
263 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
264 source_type(candles, source),
265 ),
266 KeltnerData::Slice(h, l, c, s) => (*h, *l, *c, *s),
267 };
268 let period = input.get_period();
269 let len = close.len();
270 if len == 0 {
271 return Err(KeltnerError::EmptyInputData);
272 }
273 if period == 0 || period > len {
274 return Err(KeltnerError::InvalidPeriod {
275 period,
276 data_len: len,
277 });
278 }
279 let first = close
280 .iter()
281 .position(|x| !x.is_nan())
282 .ok_or(KeltnerError::AllValuesNaN)?;
283
284 if (len - first) < period {
285 return Err(KeltnerError::NotEnoughValidData {
286 needed: period,
287 valid: len - first,
288 });
289 }
290
291 let chosen = match kernel {
292 Kernel::Auto => Kernel::Scalar,
293 other => other,
294 };
295
296 let warm = first + period - 1;
297 let mut upper_band = alloc_with_nan_prefix(len, warm);
298 let mut middle_band = alloc_with_nan_prefix(len, warm);
299 let mut lower_band = alloc_with_nan_prefix(len, warm);
300
301 unsafe {
302 match chosen {
303 Kernel::Scalar | Kernel::ScalarBatch => keltner_scalar(
304 high,
305 low,
306 close,
307 source_slice,
308 period,
309 input.get_multiplier(),
310 input.get_ma_type(),
311 first,
312 &mut upper_band,
313 &mut middle_band,
314 &mut lower_band,
315 ),
316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317 Kernel::Avx2 | Kernel::Avx2Batch => keltner_avx2(
318 high,
319 low,
320 close,
321 source_slice,
322 period,
323 input.get_multiplier(),
324 input.get_ma_type(),
325 first,
326 &mut upper_band,
327 &mut middle_band,
328 &mut lower_band,
329 ),
330 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
331 Kernel::Avx512 | Kernel::Avx512Batch => keltner_avx512(
332 high,
333 low,
334 close,
335 source_slice,
336 period,
337 input.get_multiplier(),
338 input.get_ma_type(),
339 first,
340 &mut upper_band,
341 &mut middle_band,
342 &mut lower_band,
343 ),
344 _ => unreachable!(),
345 }
346 }
347 Ok(KeltnerOutput {
348 upper_band,
349 middle_band,
350 lower_band,
351 })
352}
353
354#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
355#[inline(always)]
356pub fn keltner_into(
357 input: &KeltnerInput,
358 upper_dst: &mut [f64],
359 middle_dst: &mut [f64],
360 lower_dst: &mut [f64],
361) -> Result<(), KeltnerError> {
362 keltner_into_slice(upper_dst, middle_dst, lower_dst, input, Kernel::Auto)
363}
364
365#[inline(always)]
366pub fn keltner_into_slice(
367 upper_dst: &mut [f64],
368 middle_dst: &mut [f64],
369 lower_dst: &mut [f64],
370 input: &KeltnerInput,
371 kernel: Kernel,
372) -> Result<(), KeltnerError> {
373 let (high, low, close, source_slice): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
374 KeltnerData::Candles { candles, source } => (
375 candles
376 .select_candle_field("high")
377 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
378 candles
379 .select_candle_field("low")
380 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
381 candles
382 .select_candle_field("close")
383 .map_err(|e| KeltnerError::MaError(e.to_string()))?,
384 source_type(candles, source),
385 ),
386 KeltnerData::Slice(h, l, c, s) => (*h, *l, *c, *s),
387 };
388
389 let period = input.get_period();
390 let len = close.len();
391
392 if len == 0 {
393 return Err(KeltnerError::EmptyInputData);
394 }
395
396 if upper_dst.len() != len || middle_dst.len() != len || lower_dst.len() != len {
397 return Err(KeltnerError::OutputLengthMismatch {
398 expected: len,
399 got: upper_dst.len().max(middle_dst.len()).max(lower_dst.len()),
400 });
401 }
402
403 if period == 0 || period > len {
404 return Err(KeltnerError::InvalidPeriod {
405 period,
406 data_len: len,
407 });
408 }
409
410 let first = close
411 .iter()
412 .position(|x| !x.is_nan())
413 .ok_or(KeltnerError::AllValuesNaN)?;
414
415 if (len - first) < period {
416 return Err(KeltnerError::NotEnoughValidData {
417 needed: period,
418 valid: len - first,
419 });
420 }
421
422 let chosen = match kernel {
423 Kernel::Auto => Kernel::Scalar,
424 other => other,
425 };
426
427 unsafe {
428 match chosen {
429 Kernel::Scalar | Kernel::ScalarBatch => keltner_scalar(
430 high,
431 low,
432 close,
433 source_slice,
434 period,
435 input.get_multiplier(),
436 input.get_ma_type(),
437 first,
438 upper_dst,
439 middle_dst,
440 lower_dst,
441 ),
442 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
443 Kernel::Avx2 | Kernel::Avx2Batch => keltner_avx2(
444 high,
445 low,
446 close,
447 source_slice,
448 period,
449 input.get_multiplier(),
450 input.get_ma_type(),
451 first,
452 upper_dst,
453 middle_dst,
454 lower_dst,
455 ),
456 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
457 Kernel::Avx512 | Kernel::Avx512Batch => keltner_avx512(
458 high,
459 low,
460 close,
461 source_slice,
462 period,
463 input.get_multiplier(),
464 input.get_ma_type(),
465 first,
466 upper_dst,
467 middle_dst,
468 lower_dst,
469 ),
470 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
471 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
472 keltner_scalar(
473 high,
474 low,
475 close,
476 source_slice,
477 period,
478 input.get_multiplier(),
479 input.get_ma_type(),
480 first,
481 upper_dst,
482 middle_dst,
483 lower_dst,
484 )
485 }
486 _ => unreachable!(),
487 }
488 }
489
490 let warm = first + period - 1;
491 for i in 0..warm {
492 upper_dst[i] = f64::NAN;
493 middle_dst[i] = f64::NAN;
494 lower_dst[i] = f64::NAN;
495 }
496
497 Ok(())
498}
499
500#[inline]
501
502pub fn keltner_scalar_classic_sma(
503 high: &[f64],
504 low: &[f64],
505 close: &[f64],
506 source: &[f64],
507 period: usize,
508 multiplier: f64,
509 first: usize,
510 upper: &mut [f64],
511 middle: &mut [f64],
512 lower: &mut [f64],
513) {
514 let len = close.len();
515 let warm = first + period - 1;
516
517 if warm >= len {
518 return;
519 }
520
521 let alpha = 1.0 / (period as f64);
522 let mut sum_tr = 0.0;
523 let mut rma = f64::NAN;
524 let mut atr_values = vec![f64::NAN; len];
525
526 for i in 0..len {
527 let tr = if i == 0 {
528 high[0] - low[0]
529 } else {
530 let hl = high[i] - low[i];
531 let hc = (high[i] - close[i - 1]).abs();
532 let lc = (low[i] - close[i - 1]).abs();
533 hl.max(hc).max(lc)
534 };
535
536 if i < period {
537 sum_tr += tr;
538 if i == period - 1 {
539 rma = sum_tr / (period as f64);
540 atr_values[i] = rma;
541 }
542 } else {
543 rma += alpha * (tr - rma);
544 atr_values[i] = rma;
545 }
546 }
547
548 let mut sum = 0.0;
549
550 for j in 0..period {
551 sum += source[first + j];
552 }
553 let mut sma_val = sum / period as f64;
554
555 if warm < len {
556 middle[warm] = sma_val;
557 let atr_v = atr_values[warm];
558 if !atr_v.is_nan() {
559 upper[warm] = sma_val + multiplier * atr_v;
560 lower[warm] = sma_val - multiplier * atr_v;
561 }
562 }
563
564 for i in (warm + 1)..len {
565 sum += source[i] - source[i - period];
566 sma_val = sum / period as f64;
567
568 middle[i] = sma_val;
569 let atr_v = atr_values[i];
570 if !atr_v.is_nan() {
571 upper[i] = sma_val + multiplier * atr_v;
572 lower[i] = sma_val - multiplier * atr_v;
573 }
574 }
575}
576
577pub fn keltner_scalar_classic_ema(
578 high: &[f64],
579 low: &[f64],
580 close: &[f64],
581 source: &[f64],
582 period: usize,
583 multiplier: f64,
584 first: usize,
585 upper: &mut [f64],
586 middle: &mut [f64],
587 lower: &mut [f64],
588) {
589 let len = close.len();
590 let warm = first + period - 1;
591
592 if warm >= len {
593 return;
594 }
595
596 let alpha = 1.0 / (period as f64);
597 let mut sum_tr = 0.0;
598 let mut rma = f64::NAN;
599 let mut atr_values = vec![f64::NAN; len];
600
601 for i in 0..len {
602 let tr = if i == 0 {
603 high[0] - low[0]
604 } else {
605 let hl = high[i] - low[i];
606 let hc = (high[i] - close[i - 1]).abs();
607 let lc = (low[i] - close[i - 1]).abs();
608 hl.max(hc).max(lc)
609 };
610
611 if i < period {
612 sum_tr += tr;
613 if i == period - 1 {
614 rma = sum_tr / (period as f64);
615 atr_values[i] = rma;
616 }
617 } else {
618 rma += alpha * (tr - rma);
619 atr_values[i] = rma;
620 }
621 }
622
623 let ema_alpha = 2.0 / (period as f64 + 1.0);
624 let ema_alpha_1 = 1.0 - ema_alpha;
625
626 let mut sum = 0.0;
627 for j in 0..period {
628 sum += source[first + j];
629 }
630 let mut ema_val = sum / period as f64;
631
632 if warm < len {
633 middle[warm] = ema_val;
634 let atr_v = atr_values[warm];
635 if !atr_v.is_nan() {
636 upper[warm] = ema_val + multiplier * atr_v;
637 lower[warm] = ema_val - multiplier * atr_v;
638 }
639 }
640
641 for i in (warm + 1)..len {
642 ema_val = ema_alpha * source[i] + ema_alpha_1 * ema_val;
643
644 middle[i] = ema_val;
645 let atr_v = atr_values[i];
646 if !atr_v.is_nan() {
647 upper[i] = ema_val + multiplier * atr_v;
648 lower[i] = ema_val - multiplier * atr_v;
649 }
650 }
651}
652
653pub fn keltner_scalar(
654 high: &[f64],
655 low: &[f64],
656 close: &[f64],
657 source: &[f64],
658 period: usize,
659 multiplier: f64,
660 ma_type: &str,
661 first: usize,
662 upper: &mut [f64],
663 middle: &mut [f64],
664 lower: &mut [f64],
665) {
666 let len = close.len();
667 let warm = first + period - 1;
668 if warm >= len {
669 return;
670 }
671
672 let pf = period as f64;
673 let rma_alpha = 1.0 / pf;
674
675 let mut atr: f64;
676 unsafe {
677 atr = *high.get_unchecked(0) - *low.get_unchecked(0);
678
679 let mut i = 1usize;
680 while i < period {
681 let hi = *high.get_unchecked(i);
682 let lo = *low.get_unchecked(i);
683 let pc = *close.get_unchecked(i - 1);
684 let hl = hi - lo;
685 let hc = (hi - pc).abs();
686 let lc = (lo - pc).abs();
687 atr += hl.max(hc).max(lc);
688 i += 1;
689 }
690 atr /= pf;
691
692 let mut k = period;
693 while k <= warm {
694 let hi = *high.get_unchecked(k);
695 let lo = *low.get_unchecked(k);
696 let pc = *close.get_unchecked(k - 1);
697 let hl = hi - lo;
698 let hc = (hi - pc).abs();
699 let lc = (lo - pc).abs();
700 let tr = hl.max(hc).max(lc);
701 atr = (tr - atr).mul_add(rma_alpha, atr);
702 k += 1;
703 }
704 }
705
706 let m = multiplier;
707
708 if ma_type.eq_ignore_ascii_case("ema") {
709 let mut ema: f64 = 0.0;
710 unsafe {
711 let mut j = 0usize;
712 while j < period {
713 ema += *source.get_unchecked(first + j);
714 j += 1;
715 }
716 }
717 ema /= pf;
718
719 middle[warm] = ema;
720 upper[warm] = m.mul_add(atr, ema);
721 lower[warm] = (-m).mul_add(atr, ema);
722
723 let ema_alpha = 2.0 / (pf + 1.0);
724
725 unsafe {
726 let mut i = warm + 1;
727 while i < len {
728 let hi = *high.get_unchecked(i);
729 let lo = *low.get_unchecked(i);
730 let pc = *close.get_unchecked(i - 1);
731 let hl = hi - lo;
732 let hc = (hi - pc).abs();
733 let lc = (lo - pc).abs();
734 let tr = hl.max(hc).max(lc);
735 atr = (tr - atr).mul_add(rma_alpha, atr);
736
737 let xi = *source.get_unchecked(i);
738 ema = (xi - ema).mul_add(ema_alpha, ema);
739
740 middle[i] = ema;
741 upper[i] = m.mul_add(atr, ema);
742 lower[i] = (-m).mul_add(atr, ema);
743 i += 1;
744 }
745 }
746 return;
747 }
748
749 if ma_type.eq_ignore_ascii_case("sma") {
750 let mut sum: f64 = 0.0;
751 unsafe {
752 let mut j = 0usize;
753 while j < period {
754 sum += *source.get_unchecked(first + j);
755 j += 1;
756 }
757 }
758 let mut mid = sum / pf;
759 middle[warm] = mid;
760 upper[warm] = m.mul_add(atr, mid);
761 lower[warm] = (-m).mul_add(atr, mid);
762
763 unsafe {
764 let mut i = warm + 1;
765 while i < len {
766 let hi = *high.get_unchecked(i);
767 let lo = *low.get_unchecked(i);
768 let pc = *close.get_unchecked(i - 1);
769 let hl = hi - lo;
770 let hc = (hi - pc).abs();
771 let lc = (lo - pc).abs();
772 let tr = hl.max(hc).max(lc);
773 atr = (tr - atr).mul_add(rma_alpha, atr);
774
775 let new_x = *source.get_unchecked(i);
776 let old_x = *source.get_unchecked(i - period);
777 sum += new_x - old_x;
778 mid = sum / pf;
779
780 middle[i] = mid;
781 upper[i] = m.mul_add(atr, mid);
782 lower[i] = (-m).mul_add(atr, mid);
783 i += 1;
784 }
785 }
786 return;
787 }
788
789 let mut atr = crate::utilities::helpers::alloc_with_nan_prefix(len, warm);
790 let alpha = 1.0 / (period as f64);
791 let mut sum_tr = 0.0;
792 let mut rma = f64::NAN;
793
794 for i in 0..len {
795 let tr = if i == 0 {
796 high[0] - low[0]
797 } else {
798 let hl = high[i] - low[i];
799 let hc = (high[i] - close[i - 1]).abs();
800 let lc = (low[i] - close[i - 1]).abs();
801 hl.max(hc).max(lc)
802 };
803 if i < period {
804 sum_tr += tr;
805 if i == period - 1 {
806 rma = sum_tr / (period as f64);
807 atr[i] = rma;
808 }
809 } else {
810 rma += alpha * (tr - rma);
811 atr[i] = rma;
812 }
813 }
814
815 let mut ma_values = crate::utilities::helpers::alloc_with_nan_prefix(len, warm);
816
817 match ma_type {
818 "ema" => {
819 use crate::indicators::moving_averages::ema::{
820 ema_into_slice, EmaData, EmaInput, EmaParams,
821 };
822 let ema_input = EmaInput {
823 data: EmaData::Slice(source),
824 params: EmaParams {
825 period: Some(period),
826 },
827 };
828 let _ = ema_into_slice(&mut ma_values, &ema_input, Kernel::Auto);
829 }
830 "sma" => {
831 use crate::indicators::moving_averages::sma::{
832 sma_into_slice, SmaData, SmaInput, SmaParams,
833 };
834 let sma_input = SmaInput {
835 data: SmaData::Slice(source),
836 params: SmaParams {
837 period: Some(period),
838 },
839 };
840 let _ = sma_into_slice(&mut ma_values, &sma_input, Kernel::Auto);
841 }
842 _ => {
843 if let Ok(result) = crate::indicators::moving_averages::ma::ma(
844 ma_type,
845 crate::indicators::moving_averages::ma::MaData::Slice(source),
846 period,
847 ) {
848 ma_values.copy_from_slice(&result);
849 }
850 }
851 }
852
853 for i in warm..len {
854 let ma_v = ma_values[i];
855 let atr_v = atr[i];
856 if ma_v.is_nan() || atr_v.is_nan() {
857 continue;
858 }
859 middle[i] = ma_v;
860 upper[i] = multiplier.mul_add(atr_v, ma_v);
861 lower[i] = (-multiplier).mul_add(atr_v, ma_v);
862 }
863}
864
865#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
866#[inline]
867pub fn keltner_avx2(
868 high: &[f64],
869 low: &[f64],
870 close: &[f64],
871 source: &[f64],
872 period: usize,
873 multiplier: f64,
874 ma_type: &str,
875 first: usize,
876 upper: &mut [f64],
877 middle: &mut [f64],
878 lower: &mut [f64],
879) {
880 keltner_scalar(
881 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
882 )
883}
884
885#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
886#[inline]
887pub fn keltner_avx512(
888 high: &[f64],
889 low: &[f64],
890 close: &[f64],
891 source: &[f64],
892 period: usize,
893 multiplier: f64,
894 ma_type: &str,
895 first: usize,
896 upper: &mut [f64],
897 middle: &mut [f64],
898 lower: &mut [f64],
899) {
900 if period <= 32 {
901 unsafe {
902 keltner_avx512_short(
903 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
904 )
905 }
906 } else {
907 unsafe {
908 keltner_avx512_long(
909 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
910 )
911 }
912 }
913}
914
915#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
916#[inline]
917pub unsafe fn keltner_avx512_short(
918 high: &[f64],
919 low: &[f64],
920 close: &[f64],
921 source: &[f64],
922 period: usize,
923 multiplier: f64,
924 ma_type: &str,
925 first: usize,
926 upper: &mut [f64],
927 middle: &mut [f64],
928 lower: &mut [f64],
929) {
930 keltner_scalar(
931 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
932 )
933}
934
935#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
936#[inline]
937pub unsafe fn keltner_avx512_long(
938 high: &[f64],
939 low: &[f64],
940 close: &[f64],
941 source: &[f64],
942 period: usize,
943 multiplier: f64,
944 ma_type: &str,
945 first: usize,
946 upper: &mut [f64],
947 middle: &mut [f64],
948 lower: &mut [f64],
949) {
950 keltner_scalar(
951 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
952 )
953}
954
955#[inline(always)]
956pub fn keltner_row_scalar(
957 high: &[f64],
958 low: &[f64],
959 close: &[f64],
960 source: &[f64],
961 period: usize,
962 multiplier: f64,
963 ma_type: &str,
964 first: usize,
965 upper: &mut [f64],
966 middle: &mut [f64],
967 lower: &mut [f64],
968) {
969 keltner_scalar(
970 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
971 );
972}
973
974#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
975#[inline(always)]
976pub fn keltner_row_avx2(
977 high: &[f64],
978 low: &[f64],
979 close: &[f64],
980 source: &[f64],
981 period: usize,
982 multiplier: f64,
983 ma_type: &str,
984 first: usize,
985 upper: &mut [f64],
986 middle: &mut [f64],
987 lower: &mut [f64],
988) {
989 keltner_avx2(
990 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
991 )
992}
993
994#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
995#[inline(always)]
996pub fn keltner_row_avx512(
997 high: &[f64],
998 low: &[f64],
999 close: &[f64],
1000 source: &[f64],
1001 period: usize,
1002 multiplier: f64,
1003 ma_type: &str,
1004 first: usize,
1005 upper: &mut [f64],
1006 middle: &mut [f64],
1007 lower: &mut [f64],
1008) {
1009 keltner_avx512(
1010 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1011 )
1012}
1013
1014#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015#[inline(always)]
1016pub unsafe fn keltner_row_avx512_short(
1017 high: &[f64],
1018 low: &[f64],
1019 close: &[f64],
1020 source: &[f64],
1021 period: usize,
1022 multiplier: f64,
1023 ma_type: &str,
1024 first: usize,
1025 upper: &mut [f64],
1026 middle: &mut [f64],
1027 lower: &mut [f64],
1028) {
1029 keltner_avx512_short(
1030 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1031 )
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036pub unsafe fn keltner_row_avx512_long(
1037 high: &[f64],
1038 low: &[f64],
1039 close: &[f64],
1040 source: &[f64],
1041 period: usize,
1042 multiplier: f64,
1043 ma_type: &str,
1044 first: usize,
1045 upper: &mut [f64],
1046 middle: &mut [f64],
1047 lower: &mut [f64],
1048) {
1049 keltner_avx512_long(
1050 high, low, close, source, period, multiplier, ma_type, first, upper, middle, lower,
1051 )
1052}
1053
1054#[derive(Clone, Debug)]
1055pub struct KeltnerBatchRange {
1056 pub period: (usize, usize, usize),
1057 pub multiplier: (f64, f64, f64),
1058}
1059
1060impl Default for KeltnerBatchRange {
1061 fn default() -> Self {
1062 Self {
1063 period: (20, 269, 1),
1064 multiplier: (2.0, 2.0, 0.0),
1065 }
1066 }
1067}
1068
1069#[derive(Clone, Debug)]
1070pub struct KeltnerBatchBuilder {
1071 range: KeltnerBatchRange,
1072 kernel: Kernel,
1073}
1074
1075impl Default for KeltnerBatchBuilder {
1076 fn default() -> Self {
1077 Self {
1078 range: KeltnerBatchRange::default(),
1079 kernel: Kernel::Auto,
1080 }
1081 }
1082}
1083impl KeltnerBatchBuilder {
1084 pub fn new() -> Self {
1085 Self::default()
1086 }
1087 pub fn kernel(mut self, k: Kernel) -> Self {
1088 self.kernel = k;
1089 self
1090 }
1091 #[inline]
1092 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1093 self.range.period = (start, end, step);
1094 self
1095 }
1096 #[inline]
1097 pub fn period_static(mut self, p: usize) -> Self {
1098 self.range.period = (p, p, 0);
1099 self
1100 }
1101 #[inline]
1102 pub fn multiplier_range(mut self, start: f64, end: f64, step: f64) -> Self {
1103 self.range.multiplier = (start, end, step);
1104 self
1105 }
1106 #[inline]
1107 pub fn multiplier_static(mut self, m: f64) -> Self {
1108 self.range.multiplier = (m, m, 0.0);
1109 self
1110 }
1111 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<KeltnerBatchOutput, KeltnerError> {
1112 let h = c
1113 .select_candle_field("high")
1114 .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1115 let l = c
1116 .select_candle_field("low")
1117 .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1118 let cl = c
1119 .select_candle_field("close")
1120 .map_err(|e| KeltnerError::MaError(e.to_string()))?;
1121 let src_v = source_type(c, src);
1122 self.apply_slice(&h, &l, &cl, src_v)
1123 }
1124 pub fn apply_slice(
1125 self,
1126 high: &[f64],
1127 low: &[f64],
1128 close: &[f64],
1129 source: &[f64],
1130 ) -> Result<KeltnerBatchOutput, KeltnerError> {
1131 keltner_batch_with_kernel(high, low, close, source, &self.range, self.kernel)
1132 }
1133
1134 pub fn with_default_slice(
1135 high: &[f64],
1136 low: &[f64],
1137 close: &[f64],
1138 source: &[f64],
1139 k: Kernel,
1140 ) -> Result<KeltnerBatchOutput, KeltnerError> {
1141 KeltnerBatchBuilder::new()
1142 .kernel(k)
1143 .apply_slice(high, low, close, source)
1144 }
1145}
1146
1147#[derive(Clone, Debug)]
1148pub struct KeltnerBatchOutput {
1149 pub upper_band: Vec<f64>,
1150 pub middle_band: Vec<f64>,
1151 pub lower_band: Vec<f64>,
1152 pub combos: Vec<KeltnerParams>,
1153 pub rows: usize,
1154 pub cols: usize,
1155}
1156impl KeltnerBatchOutput {
1157 pub fn row_for_params(&self, p: &KeltnerParams) -> Option<usize> {
1158 self.combos.iter().position(|c| {
1159 c.period.unwrap_or(20) == p.period.unwrap_or(20)
1160 && (c.multiplier.unwrap_or(2.0) - p.multiplier.unwrap_or(2.0)).abs() < 1e-12
1161 })
1162 }
1163 pub fn values_for(&self, p: &KeltnerParams) -> Option<(&[f64], &[f64], &[f64])> {
1164 self.row_for_params(p).map(|row| {
1165 let start = row * self.cols;
1166 (
1167 &self.upper_band[start..start + self.cols],
1168 &self.middle_band[start..start + self.cols],
1169 &self.lower_band[start..start + self.cols],
1170 )
1171 })
1172 }
1173}
1174
1175fn expand_grid(r: &KeltnerBatchRange) -> Result<Vec<KeltnerParams>, KeltnerError> {
1176 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, KeltnerError> {
1177 if step == 0 || start == end {
1178 return Ok(vec![start]);
1179 }
1180 if start < end {
1181 return Ok((start..=end).step_by(step.max(1)).collect());
1182 }
1183 let mut v = Vec::new();
1184 let mut x = start as isize;
1185 let end_i = end as isize;
1186 let st = (step as isize).max(1);
1187 while x >= end_i {
1188 v.push(x as usize);
1189 x -= st;
1190 }
1191 if v.is_empty() {
1192 return Err(KeltnerError::InvalidRange {
1193 start: start.to_string(),
1194 end: end.to_string(),
1195 step: step.to_string(),
1196 });
1197 }
1198 Ok(v)
1199 }
1200 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, KeltnerError> {
1201 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1202 return Ok(vec![start]);
1203 }
1204 if start < end {
1205 let mut v = Vec::new();
1206 let mut x = start;
1207 let st = step.abs();
1208 while x <= end + 1e-12 {
1209 v.push(x);
1210 x += st;
1211 }
1212 if v.is_empty() {
1213 return Err(KeltnerError::InvalidRange {
1214 start: start.to_string(),
1215 end: end.to_string(),
1216 step: step.to_string(),
1217 });
1218 }
1219 return Ok(v);
1220 }
1221 let mut v = Vec::new();
1222 let mut x = start;
1223 let st = step.abs();
1224 while x + 1e-12 >= end {
1225 v.push(x);
1226 x -= st;
1227 }
1228 if v.is_empty() {
1229 return Err(KeltnerError::InvalidRange {
1230 start: start.to_string(),
1231 end: end.to_string(),
1232 step: step.to_string(),
1233 });
1234 }
1235 Ok(v)
1236 }
1237
1238 let periods = axis_usize(r.period)?;
1239 let mults = axis_f64(r.multiplier)?;
1240
1241 let cap = periods
1242 .len()
1243 .checked_mul(mults.len())
1244 .ok_or_else(|| KeltnerError::InvalidRange {
1245 start: "rows".into(),
1246 end: "cols".into(),
1247 step: "rows*cols".into(),
1248 })?;
1249
1250 let mut out = Vec::with_capacity(cap);
1251 for &p in &periods {
1252 for &m in &mults {
1253 out.push(KeltnerParams {
1254 period: Some(p),
1255 multiplier: Some(m),
1256 ma_type: None,
1257 });
1258 }
1259 }
1260 Ok(out)
1261}
1262
1263pub fn keltner_batch_with_kernel(
1264 high: &[f64],
1265 low: &[f64],
1266 close: &[f64],
1267 source: &[f64],
1268 sweep: &KeltnerBatchRange,
1269 k: Kernel,
1270) -> Result<KeltnerBatchOutput, KeltnerError> {
1271 let kernel = match k {
1272 Kernel::Auto => {
1273 let best = detect_best_batch_kernel();
1274 if best == Kernel::Avx512Batch {
1275 Kernel::Avx2Batch
1276 } else {
1277 best
1278 }
1279 }
1280 other if other.is_batch() => other,
1281 _ => {
1282 return Err(KeltnerError::InvalidKernelForBatch(k));
1283 }
1284 };
1285 let simd = match kernel {
1286 Kernel::Avx512Batch => Kernel::Avx512,
1287 Kernel::Avx2Batch => Kernel::Avx2,
1288 Kernel::ScalarBatch => Kernel::Scalar,
1289 _ => unreachable!(),
1290 };
1291 keltner_batch_par_slice(high, low, close, source, sweep, simd)
1292}
1293
1294pub fn keltner_batch_slice(
1295 high: &[f64],
1296 low: &[f64],
1297 close: &[f64],
1298 source: &[f64],
1299 sweep: &KeltnerBatchRange,
1300 kern: Kernel,
1301) -> Result<KeltnerBatchOutput, KeltnerError> {
1302 keltner_batch_inner(high, low, close, source, sweep, kern, false, None)
1303}
1304pub fn keltner_batch_par_slice(
1305 high: &[f64],
1306 low: &[f64],
1307 close: &[f64],
1308 source: &[f64],
1309 sweep: &KeltnerBatchRange,
1310 kern: Kernel,
1311) -> Result<KeltnerBatchOutput, KeltnerError> {
1312 keltner_batch_inner(high, low, close, source, sweep, kern, true, None)
1313}
1314fn keltner_batch_inner(
1315 high: &[f64],
1316 low: &[f64],
1317 close: &[f64],
1318 source: &[f64],
1319 sweep: &KeltnerBatchRange,
1320 kern: Kernel,
1321 parallel: bool,
1322 ma_type: Option<&str>,
1323) -> Result<KeltnerBatchOutput, KeltnerError> {
1324 let combos = expand_grid(sweep)?;
1325 if combos.is_empty() {
1326 return Err(KeltnerError::InvalidRange {
1327 start: "range".into(),
1328 end: "range".into(),
1329 step: "empty".into(),
1330 });
1331 }
1332 let len = close.len();
1333 let first = close
1334 .iter()
1335 .position(|x| !x.is_nan())
1336 .ok_or(KeltnerError::AllValuesNaN)?;
1337 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1338 if len - first < max_p {
1339 return Err(KeltnerError::NotEnoughValidData {
1340 needed: max_p,
1341 valid: len - first,
1342 });
1343 }
1344 let rows = combos.len();
1345 let cols = len;
1346
1347 rows.checked_mul(cols)
1348 .ok_or_else(|| KeltnerError::InvalidRange {
1349 start: rows.to_string(),
1350 end: cols.to_string(),
1351 step: "rows*cols".into(),
1352 })?;
1353
1354 let warm: Vec<usize> = combos
1355 .iter()
1356 .map(|c| {
1357 let p = c.period.unwrap();
1358 first
1359 .checked_add(p.saturating_sub(1))
1360 .ok_or_else(|| KeltnerError::InvalidRange {
1361 start: first.to_string(),
1362 end: p.to_string(),
1363 step: "first+period-1".into(),
1364 })
1365 })
1366 .collect::<Result<Vec<_>, _>>()?;
1367
1368 let mut upper_mu = make_uninit_matrix(rows, cols);
1369 let mut middle_mu = make_uninit_matrix(rows, cols);
1370 let mut lower_mu = make_uninit_matrix(rows, cols);
1371
1372 init_matrix_prefixes(&mut upper_mu, cols, &warm);
1373 init_matrix_prefixes(&mut middle_mu, cols, &warm);
1374 init_matrix_prefixes(&mut lower_mu, cols, &warm);
1375
1376 let mut upper_guard = core::mem::ManuallyDrop::new(upper_mu);
1377 let mut middle_guard = core::mem::ManuallyDrop::new(middle_mu);
1378 let mut lower_guard = core::mem::ManuallyDrop::new(lower_mu);
1379
1380 let upper: &mut [f64] = unsafe {
1381 core::slice::from_raw_parts_mut(upper_guard.as_mut_ptr() as *mut f64, upper_guard.len())
1382 };
1383 let middle: &mut [f64] = unsafe {
1384 core::slice::from_raw_parts_mut(middle_guard.as_mut_ptr() as *mut f64, middle_guard.len())
1385 };
1386 let lower: &mut [f64] = unsafe {
1387 core::slice::from_raw_parts_mut(lower_guard.as_mut_ptr() as *mut f64, lower_guard.len())
1388 };
1389
1390 let mut tr: Vec<f64> = vec![0.0; cols];
1391 unsafe {
1392 let mut i = 0usize;
1393 while i < cols {
1394 let val = if i == 0 {
1395 *high.get_unchecked(0) - *low.get_unchecked(0)
1396 } else {
1397 let hi = *high.get_unchecked(i);
1398 let lo = *low.get_unchecked(i);
1399 let pc = *close.get_unchecked(i - 1);
1400 let hl = hi - lo;
1401 let hc = (hi - pc).abs();
1402 let lc = (lo - pc).abs();
1403 hl.max(hc).max(lc)
1404 };
1405 *tr.get_unchecked_mut(i) = val;
1406 i += 1;
1407 }
1408 }
1409
1410 let ps: Option<Vec<f64>> = if ma_type.unwrap_or("ema").eq_ignore_ascii_case("sma") {
1411 let mut buf = vec![0.0; cols + 1];
1412 unsafe {
1413 let mut i = 0usize;
1414 while i < cols {
1415 let prev = *buf.get_unchecked(i);
1416 let xi = *source.get_unchecked(i);
1417 *buf.get_unchecked_mut(i + 1) = prev + xi;
1418 i += 1;
1419 }
1420 }
1421 Some(buf)
1422 } else {
1423 None
1424 };
1425
1426 let ma = ma_type.unwrap_or("ema");
1427 let do_row = |row: usize, up: &mut [f64], mid: &mut [f64], low_out: &mut [f64]| {
1428 let period = combos[row].period.unwrap();
1429 let mult = combos[row].multiplier.unwrap();
1430 let row_warm = warm[row];
1431
1432 if row_warm >= cols {
1433 return;
1434 }
1435
1436 let pf = period as f64;
1437 let alpha_rma = 1.0 / pf;
1438
1439 let mut atr = 0.0f64;
1440 unsafe {
1441 let mut j = 0usize;
1442 while j < period {
1443 atr += *tr.get_unchecked(j);
1444 j += 1;
1445 }
1446 }
1447 atr /= pf;
1448 let mut k = period;
1449 unsafe {
1450 while k <= row_warm {
1451 let tri = *tr.get_unchecked(k);
1452 atr = (tri - atr).mul_add(alpha_rma, atr);
1453 k += 1;
1454 }
1455 }
1456
1457 if ma.eq_ignore_ascii_case("ema") {
1458 let mut acc = 0.0f64;
1459 unsafe {
1460 let mut j = 0usize;
1461 while j < period {
1462 acc += *source.get_unchecked(first + j);
1463 j += 1;
1464 }
1465 }
1466 let mut ema = acc / pf;
1467 unsafe {
1468 *mid.get_unchecked_mut(row_warm) = ema;
1469 *up.get_unchecked_mut(row_warm) = mult.mul_add(atr, ema);
1470 *low_out.get_unchecked_mut(row_warm) = (-mult).mul_add(atr, ema);
1471 }
1472
1473 let alpha_ema = 2.0 / (pf + 1.0);
1474 unsafe {
1475 let mut i = row_warm + 1;
1476 while i < cols {
1477 let tri = *tr.get_unchecked(i);
1478 atr = (tri - atr).mul_add(alpha_rma, atr);
1479
1480 let xi = *source.get_unchecked(i);
1481 ema = (xi - ema).mul_add(alpha_ema, ema);
1482
1483 *mid.get_unchecked_mut(i) = ema;
1484 *up.get_unchecked_mut(i) = mult.mul_add(atr, ema);
1485 *low_out.get_unchecked_mut(i) = (-mult).mul_add(atr, ema);
1486 i += 1;
1487 }
1488 }
1489 } else if ma.eq_ignore_ascii_case("sma") {
1490 let ps = ps.as_ref().expect("prefix sums computed for SMA");
1491
1492 let start = row_warm + 1 - period;
1493 let end = row_warm + 1;
1494 let mut sm = unsafe { (*ps.get_unchecked(end) - *ps.get_unchecked(start)) / pf };
1495 unsafe {
1496 *mid.get_unchecked_mut(row_warm) = sm;
1497 *up.get_unchecked_mut(row_warm) = mult.mul_add(atr, sm);
1498 *low_out.get_unchecked_mut(row_warm) = (-mult).mul_add(atr, sm);
1499 }
1500
1501 unsafe {
1502 let mut i = row_warm + 1;
1503 while i < cols {
1504 let tri = *tr.get_unchecked(i);
1505 atr = (tri - atr).mul_add(alpha_rma, atr);
1506 let s = (*ps.get_unchecked(i + 1) - *ps.get_unchecked(i + 1 - period)) / pf;
1507 sm = s;
1508 *mid.get_unchecked_mut(i) = sm;
1509 *up.get_unchecked_mut(i) = mult.mul_add(atr, sm);
1510 *low_out.get_unchecked_mut(i) = (-mult).mul_add(atr, sm);
1511 i += 1;
1512 }
1513 }
1514 } else {
1515 keltner_row_scalar(
1516 high, low, close, source, period, mult, ma, first, up, mid, low_out,
1517 );
1518 }
1519 };
1520 if parallel {
1521 #[cfg(not(target_arch = "wasm32"))]
1522 {
1523 upper
1524 .par_chunks_mut(cols)
1525 .zip(middle.par_chunks_mut(cols))
1526 .zip(lower.par_chunks_mut(cols))
1527 .enumerate()
1528 .for_each(|(row, ((u, m), l))| do_row(row, u, m, l));
1529 }
1530
1531 #[cfg(target_arch = "wasm32")]
1532 {
1533 for ((row, u), (m, l)) in upper
1534 .chunks_mut(cols)
1535 .enumerate()
1536 .zip(middle.chunks_mut(cols).zip(lower.chunks_mut(cols)))
1537 {
1538 do_row(row, u, m, l);
1539 }
1540 }
1541 } else {
1542 for ((row, u), (m, l)) in upper
1543 .chunks_mut(cols)
1544 .enumerate()
1545 .zip(middle.chunks_mut(cols).zip(lower.chunks_mut(cols)))
1546 {
1547 do_row(row, u, m, l);
1548 }
1549 }
1550
1551 let upper = unsafe {
1552 let ptr = upper_guard.as_mut_ptr() as *mut f64;
1553 let len = upper_guard.len();
1554 let cap = upper_guard.capacity();
1555 core::mem::forget(upper_guard);
1556 Vec::from_raw_parts(ptr, len, cap)
1557 };
1558
1559 let middle = unsafe {
1560 let ptr = middle_guard.as_mut_ptr() as *mut f64;
1561 let len = middle_guard.len();
1562 let cap = middle_guard.capacity();
1563 core::mem::forget(middle_guard);
1564 Vec::from_raw_parts(ptr, len, cap)
1565 };
1566
1567 let lower = unsafe {
1568 let ptr = lower_guard.as_mut_ptr() as *mut f64;
1569 let len = lower_guard.len();
1570 let cap = lower_guard.capacity();
1571 core::mem::forget(lower_guard);
1572 Vec::from_raw_parts(ptr, len, cap)
1573 };
1574
1575 Ok(KeltnerBatchOutput {
1576 upper_band: upper,
1577 middle_band: middle,
1578 lower_band: lower,
1579 combos,
1580 rows,
1581 cols,
1582 })
1583}
1584
1585#[derive(Debug, Clone)]
1586pub struct KeltnerStream {
1587 period: usize,
1588 rcp_period: f64,
1589 multiplier: f64,
1590
1591 ma_impl: MaImpl,
1592
1593 atr: f64,
1594 atr_sum: f64,
1595 rma_alpha: f64,
1596
1597 count: usize,
1598 prev_close: f64,
1599}
1600
1601#[derive(Debug, Clone)]
1602enum MaImpl {
1603 Ema {
1604 alpha: f64,
1605 value: f64,
1606 seed_sum: f64,
1607 },
1608
1609 Sma {
1610 buffer: Vec<f64>,
1611 sum: f64,
1612 idx: usize,
1613 filled: bool,
1614 },
1615}
1616
1617impl KeltnerStream {
1618 pub fn try_new(params: KeltnerParams) -> Result<Self, KeltnerError> {
1619 let period = params.period.unwrap_or(20);
1620 let multiplier = params.multiplier.unwrap_or(2.0);
1621 let ma_type = params
1622 .ma_type
1623 .unwrap_or_else(|| "ema".to_string())
1624 .to_lowercase();
1625
1626 if period == 0 {
1627 return Err(KeltnerError::InvalidPeriod {
1628 period,
1629 data_len: 0,
1630 });
1631 }
1632
1633 let pf = period as f64;
1634 let rcp = 1.0 / pf;
1635
1636 let ma_impl = if ma_type == "sma" {
1637 MaImpl::Sma {
1638 buffer: vec![0.0; period],
1639 sum: 0.0,
1640 idx: 0,
1641 filled: false,
1642 }
1643 } else {
1644 MaImpl::Ema {
1645 alpha: 2.0 / (pf + 1.0),
1646 value: 0.0,
1647 seed_sum: 0.0,
1648 }
1649 };
1650
1651 Ok(Self {
1652 period,
1653 rcp_period: rcp,
1654 multiplier,
1655 ma_impl,
1656 atr: 0.0,
1657 atr_sum: 0.0,
1658 rma_alpha: rcp,
1659 count: 0,
1660 prev_close: f64::NAN,
1661 })
1662 }
1663
1664 #[inline(always)]
1665 pub fn update(
1666 &mut self,
1667 high: f64,
1668 low: f64,
1669 close: f64,
1670 source: f64,
1671 ) -> Option<(f64, f64, f64)> {
1672 let tr = if self.count == 0 {
1673 high - low
1674 } else {
1675 let hl = high - low;
1676 let hc = (high - self.prev_close).abs();
1677 let lc = (low - self.prev_close).abs();
1678 hl.max(hc).max(lc)
1679 };
1680
1681 self.prev_close = close;
1682 self.count += 1;
1683
1684 if self.count < self.period {
1685 self.atr_sum += tr;
1686
1687 match &mut self.ma_impl {
1688 MaImpl::Ema { seed_sum, .. } => {
1689 *seed_sum += source;
1690 }
1691 MaImpl::Sma {
1692 buffer, sum, idx, ..
1693 } => {
1694 *sum += source;
1695 buffer[*idx] = source;
1696 *idx = (*idx + 1) % self.period;
1697 }
1698 }
1699 return None;
1700 }
1701
1702 if self.count == self.period {
1703 self.atr = (self.atr_sum + tr) * self.rcp_period;
1704
1705 let mid = match &mut self.ma_impl {
1706 MaImpl::Ema {
1707 value, seed_sum, ..
1708 } => {
1709 *seed_sum += source;
1710 *value = *seed_sum * self.rcp_period;
1711 *value
1712 }
1713 MaImpl::Sma {
1714 buffer,
1715 sum,
1716 idx,
1717 filled,
1718 } => {
1719 *sum += source;
1720 buffer[*idx] = source;
1721 *idx = (*idx + 1) % self.period;
1722 *filled = true;
1723 *sum * self.rcp_period
1724 }
1725 };
1726
1727 let up = self.multiplier.mul_add(self.atr, mid);
1728 let lo = (-self.multiplier).mul_add(self.atr, mid);
1729 return Some((up, mid, lo));
1730 }
1731
1732 self.atr = (tr - self.atr).mul_add(self.rma_alpha, self.atr);
1733
1734 let mid = match &mut self.ma_impl {
1735 MaImpl::Ema { alpha, value, .. } => {
1736 *value = (source - *value).mul_add(*alpha, *value);
1737 *value
1738 }
1739 MaImpl::Sma {
1740 buffer, sum, idx, ..
1741 } => {
1742 let old = buffer[*idx];
1743 buffer[*idx] = source;
1744 *sum += source - old;
1745 *idx = (*idx + 1) % self.period;
1746 *sum * self.rcp_period
1747 }
1748 };
1749
1750 let up = self.multiplier.mul_add(self.atr, mid);
1751 let lo = (-self.multiplier).mul_add(self.atr, mid);
1752 Some((up, mid, lo))
1753 }
1754}
1755
1756#[cfg(test)]
1757mod tests {
1758 use super::*;
1759 use crate::skip_if_unsupported;
1760 use crate::utilities::data_loader::read_candles_from_csv;
1761 use crate::utilities::enums::Kernel;
1762 #[cfg(feature = "proptest")]
1763 use proptest::prelude::*;
1764
1765 fn check_keltner_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1766 skip_if_unsupported!(kernel, test_name);
1767 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1768 let candles = read_candles_from_csv(file_path)?;
1769
1770 let params = KeltnerParams {
1771 period: Some(20),
1772 multiplier: Some(2.0),
1773 ma_type: Some("ema".to_string()),
1774 };
1775 let input = KeltnerInput::from_candles(&candles, "close", params);
1776 let result = keltner_with_kernel(&input, kernel)?;
1777
1778 assert_eq!(result.upper_band.len(), candles.close.len());
1779 assert_eq!(result.middle_band.len(), candles.close.len());
1780 assert_eq!(result.lower_band.len(), candles.close.len());
1781
1782 let last_five_index = candles.close.len().saturating_sub(5);
1783 let expected_upper = [
1784 61619.504155205745,
1785 61503.56119134791,
1786 61387.47897150178,
1787 61286.61078267451,
1788 61206.25688331261,
1789 ];
1790 let expected_middle = [
1791 59758.339871629956,
1792 59703.35512195091,
1793 59640.083205574636,
1794 59593.884805043715,
1795 59504.46720456336,
1796 ];
1797 let expected_lower = [
1798 57897.17558805417,
1799 57903.14905255391,
1800 57892.68743964749,
1801 57901.158827412924,
1802 57802.67752581411,
1803 ];
1804 let last_five_upper = &result.upper_band[last_five_index..];
1805 let last_five_middle = &result.middle_band[last_five_index..];
1806 let last_five_lower = &result.lower_band[last_five_index..];
1807 for i in 0..5 {
1808 let diff_u = (last_five_upper[i] - expected_upper[i]).abs();
1809 let diff_m = (last_five_middle[i] - expected_middle[i]).abs();
1810 let diff_l = (last_five_lower[i] - expected_lower[i]).abs();
1811 assert!(
1812 diff_u < 1e-1,
1813 "Upper band mismatch at index {}: expected {}, got {}",
1814 i,
1815 expected_upper[i],
1816 last_five_upper[i]
1817 );
1818 assert!(
1819 diff_m < 1e-1,
1820 "Middle band mismatch at index {}: expected {}, got {}",
1821 i,
1822 expected_middle[i],
1823 last_five_middle[i]
1824 );
1825 assert!(
1826 diff_l < 1e-1,
1827 "Lower band mismatch at index {}: expected {}, got {}",
1828 i,
1829 expected_lower[i],
1830 last_five_lower[i]
1831 );
1832 }
1833 Ok(())
1834 }
1835
1836 fn check_keltner_default_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1837 skip_if_unsupported!(kernel, test_name);
1838 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1839 let candles = read_candles_from_csv(file_path)?;
1840 let default_params = KeltnerParams::default();
1841 let input = KeltnerInput::from_candles(&candles, "close", default_params);
1842 let result = keltner_with_kernel(&input, kernel)?;
1843 assert_eq!(result.upper_band.len(), candles.close.len());
1844 assert_eq!(result.middle_band.len(), candles.close.len());
1845 assert_eq!(result.lower_band.len(), candles.close.len());
1846 Ok(())
1847 }
1848
1849 fn check_keltner_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850 skip_if_unsupported!(kernel, test_name);
1851 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1852 let candles = read_candles_from_csv(file_path)?;
1853 let params = KeltnerParams {
1854 period: Some(0),
1855 multiplier: Some(2.0),
1856 ma_type: Some("ema".to_string()),
1857 };
1858 let input = KeltnerInput::from_candles(&candles, "close", params);
1859 let result = keltner_with_kernel(&input, kernel);
1860 assert!(result.is_err());
1861 if let Err(e) = result {
1862 assert!(e.to_string().contains("invalid period"));
1863 }
1864 Ok(())
1865 }
1866
1867 fn check_keltner_large_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1868 skip_if_unsupported!(kernel, test_name);
1869 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1870 let candles = read_candles_from_csv(file_path)?;
1871 let params = KeltnerParams {
1872 period: Some(999999),
1873 multiplier: Some(2.0),
1874 ma_type: Some("ema".to_string()),
1875 };
1876 let input = KeltnerInput::from_candles(&candles, "close", params);
1877 let result = keltner_with_kernel(&input, kernel);
1878 assert!(result.is_err());
1879 if let Err(e) = result {
1880 assert!(e.to_string().contains("invalid period"));
1881 }
1882 Ok(())
1883 }
1884
1885 fn check_keltner_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1886 skip_if_unsupported!(kernel, test_name);
1887 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1888 let candles = read_candles_from_csv(file_path)?;
1889 let params = KeltnerParams::default();
1890 let input = KeltnerInput::from_candles(&candles, "close", params);
1891 let result = keltner_with_kernel(&input, kernel)?;
1892 assert_eq!(result.middle_band.len(), candles.close.len());
1893 if result.middle_band.len() > 240 {
1894 for (i, &val) in result.middle_band[240..].iter().enumerate() {
1895 assert!(
1896 !val.is_nan(),
1897 "[{}] Found unexpected NaN at out-index {}",
1898 test_name,
1899 240 + i
1900 );
1901 }
1902 }
1903 Ok(())
1904 }
1905
1906 fn check_keltner_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1907 skip_if_unsupported!(kernel, test_name);
1908
1909 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1910 let candles = read_candles_from_csv(file_path)?;
1911 let period = 20;
1912 let multiplier = 2.0;
1913
1914 let params = KeltnerParams {
1915 period: Some(period),
1916 multiplier: Some(multiplier),
1917 ma_type: Some("ema".to_string()),
1918 };
1919 let input = KeltnerInput::from_candles(&candles, "close", params.clone());
1920 let batch_output = keltner_with_kernel(&input, kernel)?;
1921
1922 let mut stream = KeltnerStream::try_new(params)?;
1923 let mut upper_stream = Vec::with_capacity(candles.close.len());
1924 let mut middle_stream = Vec::with_capacity(candles.close.len());
1925 let mut lower_stream = Vec::with_capacity(candles.close.len());
1926
1927 for i in 0..candles.close.len() {
1928 let hi = candles.high[i];
1929 let lo = candles.low[i];
1930 let cl = candles.close[i];
1931 let src = candles.close[i];
1932 match stream.update(hi, lo, cl, src) {
1933 Some((up, mid, low)) => {
1934 upper_stream.push(up);
1935 middle_stream.push(mid);
1936 lower_stream.push(low);
1937 }
1938 None => {
1939 upper_stream.push(f64::NAN);
1940 middle_stream.push(f64::NAN);
1941 lower_stream.push(f64::NAN);
1942 }
1943 }
1944 }
1945 assert_eq!(batch_output.upper_band.len(), upper_stream.len());
1946 assert_eq!(batch_output.middle_band.len(), middle_stream.len());
1947 assert_eq!(batch_output.lower_band.len(), lower_stream.len());
1948 for (i, (&b, &s)) in batch_output
1949 .middle_band
1950 .iter()
1951 .zip(middle_stream.iter())
1952 .enumerate()
1953 {
1954 if b.is_nan() && s.is_nan() {
1955 continue;
1956 }
1957 let diff = (b - s).abs();
1958 assert!(
1959 diff < 1e-8,
1960 "[{}] Keltner streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1961 test_name,
1962 i,
1963 b,
1964 s,
1965 diff
1966 );
1967 }
1968 Ok(())
1969 }
1970
1971 #[cfg(debug_assertions)]
1972 fn check_keltner_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1973 skip_if_unsupported!(kernel, test_name);
1974
1975 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1976 let candles = read_candles_from_csv(file_path)?;
1977
1978 let test_params = vec![
1979 KeltnerParams::default(),
1980 KeltnerParams {
1981 period: Some(2),
1982 multiplier: Some(1.0),
1983 ma_type: Some("ema".to_string()),
1984 },
1985 KeltnerParams {
1986 period: Some(5),
1987 multiplier: Some(0.5),
1988 ma_type: Some("ema".to_string()),
1989 },
1990 KeltnerParams {
1991 period: Some(10),
1992 multiplier: Some(1.5),
1993 ma_type: Some("sma".to_string()),
1994 },
1995 KeltnerParams {
1996 period: Some(20),
1997 multiplier: Some(3.0),
1998 ma_type: Some("ema".to_string()),
1999 },
2000 KeltnerParams {
2001 period: Some(50),
2002 multiplier: Some(2.5),
2003 ma_type: Some("sma".to_string()),
2004 },
2005 KeltnerParams {
2006 period: Some(100),
2007 multiplier: Some(1.0),
2008 ma_type: Some("ema".to_string()),
2009 },
2010 KeltnerParams {
2011 period: Some(14),
2012 multiplier: Some(2.0),
2013 ma_type: Some("sma".to_string()),
2014 },
2015 KeltnerParams {
2016 period: Some(7),
2017 multiplier: Some(1.0),
2018 ma_type: Some("ema".to_string()),
2019 },
2020 KeltnerParams {
2021 period: Some(21),
2022 multiplier: Some(1.5),
2023 ma_type: Some("ema".to_string()),
2024 },
2025 KeltnerParams {
2026 period: Some(30),
2027 multiplier: Some(2.0),
2028 ma_type: Some("sma".to_string()),
2029 },
2030 KeltnerParams {
2031 period: Some(3),
2032 multiplier: Some(0.75),
2033 ma_type: Some("ema".to_string()),
2034 },
2035 ];
2036
2037 for (param_idx, params) in test_params.iter().enumerate() {
2038 let input = KeltnerInput::from_candles(&candles, "close", params.clone());
2039 let output = keltner_with_kernel(&input, kernel)?;
2040
2041 for (band_name, band_values) in [
2042 ("upper", &output.upper_band),
2043 ("middle", &output.middle_band),
2044 ("lower", &output.lower_band),
2045 ] {
2046 for (i, &val) in band_values.iter().enumerate() {
2047 if val.is_nan() {
2048 continue;
2049 }
2050
2051 let bits = val.to_bits();
2052
2053 if bits == 0x11111111_11111111 {
2054 panic!(
2055 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2056 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2057 test_name, val, bits, i, band_name,
2058 params.period.unwrap_or(20),
2059 params.multiplier.unwrap_or(2.0),
2060 params.ma_type.as_deref().unwrap_or("ema"),
2061 param_idx
2062 );
2063 }
2064
2065 if bits == 0x22222222_22222222 {
2066 panic!(
2067 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2068 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2069 test_name, val, bits, i, band_name,
2070 params.period.unwrap_or(20),
2071 params.multiplier.unwrap_or(2.0),
2072 params.ma_type.as_deref().unwrap_or("ema"),
2073 param_idx
2074 );
2075 }
2076
2077 if bits == 0x33333333_33333333 {
2078 panic!(
2079 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2080 in {} band with params: period={}, multiplier={}, ma_type={} (param set {})",
2081 test_name, val, bits, i, band_name,
2082 params.period.unwrap_or(20),
2083 params.multiplier.unwrap_or(2.0),
2084 params.ma_type.as_deref().unwrap_or("ema"),
2085 param_idx
2086 );
2087 }
2088 }
2089 }
2090 }
2091
2092 Ok(())
2093 }
2094
2095 #[cfg(not(debug_assertions))]
2096 fn check_keltner_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2097 Ok(())
2098 }
2099
2100 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2101 skip_if_unsupported!(kernel, test_name);
2102
2103 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2104 let c = read_candles_from_csv(file)?;
2105 let output = KeltnerBatchBuilder::new()
2106 .kernel(kernel)
2107 .apply_candles(&c, "close")?;
2108
2109 let def = KeltnerParams::default();
2110 let (upper, middle, lower) = output.values_for(&def).expect("default row missing");
2111
2112 assert_eq!(upper.len(), c.close.len());
2113 assert_eq!(middle.len(), c.close.len());
2114 assert_eq!(lower.len(), c.close.len());
2115
2116 Ok(())
2117 }
2118
2119 #[cfg(debug_assertions)]
2120 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2121 skip_if_unsupported!(kernel, test);
2122
2123 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2124 let c = read_candles_from_csv(file)?;
2125
2126 let test_configs = vec![
2127 (2, 10, 2, 0.5, 2.5, 0.5),
2128 (5, 25, 5, 1.0, 3.0, 1.0),
2129 (30, 60, 15, 2.0, 2.0, 0.0),
2130 (2, 5, 1, 1.5, 2.5, 0.25),
2131 (10, 30, 10, 0.75, 2.25, 0.75),
2132 (14, 21, 7, 1.0, 2.0, 0.5),
2133 (20, 20, 0, 0.5, 3.0, 0.5),
2134 ];
2135
2136 for (cfg_idx, &(p_start, p_end, p_step, m_start, m_end, m_step)) in
2137 test_configs.iter().enumerate()
2138 {
2139 let output = KeltnerBatchBuilder::new()
2140 .kernel(kernel)
2141 .period_range(p_start, p_end, p_step)
2142 .multiplier_range(m_start, m_end, m_step)
2143 .apply_candles(&c, "close")?;
2144
2145 for (band_name, band_values) in [
2146 ("upper", &output.upper_band),
2147 ("middle", &output.middle_band),
2148 ("lower", &output.lower_band),
2149 ] {
2150 for (idx, &val) in band_values.iter().enumerate() {
2151 if val.is_nan() {
2152 continue;
2153 }
2154
2155 let bits = val.to_bits();
2156 let row = idx / output.cols;
2157 let col = idx % output.cols;
2158 let combo = &output.combos[row];
2159
2160 if bits == 0x11111111_11111111 {
2161 panic!(
2162 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2163 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2164 test, cfg_idx, val, bits, row, col, idx, band_name,
2165 combo.period.unwrap_or(20),
2166 combo.multiplier.unwrap_or(2.0)
2167 );
2168 }
2169
2170 if bits == 0x22222222_22222222 {
2171 panic!(
2172 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2173 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2174 test, cfg_idx, val, bits, row, col, idx, band_name,
2175 combo.period.unwrap_or(20),
2176 combo.multiplier.unwrap_or(2.0)
2177 );
2178 }
2179
2180 if bits == 0x33333333_33333333 {
2181 panic!(
2182 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2183 at row {} col {} (flat index {}) in {} band with params: period={}, multiplier={}",
2184 test,
2185 cfg_idx,
2186 val,
2187 bits,
2188 row,
2189 col,
2190 idx,
2191 band_name,
2192 combo.period.unwrap_or(20),
2193 combo.multiplier.unwrap_or(2.0)
2194 );
2195 }
2196 }
2197 }
2198 }
2199
2200 Ok(())
2201 }
2202
2203 #[cfg(not(debug_assertions))]
2204 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2205 Ok(())
2206 }
2207
2208 #[cfg(feature = "proptest")]
2209 #[allow(clippy::float_cmp)]
2210 fn check_keltner_property(
2211 test_name: &str,
2212 kernel: Kernel,
2213 ) -> Result<(), Box<dyn std::error::Error>> {
2214 use proptest::prelude::*;
2215 skip_if_unsupported!(kernel, test_name);
2216
2217 let strat = (
2218 2usize..=50,
2219 50usize..500,
2220 0.5f64..3.0f64,
2221 0usize..6,
2222 any::<u64>(),
2223 )
2224 .prop_map(|(period, len, multiplier, scenario, seed)| {
2225 let mut high = Vec::with_capacity(len);
2226 let mut low = Vec::with_capacity(len);
2227 let mut close = Vec::with_capacity(len);
2228
2229 let mut rng_state = seed;
2230 let mut next_random = || -> f64 {
2231 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2232 (rng_state as f64) / (u64::MAX as f64)
2233 };
2234
2235 match scenario {
2236 0 => {
2237 let mut prev_close = 100.0;
2238 for _ in 0..len {
2239 let volatility = 0.01 + next_random() * 0.04;
2240 let change = -volatility + next_random() * (2.0 * volatility);
2241 let new_close = (prev_close * (1.0 + change)).max(0.1f64);
2242 let high_val = new_close * (1.0 + next_random() * volatility);
2243 let low_val = new_close * (1.0 - next_random() * volatility);
2244
2245 high.push(high_val);
2246 low.push(low_val.min(high_val));
2247 close.push(new_close);
2248 prev_close = new_close;
2249 }
2250 }
2251 1 => {
2252 let start = 100.0;
2253 for i in 0..len {
2254 let base = start * (1.0 + 0.01 * i as f64);
2255 let spread = base * 0.02;
2256 high.push(base + spread);
2257 low.push(base - spread);
2258 close.push(base);
2259 }
2260 }
2261 2 => {
2262 let start = 100.0;
2263 for i in 0..len {
2264 let base = start * (1.0 - 0.005 * i as f64).max(10.0);
2265 let spread = base * 0.02;
2266 high.push(base + spread);
2267 low.push(base - spread);
2268 close.push(base);
2269 }
2270 }
2271 3 => {
2272 let mut price = 100.0;
2273 for i in 0..len {
2274 let volatility = 0.1 * (1.0 + (i as f64 * 0.1).sin());
2275 let change = if i % 2 == 0 {
2276 volatility
2277 } else {
2278 -volatility * 0.8
2279 };
2280 price = (price * (1.0 + change)).max(1.0);
2281
2282 let spread = price * volatility;
2283 high.push(price + spread);
2284 low.push(price - spread * 0.8);
2285 close.push(price);
2286 }
2287 }
2288 4 => {
2289 let constant_price = 50.0;
2290 high = vec![constant_price; len];
2291 low = vec![constant_price; len];
2292 close = vec![constant_price; len];
2293 }
2294 _ => {
2295 let mut price = 1000.0;
2296 let mut momentum = 0.0;
2297
2298 for i in 0..len {
2299 momentum = momentum * 0.9 + (if i % 20 < 10 { 0.001 } else { -0.001 });
2300 let noise = ((i as f64 * 0.3).sin() * 0.005);
2301 price = (price * (1.0 + momentum + noise)).max(100.0);
2302
2303 let daily_range = price * 0.02;
2304 let high_val = price + daily_range * 0.6;
2305 let low_val = price - daily_range * 0.4;
2306
2307 high.push(high_val);
2308 low.push(low_val);
2309 close.push(price);
2310 }
2311 }
2312 }
2313
2314 let ma_type = if next_random() > 0.5 { "ema" } else { "sma" };
2315 (high, low, close, period, multiplier, ma_type.to_string())
2316 });
2317
2318 proptest::test_runner::TestRunner::default()
2319 .run(&strat, |(high, low, close, period, multiplier, ma_type)| {
2320
2321 let source = close.clone();
2322
2323 let params = KeltnerParams {
2324 period: Some(period),
2325 multiplier: Some(multiplier),
2326 ma_type: Some(ma_type.clone()),
2327 };
2328 let input = KeltnerInput::from_slice(&high, &low, &close, &source, params.clone());
2329
2330 let result = keltner_with_kernel(&input, kernel).unwrap();
2331 let scalar_result = keltner_with_kernel(&input, Kernel::Scalar).unwrap();
2332
2333
2334 prop_assert_eq!(result.upper_band.len(), close.len());
2335 prop_assert_eq!(result.middle_band.len(), close.len());
2336 prop_assert_eq!(result.lower_band.len(), close.len());
2337
2338
2339 let warmup = period - 1;
2340 for i in 0..warmup.min(close.len()) {
2341 prop_assert!(
2342 result.upper_band[i].is_nan(),
2343 "Upper band[{}] should be NaN during warmup", i
2344 );
2345 prop_assert!(
2346 result.middle_band[i].is_nan(),
2347 "Middle band[{}] should be NaN during warmup", i
2348 );
2349 prop_assert!(
2350 result.lower_band[i].is_nan(),
2351 "Lower band[{}] should be NaN during warmup", i
2352 );
2353 }
2354
2355
2356 for i in warmup..close.len() {
2357 let upper = result.upper_band[i];
2358 let middle = result.middle_band[i];
2359 let lower = result.lower_band[i];
2360
2361 let scalar_upper = scalar_result.upper_band[i];
2362 let scalar_middle = scalar_result.middle_band[i];
2363 let scalar_lower = scalar_result.lower_band[i];
2364
2365
2366 if upper.is_nan() || middle.is_nan() || lower.is_nan() {
2367 continue;
2368 }
2369
2370
2371 prop_assert!(
2372 upper >= middle - 1e-10,
2373 "Upper band {} must be >= middle band {} at index {}", upper, middle, i
2374 );
2375 prop_assert!(
2376 middle >= lower - 1e-10,
2377 "Middle band {} must be >= lower band {} at index {}", middle, lower, i
2378 );
2379
2380
2381 let spread = upper - lower;
2382 prop_assert!(
2383 spread >= -1e-10,
2384 "Spread {} must be positive at index {}", spread, i
2385 );
2386
2387
2388 if i >= warmup + period {
2389
2390 let scalar_spread = scalar_upper - scalar_lower;
2391 if scalar_spread > 0.0 && spread > 0.0 {
2392 let ratio = spread / scalar_spread;
2393 prop_assert!(
2394 (ratio - 1.0).abs() < 0.01,
2395 "Spread ratio between kernels should be consistent at index {}: ratio={}", i, ratio
2396 );
2397 }
2398 }
2399
2400
2401 let window_start = i.saturating_sub(period - 1);
2402 let window_high = high[window_start..=i].iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2403 let window_low = low[window_start..=i].iter().fold(f64::INFINITY, |a, &b| a.min(b));
2404
2405
2406 prop_assert!(
2407 middle <= window_high * 1.05 + 1.0,
2408 "Middle band {} exceeds window high {} at index {}", middle, window_high, i
2409 );
2410 prop_assert!(
2411 middle >= window_low * 0.95 - 1.0,
2412 "Middle band {} below window low {} at index {}", middle, window_low, i
2413 );
2414
2415
2416 let tolerance = 1e-9;
2417 prop_assert!(
2418 (upper - scalar_upper).abs() <= tolerance,
2419 "Upper band kernel mismatch at {}: {} vs {} (diff: {})",
2420 i, upper, scalar_upper, (upper - scalar_upper).abs()
2421 );
2422 prop_assert!(
2423 (middle - scalar_middle).abs() <= tolerance,
2424 "Middle band kernel mismatch at {}: {} vs {} (diff: {})",
2425 i, middle, scalar_middle, (middle - scalar_middle).abs()
2426 );
2427 prop_assert!(
2428 (lower - scalar_lower).abs() <= tolerance,
2429 "Lower band kernel mismatch at {}: {} vs {} (diff: {})",
2430 i, lower, scalar_lower, (lower - scalar_lower).abs()
2431 );
2432
2433
2434 #[cfg(debug_assertions)]
2435 {
2436 let upper_bits = upper.to_bits();
2437 let middle_bits = middle.to_bits();
2438 let lower_bits = lower.to_bits();
2439
2440 prop_assert!(
2441 upper_bits != 0x11111111_11111111 &&
2442 upper_bits != 0x22222222_22222222 &&
2443 upper_bits != 0x33333333_33333333,
2444 "Found poison value in upper band at index {}", i
2445 );
2446 prop_assert!(
2447 middle_bits != 0x11111111_11111111 &&
2448 middle_bits != 0x22222222_22222222 &&
2449 middle_bits != 0x33333333_33333333,
2450 "Found poison value in middle band at index {}", i
2451 );
2452 prop_assert!(
2453 lower_bits != 0x11111111_11111111 &&
2454 lower_bits != 0x22222222_22222222 &&
2455 lower_bits != 0x33333333_33333333,
2456 "Found poison value in lower band at index {}", i
2457 );
2458 }
2459
2460
2461
2462 let all_same = high[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) &&
2463 low[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) &&
2464 close[window_start..=i].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
2465
2466
2467 let no_spread = high[window_start..=i].iter()
2468 .zip(low[window_start..=i].iter())
2469 .zip(close[window_start..=i].iter())
2470 .all(|((h, l), c)| (h - l).abs() < 1e-10 && (h - c).abs() < 1e-10);
2471
2472 if all_same && no_spread && i >= warmup + period * 3 {
2473
2474
2475 let band_spread = upper - lower;
2476 prop_assert!(
2477 band_spread < 0.01 || band_spread < middle * 0.001,
2478 "Bands should converge for constant prices with no spread, but spread is {} at index {} (middle: {})",
2479 band_spread, i, middle
2480 );
2481 }
2482 }
2483
2484 Ok(())
2485 })
2486 .unwrap();
2487
2488 Ok(())
2489 }
2490
2491 macro_rules! generate_all_keltner_tests {
2492 ($($test_fn:ident),*) => {
2493 paste::paste! {
2494 $(
2495 #[test]
2496 fn [<$test_fn _scalar_f64>]() {
2497 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2498 }
2499 )*
2500 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2501 $(
2502 #[test]
2503 fn [<$test_fn _avx2_f64>]() {
2504 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2505 }
2506 #[test]
2507 fn [<$test_fn _avx512_f64>]() {
2508 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2509 }
2510 )*
2511 }
2512 }
2513 }
2514
2515 generate_all_keltner_tests!(
2516 check_keltner_accuracy,
2517 check_keltner_default_params,
2518 check_keltner_zero_period,
2519 check_keltner_large_period,
2520 check_keltner_nan_handling,
2521 check_keltner_streaming,
2522 check_keltner_no_poison
2523 );
2524
2525 #[cfg(feature = "proptest")]
2526 generate_all_keltner_tests!(check_keltner_property);
2527
2528 macro_rules! gen_batch_tests {
2529 ($fn_name:ident) => {
2530 paste::paste! {
2531 #[test] fn [<$fn_name _scalar>]() {
2532 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2533 }
2534 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2535 #[test] fn [<$fn_name _avx2>]() {
2536 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2537 }
2538 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2539 #[test] fn [<$fn_name _avx512>]() {
2540 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2541 }
2542 #[test] fn [<$fn_name _auto_detect>]() {
2543 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2544 }
2545 }
2546 };
2547 }
2548 gen_batch_tests!(check_batch_default_row);
2549 gen_batch_tests!(check_batch_no_poison);
2550
2551 #[test]
2552 fn test_keltner_into_matches_api() {
2553 use crate::utilities::data_loader::Candles;
2554
2555 let n = 256usize;
2556 let mut ts = Vec::with_capacity(n);
2557 let mut open = Vec::with_capacity(n);
2558 let mut high = Vec::with_capacity(n);
2559 let mut low = Vec::with_capacity(n);
2560 let mut close = Vec::with_capacity(n);
2561 let mut volume = Vec::with_capacity(n);
2562
2563 let mut price = 1000.0f64;
2564 for i in 0..n {
2565 let i_f = i as f64;
2566 let drift = (i_f * 0.001).sin() * 0.5;
2567 let noise = (i_f * 0.07).cos() * 0.8;
2568 price = (price + drift + noise).max(1.0);
2569 let spread = 2.0 + (i % 5) as f64;
2570 let h = price + spread;
2571 let l = price - spread * 0.8;
2572 let o = price - 0.25 * spread;
2573 let c = price + 0.25 * spread;
2574
2575 ts.push(i as i64);
2576 open.push(o);
2577 high.push(h);
2578 low.push(l);
2579 close.push(c);
2580 volume.push(1000.0 + i as f64);
2581 }
2582
2583 let candles = Candles::new(ts, open, high, low, close, volume);
2584 let input = KeltnerInput::from_candles(&candles, "close", KeltnerParams::default());
2585
2586 let base = keltner(&input).expect("keltner baseline failed");
2587
2588 let len = candles.close.len();
2589 let mut up = vec![0.0; len];
2590 let mut mid = vec![0.0; len];
2591 let mut lo = vec![0.0; len];
2592
2593 keltner_into(&input, &mut up, &mut mid, &mut lo).expect("keltner_into failed");
2594
2595 assert_eq!(base.upper_band.len(), up.len());
2596 assert_eq!(base.middle_band.len(), mid.len());
2597 assert_eq!(base.lower_band.len(), lo.len());
2598
2599 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2600 (a.is_nan() && b.is_nan()) || (a == b)
2601 }
2602
2603 for i in 0..len {
2604 assert!(
2605 eq_or_both_nan(base.upper_band[i], up[i]),
2606 "upper mismatch at {}: base={} into={}",
2607 i,
2608 base.upper_band[i],
2609 up[i]
2610 );
2611 assert!(
2612 eq_or_both_nan(base.middle_band[i], mid[i]),
2613 "middle mismatch at {}: base={} into={}",
2614 i,
2615 base.middle_band[i],
2616 mid[i]
2617 );
2618 assert!(
2619 eq_or_both_nan(base.lower_band[i], lo[i]),
2620 "lower mismatch at {}: base={} into={}",
2621 i,
2622 base.lower_band[i],
2623 lo[i]
2624 );
2625 }
2626 }
2627}
2628
2629#[cfg(all(feature = "python", feature = "cuda"))]
2630pub struct KeltnerDeviceArrayF32 {
2631 pub inner: DeviceArrayF32,
2632 pub context: Arc<Context>,
2633 pub device_id: u32,
2634}
2635
2636#[cfg(all(feature = "python", feature = "cuda"))]
2637#[pyclass(module = "ta_indicators.cuda", unsendable)]
2638pub struct KeltnerDeviceArrayF32Py {
2639 pub(crate) inner: KeltnerDeviceArrayF32,
2640}
2641
2642#[cfg(all(feature = "python", feature = "cuda"))]
2643#[pymethods]
2644impl KeltnerDeviceArrayF32Py {
2645 #[getter]
2646 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2647 let inner = &self.inner.inner;
2648 let d = PyDict::new(py);
2649 d.set_item("shape", (inner.rows, inner.cols))?;
2650 d.set_item("typestr", "<f4")?;
2651 d.set_item(
2652 "strides",
2653 (
2654 inner.cols * std::mem::size_of::<f32>(),
2655 std::mem::size_of::<f32>(),
2656 ),
2657 )?;
2658 d.set_item("data", (inner.device_ptr() as usize, false))?;
2659
2660 d.set_item("version", 3)?;
2661 Ok(d)
2662 }
2663
2664 fn __dlpack_device__(&self) -> (i32, i32) {
2665 (2, self.inner.device_id as i32)
2666 }
2667
2668 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2669 fn __dlpack__<'py>(
2670 &mut self,
2671 py: Python<'py>,
2672 stream: Option<pyo3::PyObject>,
2673 max_version: Option<pyo3::PyObject>,
2674 dl_device: Option<pyo3::PyObject>,
2675 copy: Option<pyo3::PyObject>,
2676 ) -> PyResult<PyObject> {
2677 use cust::memory::DeviceBuffer;
2678
2679 let (kdl, alloc_dev) = self.__dlpack_device__();
2680 if let Some(dev_obj) = dl_device.as_ref() {
2681 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2682 if dev_ty != kdl || dev_id != alloc_dev {
2683 let wants_copy = copy
2684 .as_ref()
2685 .and_then(|c| c.extract::<bool>(py).ok())
2686 .unwrap_or(false);
2687 if wants_copy {
2688 return Err(PyValueError::new_err(
2689 "device copy not implemented for __dlpack__",
2690 ));
2691 } else {
2692 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2693 }
2694 }
2695 }
2696 }
2697
2698 let _ = stream;
2699
2700 if let Some(copy_obj) = copy.as_ref() {
2701 let do_copy: bool = copy_obj.extract(py)?;
2702 if do_copy {
2703 return Err(PyValueError::new_err(
2704 "__dlpack__(copy=True) not supported for keltner CUDA buffers",
2705 ));
2706 }
2707 }
2708
2709 let dummy =
2710 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2711 let context = self.inner.context.clone();
2712 let device_id = self.inner.device_id;
2713 let inner = std::mem::replace(
2714 &mut self.inner,
2715 KeltnerDeviceArrayF32 {
2716 inner: DeviceArrayF32 {
2717 buf: dummy,
2718 rows: 0,
2719 cols: 0,
2720 },
2721 context,
2722 device_id,
2723 },
2724 );
2725
2726 let rows = inner.inner.rows;
2727 let cols = inner.inner.cols;
2728
2729 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2730
2731 export_f32_cuda_dlpack_2d(
2732 py,
2733 inner.inner.buf,
2734 rows,
2735 cols,
2736 alloc_dev,
2737 max_version_bound,
2738 )
2739 }
2740}
2741
2742#[cfg(feature = "python")]
2743#[pyfunction(name = "keltner")]
2744#[pyo3(signature = (high, low, close, source, period, multiplier, ma_type="ema", kernel=None))]
2745pub fn keltner_py<'py>(
2746 py: Python<'py>,
2747 high: numpy::PyReadonlyArray1<'py, f64>,
2748 low: numpy::PyReadonlyArray1<'py, f64>,
2749 close: numpy::PyReadonlyArray1<'py, f64>,
2750 source: numpy::PyReadonlyArray1<'py, f64>,
2751 period: usize,
2752 multiplier: f64,
2753 ma_type: &str,
2754 kernel: Option<&str>,
2755) -> PyResult<(
2756 Bound<'py, numpy::PyArray1<f64>>,
2757 Bound<'py, numpy::PyArray1<f64>>,
2758 Bound<'py, numpy::PyArray1<f64>>,
2759)> {
2760 use numpy::{PyArray1, PyArrayMethods};
2761
2762 let h = high.as_slice()?;
2763 let l = low.as_slice()?;
2764 let c = close.as_slice()?;
2765 let s = source.as_slice()?;
2766 let len = c.len();
2767
2768 let mut up_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2769 let mut mid_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2770 let mut low_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
2771
2772 let up = unsafe { up_arr.as_slice_mut()? };
2773 let mid = unsafe { mid_arr.as_slice_mut()? };
2774 let lowo = unsafe { low_arr.as_slice_mut()? };
2775
2776 let params = KeltnerParams {
2777 period: Some(period),
2778 multiplier: Some(multiplier),
2779 ma_type: Some(ma_type.to_string()),
2780 };
2781 let input = KeltnerInput::from_slice(h, l, c, s, params);
2782 let kern = validate_kernel(kernel, false)?;
2783
2784 py.allow_threads(|| keltner_into_slice(up, mid, lowo, &input, kern))
2785 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2786
2787 Ok((up_arr, mid_arr, low_arr))
2788}
2789
2790#[cfg(feature = "python")]
2791#[pyclass(name = "KeltnerStream")]
2792pub struct KeltnerStreamPy {
2793 stream: KeltnerStream,
2794}
2795
2796#[cfg(feature = "python")]
2797#[pymethods]
2798impl KeltnerStreamPy {
2799 #[new]
2800 fn new(period: usize, multiplier: f64, ma_type: &str) -> PyResult<Self> {
2801 let params = KeltnerParams {
2802 period: Some(period),
2803 multiplier: Some(multiplier),
2804 ma_type: Some(ma_type.to_string()),
2805 };
2806 let stream =
2807 KeltnerStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2808 Ok(KeltnerStreamPy { stream })
2809 }
2810
2811 fn update(&mut self, high: f64, low: f64, close: f64, source: f64) -> Option<(f64, f64, f64)> {
2812 self.stream.update(high, low, close, source)
2813 }
2814}
2815
2816#[cfg(feature = "python")]
2817#[pyfunction(name = "keltner_batch")]
2818#[pyo3(signature = (high, low, close, source, period_range, multiplier_range, kernel=None))]
2819pub fn keltner_batch_py<'py>(
2820 py: Python<'py>,
2821 high: numpy::PyReadonlyArray1<'py, f64>,
2822 low: numpy::PyReadonlyArray1<'py, f64>,
2823 close: numpy::PyReadonlyArray1<'py, f64>,
2824 source: numpy::PyReadonlyArray1<'py, f64>,
2825 period_range: (usize, usize, usize),
2826 multiplier_range: (f64, f64, f64),
2827 kernel: Option<&str>,
2828) -> PyResult<Bound<'py, PyDict>> {
2829 use numpy::{PyArray1, PyArrayMethods};
2830 let h = high.as_slice()?;
2831 let l = low.as_slice()?;
2832 let c = close.as_slice()?;
2833 let s = source.as_slice()?;
2834
2835 let sweep = KeltnerBatchRange {
2836 period: period_range,
2837 multiplier: multiplier_range,
2838 };
2839 let kern = validate_kernel(kernel, true)?;
2840
2841 let out = py
2842 .allow_threads(|| {
2843 keltner_batch_par_slice(
2844 h,
2845 l,
2846 c,
2847 s,
2848 &sweep,
2849 match kern {
2850 Kernel::Auto => detect_best_batch_kernel(),
2851 k => k,
2852 },
2853 )
2854 })
2855 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2856
2857 let rows = out.rows;
2858 let cols = out.cols;
2859
2860 let total = rows
2861 .checked_mul(cols)
2862 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2863 let up_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2864 let mid_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2865 let low_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2866
2867 unsafe { up_arr.as_slice_mut()? }.copy_from_slice(&out.upper_band);
2868 unsafe { mid_arr.as_slice_mut()? }.copy_from_slice(&out.middle_band);
2869 unsafe { low_arr.as_slice_mut()? }.copy_from_slice(&out.lower_band);
2870
2871 let dict = PyDict::new(py);
2872 dict.set_item("upper", up_arr.reshape((rows, cols))?)?;
2873 dict.set_item("middle", mid_arr.reshape((rows, cols))?)?;
2874 dict.set_item("lower", low_arr.reshape((rows, cols))?)?;
2875 use numpy::IntoPyArray;
2876 dict.set_item(
2877 "periods",
2878 out.combos
2879 .iter()
2880 .map(|p| p.period.unwrap() as u64)
2881 .collect::<Vec<_>>()
2882 .into_pyarray(py),
2883 )?;
2884 dict.set_item(
2885 "multipliers",
2886 out.combos
2887 .iter()
2888 .map(|p| p.multiplier.unwrap())
2889 .collect::<Vec<_>>()
2890 .into_pyarray(py),
2891 )?;
2892 Ok(dict)
2893}
2894
2895#[cfg(all(feature = "python", feature = "cuda"))]
2896#[pyfunction(name = "keltner_cuda_batch_dev")]
2897#[pyo3(signature = (high_f32, low_f32, close_f32, source_f32, period_range, multiplier_range, ma_type="ema", device_id=0))]
2898pub fn keltner_cuda_batch_dev_py<'py>(
2899 py: Python<'py>,
2900 high_f32: numpy::PyReadonlyArray1<'py, f32>,
2901 low_f32: numpy::PyReadonlyArray1<'py, f32>,
2902 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2903 source_f32: numpy::PyReadonlyArray1<'py, f32>,
2904 period_range: (usize, usize, usize),
2905 multiplier_range: (f64, f64, f64),
2906 ma_type: &str,
2907 device_id: usize,
2908) -> PyResult<Bound<'py, PyDict>> {
2909 use crate::cuda::cuda_available;
2910 if !cuda_available() {
2911 return Err(PyValueError::new_err("CUDA not available"));
2912 }
2913 let h = high_f32.as_slice()?;
2914 let l = low_f32.as_slice()?;
2915 let c = close_f32.as_slice()?;
2916 let s = source_f32.as_slice()?;
2917 if !(h.len() == l.len() && l.len() == c.len() && c.len() == s.len()) {
2918 return Err(PyValueError::new_err("input length mismatch"));
2919 }
2920 let sweep = KeltnerBatchRange {
2921 period: period_range,
2922 multiplier: multiplier_range,
2923 };
2924 let (up, mid, low, rows, cols, ctx, dev_id) = py.allow_threads(|| {
2925 let cuda = CudaKeltner::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2926 let ctx = cuda.context_arc();
2927 let dev_id = cuda.device_id();
2928 let res = cuda
2929 .keltner_batch_dev(h, l, c, s, &sweep, ma_type)
2930 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2931 let rows = res.outputs.upper.rows;
2932 let cols = res.outputs.upper.cols;
2933 Ok::<_, PyErr>((
2934 res.outputs.upper,
2935 res.outputs.middle,
2936 res.outputs.lower,
2937 rows,
2938 cols,
2939 ctx,
2940 dev_id,
2941 ))
2942 })?;
2943 let dict = PyDict::new(py);
2944 dict.set_item(
2945 "upper",
2946 Py::new(
2947 py,
2948 KeltnerDeviceArrayF32Py {
2949 inner: KeltnerDeviceArrayF32 {
2950 inner: up,
2951 context: ctx.clone(),
2952 device_id: dev_id,
2953 },
2954 },
2955 )?,
2956 )?;
2957 dict.set_item(
2958 "middle",
2959 Py::new(
2960 py,
2961 KeltnerDeviceArrayF32Py {
2962 inner: KeltnerDeviceArrayF32 {
2963 inner: mid,
2964 context: ctx.clone(),
2965 device_id: dev_id,
2966 },
2967 },
2968 )?,
2969 )?;
2970 dict.set_item(
2971 "lower",
2972 Py::new(
2973 py,
2974 KeltnerDeviceArrayF32Py {
2975 inner: KeltnerDeviceArrayF32 {
2976 inner: low,
2977 context: ctx,
2978 device_id: dev_id,
2979 },
2980 },
2981 )?,
2982 )?;
2983 dict.set_item("rows", rows)?;
2984 dict.set_item("cols", cols)?;
2985 Ok(dict)
2986}
2987
2988#[cfg(all(feature = "python", feature = "cuda"))]
2989#[pyfunction(name = "keltner_cuda_many_series_one_param_dev")]
2990#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, source_tm_f32, cols, rows, period, multiplier, ma_type="ema", device_id=0))]
2991pub fn keltner_cuda_many_series_one_param_dev_py(
2992 py: Python<'_>,
2993 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2994 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2995 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2996 source_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2997 cols: usize,
2998 rows: usize,
2999 period: usize,
3000 multiplier: f32,
3001 ma_type: &str,
3002 device_id: usize,
3003) -> PyResult<(
3004 KeltnerDeviceArrayF32Py,
3005 KeltnerDeviceArrayF32Py,
3006 KeltnerDeviceArrayF32Py,
3007)> {
3008 use crate::cuda::cuda_available;
3009 if !cuda_available() {
3010 return Err(PyValueError::new_err("CUDA not available"));
3011 }
3012 let ht = high_tm_f32.as_slice()?;
3013 let lt = low_tm_f32.as_slice()?;
3014 let ct = close_tm_f32.as_slice()?;
3015 let st = source_tm_f32.as_slice()?;
3016 let expected = cols
3017 .checked_mul(rows)
3018 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
3019 if ht.len() != expected || lt.len() != expected || ct.len() != expected || st.len() != expected
3020 {
3021 return Err(PyValueError::new_err("time-major input length mismatch"));
3022 }
3023 let (up, mid, low, ctx, dev_id) = py.allow_threads(|| {
3024 let cuda = CudaKeltner::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3025 let ctx = cuda.context_arc();
3026 let dev_id = cuda.device_id();
3027 let trip = cuda
3028 .keltner_many_series_one_param_time_major_dev(
3029 ht, lt, ct, st, cols, rows, period, multiplier, ma_type,
3030 )
3031 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3032 Ok::<_, PyErr>((trip.upper, trip.middle, trip.lower, ctx, dev_id))
3033 })?;
3034 Ok((
3035 KeltnerDeviceArrayF32Py {
3036 inner: KeltnerDeviceArrayF32 {
3037 inner: up,
3038 context: ctx.clone(),
3039 device_id: dev_id,
3040 },
3041 },
3042 KeltnerDeviceArrayF32Py {
3043 inner: KeltnerDeviceArrayF32 {
3044 inner: mid,
3045 context: ctx.clone(),
3046 device_id: dev_id,
3047 },
3048 },
3049 KeltnerDeviceArrayF32Py {
3050 inner: KeltnerDeviceArrayF32 {
3051 inner: low,
3052 context: ctx,
3053 device_id: dev_id,
3054 },
3055 },
3056 ))
3057}
3058
3059#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3060#[derive(Serialize, Deserialize)]
3061pub struct KeltnerResult {
3062 pub values: Vec<f64>,
3063 pub rows: usize,
3064 pub cols: usize,
3065}
3066
3067#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3068#[wasm_bindgen(js_name = "keltner")]
3069pub fn keltner_js(
3070 high: &[f64],
3071 low: &[f64],
3072 close: &[f64],
3073 source: &[f64],
3074 period: usize,
3075 multiplier: f64,
3076 ma_type: String,
3077) -> Result<JsValue, JsValue> {
3078 if !(high.len() == low.len() && low.len() == close.len() && close.len() == source.len()) {
3079 return Err(JsValue::from_str("Input arrays must have equal length"));
3080 }
3081 let len = close.len();
3082
3083 let mut values = vec![0.0f64; 3 * len];
3084 let (upper, rest) = values.split_at_mut(len);
3085 let (middle, lower) = rest.split_at_mut(len);
3086
3087 let params = KeltnerParams {
3088 period: Some(period),
3089 multiplier: Some(multiplier),
3090 ma_type: Some(ma_type),
3091 };
3092 let input = KeltnerInput::from_slice(high, low, close, source, params);
3093
3094 keltner_into_slice(upper, middle, lower, &input, detect_best_kernel())
3095 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3096
3097 let out = KeltnerResult {
3098 values,
3099 rows: 3,
3100 cols: len,
3101 };
3102 serde_wasm_bindgen::to_value(&out)
3103 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3104}
3105
3106#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3107#[wasm_bindgen]
3108pub fn keltner_into(
3109 high_ptr: *const f64,
3110 low_ptr: *const f64,
3111 close_ptr: *const f64,
3112 source_ptr: *const f64,
3113 upper_ptr: *mut f64,
3114 middle_ptr: *mut f64,
3115 lower_ptr: *mut f64,
3116 len: usize,
3117 period: usize,
3118 multiplier: f64,
3119 ma_type: &str,
3120) -> Result<(), JsValue> {
3121 if high_ptr.is_null()
3122 || low_ptr.is_null()
3123 || close_ptr.is_null()
3124 || source_ptr.is_null()
3125 || upper_ptr.is_null()
3126 || middle_ptr.is_null()
3127 || lower_ptr.is_null()
3128 {
3129 return Err(JsValue::from_str("Null pointer provided"));
3130 }
3131
3132 unsafe {
3133 let high = std::slice::from_raw_parts(high_ptr, len);
3134 let low = std::slice::from_raw_parts(low_ptr, len);
3135 let close = std::slice::from_raw_parts(close_ptr, len);
3136 let source = std::slice::from_raw_parts(source_ptr, len);
3137
3138 let params = KeltnerParams {
3139 period: Some(period),
3140 multiplier: Some(multiplier),
3141 ma_type: Some(ma_type.to_string()),
3142 };
3143 let input = KeltnerInput::from_slice(high, low, close, source, params);
3144
3145 let input_ptrs = [
3146 high_ptr as *const f64,
3147 low_ptr as *const f64,
3148 close_ptr as *const f64,
3149 source_ptr as *const f64,
3150 ];
3151 let output_ptrs = [
3152 upper_ptr as *const f64,
3153 middle_ptr as *const f64,
3154 lower_ptr as *const f64,
3155 ];
3156
3157 let has_aliasing = input_ptrs
3158 .iter()
3159 .any(|&in_ptr| output_ptrs.iter().any(|&out_ptr| in_ptr == out_ptr));
3160
3161 if has_aliasing {
3162 let mut temp_upper = vec![0.0; len];
3163 let mut temp_middle = vec![0.0; len];
3164 let mut temp_lower = vec![0.0; len];
3165
3166 keltner_into_slice(
3167 &mut temp_upper,
3168 &mut temp_middle,
3169 &mut temp_lower,
3170 &input,
3171 Kernel::Auto,
3172 )
3173 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3174
3175 let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
3176 let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
3177 let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
3178
3179 upper_out.copy_from_slice(&temp_upper);
3180 middle_out.copy_from_slice(&temp_middle);
3181 lower_out.copy_from_slice(&temp_lower);
3182 } else {
3183 let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
3184 let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
3185 let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
3186
3187 keltner_into_slice(upper_out, middle_out, lower_out, &input, Kernel::Auto)
3188 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3189 }
3190
3191 Ok(())
3192 }
3193}
3194
3195#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3196#[wasm_bindgen]
3197pub fn keltner_alloc(len: usize) -> *mut f64 {
3198 let mut vec = Vec::<f64>::with_capacity(len);
3199 let ptr = vec.as_mut_ptr();
3200 std::mem::forget(vec);
3201 ptr
3202}
3203
3204#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3205#[wasm_bindgen]
3206pub fn keltner_free(ptr: *mut f64, len: usize) {
3207 if !ptr.is_null() {
3208 unsafe {
3209 let _ = Vec::from_raw_parts(ptr, len, len);
3210 }
3211 }
3212}
3213
3214#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3215#[derive(Serialize, Deserialize)]
3216pub struct KeltnerBatchConfig {
3217 pub period_range: (usize, usize, usize),
3218 pub multiplier_range: (f64, f64, f64),
3219 pub ma_type: String,
3220}
3221
3222#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3223#[derive(Serialize, Deserialize)]
3224pub struct KeltnerBatchJsOutput {
3225 pub upper: Vec<f64>,
3226 pub middle: Vec<f64>,
3227 pub lower: Vec<f64>,
3228 pub combos: Vec<KeltnerParams>,
3229 pub rows: usize,
3230 pub cols: usize,
3231}
3232
3233#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3234#[wasm_bindgen(js_name = "keltner_batch")]
3235pub fn keltner_batch_unified_js(
3236 high: &[f64],
3237 low: &[f64],
3238 close: &[f64],
3239 source: &[f64],
3240 config: JsValue,
3241) -> Result<JsValue, JsValue> {
3242 let cfg: KeltnerBatchConfig = serde_wasm_bindgen::from_value(config)
3243 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3244 let sweep = KeltnerBatchRange {
3245 period: cfg.period_range,
3246 multiplier: cfg.multiplier_range,
3247 };
3248
3249 let out = keltner_batch_inner(
3250 high,
3251 low,
3252 close,
3253 source,
3254 &sweep,
3255 detect_best_batch_kernel(),
3256 false,
3257 Some(&cfg.ma_type),
3258 )
3259 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3260
3261 let js_out = KeltnerBatchJsOutput {
3262 upper: out.upper_band,
3263 middle: out.middle_band,
3264 lower: out.lower_band,
3265 combos: out.combos,
3266 rows: out.rows,
3267 cols: out.cols,
3268 };
3269 serde_wasm_bindgen::to_value(&js_out)
3270 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3271}
3272
3273#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3274#[wasm_bindgen(js_name = "keltner_into_concat")]
3275pub fn keltner_into_concat(
3276 h_ptr: *const f64,
3277 l_ptr: *const f64,
3278 c_ptr: *const f64,
3279 s_ptr: *const f64,
3280 out_ptr: *mut f64,
3281 len: usize,
3282 period: usize,
3283 multiplier: f64,
3284 ma_type: String,
3285) -> Result<(), JsValue> {
3286 if [h_ptr, l_ptr, c_ptr, s_ptr, out_ptr as *const f64]
3287 .iter()
3288 .any(|p| p.is_null())
3289 {
3290 return Err(JsValue::from_str(
3291 "null pointer passed to keltner_into_concat",
3292 ));
3293 }
3294 unsafe {
3295 let h = std::slice::from_raw_parts(h_ptr, len);
3296 let l = std::slice::from_raw_parts(l_ptr, len);
3297 let c = std::slice::from_raw_parts(c_ptr, len);
3298 let s = std::slice::from_raw_parts(s_ptr, len);
3299
3300 let out = std::slice::from_raw_parts_mut(out_ptr, 3 * len);
3301 let (upper, rest) = out.split_at_mut(len);
3302 let (middle, lower) = rest.split_at_mut(len);
3303
3304 let params = KeltnerParams {
3305 period: Some(period),
3306 multiplier: Some(multiplier),
3307 ma_type: Some(ma_type),
3308 };
3309 let input = KeltnerInput::from_slice(h, l, c, s, params);
3310 keltner_into_slice(upper, middle, lower, &input, detect_best_kernel())
3311 .map_err(|e| JsValue::from_str(&e.to_string()))
3312 }
3313}