1#[cfg(all(feature = "python", feature = "cuda"))]
2pub use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
3
4#[cfg(feature = "python")]
5use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
6#[cfg(feature = "python")]
7use pyo3::exceptions::PyValueError;
8#[cfg(feature = "python")]
9use pyo3::prelude::*;
10#[cfg(feature = "python")]
11use pyo3::types::PyDict;
12
13#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
14use serde::{Deserialize, Serialize};
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use wasm_bindgen::prelude::*;
17
18use crate::utilities::data_loader::Candles;
19use crate::utilities::enums::Kernel;
20use crate::utilities::helpers::{
21 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
22 make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::mem::ManuallyDrop;
29use thiserror::Error;
30
31const GK_COEFF: f64 = 2.0 * std::f64::consts::LN_2 - 1.0;
32
33#[derive(Debug, Clone)]
34pub enum GarmanKlassVolatilityData<'a> {
35 Candles {
36 candles: &'a Candles,
37 },
38 Slices {
39 open: &'a [f64],
40 high: &'a [f64],
41 low: &'a [f64],
42 close: &'a [f64],
43 },
44}
45
46#[derive(Debug, Clone)]
47pub struct GarmanKlassVolatilityOutput {
48 pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53 all(target_arch = "wasm32", feature = "wasm"),
54 derive(Serialize, Deserialize)
55)]
56pub struct GarmanKlassVolatilityParams {
57 pub lookback: Option<usize>,
58}
59
60impl Default for GarmanKlassVolatilityParams {
61 fn default() -> Self {
62 Self { lookback: Some(14) }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct GarmanKlassVolatilityInput<'a> {
68 pub data: GarmanKlassVolatilityData<'a>,
69 pub params: GarmanKlassVolatilityParams,
70}
71
72impl<'a> GarmanKlassVolatilityInput<'a> {
73 #[inline]
74 pub fn from_candles(candles: &'a Candles, params: GarmanKlassVolatilityParams) -> Self {
75 Self {
76 data: GarmanKlassVolatilityData::Candles { candles },
77 params,
78 }
79 }
80
81 #[inline]
82 pub fn from_slices(
83 open: &'a [f64],
84 high: &'a [f64],
85 low: &'a [f64],
86 close: &'a [f64],
87 params: GarmanKlassVolatilityParams,
88 ) -> Self {
89 Self {
90 data: GarmanKlassVolatilityData::Slices {
91 open,
92 high,
93 low,
94 close,
95 },
96 params,
97 }
98 }
99
100 #[inline]
101 pub fn with_default_candles(candles: &'a Candles) -> Self {
102 Self::from_candles(candles, GarmanKlassVolatilityParams::default())
103 }
104
105 #[inline]
106 pub fn get_lookback(&self) -> usize {
107 self.params.lookback.unwrap_or(14)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct GarmanKlassVolatilityBuilder {
113 lookback: Option<usize>,
114 kernel: Kernel,
115}
116
117impl Default for GarmanKlassVolatilityBuilder {
118 fn default() -> Self {
119 Self {
120 lookback: None,
121 kernel: Kernel::Auto,
122 }
123 }
124}
125
126impl GarmanKlassVolatilityBuilder {
127 #[inline(always)]
128 pub fn new() -> Self {
129 Self::default()
130 }
131
132 #[inline(always)]
133 pub fn lookback(mut self, lookback: usize) -> Self {
134 self.lookback = Some(lookback);
135 self
136 }
137
138 #[inline(always)]
139 pub fn kernel(mut self, kernel: Kernel) -> Self {
140 self.kernel = kernel;
141 self
142 }
143
144 #[inline(always)]
145 pub fn apply(
146 self,
147 candles: &Candles,
148 ) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
149 let input = GarmanKlassVolatilityInput::from_candles(
150 candles,
151 GarmanKlassVolatilityParams {
152 lookback: self.lookback,
153 },
154 );
155 garman_klass_volatility_with_kernel(&input, self.kernel)
156 }
157
158 #[inline(always)]
159 pub fn apply_slices(
160 self,
161 open: &[f64],
162 high: &[f64],
163 low: &[f64],
164 close: &[f64],
165 ) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
166 let input = GarmanKlassVolatilityInput::from_slices(
167 open,
168 high,
169 low,
170 close,
171 GarmanKlassVolatilityParams {
172 lookback: self.lookback,
173 },
174 );
175 garman_klass_volatility_with_kernel(&input, self.kernel)
176 }
177
178 #[inline(always)]
179 pub fn into_stream(self) -> Result<GarmanKlassVolatilityStream, GarmanKlassVolatilityError> {
180 GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
181 lookback: self.lookback,
182 })
183 }
184}
185
186#[derive(Debug, Error)]
187pub enum GarmanKlassVolatilityError {
188 #[error("garman_klass_volatility: Input data slice is empty.")]
189 EmptyInputData,
190 #[error("garman_klass_volatility: All values are NaN or non-positive.")]
191 AllValuesNaN,
192 #[error(
193 "garman_klass_volatility: Invalid lookback: lookback = {lookback}, data length = {data_len}"
194 )]
195 InvalidLookback { lookback: usize, data_len: usize },
196 #[error("garman_klass_volatility: Not enough valid data: needed = {needed}, valid = {valid}")]
197 NotEnoughValidData { needed: usize, valid: usize },
198 #[error("garman_klass_volatility: Inconsistent slice lengths: open={open_len}, high={high_len}, low={low_len}, close={close_len}")]
199 InconsistentSliceLengths {
200 open_len: usize,
201 high_len: usize,
202 low_len: usize,
203 close_len: usize,
204 },
205 #[error("garman_klass_volatility: Output length mismatch: expected = {expected}, got = {got}")]
206 OutputLengthMismatch { expected: usize, got: usize },
207 #[error("garman_klass_volatility: Invalid range: start={start}, end={end}, step={step}")]
208 InvalidRange {
209 start: String,
210 end: String,
211 step: String,
212 },
213 #[error("garman_klass_volatility: Invalid kernel for batch: {0:?}")]
214 InvalidKernelForBatch(Kernel),
215}
216
217#[derive(Debug, Clone)]
218pub struct GarmanKlassVolatilityStream {
219 lookback: usize,
220 terms: Vec<f64>,
221 valid: Vec<u8>,
222 idx: usize,
223 cnt: usize,
224 valid_count: usize,
225 sum_terms: f64,
226}
227
228impl GarmanKlassVolatilityStream {
229 pub fn try_new(
230 params: GarmanKlassVolatilityParams,
231 ) -> Result<GarmanKlassVolatilityStream, GarmanKlassVolatilityError> {
232 let lookback = params.lookback.unwrap_or(14);
233 if lookback == 0 {
234 return Err(GarmanKlassVolatilityError::InvalidLookback {
235 lookback,
236 data_len: 0,
237 });
238 }
239 Ok(Self {
240 lookback,
241 terms: vec![0.0; lookback],
242 valid: vec![0u8; lookback],
243 idx: 0,
244 cnt: 0,
245 valid_count: 0,
246 sum_terms: 0.0,
247 })
248 }
249
250 #[inline(always)]
251 pub fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<f64> {
252 if self.cnt >= self.lookback {
253 let old_idx = self.idx;
254 if self.valid[old_idx] != 0 {
255 self.valid_count = self.valid_count.saturating_sub(1);
256 self.sum_terms -= self.terms[old_idx];
257 }
258 } else {
259 self.cnt += 1;
260 }
261
262 if valid_ohlc_bar(open, high, low, close) {
263 let term = gk_term(open, high, low, close);
264 self.terms[self.idx] = term;
265 self.valid[self.idx] = 1;
266 self.valid_count += 1;
267 self.sum_terms += term;
268 } else {
269 self.terms[self.idx] = 0.0;
270 self.valid[self.idx] = 0;
271 }
272
273 self.idx += 1;
274 if self.idx == self.lookback {
275 self.idx = 0;
276 }
277
278 if self.cnt < self.lookback || self.valid_count != self.lookback {
279 return None;
280 }
281
282 let mut variance = self.sum_terms / self.lookback as f64;
283 if variance < 0.0 {
284 variance = 0.0;
285 }
286 Some(variance.sqrt())
287 }
288
289 #[inline(always)]
290 pub fn get_warmup_period(&self) -> usize {
291 self.lookback.saturating_sub(1)
292 }
293}
294
295#[inline]
296pub fn garman_klass_volatility(
297 input: &GarmanKlassVolatilityInput,
298) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
299 garman_klass_volatility_with_kernel(input, Kernel::Auto)
300}
301
302#[inline(always)]
303fn valid_ohlc_bar(open: f64, high: f64, low: f64, close: f64) -> bool {
304 open.is_finite()
305 && high.is_finite()
306 && low.is_finite()
307 && close.is_finite()
308 && open > 0.0
309 && high > 0.0
310 && low > 0.0
311 && close > 0.0
312}
313
314#[inline(always)]
315fn gk_term(open: f64, high: f64, low: f64, close: f64) -> f64 {
316 let hl = (high / low).ln();
317 let co = (close / open).ln();
318 0.5 * hl * hl - GK_COEFF * co * co
319}
320
321#[inline(always)]
322fn first_valid_ohlc(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
323 let len = close.len();
324 let mut i = 0usize;
325 while i < len {
326 if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
327 break;
328 }
329 i += 1;
330 }
331 i.min(len)
332}
333
334#[inline(always)]
335fn count_valid_ohlc(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
336 let mut count = 0usize;
337 for i in 0..close.len() {
338 if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
339 count += 1;
340 }
341 }
342 count
343}
344
345#[inline(always)]
346fn build_prefix_terms(
347 open: &[f64],
348 high: &[f64],
349 low: &[f64],
350 close: &[f64],
351) -> (Vec<u32>, Vec<f64>) {
352 let len = close.len();
353 let mut prefix_valid = vec![0u32; len + 1];
354 let mut prefix_sum = vec![0.0f64; len + 1];
355
356 for i in 0..len {
357 if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
358 prefix_valid[i + 1] = prefix_valid[i] + 1;
359 prefix_sum[i + 1] = prefix_sum[i] + gk_term(open[i], high[i], low[i], close[i]);
360 } else {
361 prefix_valid[i + 1] = prefix_valid[i];
362 prefix_sum[i + 1] = prefix_sum[i];
363 }
364 }
365
366 (prefix_valid, prefix_sum)
367}
368
369#[inline(always)]
370fn gk_row_from_prefix(
371 prefix_valid: &[u32],
372 prefix_sum: &[f64],
373 lookback: usize,
374 first: usize,
375 out: &mut [f64],
376) {
377 let warmup = first.saturating_add(lookback.saturating_sub(1));
378 let lookback_u32 = lookback as u32;
379 let inv_lb = 1.0 / lookback as f64;
380
381 for (t, slot) in out.iter_mut().enumerate() {
382 if t < warmup {
383 *slot = f64::NAN;
384 continue;
385 }
386
387 let window_start = t + 1 - lookback;
388 let valid_count = prefix_valid[t + 1] - prefix_valid[window_start];
389 if valid_count != lookback_u32 {
390 *slot = f64::NAN;
391 continue;
392 }
393
394 let mut variance = (prefix_sum[t + 1] - prefix_sum[window_start]) * inv_lb;
395 if variance < 0.0 {
396 variance = 0.0;
397 }
398 *slot = variance.sqrt();
399 }
400}
401
402#[inline(always)]
403fn garman_klass_prepare<'a>(
404 input: &'a GarmanKlassVolatilityInput,
405 kernel: Kernel,
406) -> Result<
407 (
408 &'a [f64],
409 &'a [f64],
410 &'a [f64],
411 &'a [f64],
412 usize,
413 usize,
414 Kernel,
415 ),
416 GarmanKlassVolatilityError,
417> {
418 let (open, high, low, close): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
419 GarmanKlassVolatilityData::Candles { candles } => {
420 (&candles.open, &candles.high, &candles.low, &candles.close)
421 }
422 GarmanKlassVolatilityData::Slices {
423 open,
424 high,
425 low,
426 close,
427 } => (open, high, low, close),
428 };
429
430 let len = close.len();
431 if len == 0 {
432 return Err(GarmanKlassVolatilityError::EmptyInputData);
433 }
434 if open.len() != len || high.len() != len || low.len() != len {
435 return Err(GarmanKlassVolatilityError::InconsistentSliceLengths {
436 open_len: open.len(),
437 high_len: high.len(),
438 low_len: low.len(),
439 close_len: close.len(),
440 });
441 }
442
443 let first = first_valid_ohlc(open, high, low, close);
444 if first >= len {
445 return Err(GarmanKlassVolatilityError::AllValuesNaN);
446 }
447
448 let lookback = input.get_lookback();
449 if lookback == 0 || lookback > len {
450 return Err(GarmanKlassVolatilityError::InvalidLookback {
451 lookback,
452 data_len: len,
453 });
454 }
455
456 let valid = count_valid_ohlc(open, high, low, close);
457 if valid < lookback {
458 return Err(GarmanKlassVolatilityError::NotEnoughValidData {
459 needed: lookback,
460 valid,
461 });
462 }
463
464 let chosen = match kernel {
465 Kernel::Auto => detect_best_kernel(),
466 other => other.to_non_batch(),
467 };
468
469 Ok((open, high, low, close, lookback, first, chosen))
470}
471
472#[inline]
473pub fn garman_klass_volatility_with_kernel(
474 input: &GarmanKlassVolatilityInput,
475 kernel: Kernel,
476) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
477 let (open, high, low, close, lookback, first, _chosen) = garman_klass_prepare(input, kernel)?;
478 let len = close.len();
479 let warmup = first.saturating_add(lookback.saturating_sub(1));
480 let mut values = alloc_with_nan_prefix(len, warmup);
481 let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
482 gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, &mut values);
483 Ok(GarmanKlassVolatilityOutput { values })
484}
485
486#[inline]
487pub fn garman_klass_volatility_into_slice(
488 dst: &mut [f64],
489 input: &GarmanKlassVolatilityInput,
490 kernel: Kernel,
491) -> Result<(), GarmanKlassVolatilityError> {
492 let (open, high, low, close, lookback, first, _chosen) = garman_klass_prepare(input, kernel)?;
493 let expected = close.len();
494 if dst.len() != expected {
495 return Err(GarmanKlassVolatilityError::OutputLengthMismatch {
496 expected,
497 got: dst.len(),
498 });
499 }
500 let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
501 gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, dst);
502 Ok(())
503}
504
505#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
506#[inline]
507pub fn garman_klass_volatility_into(
508 input: &GarmanKlassVolatilityInput,
509 out: &mut [f64],
510) -> Result<(), GarmanKlassVolatilityError> {
511 garman_klass_volatility_into_slice(out, input, Kernel::Auto)
512}
513
514#[derive(Clone, Debug)]
515pub struct GarmanKlassVolatilityBatchRange {
516 pub lookback: (usize, usize, usize),
517}
518
519impl Default for GarmanKlassVolatilityBatchRange {
520 fn default() -> Self {
521 Self {
522 lookback: (14, 252, 1),
523 }
524 }
525}
526
527#[derive(Clone, Debug, Default)]
528pub struct GarmanKlassVolatilityBatchBuilder {
529 range: GarmanKlassVolatilityBatchRange,
530 kernel: Kernel,
531}
532
533impl GarmanKlassVolatilityBatchBuilder {
534 pub fn new() -> Self {
535 Self::default()
536 }
537
538 pub fn kernel(mut self, kernel: Kernel) -> Self {
539 self.kernel = kernel;
540 self
541 }
542
543 #[inline]
544 pub fn lookback_range(mut self, start: usize, end: usize, step: usize) -> Self {
545 self.range.lookback = (start, end, step);
546 self
547 }
548
549 #[inline]
550 pub fn lookback_static(mut self, lookback: usize) -> Self {
551 self.range.lookback = (lookback, lookback, 0);
552 self
553 }
554
555 pub fn apply_slices(
556 self,
557 open: &[f64],
558 high: &[f64],
559 low: &[f64],
560 close: &[f64],
561 ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
562 garman_klass_volatility_batch_with_kernel(open, high, low, close, &self.range, self.kernel)
563 }
564
565 pub fn apply_candles(
566 self,
567 candles: &Candles,
568 ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
569 self.apply_slices(&candles.open, &candles.high, &candles.low, &candles.close)
570 }
571
572 pub fn with_default_candles(
573 candles: &Candles,
574 ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
575 GarmanKlassVolatilityBatchBuilder::new()
576 .kernel(Kernel::Auto)
577 .apply_candles(candles)
578 }
579}
580
581#[derive(Clone, Debug)]
582pub struct GarmanKlassVolatilityBatchOutput {
583 pub values: Vec<f64>,
584 pub combos: Vec<GarmanKlassVolatilityParams>,
585 pub rows: usize,
586 pub cols: usize,
587}
588
589impl GarmanKlassVolatilityBatchOutput {
590 pub fn row_for_params(&self, params: &GarmanKlassVolatilityParams) -> Option<usize> {
591 self.combos
592 .iter()
593 .position(|combo| combo.lookback.unwrap_or(14) == params.lookback.unwrap_or(14))
594 }
595
596 pub fn values_for(&self, params: &GarmanKlassVolatilityParams) -> Option<&[f64]> {
597 self.row_for_params(params).and_then(|row| {
598 row.checked_mul(self.cols)
599 .and_then(|start| self.values.get(start..start + self.cols))
600 })
601 }
602}
603
604#[inline(always)]
605fn expand_grid_garman_klass(
606 range: &GarmanKlassVolatilityBatchRange,
607) -> Result<Vec<GarmanKlassVolatilityParams>, GarmanKlassVolatilityError> {
608 fn axis_usize(
609 (start, end, step): (usize, usize, usize),
610 ) -> Result<Vec<usize>, GarmanKlassVolatilityError> {
611 if step == 0 || start == end {
612 return Ok(vec![start]);
613 }
614 let step = step.max(1);
615 if start < end {
616 let mut out = Vec::new();
617 let mut x = start;
618 while x <= end {
619 out.push(x);
620 match x.checked_add(step) {
621 Some(next) if next != x => x = next,
622 _ => break,
623 }
624 }
625 if out.is_empty() {
626 return Err(GarmanKlassVolatilityError::InvalidRange {
627 start: start.to_string(),
628 end: end.to_string(),
629 step: step.to_string(),
630 });
631 }
632 Ok(out)
633 } else {
634 let mut out = Vec::new();
635 let mut x = start;
636 loop {
637 out.push(x);
638 if x == end {
639 break;
640 }
641 let next = x.saturating_sub(step);
642 if next == x || next < end {
643 break;
644 }
645 x = next;
646 }
647 if out.is_empty() {
648 return Err(GarmanKlassVolatilityError::InvalidRange {
649 start: start.to_string(),
650 end: end.to_string(),
651 step: step.to_string(),
652 });
653 }
654 Ok(out)
655 }
656 }
657
658 Ok(axis_usize(range.lookback)?
659 .into_iter()
660 .map(|lookback| GarmanKlassVolatilityParams {
661 lookback: Some(lookback),
662 })
663 .collect())
664}
665
666pub fn garman_klass_volatility_batch_with_kernel(
667 open: &[f64],
668 high: &[f64],
669 low: &[f64],
670 close: &[f64],
671 sweep: &GarmanKlassVolatilityBatchRange,
672 kernel: Kernel,
673) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
674 let batch_kernel = match kernel {
675 Kernel::Auto => detect_best_batch_kernel(),
676 other if other.is_batch() => other,
677 other => return Err(GarmanKlassVolatilityError::InvalidKernelForBatch(other)),
678 };
679 garman_klass_volatility_batch_par_slice(
680 open,
681 high,
682 low,
683 close,
684 sweep,
685 batch_kernel.to_non_batch(),
686 )
687}
688
689#[inline(always)]
690pub fn garman_klass_volatility_batch_slice(
691 open: &[f64],
692 high: &[f64],
693 low: &[f64],
694 close: &[f64],
695 sweep: &GarmanKlassVolatilityBatchRange,
696 kernel: Kernel,
697) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
698 garman_klass_volatility_batch_inner(open, high, low, close, sweep, kernel, false)
699}
700
701#[inline(always)]
702pub fn garman_klass_volatility_batch_par_slice(
703 open: &[f64],
704 high: &[f64],
705 low: &[f64],
706 close: &[f64],
707 sweep: &GarmanKlassVolatilityBatchRange,
708 kernel: Kernel,
709) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
710 garman_klass_volatility_batch_inner(open, high, low, close, sweep, kernel, true)
711}
712
713#[inline(always)]
714fn garman_klass_volatility_batch_inner(
715 open: &[f64],
716 high: &[f64],
717 low: &[f64],
718 close: &[f64],
719 sweep: &GarmanKlassVolatilityBatchRange,
720 _kernel: Kernel,
721 parallel: bool,
722) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
723 let combos = expand_grid_garman_klass(sweep)?;
724 let len = close.len();
725 if len == 0 {
726 return Err(GarmanKlassVolatilityError::EmptyInputData);
727 }
728 if open.len() != len || high.len() != len || low.len() != len {
729 return Err(GarmanKlassVolatilityError::InconsistentSliceLengths {
730 open_len: open.len(),
731 high_len: high.len(),
732 low_len: low.len(),
733 close_len: close.len(),
734 });
735 }
736
737 let first = first_valid_ohlc(open, high, low, close);
738 if first >= len {
739 return Err(GarmanKlassVolatilityError::AllValuesNaN);
740 }
741
742 let valid = count_valid_ohlc(open, high, low, close);
743 let max_lookback = combos
744 .iter()
745 .map(|combo| combo.lookback.unwrap_or(14))
746 .max()
747 .unwrap_or(0);
748 if max_lookback == 0 || valid < max_lookback {
749 return Err(GarmanKlassVolatilityError::NotEnoughValidData {
750 needed: max_lookback,
751 valid,
752 });
753 }
754
755 let rows = combos.len();
756 let cols = len;
757 let mut buf_mu = make_uninit_matrix(rows, cols);
758 let warmups: Vec<usize> = combos
759 .iter()
760 .map(|combo| first.saturating_add(combo.lookback.unwrap_or(14).saturating_sub(1)))
761 .collect();
762 init_matrix_prefixes(&mut buf_mu, cols, &warmups);
763
764 let mut guard = ManuallyDrop::new(buf_mu);
765 let out: &mut [f64] =
766 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
767 let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
768
769 if parallel {
770 #[cfg(not(target_arch = "wasm32"))]
771 out.par_chunks_mut(cols)
772 .enumerate()
773 .for_each(|(row, out_row)| {
774 let lookback = combos[row].lookback.unwrap_or(14);
775 gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
776 });
777
778 #[cfg(target_arch = "wasm32")]
779 for (row, out_row) in out.chunks_mut(cols).enumerate() {
780 let lookback = combos[row].lookback.unwrap_or(14);
781 gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
782 }
783 } else {
784 for (row, out_row) in out.chunks_mut(cols).enumerate() {
785 let lookback = combos[row].lookback.unwrap_or(14);
786 gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
787 }
788 }
789
790 let values = unsafe {
791 Vec::from_raw_parts(
792 guard.as_mut_ptr() as *mut f64,
793 guard.len(),
794 guard.capacity(),
795 )
796 };
797
798 Ok(GarmanKlassVolatilityBatchOutput {
799 values,
800 combos,
801 rows,
802 cols,
803 })
804}
805
806#[cfg(feature = "python")]
807#[pyfunction(name = "garman_klass_volatility")]
808#[pyo3(signature = (open, high, low, close, lookback=14, kernel=None))]
809pub fn garman_klass_volatility_py<'py>(
810 py: Python<'py>,
811 open: PyReadonlyArray1<'py, f64>,
812 high: PyReadonlyArray1<'py, f64>,
813 low: PyReadonlyArray1<'py, f64>,
814 close: PyReadonlyArray1<'py, f64>,
815 lookback: usize,
816 kernel: Option<&str>,
817) -> PyResult<Bound<'py, PyArray1<f64>>> {
818 let open = open.as_slice()?;
819 let high = high.as_slice()?;
820 let low = low.as_slice()?;
821 let close = close.as_slice()?;
822 if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
823 return Err(PyValueError::new_err("OHLC slice length mismatch"));
824 }
825
826 let kernel = validate_kernel(kernel, false)?;
827 let input = GarmanKlassVolatilityInput::from_slices(
828 open,
829 high,
830 low,
831 close,
832 GarmanKlassVolatilityParams {
833 lookback: Some(lookback),
834 },
835 );
836 let output = py
837 .allow_threads(|| garman_klass_volatility_with_kernel(&input, kernel))
838 .map_err(|e| PyValueError::new_err(e.to_string()))?;
839 Ok(output.values.into_pyarray(py))
840}
841
842#[cfg(feature = "python")]
843#[pyclass(name = "GarmanKlassVolatilityStream")]
844pub struct GarmanKlassVolatilityStreamPy {
845 stream: GarmanKlassVolatilityStream,
846}
847
848#[cfg(feature = "python")]
849#[pymethods]
850impl GarmanKlassVolatilityStreamPy {
851 #[new]
852 fn new(lookback: usize) -> PyResult<Self> {
853 let stream = GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
854 lookback: Some(lookback),
855 })
856 .map_err(|e| PyValueError::new_err(e.to_string()))?;
857 Ok(Self { stream })
858 }
859
860 fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<f64> {
861 self.stream.update(open, high, low, close)
862 }
863}
864
865#[cfg(feature = "python")]
866#[pyfunction(name = "garman_klass_volatility_batch")]
867#[pyo3(signature = (open, high, low, close, lookback_range, kernel=None))]
868pub fn garman_klass_volatility_batch_py<'py>(
869 py: Python<'py>,
870 open: PyReadonlyArray1<'py, f64>,
871 high: PyReadonlyArray1<'py, f64>,
872 low: PyReadonlyArray1<'py, f64>,
873 close: PyReadonlyArray1<'py, f64>,
874 lookback_range: (usize, usize, usize),
875 kernel: Option<&str>,
876) -> PyResult<Bound<'py, PyDict>> {
877 let open = open.as_slice()?;
878 let high = high.as_slice()?;
879 let low = low.as_slice()?;
880 let close = close.as_slice()?;
881 if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
882 return Err(PyValueError::new_err("OHLC slice length mismatch"));
883 }
884
885 let sweep = GarmanKlassVolatilityBatchRange {
886 lookback: lookback_range,
887 };
888 let output = {
889 let kernel = validate_kernel(kernel, true)?;
890 py.allow_threads(|| {
891 let batch = match kernel {
892 Kernel::Auto => detect_best_batch_kernel(),
893 other => other,
894 };
895 garman_klass_volatility_batch_inner(
896 open,
897 high,
898 low,
899 close,
900 &sweep,
901 batch.to_non_batch(),
902 true,
903 )
904 })
905 .map_err(|e| PyValueError::new_err(e.to_string()))?
906 };
907
908 let dict = PyDict::new(py);
909 dict.set_item(
910 "values",
911 output
912 .values
913 .into_pyarray(py)
914 .reshape((output.rows, output.cols))?,
915 )?;
916 dict.set_item(
917 "lookbacks",
918 output
919 .combos
920 .iter()
921 .map(|combo| combo.lookback.unwrap_or(14) as u64)
922 .collect::<Vec<_>>()
923 .into_pyarray(py),
924 )?;
925 dict.set_item("rows", output.rows)?;
926 dict.set_item("cols", output.cols)?;
927 Ok(dict)
928}
929
930#[cfg(feature = "python")]
931pub fn register_garman_klass_volatility_module(
932 module: &Bound<'_, pyo3::types::PyModule>,
933) -> PyResult<()> {
934 module.add_function(wrap_pyfunction!(garman_klass_volatility_py, module)?)?;
935 module.add_function(wrap_pyfunction!(garman_klass_volatility_batch_py, module)?)?;
936 module.add_class::<GarmanKlassVolatilityStreamPy>()?;
937 Ok(())
938}
939
940#[cfg(all(feature = "python", feature = "cuda"))]
941#[pyfunction(name = "garman_klass_volatility_cuda_batch_dev")]
942#[pyo3(signature = (open_f32, high_f32, low_f32, close_f32, lookback_range, device_id=0))]
943pub fn garman_klass_volatility_cuda_batch_dev_py<'py>(
944 py: Python<'py>,
945 open_f32: PyReadonlyArray1<'py, f32>,
946 high_f32: PyReadonlyArray1<'py, f32>,
947 low_f32: PyReadonlyArray1<'py, f32>,
948 close_f32: PyReadonlyArray1<'py, f32>,
949 lookback_range: (usize, usize, usize),
950 device_id: usize,
951) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
952 use crate::cuda::{cuda_available, CudaGarmanKlassVolatility};
953
954 if !cuda_available() {
955 return Err(PyValueError::new_err("CUDA not available"));
956 }
957
958 let open = open_f32.as_slice()?;
959 let high = high_f32.as_slice()?;
960 let low = low_f32.as_slice()?;
961 let close = close_f32.as_slice()?;
962 let sweep = GarmanKlassVolatilityBatchRange {
963 lookback: lookback_range,
964 };
965 let result = py.allow_threads(|| {
966 let cuda = CudaGarmanKlassVolatility::new(device_id)
967 .map_err(|e| PyValueError::new_err(e.to_string()))?;
968 cuda.garman_klass_volatility_batch_dev(open, high, low, close, &sweep)
969 .map_err(|e| PyValueError::new_err(e.to_string()))
970 })?;
971
972 let dict = PyDict::new(py);
973 dict.set_item(
974 "lookbacks",
975 result
976 .combos
977 .iter()
978 .map(|combo| combo.lookback.unwrap_or(14) as u64)
979 .collect::<Vec<_>>()
980 .into_pyarray(py),
981 )?;
982 Ok((make_device_array_py(device_id, result.outputs)?, dict))
983}
984
985#[cfg(all(feature = "python", feature = "cuda"))]
986#[pyfunction(name = "garman_klass_volatility_cuda_many_series_one_param_dev")]
987#[pyo3(signature = (open_tm_f32, high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, lookback=14, device_id=0))]
988pub fn garman_klass_volatility_cuda_many_series_one_param_dev_py<'py>(
989 py: Python<'py>,
990 open_tm_f32: PyReadonlyArray1<'py, f32>,
991 high_tm_f32: PyReadonlyArray1<'py, f32>,
992 low_tm_f32: PyReadonlyArray1<'py, f32>,
993 close_tm_f32: PyReadonlyArray1<'py, f32>,
994 cols: usize,
995 rows: usize,
996 lookback: usize,
997 device_id: usize,
998) -> PyResult<DeviceArrayF32Py> {
999 use crate::cuda::{cuda_available, CudaGarmanKlassVolatility};
1000
1001 if !cuda_available() {
1002 return Err(PyValueError::new_err("CUDA not available"));
1003 }
1004
1005 let open = open_tm_f32.as_slice()?;
1006 let high = high_tm_f32.as_slice()?;
1007 let low = low_tm_f32.as_slice()?;
1008 let close = close_tm_f32.as_slice()?;
1009 let inner = py.allow_threads(|| {
1010 let cuda = CudaGarmanKlassVolatility::new(device_id)
1011 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1012 cuda.garman_klass_volatility_many_series_one_param_time_major_dev(
1013 open, high, low, close, cols, rows, lookback,
1014 )
1015 .map_err(|e| PyValueError::new_err(e.to_string()))
1016 })?;
1017 make_device_array_py(device_id, inner)
1018}
1019
1020#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1021#[wasm_bindgen(js_name = "garman_klass_volatility_js")]
1022pub fn garman_klass_volatility_js(
1023 open: &[f64],
1024 high: &[f64],
1025 low: &[f64],
1026 close: &[f64],
1027 lookback: usize,
1028) -> Result<Vec<f64>, JsValue> {
1029 let input = GarmanKlassVolatilityInput::from_slices(
1030 open,
1031 high,
1032 low,
1033 close,
1034 GarmanKlassVolatilityParams {
1035 lookback: Some(lookback),
1036 },
1037 );
1038 let mut output = vec![0.0; close.len()];
1039 garman_klass_volatility_into_slice(&mut output, &input, Kernel::Auto)
1040 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1041 Ok(output)
1042}
1043
1044#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1045#[wasm_bindgen]
1046pub fn garman_klass_volatility_alloc(len: usize) -> *mut f64 {
1047 let mut vec = Vec::<f64>::with_capacity(len);
1048 let ptr = vec.as_mut_ptr();
1049 std::mem::forget(vec);
1050 ptr
1051}
1052
1053#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1054#[wasm_bindgen]
1055pub fn garman_klass_volatility_free(ptr: *mut f64, len: usize) {
1056 if !ptr.is_null() {
1057 unsafe {
1058 let _ = Vec::from_raw_parts(ptr, len, len);
1059 }
1060 }
1061}
1062
1063#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1064#[wasm_bindgen]
1065pub fn garman_klass_volatility_into(
1066 open_ptr: *const f64,
1067 high_ptr: *const f64,
1068 low_ptr: *const f64,
1069 close_ptr: *const f64,
1070 out_ptr: *mut f64,
1071 len: usize,
1072 lookback: usize,
1073) -> Result<(), JsValue> {
1074 if open_ptr.is_null()
1075 || high_ptr.is_null()
1076 || low_ptr.is_null()
1077 || close_ptr.is_null()
1078 || out_ptr.is_null()
1079 {
1080 return Err(JsValue::from_str("Null pointer provided"));
1081 }
1082
1083 unsafe {
1084 let open = std::slice::from_raw_parts(open_ptr, len);
1085 let high = std::slice::from_raw_parts(high_ptr, len);
1086 let low = std::slice::from_raw_parts(low_ptr, len);
1087 let close = std::slice::from_raw_parts(close_ptr, len);
1088 let input = GarmanKlassVolatilityInput::from_slices(
1089 open,
1090 high,
1091 low,
1092 close,
1093 GarmanKlassVolatilityParams {
1094 lookback: Some(lookback),
1095 },
1096 );
1097
1098 if open_ptr == out_ptr || high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr
1099 {
1100 let mut tmp = vec![0.0; len];
1101 garman_klass_volatility_into_slice(&mut tmp, &input, Kernel::Auto)
1102 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1103 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1104 } else {
1105 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1106 garman_klass_volatility_into_slice(out, &input, Kernel::Auto)
1107 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1108 }
1109 }
1110 Ok(())
1111}
1112
1113#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1114#[derive(Serialize, Deserialize)]
1115pub struct GarmanKlassVolatilityBatchConfig {
1116 pub lookback_range: (usize, usize, usize),
1117}
1118
1119#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1120#[derive(Serialize, Deserialize)]
1121pub struct GarmanKlassVolatilityBatchJsOutput {
1122 pub values: Vec<f64>,
1123 pub combos: Vec<GarmanKlassVolatilityParams>,
1124 pub rows: usize,
1125 pub cols: usize,
1126}
1127
1128#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1129#[wasm_bindgen(js_name = "garman_klass_volatility_batch_js")]
1130pub fn garman_klass_volatility_batch_js(
1131 open: &[f64],
1132 high: &[f64],
1133 low: &[f64],
1134 close: &[f64],
1135 config: JsValue,
1136) -> Result<JsValue, JsValue> {
1137 let config: GarmanKlassVolatilityBatchConfig = serde_wasm_bindgen::from_value(config)
1138 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1139 let sweep = GarmanKlassVolatilityBatchRange {
1140 lookback: config.lookback_range,
1141 };
1142 let output = garman_klass_volatility_batch_inner(
1143 open,
1144 high,
1145 low,
1146 close,
1147 &sweep,
1148 detect_best_kernel(),
1149 false,
1150 )
1151 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1152 serde_wasm_bindgen::to_value(&GarmanKlassVolatilityBatchJsOutput {
1153 values: output.values,
1154 combos: output.combos,
1155 rows: output.rows,
1156 cols: output.cols,
1157 })
1158 .map_err(|e| JsValue::from_str(&e.to_string()))
1159}
1160
1161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1162#[wasm_bindgen]
1163pub fn garman_klass_volatility_batch_into(
1164 open_ptr: *const f64,
1165 high_ptr: *const f64,
1166 low_ptr: *const f64,
1167 close_ptr: *const f64,
1168 out_ptr: *mut f64,
1169 len: usize,
1170 lookback_start: usize,
1171 lookback_end: usize,
1172 lookback_step: usize,
1173) -> Result<usize, JsValue> {
1174 if open_ptr.is_null()
1175 || high_ptr.is_null()
1176 || low_ptr.is_null()
1177 || close_ptr.is_null()
1178 || out_ptr.is_null()
1179 {
1180 return Err(JsValue::from_str("Null pointer provided"));
1181 }
1182
1183 let sweep = GarmanKlassVolatilityBatchRange {
1184 lookback: (lookback_start, lookback_end, lookback_step),
1185 };
1186 let combos = expand_grid_garman_klass(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1187 let rows = combos.len();
1188
1189 unsafe {
1190 let open = std::slice::from_raw_parts(open_ptr, len);
1191 let high = std::slice::from_raw_parts(high_ptr, len);
1192 let low = std::slice::from_raw_parts(low_ptr, len);
1193 let close = std::slice::from_raw_parts(close_ptr, len);
1194 let total = rows
1195 .checked_mul(len)
1196 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1197 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1198 let batch = garman_klass_volatility_batch_inner(
1199 open,
1200 high,
1201 low,
1202 close,
1203 &sweep,
1204 detect_best_kernel(),
1205 false,
1206 )
1207 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1208 out.copy_from_slice(&batch.values);
1209 }
1210
1211 Ok(rows)
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216 use super::*;
1217
1218 fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1219 let mut open = vec![f64::NAN; len];
1220 let mut high = vec![f64::NAN; len];
1221 let mut low = vec![f64::NAN; len];
1222 let mut close = vec![f64::NAN; len];
1223 let mut prev = 100.0;
1224 for i in 2..len {
1225 let x = i as f64;
1226 let o = (prev + (x * 0.021).sin() * 1.5 + 0.03 * x).max(1.0);
1227 let c = (o + (x * 0.017).cos() * 0.8).max(1.0);
1228 let h = o.max(c) + 0.5 + (x * 0.011).sin().abs() * 0.2;
1229 let l = (o.min(c) - 0.45 - (x * 0.013).cos().abs() * 0.15).max(0.01);
1230 open[i] = o;
1231 high[i] = h;
1232 low[i] = l;
1233 close[i] = c;
1234 prev = c;
1235 }
1236 (open, high, low, close)
1237 }
1238
1239 #[test]
1240 fn gk_output_contract() {
1241 let (open, high, low, close) = sample_ohlc(128);
1242 let input = GarmanKlassVolatilityInput::from_slices(
1243 &open,
1244 &high,
1245 &low,
1246 &close,
1247 GarmanKlassVolatilityParams { lookback: Some(14) },
1248 );
1249 let out = garman_klass_volatility(&input).expect("gk");
1250 assert_eq!(out.values.len(), close.len());
1251 assert!(out.values.iter().any(|v| v.is_finite()));
1252 let first_valid = out
1253 .values
1254 .iter()
1255 .position(|v| v.is_finite())
1256 .expect("first valid");
1257 assert!(first_valid >= 15);
1258 }
1259
1260 #[test]
1261 fn gk_into_matches_api() {
1262 let (open, high, low, close) = sample_ohlc(192);
1263 let input = GarmanKlassVolatilityInput::from_slices(
1264 &open,
1265 &high,
1266 &low,
1267 &close,
1268 GarmanKlassVolatilityParams { lookback: Some(20) },
1269 );
1270 let api = garman_klass_volatility(&input).expect("api");
1271 let mut out = vec![0.0; close.len()];
1272 garman_klass_volatility_into(&input, &mut out).expect("into");
1273 for i in 0..out.len() {
1274 if api.values[i].is_nan() {
1275 assert!(out[i].is_nan(), "expected NaN at index {i}");
1276 } else {
1277 assert!(
1278 (api.values[i] - out[i]).abs() <= 1e-12,
1279 "into mismatch at {i}: {} vs {}",
1280 api.values[i],
1281 out[i]
1282 );
1283 }
1284 }
1285 }
1286
1287 #[test]
1288 fn gk_stream_matches_batch() {
1289 let (open, high, low, close) = sample_ohlc(160);
1290 let input = GarmanKlassVolatilityInput::from_slices(
1291 &open,
1292 &high,
1293 &low,
1294 &close,
1295 GarmanKlassVolatilityParams { lookback: Some(12) },
1296 );
1297 let batch = garman_klass_volatility(&input).expect("batch");
1298 let mut stream = GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
1299 lookback: Some(12),
1300 })
1301 .expect("stream");
1302 let mut streamed = Vec::with_capacity(close.len());
1303 for i in 0..close.len() {
1304 streamed.push(
1305 stream
1306 .update(open[i], high[i], low[i], close[i])
1307 .unwrap_or(f64::NAN),
1308 );
1309 }
1310 for i in 0..streamed.len() {
1311 if batch.values[i].is_nan() {
1312 assert!(streamed[i].is_nan(), "stream index {i}");
1313 } else {
1314 assert!(
1315 (batch.values[i] - streamed[i]).abs() <= 1e-12,
1316 "stream mismatch at {i}: {} vs {}",
1317 batch.values[i],
1318 streamed[i]
1319 );
1320 }
1321 }
1322 }
1323
1324 #[test]
1325 fn gk_batch_single_param_matches_single() {
1326 let (open, high, low, close) = sample_ohlc(200);
1327 let single_input = GarmanKlassVolatilityInput::from_slices(
1328 &open,
1329 &high,
1330 &low,
1331 &close,
1332 GarmanKlassVolatilityParams { lookback: Some(16) },
1333 );
1334 let single = garman_klass_volatility(&single_input).expect("single");
1335 let batch = garman_klass_volatility_batch_with_kernel(
1336 &open,
1337 &high,
1338 &low,
1339 &close,
1340 &GarmanKlassVolatilityBatchRange {
1341 lookback: (16, 16, 0),
1342 },
1343 Kernel::ScalarBatch,
1344 )
1345 .expect("batch");
1346 assert_eq!(batch.rows, 1);
1347 assert_eq!(batch.cols, close.len());
1348 for i in 0..batch.values.len() {
1349 if single.values[i].is_nan() {
1350 assert!(batch.values[i].is_nan(), "expected NaN at index {i}");
1351 } else {
1352 assert!(
1353 (batch.values[i] - single.values[i]).abs() <= 1e-12,
1354 "batch mismatch at {i}: {} vs {}",
1355 batch.values[i],
1356 single.values[i]
1357 );
1358 }
1359 }
1360 }
1361
1362 #[test]
1363 fn gk_internal_invalid_bar_produces_nan_window_and_recovers() {
1364 let (mut open, mut high, mut low, mut close) = sample_ohlc(80);
1365 open[30] = f64::NAN;
1366 high[30] = f64::NAN;
1367 low[30] = f64::NAN;
1368 close[30] = f64::NAN;
1369
1370 let input = GarmanKlassVolatilityInput::from_slices(
1371 &open,
1372 &high,
1373 &low,
1374 &close,
1375 GarmanKlassVolatilityParams { lookback: Some(10) },
1376 );
1377 let out = garman_klass_volatility(&input).expect("gk");
1378 assert!(out.values[30].is_nan());
1379 assert!(out.values[39].is_nan());
1380 assert!(out.values[40].is_finite());
1381 }
1382
1383 #[test]
1384 fn gk_rejects_invalid_lookback() {
1385 let (open, high, low, close) = sample_ohlc(8);
1386 let input = GarmanKlassVolatilityInput::from_slices(
1387 &open,
1388 &high,
1389 &low,
1390 &close,
1391 GarmanKlassVolatilityParams { lookback: Some(0) },
1392 );
1393 let err = garman_klass_volatility(&input).unwrap_err();
1394 match err {
1395 GarmanKlassVolatilityError::InvalidLookback { lookback, .. } => {
1396 assert_eq!(lookback, 0);
1397 }
1398 other => panic!("unexpected error: {other:?}"),
1399 }
1400 }
1401}