Skip to main content

vector_ta/indicators/
garman_klass_volatility.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2pub use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
3
4#[cfg(feature = "python")]
5use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
6#[cfg(feature = "python")]
7use pyo3::exceptions::PyValueError;
8#[cfg(feature = "python")]
9use pyo3::prelude::*;
10#[cfg(feature = "python")]
11use pyo3::types::PyDict;
12
13#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
14use serde::{Deserialize, Serialize};
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use wasm_bindgen::prelude::*;
17
18use crate::utilities::data_loader::Candles;
19use crate::utilities::enums::Kernel;
20use crate::utilities::helpers::{
21    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
22    make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::mem::ManuallyDrop;
29use thiserror::Error;
30
31const GK_COEFF: f64 = 2.0 * std::f64::consts::LN_2 - 1.0;
32
33#[derive(Debug, Clone)]
34pub enum GarmanKlassVolatilityData<'a> {
35    Candles {
36        candles: &'a Candles,
37    },
38    Slices {
39        open: &'a [f64],
40        high: &'a [f64],
41        low: &'a [f64],
42        close: &'a [f64],
43    },
44}
45
46#[derive(Debug, Clone)]
47pub struct GarmanKlassVolatilityOutput {
48    pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    all(target_arch = "wasm32", feature = "wasm"),
54    derive(Serialize, Deserialize)
55)]
56pub struct GarmanKlassVolatilityParams {
57    pub lookback: Option<usize>,
58}
59
60impl Default for GarmanKlassVolatilityParams {
61    fn default() -> Self {
62        Self { lookback: Some(14) }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct GarmanKlassVolatilityInput<'a> {
68    pub data: GarmanKlassVolatilityData<'a>,
69    pub params: GarmanKlassVolatilityParams,
70}
71
72impl<'a> GarmanKlassVolatilityInput<'a> {
73    #[inline]
74    pub fn from_candles(candles: &'a Candles, params: GarmanKlassVolatilityParams) -> Self {
75        Self {
76            data: GarmanKlassVolatilityData::Candles { candles },
77            params,
78        }
79    }
80
81    #[inline]
82    pub fn from_slices(
83        open: &'a [f64],
84        high: &'a [f64],
85        low: &'a [f64],
86        close: &'a [f64],
87        params: GarmanKlassVolatilityParams,
88    ) -> Self {
89        Self {
90            data: GarmanKlassVolatilityData::Slices {
91                open,
92                high,
93                low,
94                close,
95            },
96            params,
97        }
98    }
99
100    #[inline]
101    pub fn with_default_candles(candles: &'a Candles) -> Self {
102        Self::from_candles(candles, GarmanKlassVolatilityParams::default())
103    }
104
105    #[inline]
106    pub fn get_lookback(&self) -> usize {
107        self.params.lookback.unwrap_or(14)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct GarmanKlassVolatilityBuilder {
113    lookback: Option<usize>,
114    kernel: Kernel,
115}
116
117impl Default for GarmanKlassVolatilityBuilder {
118    fn default() -> Self {
119        Self {
120            lookback: None,
121            kernel: Kernel::Auto,
122        }
123    }
124}
125
126impl GarmanKlassVolatilityBuilder {
127    #[inline(always)]
128    pub fn new() -> Self {
129        Self::default()
130    }
131
132    #[inline(always)]
133    pub fn lookback(mut self, lookback: usize) -> Self {
134        self.lookback = Some(lookback);
135        self
136    }
137
138    #[inline(always)]
139    pub fn kernel(mut self, kernel: Kernel) -> Self {
140        self.kernel = kernel;
141        self
142    }
143
144    #[inline(always)]
145    pub fn apply(
146        self,
147        candles: &Candles,
148    ) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
149        let input = GarmanKlassVolatilityInput::from_candles(
150            candles,
151            GarmanKlassVolatilityParams {
152                lookback: self.lookback,
153            },
154        );
155        garman_klass_volatility_with_kernel(&input, self.kernel)
156    }
157
158    #[inline(always)]
159    pub fn apply_slices(
160        self,
161        open: &[f64],
162        high: &[f64],
163        low: &[f64],
164        close: &[f64],
165    ) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
166        let input = GarmanKlassVolatilityInput::from_slices(
167            open,
168            high,
169            low,
170            close,
171            GarmanKlassVolatilityParams {
172                lookback: self.lookback,
173            },
174        );
175        garman_klass_volatility_with_kernel(&input, self.kernel)
176    }
177
178    #[inline(always)]
179    pub fn into_stream(self) -> Result<GarmanKlassVolatilityStream, GarmanKlassVolatilityError> {
180        GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
181            lookback: self.lookback,
182        })
183    }
184}
185
186#[derive(Debug, Error)]
187pub enum GarmanKlassVolatilityError {
188    #[error("garman_klass_volatility: Input data slice is empty.")]
189    EmptyInputData,
190    #[error("garman_klass_volatility: All values are NaN or non-positive.")]
191    AllValuesNaN,
192    #[error(
193        "garman_klass_volatility: Invalid lookback: lookback = {lookback}, data length = {data_len}"
194    )]
195    InvalidLookback { lookback: usize, data_len: usize },
196    #[error("garman_klass_volatility: Not enough valid data: needed = {needed}, valid = {valid}")]
197    NotEnoughValidData { needed: usize, valid: usize },
198    #[error("garman_klass_volatility: Inconsistent slice lengths: open={open_len}, high={high_len}, low={low_len}, close={close_len}")]
199    InconsistentSliceLengths {
200        open_len: usize,
201        high_len: usize,
202        low_len: usize,
203        close_len: usize,
204    },
205    #[error("garman_klass_volatility: Output length mismatch: expected = {expected}, got = {got}")]
206    OutputLengthMismatch { expected: usize, got: usize },
207    #[error("garman_klass_volatility: Invalid range: start={start}, end={end}, step={step}")]
208    InvalidRange {
209        start: String,
210        end: String,
211        step: String,
212    },
213    #[error("garman_klass_volatility: Invalid kernel for batch: {0:?}")]
214    InvalidKernelForBatch(Kernel),
215}
216
217#[derive(Debug, Clone)]
218pub struct GarmanKlassVolatilityStream {
219    lookback: usize,
220    terms: Vec<f64>,
221    valid: Vec<u8>,
222    idx: usize,
223    cnt: usize,
224    valid_count: usize,
225    sum_terms: f64,
226}
227
228impl GarmanKlassVolatilityStream {
229    pub fn try_new(
230        params: GarmanKlassVolatilityParams,
231    ) -> Result<GarmanKlassVolatilityStream, GarmanKlassVolatilityError> {
232        let lookback = params.lookback.unwrap_or(14);
233        if lookback == 0 {
234            return Err(GarmanKlassVolatilityError::InvalidLookback {
235                lookback,
236                data_len: 0,
237            });
238        }
239        Ok(Self {
240            lookback,
241            terms: vec![0.0; lookback],
242            valid: vec![0u8; lookback],
243            idx: 0,
244            cnt: 0,
245            valid_count: 0,
246            sum_terms: 0.0,
247        })
248    }
249
250    #[inline(always)]
251    pub fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<f64> {
252        if self.cnt >= self.lookback {
253            let old_idx = self.idx;
254            if self.valid[old_idx] != 0 {
255                self.valid_count = self.valid_count.saturating_sub(1);
256                self.sum_terms -= self.terms[old_idx];
257            }
258        } else {
259            self.cnt += 1;
260        }
261
262        if valid_ohlc_bar(open, high, low, close) {
263            let term = gk_term(open, high, low, close);
264            self.terms[self.idx] = term;
265            self.valid[self.idx] = 1;
266            self.valid_count += 1;
267            self.sum_terms += term;
268        } else {
269            self.terms[self.idx] = 0.0;
270            self.valid[self.idx] = 0;
271        }
272
273        self.idx += 1;
274        if self.idx == self.lookback {
275            self.idx = 0;
276        }
277
278        if self.cnt < self.lookback || self.valid_count != self.lookback {
279            return None;
280        }
281
282        let mut variance = self.sum_terms / self.lookback as f64;
283        if variance < 0.0 {
284            variance = 0.0;
285        }
286        Some(variance.sqrt())
287    }
288
289    #[inline(always)]
290    pub fn get_warmup_period(&self) -> usize {
291        self.lookback.saturating_sub(1)
292    }
293}
294
295#[inline]
296pub fn garman_klass_volatility(
297    input: &GarmanKlassVolatilityInput,
298) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
299    garman_klass_volatility_with_kernel(input, Kernel::Auto)
300}
301
302#[inline(always)]
303fn valid_ohlc_bar(open: f64, high: f64, low: f64, close: f64) -> bool {
304    open.is_finite()
305        && high.is_finite()
306        && low.is_finite()
307        && close.is_finite()
308        && open > 0.0
309        && high > 0.0
310        && low > 0.0
311        && close > 0.0
312}
313
314#[inline(always)]
315fn gk_term(open: f64, high: f64, low: f64, close: f64) -> f64 {
316    let hl = (high / low).ln();
317    let co = (close / open).ln();
318    0.5 * hl * hl - GK_COEFF * co * co
319}
320
321#[inline(always)]
322fn first_valid_ohlc(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
323    let len = close.len();
324    let mut i = 0usize;
325    while i < len {
326        if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
327            break;
328        }
329        i += 1;
330    }
331    i.min(len)
332}
333
334#[inline(always)]
335fn count_valid_ohlc(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> usize {
336    let mut count = 0usize;
337    for i in 0..close.len() {
338        if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
339            count += 1;
340        }
341    }
342    count
343}
344
345#[inline(always)]
346fn build_prefix_terms(
347    open: &[f64],
348    high: &[f64],
349    low: &[f64],
350    close: &[f64],
351) -> (Vec<u32>, Vec<f64>) {
352    let len = close.len();
353    let mut prefix_valid = vec![0u32; len + 1];
354    let mut prefix_sum = vec![0.0f64; len + 1];
355
356    for i in 0..len {
357        if valid_ohlc_bar(open[i], high[i], low[i], close[i]) {
358            prefix_valid[i + 1] = prefix_valid[i] + 1;
359            prefix_sum[i + 1] = prefix_sum[i] + gk_term(open[i], high[i], low[i], close[i]);
360        } else {
361            prefix_valid[i + 1] = prefix_valid[i];
362            prefix_sum[i + 1] = prefix_sum[i];
363        }
364    }
365
366    (prefix_valid, prefix_sum)
367}
368
369#[inline(always)]
370fn gk_row_from_prefix(
371    prefix_valid: &[u32],
372    prefix_sum: &[f64],
373    lookback: usize,
374    first: usize,
375    out: &mut [f64],
376) {
377    let warmup = first.saturating_add(lookback.saturating_sub(1));
378    let lookback_u32 = lookback as u32;
379    let inv_lb = 1.0 / lookback as f64;
380
381    for (t, slot) in out.iter_mut().enumerate() {
382        if t < warmup {
383            *slot = f64::NAN;
384            continue;
385        }
386
387        let window_start = t + 1 - lookback;
388        let valid_count = prefix_valid[t + 1] - prefix_valid[window_start];
389        if valid_count != lookback_u32 {
390            *slot = f64::NAN;
391            continue;
392        }
393
394        let mut variance = (prefix_sum[t + 1] - prefix_sum[window_start]) * inv_lb;
395        if variance < 0.0 {
396            variance = 0.0;
397        }
398        *slot = variance.sqrt();
399    }
400}
401
402#[inline(always)]
403fn garman_klass_prepare<'a>(
404    input: &'a GarmanKlassVolatilityInput,
405    kernel: Kernel,
406) -> Result<
407    (
408        &'a [f64],
409        &'a [f64],
410        &'a [f64],
411        &'a [f64],
412        usize,
413        usize,
414        Kernel,
415    ),
416    GarmanKlassVolatilityError,
417> {
418    let (open, high, low, close): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
419        GarmanKlassVolatilityData::Candles { candles } => {
420            (&candles.open, &candles.high, &candles.low, &candles.close)
421        }
422        GarmanKlassVolatilityData::Slices {
423            open,
424            high,
425            low,
426            close,
427        } => (open, high, low, close),
428    };
429
430    let len = close.len();
431    if len == 0 {
432        return Err(GarmanKlassVolatilityError::EmptyInputData);
433    }
434    if open.len() != len || high.len() != len || low.len() != len {
435        return Err(GarmanKlassVolatilityError::InconsistentSliceLengths {
436            open_len: open.len(),
437            high_len: high.len(),
438            low_len: low.len(),
439            close_len: close.len(),
440        });
441    }
442
443    let first = first_valid_ohlc(open, high, low, close);
444    if first >= len {
445        return Err(GarmanKlassVolatilityError::AllValuesNaN);
446    }
447
448    let lookback = input.get_lookback();
449    if lookback == 0 || lookback > len {
450        return Err(GarmanKlassVolatilityError::InvalidLookback {
451            lookback,
452            data_len: len,
453        });
454    }
455
456    let valid = count_valid_ohlc(open, high, low, close);
457    if valid < lookback {
458        return Err(GarmanKlassVolatilityError::NotEnoughValidData {
459            needed: lookback,
460            valid,
461        });
462    }
463
464    let chosen = match kernel {
465        Kernel::Auto => detect_best_kernel(),
466        other => other.to_non_batch(),
467    };
468
469    Ok((open, high, low, close, lookback, first, chosen))
470}
471
472#[inline]
473pub fn garman_klass_volatility_with_kernel(
474    input: &GarmanKlassVolatilityInput,
475    kernel: Kernel,
476) -> Result<GarmanKlassVolatilityOutput, GarmanKlassVolatilityError> {
477    let (open, high, low, close, lookback, first, _chosen) = garman_klass_prepare(input, kernel)?;
478    let len = close.len();
479    let warmup = first.saturating_add(lookback.saturating_sub(1));
480    let mut values = alloc_with_nan_prefix(len, warmup);
481    let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
482    gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, &mut values);
483    Ok(GarmanKlassVolatilityOutput { values })
484}
485
486#[inline]
487pub fn garman_klass_volatility_into_slice(
488    dst: &mut [f64],
489    input: &GarmanKlassVolatilityInput,
490    kernel: Kernel,
491) -> Result<(), GarmanKlassVolatilityError> {
492    let (open, high, low, close, lookback, first, _chosen) = garman_klass_prepare(input, kernel)?;
493    let expected = close.len();
494    if dst.len() != expected {
495        return Err(GarmanKlassVolatilityError::OutputLengthMismatch {
496            expected,
497            got: dst.len(),
498        });
499    }
500    let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
501    gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, dst);
502    Ok(())
503}
504
505#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
506#[inline]
507pub fn garman_klass_volatility_into(
508    input: &GarmanKlassVolatilityInput,
509    out: &mut [f64],
510) -> Result<(), GarmanKlassVolatilityError> {
511    garman_klass_volatility_into_slice(out, input, Kernel::Auto)
512}
513
514#[derive(Clone, Debug)]
515pub struct GarmanKlassVolatilityBatchRange {
516    pub lookback: (usize, usize, usize),
517}
518
519impl Default for GarmanKlassVolatilityBatchRange {
520    fn default() -> Self {
521        Self {
522            lookback: (14, 252, 1),
523        }
524    }
525}
526
527#[derive(Clone, Debug, Default)]
528pub struct GarmanKlassVolatilityBatchBuilder {
529    range: GarmanKlassVolatilityBatchRange,
530    kernel: Kernel,
531}
532
533impl GarmanKlassVolatilityBatchBuilder {
534    pub fn new() -> Self {
535        Self::default()
536    }
537
538    pub fn kernel(mut self, kernel: Kernel) -> Self {
539        self.kernel = kernel;
540        self
541    }
542
543    #[inline]
544    pub fn lookback_range(mut self, start: usize, end: usize, step: usize) -> Self {
545        self.range.lookback = (start, end, step);
546        self
547    }
548
549    #[inline]
550    pub fn lookback_static(mut self, lookback: usize) -> Self {
551        self.range.lookback = (lookback, lookback, 0);
552        self
553    }
554
555    pub fn apply_slices(
556        self,
557        open: &[f64],
558        high: &[f64],
559        low: &[f64],
560        close: &[f64],
561    ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
562        garman_klass_volatility_batch_with_kernel(open, high, low, close, &self.range, self.kernel)
563    }
564
565    pub fn apply_candles(
566        self,
567        candles: &Candles,
568    ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
569        self.apply_slices(&candles.open, &candles.high, &candles.low, &candles.close)
570    }
571
572    pub fn with_default_candles(
573        candles: &Candles,
574    ) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
575        GarmanKlassVolatilityBatchBuilder::new()
576            .kernel(Kernel::Auto)
577            .apply_candles(candles)
578    }
579}
580
581#[derive(Clone, Debug)]
582pub struct GarmanKlassVolatilityBatchOutput {
583    pub values: Vec<f64>,
584    pub combos: Vec<GarmanKlassVolatilityParams>,
585    pub rows: usize,
586    pub cols: usize,
587}
588
589impl GarmanKlassVolatilityBatchOutput {
590    pub fn row_for_params(&self, params: &GarmanKlassVolatilityParams) -> Option<usize> {
591        self.combos
592            .iter()
593            .position(|combo| combo.lookback.unwrap_or(14) == params.lookback.unwrap_or(14))
594    }
595
596    pub fn values_for(&self, params: &GarmanKlassVolatilityParams) -> Option<&[f64]> {
597        self.row_for_params(params).and_then(|row| {
598            row.checked_mul(self.cols)
599                .and_then(|start| self.values.get(start..start + self.cols))
600        })
601    }
602}
603
604#[inline(always)]
605fn expand_grid_garman_klass(
606    range: &GarmanKlassVolatilityBatchRange,
607) -> Result<Vec<GarmanKlassVolatilityParams>, GarmanKlassVolatilityError> {
608    fn axis_usize(
609        (start, end, step): (usize, usize, usize),
610    ) -> Result<Vec<usize>, GarmanKlassVolatilityError> {
611        if step == 0 || start == end {
612            return Ok(vec![start]);
613        }
614        let step = step.max(1);
615        if start < end {
616            let mut out = Vec::new();
617            let mut x = start;
618            while x <= end {
619                out.push(x);
620                match x.checked_add(step) {
621                    Some(next) if next != x => x = next,
622                    _ => break,
623                }
624            }
625            if out.is_empty() {
626                return Err(GarmanKlassVolatilityError::InvalidRange {
627                    start: start.to_string(),
628                    end: end.to_string(),
629                    step: step.to_string(),
630                });
631            }
632            Ok(out)
633        } else {
634            let mut out = Vec::new();
635            let mut x = start;
636            loop {
637                out.push(x);
638                if x == end {
639                    break;
640                }
641                let next = x.saturating_sub(step);
642                if next == x || next < end {
643                    break;
644                }
645                x = next;
646            }
647            if out.is_empty() {
648                return Err(GarmanKlassVolatilityError::InvalidRange {
649                    start: start.to_string(),
650                    end: end.to_string(),
651                    step: step.to_string(),
652                });
653            }
654            Ok(out)
655        }
656    }
657
658    Ok(axis_usize(range.lookback)?
659        .into_iter()
660        .map(|lookback| GarmanKlassVolatilityParams {
661            lookback: Some(lookback),
662        })
663        .collect())
664}
665
666pub fn garman_klass_volatility_batch_with_kernel(
667    open: &[f64],
668    high: &[f64],
669    low: &[f64],
670    close: &[f64],
671    sweep: &GarmanKlassVolatilityBatchRange,
672    kernel: Kernel,
673) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
674    let batch_kernel = match kernel {
675        Kernel::Auto => detect_best_batch_kernel(),
676        other if other.is_batch() => other,
677        other => return Err(GarmanKlassVolatilityError::InvalidKernelForBatch(other)),
678    };
679    garman_klass_volatility_batch_par_slice(
680        open,
681        high,
682        low,
683        close,
684        sweep,
685        batch_kernel.to_non_batch(),
686    )
687}
688
689#[inline(always)]
690pub fn garman_klass_volatility_batch_slice(
691    open: &[f64],
692    high: &[f64],
693    low: &[f64],
694    close: &[f64],
695    sweep: &GarmanKlassVolatilityBatchRange,
696    kernel: Kernel,
697) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
698    garman_klass_volatility_batch_inner(open, high, low, close, sweep, kernel, false)
699}
700
701#[inline(always)]
702pub fn garman_klass_volatility_batch_par_slice(
703    open: &[f64],
704    high: &[f64],
705    low: &[f64],
706    close: &[f64],
707    sweep: &GarmanKlassVolatilityBatchRange,
708    kernel: Kernel,
709) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
710    garman_klass_volatility_batch_inner(open, high, low, close, sweep, kernel, true)
711}
712
713#[inline(always)]
714fn garman_klass_volatility_batch_inner(
715    open: &[f64],
716    high: &[f64],
717    low: &[f64],
718    close: &[f64],
719    sweep: &GarmanKlassVolatilityBatchRange,
720    _kernel: Kernel,
721    parallel: bool,
722) -> Result<GarmanKlassVolatilityBatchOutput, GarmanKlassVolatilityError> {
723    let combos = expand_grid_garman_klass(sweep)?;
724    let len = close.len();
725    if len == 0 {
726        return Err(GarmanKlassVolatilityError::EmptyInputData);
727    }
728    if open.len() != len || high.len() != len || low.len() != len {
729        return Err(GarmanKlassVolatilityError::InconsistentSliceLengths {
730            open_len: open.len(),
731            high_len: high.len(),
732            low_len: low.len(),
733            close_len: close.len(),
734        });
735    }
736
737    let first = first_valid_ohlc(open, high, low, close);
738    if first >= len {
739        return Err(GarmanKlassVolatilityError::AllValuesNaN);
740    }
741
742    let valid = count_valid_ohlc(open, high, low, close);
743    let max_lookback = combos
744        .iter()
745        .map(|combo| combo.lookback.unwrap_or(14))
746        .max()
747        .unwrap_or(0);
748    if max_lookback == 0 || valid < max_lookback {
749        return Err(GarmanKlassVolatilityError::NotEnoughValidData {
750            needed: max_lookback,
751            valid,
752        });
753    }
754
755    let rows = combos.len();
756    let cols = len;
757    let mut buf_mu = make_uninit_matrix(rows, cols);
758    let warmups: Vec<usize> = combos
759        .iter()
760        .map(|combo| first.saturating_add(combo.lookback.unwrap_or(14).saturating_sub(1)))
761        .collect();
762    init_matrix_prefixes(&mut buf_mu, cols, &warmups);
763
764    let mut guard = ManuallyDrop::new(buf_mu);
765    let out: &mut [f64] =
766        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
767    let (prefix_valid, prefix_sum) = build_prefix_terms(open, high, low, close);
768
769    if parallel {
770        #[cfg(not(target_arch = "wasm32"))]
771        out.par_chunks_mut(cols)
772            .enumerate()
773            .for_each(|(row, out_row)| {
774                let lookback = combos[row].lookback.unwrap_or(14);
775                gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
776            });
777
778        #[cfg(target_arch = "wasm32")]
779        for (row, out_row) in out.chunks_mut(cols).enumerate() {
780            let lookback = combos[row].lookback.unwrap_or(14);
781            gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
782        }
783    } else {
784        for (row, out_row) in out.chunks_mut(cols).enumerate() {
785            let lookback = combos[row].lookback.unwrap_or(14);
786            gk_row_from_prefix(&prefix_valid, &prefix_sum, lookback, first, out_row);
787        }
788    }
789
790    let values = unsafe {
791        Vec::from_raw_parts(
792            guard.as_mut_ptr() as *mut f64,
793            guard.len(),
794            guard.capacity(),
795        )
796    };
797
798    Ok(GarmanKlassVolatilityBatchOutput {
799        values,
800        combos,
801        rows,
802        cols,
803    })
804}
805
806#[cfg(feature = "python")]
807#[pyfunction(name = "garman_klass_volatility")]
808#[pyo3(signature = (open, high, low, close, lookback=14, kernel=None))]
809pub fn garman_klass_volatility_py<'py>(
810    py: Python<'py>,
811    open: PyReadonlyArray1<'py, f64>,
812    high: PyReadonlyArray1<'py, f64>,
813    low: PyReadonlyArray1<'py, f64>,
814    close: PyReadonlyArray1<'py, f64>,
815    lookback: usize,
816    kernel: Option<&str>,
817) -> PyResult<Bound<'py, PyArray1<f64>>> {
818    let open = open.as_slice()?;
819    let high = high.as_slice()?;
820    let low = low.as_slice()?;
821    let close = close.as_slice()?;
822    if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
823        return Err(PyValueError::new_err("OHLC slice length mismatch"));
824    }
825
826    let kernel = validate_kernel(kernel, false)?;
827    let input = GarmanKlassVolatilityInput::from_slices(
828        open,
829        high,
830        low,
831        close,
832        GarmanKlassVolatilityParams {
833            lookback: Some(lookback),
834        },
835    );
836    let output = py
837        .allow_threads(|| garman_klass_volatility_with_kernel(&input, kernel))
838        .map_err(|e| PyValueError::new_err(e.to_string()))?;
839    Ok(output.values.into_pyarray(py))
840}
841
842#[cfg(feature = "python")]
843#[pyclass(name = "GarmanKlassVolatilityStream")]
844pub struct GarmanKlassVolatilityStreamPy {
845    stream: GarmanKlassVolatilityStream,
846}
847
848#[cfg(feature = "python")]
849#[pymethods]
850impl GarmanKlassVolatilityStreamPy {
851    #[new]
852    fn new(lookback: usize) -> PyResult<Self> {
853        let stream = GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
854            lookback: Some(lookback),
855        })
856        .map_err(|e| PyValueError::new_err(e.to_string()))?;
857        Ok(Self { stream })
858    }
859
860    fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<f64> {
861        self.stream.update(open, high, low, close)
862    }
863}
864
865#[cfg(feature = "python")]
866#[pyfunction(name = "garman_klass_volatility_batch")]
867#[pyo3(signature = (open, high, low, close, lookback_range, kernel=None))]
868pub fn garman_klass_volatility_batch_py<'py>(
869    py: Python<'py>,
870    open: PyReadonlyArray1<'py, f64>,
871    high: PyReadonlyArray1<'py, f64>,
872    low: PyReadonlyArray1<'py, f64>,
873    close: PyReadonlyArray1<'py, f64>,
874    lookback_range: (usize, usize, usize),
875    kernel: Option<&str>,
876) -> PyResult<Bound<'py, PyDict>> {
877    let open = open.as_slice()?;
878    let high = high.as_slice()?;
879    let low = low.as_slice()?;
880    let close = close.as_slice()?;
881    if open.len() != high.len() || open.len() != low.len() || open.len() != close.len() {
882        return Err(PyValueError::new_err("OHLC slice length mismatch"));
883    }
884
885    let sweep = GarmanKlassVolatilityBatchRange {
886        lookback: lookback_range,
887    };
888    let output = {
889        let kernel = validate_kernel(kernel, true)?;
890        py.allow_threads(|| {
891            let batch = match kernel {
892                Kernel::Auto => detect_best_batch_kernel(),
893                other => other,
894            };
895            garman_klass_volatility_batch_inner(
896                open,
897                high,
898                low,
899                close,
900                &sweep,
901                batch.to_non_batch(),
902                true,
903            )
904        })
905        .map_err(|e| PyValueError::new_err(e.to_string()))?
906    };
907
908    let dict = PyDict::new(py);
909    dict.set_item(
910        "values",
911        output
912            .values
913            .into_pyarray(py)
914            .reshape((output.rows, output.cols))?,
915    )?;
916    dict.set_item(
917        "lookbacks",
918        output
919            .combos
920            .iter()
921            .map(|combo| combo.lookback.unwrap_or(14) as u64)
922            .collect::<Vec<_>>()
923            .into_pyarray(py),
924    )?;
925    dict.set_item("rows", output.rows)?;
926    dict.set_item("cols", output.cols)?;
927    Ok(dict)
928}
929
930#[cfg(feature = "python")]
931pub fn register_garman_klass_volatility_module(
932    module: &Bound<'_, pyo3::types::PyModule>,
933) -> PyResult<()> {
934    module.add_function(wrap_pyfunction!(garman_klass_volatility_py, module)?)?;
935    module.add_function(wrap_pyfunction!(garman_klass_volatility_batch_py, module)?)?;
936    module.add_class::<GarmanKlassVolatilityStreamPy>()?;
937    Ok(())
938}
939
940#[cfg(all(feature = "python", feature = "cuda"))]
941#[pyfunction(name = "garman_klass_volatility_cuda_batch_dev")]
942#[pyo3(signature = (open_f32, high_f32, low_f32, close_f32, lookback_range, device_id=0))]
943pub fn garman_klass_volatility_cuda_batch_dev_py<'py>(
944    py: Python<'py>,
945    open_f32: PyReadonlyArray1<'py, f32>,
946    high_f32: PyReadonlyArray1<'py, f32>,
947    low_f32: PyReadonlyArray1<'py, f32>,
948    close_f32: PyReadonlyArray1<'py, f32>,
949    lookback_range: (usize, usize, usize),
950    device_id: usize,
951) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
952    use crate::cuda::{cuda_available, CudaGarmanKlassVolatility};
953
954    if !cuda_available() {
955        return Err(PyValueError::new_err("CUDA not available"));
956    }
957
958    let open = open_f32.as_slice()?;
959    let high = high_f32.as_slice()?;
960    let low = low_f32.as_slice()?;
961    let close = close_f32.as_slice()?;
962    let sweep = GarmanKlassVolatilityBatchRange {
963        lookback: lookback_range,
964    };
965    let result = py.allow_threads(|| {
966        let cuda = CudaGarmanKlassVolatility::new(device_id)
967            .map_err(|e| PyValueError::new_err(e.to_string()))?;
968        cuda.garman_klass_volatility_batch_dev(open, high, low, close, &sweep)
969            .map_err(|e| PyValueError::new_err(e.to_string()))
970    })?;
971
972    let dict = PyDict::new(py);
973    dict.set_item(
974        "lookbacks",
975        result
976            .combos
977            .iter()
978            .map(|combo| combo.lookback.unwrap_or(14) as u64)
979            .collect::<Vec<_>>()
980            .into_pyarray(py),
981    )?;
982    Ok((make_device_array_py(device_id, result.outputs)?, dict))
983}
984
985#[cfg(all(feature = "python", feature = "cuda"))]
986#[pyfunction(name = "garman_klass_volatility_cuda_many_series_one_param_dev")]
987#[pyo3(signature = (open_tm_f32, high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, lookback=14, device_id=0))]
988pub fn garman_klass_volatility_cuda_many_series_one_param_dev_py<'py>(
989    py: Python<'py>,
990    open_tm_f32: PyReadonlyArray1<'py, f32>,
991    high_tm_f32: PyReadonlyArray1<'py, f32>,
992    low_tm_f32: PyReadonlyArray1<'py, f32>,
993    close_tm_f32: PyReadonlyArray1<'py, f32>,
994    cols: usize,
995    rows: usize,
996    lookback: usize,
997    device_id: usize,
998) -> PyResult<DeviceArrayF32Py> {
999    use crate::cuda::{cuda_available, CudaGarmanKlassVolatility};
1000
1001    if !cuda_available() {
1002        return Err(PyValueError::new_err("CUDA not available"));
1003    }
1004
1005    let open = open_tm_f32.as_slice()?;
1006    let high = high_tm_f32.as_slice()?;
1007    let low = low_tm_f32.as_slice()?;
1008    let close = close_tm_f32.as_slice()?;
1009    let inner = py.allow_threads(|| {
1010        let cuda = CudaGarmanKlassVolatility::new(device_id)
1011            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1012        cuda.garman_klass_volatility_many_series_one_param_time_major_dev(
1013            open, high, low, close, cols, rows, lookback,
1014        )
1015        .map_err(|e| PyValueError::new_err(e.to_string()))
1016    })?;
1017    make_device_array_py(device_id, inner)
1018}
1019
1020#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1021#[wasm_bindgen(js_name = "garman_klass_volatility_js")]
1022pub fn garman_klass_volatility_js(
1023    open: &[f64],
1024    high: &[f64],
1025    low: &[f64],
1026    close: &[f64],
1027    lookback: usize,
1028) -> Result<Vec<f64>, JsValue> {
1029    let input = GarmanKlassVolatilityInput::from_slices(
1030        open,
1031        high,
1032        low,
1033        close,
1034        GarmanKlassVolatilityParams {
1035            lookback: Some(lookback),
1036        },
1037    );
1038    let mut output = vec![0.0; close.len()];
1039    garman_klass_volatility_into_slice(&mut output, &input, Kernel::Auto)
1040        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1041    Ok(output)
1042}
1043
1044#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1045#[wasm_bindgen]
1046pub fn garman_klass_volatility_alloc(len: usize) -> *mut f64 {
1047    let mut vec = Vec::<f64>::with_capacity(len);
1048    let ptr = vec.as_mut_ptr();
1049    std::mem::forget(vec);
1050    ptr
1051}
1052
1053#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1054#[wasm_bindgen]
1055pub fn garman_klass_volatility_free(ptr: *mut f64, len: usize) {
1056    if !ptr.is_null() {
1057        unsafe {
1058            let _ = Vec::from_raw_parts(ptr, len, len);
1059        }
1060    }
1061}
1062
1063#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1064#[wasm_bindgen]
1065pub fn garman_klass_volatility_into(
1066    open_ptr: *const f64,
1067    high_ptr: *const f64,
1068    low_ptr: *const f64,
1069    close_ptr: *const f64,
1070    out_ptr: *mut f64,
1071    len: usize,
1072    lookback: usize,
1073) -> Result<(), JsValue> {
1074    if open_ptr.is_null()
1075        || high_ptr.is_null()
1076        || low_ptr.is_null()
1077        || close_ptr.is_null()
1078        || out_ptr.is_null()
1079    {
1080        return Err(JsValue::from_str("Null pointer provided"));
1081    }
1082
1083    unsafe {
1084        let open = std::slice::from_raw_parts(open_ptr, len);
1085        let high = std::slice::from_raw_parts(high_ptr, len);
1086        let low = std::slice::from_raw_parts(low_ptr, len);
1087        let close = std::slice::from_raw_parts(close_ptr, len);
1088        let input = GarmanKlassVolatilityInput::from_slices(
1089            open,
1090            high,
1091            low,
1092            close,
1093            GarmanKlassVolatilityParams {
1094                lookback: Some(lookback),
1095            },
1096        );
1097
1098        if open_ptr == out_ptr || high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr
1099        {
1100            let mut tmp = vec![0.0; len];
1101            garman_klass_volatility_into_slice(&mut tmp, &input, Kernel::Auto)
1102                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1103            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1104        } else {
1105            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1106            garman_klass_volatility_into_slice(out, &input, Kernel::Auto)
1107                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1108        }
1109    }
1110    Ok(())
1111}
1112
1113#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1114#[derive(Serialize, Deserialize)]
1115pub struct GarmanKlassVolatilityBatchConfig {
1116    pub lookback_range: (usize, usize, usize),
1117}
1118
1119#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1120#[derive(Serialize, Deserialize)]
1121pub struct GarmanKlassVolatilityBatchJsOutput {
1122    pub values: Vec<f64>,
1123    pub combos: Vec<GarmanKlassVolatilityParams>,
1124    pub rows: usize,
1125    pub cols: usize,
1126}
1127
1128#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1129#[wasm_bindgen(js_name = "garman_klass_volatility_batch_js")]
1130pub fn garman_klass_volatility_batch_js(
1131    open: &[f64],
1132    high: &[f64],
1133    low: &[f64],
1134    close: &[f64],
1135    config: JsValue,
1136) -> Result<JsValue, JsValue> {
1137    let config: GarmanKlassVolatilityBatchConfig = serde_wasm_bindgen::from_value(config)
1138        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1139    let sweep = GarmanKlassVolatilityBatchRange {
1140        lookback: config.lookback_range,
1141    };
1142    let output = garman_klass_volatility_batch_inner(
1143        open,
1144        high,
1145        low,
1146        close,
1147        &sweep,
1148        detect_best_kernel(),
1149        false,
1150    )
1151    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1152    serde_wasm_bindgen::to_value(&GarmanKlassVolatilityBatchJsOutput {
1153        values: output.values,
1154        combos: output.combos,
1155        rows: output.rows,
1156        cols: output.cols,
1157    })
1158    .map_err(|e| JsValue::from_str(&e.to_string()))
1159}
1160
1161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1162#[wasm_bindgen]
1163pub fn garman_klass_volatility_batch_into(
1164    open_ptr: *const f64,
1165    high_ptr: *const f64,
1166    low_ptr: *const f64,
1167    close_ptr: *const f64,
1168    out_ptr: *mut f64,
1169    len: usize,
1170    lookback_start: usize,
1171    lookback_end: usize,
1172    lookback_step: usize,
1173) -> Result<usize, JsValue> {
1174    if open_ptr.is_null()
1175        || high_ptr.is_null()
1176        || low_ptr.is_null()
1177        || close_ptr.is_null()
1178        || out_ptr.is_null()
1179    {
1180        return Err(JsValue::from_str("Null pointer provided"));
1181    }
1182
1183    let sweep = GarmanKlassVolatilityBatchRange {
1184        lookback: (lookback_start, lookback_end, lookback_step),
1185    };
1186    let combos = expand_grid_garman_klass(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1187    let rows = combos.len();
1188
1189    unsafe {
1190        let open = std::slice::from_raw_parts(open_ptr, len);
1191        let high = std::slice::from_raw_parts(high_ptr, len);
1192        let low = std::slice::from_raw_parts(low_ptr, len);
1193        let close = std::slice::from_raw_parts(close_ptr, len);
1194        let total = rows
1195            .checked_mul(len)
1196            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1197        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1198        let batch = garman_klass_volatility_batch_inner(
1199            open,
1200            high,
1201            low,
1202            close,
1203            &sweep,
1204            detect_best_kernel(),
1205            false,
1206        )
1207        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1208        out.copy_from_slice(&batch.values);
1209    }
1210
1211    Ok(rows)
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216    use super::*;
1217
1218    fn sample_ohlc(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1219        let mut open = vec![f64::NAN; len];
1220        let mut high = vec![f64::NAN; len];
1221        let mut low = vec![f64::NAN; len];
1222        let mut close = vec![f64::NAN; len];
1223        let mut prev = 100.0;
1224        for i in 2..len {
1225            let x = i as f64;
1226            let o = (prev + (x * 0.021).sin() * 1.5 + 0.03 * x).max(1.0);
1227            let c = (o + (x * 0.017).cos() * 0.8).max(1.0);
1228            let h = o.max(c) + 0.5 + (x * 0.011).sin().abs() * 0.2;
1229            let l = (o.min(c) - 0.45 - (x * 0.013).cos().abs() * 0.15).max(0.01);
1230            open[i] = o;
1231            high[i] = h;
1232            low[i] = l;
1233            close[i] = c;
1234            prev = c;
1235        }
1236        (open, high, low, close)
1237    }
1238
1239    #[test]
1240    fn gk_output_contract() {
1241        let (open, high, low, close) = sample_ohlc(128);
1242        let input = GarmanKlassVolatilityInput::from_slices(
1243            &open,
1244            &high,
1245            &low,
1246            &close,
1247            GarmanKlassVolatilityParams { lookback: Some(14) },
1248        );
1249        let out = garman_klass_volatility(&input).expect("gk");
1250        assert_eq!(out.values.len(), close.len());
1251        assert!(out.values.iter().any(|v| v.is_finite()));
1252        let first_valid = out
1253            .values
1254            .iter()
1255            .position(|v| v.is_finite())
1256            .expect("first valid");
1257        assert!(first_valid >= 15);
1258    }
1259
1260    #[test]
1261    fn gk_into_matches_api() {
1262        let (open, high, low, close) = sample_ohlc(192);
1263        let input = GarmanKlassVolatilityInput::from_slices(
1264            &open,
1265            &high,
1266            &low,
1267            &close,
1268            GarmanKlassVolatilityParams { lookback: Some(20) },
1269        );
1270        let api = garman_klass_volatility(&input).expect("api");
1271        let mut out = vec![0.0; close.len()];
1272        garman_klass_volatility_into(&input, &mut out).expect("into");
1273        for i in 0..out.len() {
1274            if api.values[i].is_nan() {
1275                assert!(out[i].is_nan(), "expected NaN at index {i}");
1276            } else {
1277                assert!(
1278                    (api.values[i] - out[i]).abs() <= 1e-12,
1279                    "into mismatch at {i}: {} vs {}",
1280                    api.values[i],
1281                    out[i]
1282                );
1283            }
1284        }
1285    }
1286
1287    #[test]
1288    fn gk_stream_matches_batch() {
1289        let (open, high, low, close) = sample_ohlc(160);
1290        let input = GarmanKlassVolatilityInput::from_slices(
1291            &open,
1292            &high,
1293            &low,
1294            &close,
1295            GarmanKlassVolatilityParams { lookback: Some(12) },
1296        );
1297        let batch = garman_klass_volatility(&input).expect("batch");
1298        let mut stream = GarmanKlassVolatilityStream::try_new(GarmanKlassVolatilityParams {
1299            lookback: Some(12),
1300        })
1301        .expect("stream");
1302        let mut streamed = Vec::with_capacity(close.len());
1303        for i in 0..close.len() {
1304            streamed.push(
1305                stream
1306                    .update(open[i], high[i], low[i], close[i])
1307                    .unwrap_or(f64::NAN),
1308            );
1309        }
1310        for i in 0..streamed.len() {
1311            if batch.values[i].is_nan() {
1312                assert!(streamed[i].is_nan(), "stream index {i}");
1313            } else {
1314                assert!(
1315                    (batch.values[i] - streamed[i]).abs() <= 1e-12,
1316                    "stream mismatch at {i}: {} vs {}",
1317                    batch.values[i],
1318                    streamed[i]
1319                );
1320            }
1321        }
1322    }
1323
1324    #[test]
1325    fn gk_batch_single_param_matches_single() {
1326        let (open, high, low, close) = sample_ohlc(200);
1327        let single_input = GarmanKlassVolatilityInput::from_slices(
1328            &open,
1329            &high,
1330            &low,
1331            &close,
1332            GarmanKlassVolatilityParams { lookback: Some(16) },
1333        );
1334        let single = garman_klass_volatility(&single_input).expect("single");
1335        let batch = garman_klass_volatility_batch_with_kernel(
1336            &open,
1337            &high,
1338            &low,
1339            &close,
1340            &GarmanKlassVolatilityBatchRange {
1341                lookback: (16, 16, 0),
1342            },
1343            Kernel::ScalarBatch,
1344        )
1345        .expect("batch");
1346        assert_eq!(batch.rows, 1);
1347        assert_eq!(batch.cols, close.len());
1348        for i in 0..batch.values.len() {
1349            if single.values[i].is_nan() {
1350                assert!(batch.values[i].is_nan(), "expected NaN at index {i}");
1351            } else {
1352                assert!(
1353                    (batch.values[i] - single.values[i]).abs() <= 1e-12,
1354                    "batch mismatch at {i}: {} vs {}",
1355                    batch.values[i],
1356                    single.values[i]
1357                );
1358            }
1359        }
1360    }
1361
1362    #[test]
1363    fn gk_internal_invalid_bar_produces_nan_window_and_recovers() {
1364        let (mut open, mut high, mut low, mut close) = sample_ohlc(80);
1365        open[30] = f64::NAN;
1366        high[30] = f64::NAN;
1367        low[30] = f64::NAN;
1368        close[30] = f64::NAN;
1369
1370        let input = GarmanKlassVolatilityInput::from_slices(
1371            &open,
1372            &high,
1373            &low,
1374            &close,
1375            GarmanKlassVolatilityParams { lookback: Some(10) },
1376        );
1377        let out = garman_klass_volatility(&input).expect("gk");
1378        assert!(out.values[30].is_nan());
1379        assert!(out.values[39].is_nan());
1380        assert!(out.values[40].is_finite());
1381    }
1382
1383    #[test]
1384    fn gk_rejects_invalid_lookback() {
1385        let (open, high, low, close) = sample_ohlc(8);
1386        let input = GarmanKlassVolatilityInput::from_slices(
1387            &open,
1388            &high,
1389            &low,
1390            &close,
1391            GarmanKlassVolatilityParams { lookback: Some(0) },
1392        );
1393        let err = garman_klass_volatility(&input).unwrap_err();
1394        match err {
1395            GarmanKlassVolatilityError::InvalidLookback { lookback, .. } => {
1396                assert_eq!(lookback, 0);
1397            }
1398            other => panic!("unexpected error: {other:?}"),
1399        }
1400    }
1401}