Skip to main content

vector_ta/indicators/
ewma_volatility.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::convert::AsRef;
25use std::mem::{ManuallyDrop, MaybeUninit};
26use thiserror::Error;
27
28impl<'a> AsRef<[f64]> for EwmaVolatilityInput<'a> {
29    #[inline(always)]
30    fn as_ref(&self) -> &[f64] {
31        match &self.data {
32            EwmaVolatilityData::Slice(slice) => slice,
33            EwmaVolatilityData::Candles { candles, source } => source_type(candles, source),
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub enum EwmaVolatilityData<'a> {
40    Candles {
41        candles: &'a Candles,
42        source: &'a str,
43    },
44    Slice(&'a [f64]),
45}
46
47#[derive(Debug, Clone)]
48pub struct EwmaVolatilityOutput {
49    pub values: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
53#[cfg_attr(
54    all(target_arch = "wasm32", feature = "wasm"),
55    derive(Serialize, Deserialize)
56)]
57pub struct EwmaVolatilityParams {
58    pub lambda: Option<f64>,
59}
60
61impl Default for EwmaVolatilityParams {
62    fn default() -> Self {
63        Self { lambda: Some(0.94) }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct EwmaVolatilityInput<'a> {
69    pub data: EwmaVolatilityData<'a>,
70    pub params: EwmaVolatilityParams,
71}
72
73impl<'a> EwmaVolatilityInput<'a> {
74    #[inline]
75    pub fn from_candles(
76        candles: &'a Candles,
77        source: &'a str,
78        params: EwmaVolatilityParams,
79    ) -> Self {
80        Self {
81            data: EwmaVolatilityData::Candles { candles, source },
82            params,
83        }
84    }
85
86    #[inline]
87    pub fn from_slice(data: &'a [f64], params: EwmaVolatilityParams) -> Self {
88        Self {
89            data: EwmaVolatilityData::Slice(data),
90            params,
91        }
92    }
93
94    #[inline]
95    pub fn with_default_candles(candles: &'a Candles) -> Self {
96        Self::from_candles(candles, "close", EwmaVolatilityParams::default())
97    }
98
99    #[inline]
100    pub fn get_lambda(&self) -> f64 {
101        self.params.lambda.unwrap_or(0.94)
102    }
103}
104
105#[derive(Copy, Clone, Debug)]
106pub struct EwmaVolatilityBuilder {
107    lambda: Option<f64>,
108    kernel: Kernel,
109}
110
111impl Default for EwmaVolatilityBuilder {
112    fn default() -> Self {
113        Self {
114            lambda: None,
115            kernel: Kernel::Auto,
116        }
117    }
118}
119
120impl EwmaVolatilityBuilder {
121    #[inline(always)]
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    #[inline(always)]
127    pub fn lambda(mut self, value: f64) -> Self {
128        self.lambda = Some(value);
129        self
130    }
131
132    #[inline(always)]
133    pub fn kernel(mut self, kernel: Kernel) -> Self {
134        self.kernel = kernel;
135        self
136    }
137
138    #[inline(always)]
139    pub fn apply(self, candles: &Candles) -> Result<EwmaVolatilityOutput, EwmaVolatilityError> {
140        let input = EwmaVolatilityInput::from_candles(
141            candles,
142            "close",
143            EwmaVolatilityParams {
144                lambda: self.lambda,
145            },
146        );
147        ewma_volatility_with_kernel(&input, self.kernel)
148    }
149
150    #[inline(always)]
151    pub fn apply_slice(self, data: &[f64]) -> Result<EwmaVolatilityOutput, EwmaVolatilityError> {
152        let input = EwmaVolatilityInput::from_slice(
153            data,
154            EwmaVolatilityParams {
155                lambda: self.lambda,
156            },
157        );
158        ewma_volatility_with_kernel(&input, self.kernel)
159    }
160
161    #[inline(always)]
162    pub fn into_stream(self) -> Result<EwmaVolatilityStream, EwmaVolatilityError> {
163        EwmaVolatilityStream::try_new(EwmaVolatilityParams {
164            lambda: self.lambda,
165        })
166    }
167}
168
169#[derive(Debug, Error)]
170pub enum EwmaVolatilityError {
171    #[error("ewma_volatility: Input data slice is empty.")]
172    EmptyInputData,
173    #[error("ewma_volatility: All values are NaN.")]
174    AllValuesNaN,
175    #[error("ewma_volatility: Invalid lambda: {lambda}. Expected finite value in [0, 1).")]
176    InvalidLambda { lambda: f64 },
177    #[error("ewma_volatility: Not enough valid data: needed = {needed}, valid = {valid}")]
178    NotEnoughValidData { needed: usize, valid: usize },
179    #[error("ewma_volatility: Output length mismatch: expected = {expected}, got = {got}")]
180    OutputLengthMismatch { expected: usize, got: usize },
181    #[error("ewma_volatility: Invalid range: start={start}, end={end}, step={step}")]
182    InvalidRange {
183        start: String,
184        end: String,
185        step: String,
186    },
187    #[error("ewma_volatility: Invalid kernel for batch: {0:?}")]
188    InvalidKernelForBatch(Kernel),
189}
190
191#[derive(Debug, Clone)]
192struct EwmaPrepared {
193    sq_returns: Vec<f64>,
194    valid_indices: Vec<usize>,
195    valid_values: Vec<f64>,
196}
197
198const EWMA_SCALE: f64 = 100.0;
199
200#[inline(always)]
201fn period_from_lambda(lambda: f64) -> Result<usize, EwmaVolatilityError> {
202    if !lambda.is_finite() || !(0.0..1.0).contains(&lambda) {
203        return Err(EwmaVolatilityError::InvalidLambda { lambda });
204    }
205    let raw = (2.0 / (1.0 - lambda) - 1.0).round();
206    let period = raw.max(1.0) as usize;
207    Ok(period)
208}
209
210#[inline(always)]
211fn alpha_from_period(period: usize) -> f64 {
212    2.0 / (period as f64 + 1.0)
213}
214
215#[inline(always)]
216fn prepare_returns(data: &[f64]) -> Result<EwmaPrepared, EwmaVolatilityError> {
217    if data.is_empty() {
218        return Err(EwmaVolatilityError::EmptyInputData);
219    }
220    if !data.iter().any(|v| !v.is_nan()) {
221        return Err(EwmaVolatilityError::AllValuesNaN);
222    }
223
224    let len = data.len();
225    let mut sq_returns = vec![f64::NAN; len];
226    let mut valid_indices = Vec::with_capacity(len.saturating_sub(1));
227    let mut valid_values = Vec::with_capacity(len.saturating_sub(1));
228
229    for i in 1..len {
230        let prev = data[i - 1];
231        let curr = data[i];
232        if prev.is_finite() && curr.is_finite() && prev > 0.0 && curr > 0.0 {
233            let ret = (curr / prev).ln();
234            let sq = ret * ret;
235            sq_returns[i] = sq;
236            valid_indices.push(i);
237            valid_values.push(sq);
238        }
239    }
240
241    Ok(EwmaPrepared {
242        sq_returns,
243        valid_indices,
244        valid_values,
245    })
246}
247
248#[inline(always)]
249fn fill_row_from_precomputed(
250    prep: &EwmaPrepared,
251    period: usize,
252    alpha: f64,
253    out: &mut [f64],
254) -> Result<usize, EwmaVolatilityError> {
255    if prep.valid_values.len() < period {
256        return Err(EwmaVolatilityError::NotEnoughValidData {
257            needed: period,
258            valid: prep.valid_values.len(),
259        });
260    }
261
262    let seed_idx = prep.valid_indices[period - 1];
263    let mut ema = prep.valid_values[..period].iter().copied().sum::<f64>() / period as f64;
264    let beta = 1.0 - alpha;
265
266    out[seed_idx] = ema.max(0.0).sqrt() * EWMA_SCALE;
267    for i in (seed_idx + 1)..out.len() {
268        let sq = prep.sq_returns[i];
269        if sq.is_finite() {
270            ema = beta.mul_add(ema, alpha * sq);
271        }
272        out[i] = ema.max(0.0).sqrt() * EWMA_SCALE;
273    }
274
275    Ok(seed_idx)
276}
277
278#[inline(always)]
279fn ewma_prepare<'a>(
280    input: &'a EwmaVolatilityInput,
281    kernel: Kernel,
282) -> Result<(&'a [f64], usize, f64, Kernel, EwmaPrepared), EwmaVolatilityError> {
283    let data = input.as_ref();
284    let lambda = input.get_lambda();
285    let period = period_from_lambda(lambda)?;
286    let chosen = match kernel {
287        Kernel::Auto => Kernel::Scalar,
288        other => other,
289    };
290    let prep = prepare_returns(data)?;
291    Ok((data, period, alpha_from_period(period), chosen, prep))
292}
293
294#[inline]
295pub fn ewma_volatility(
296    input: &EwmaVolatilityInput,
297) -> Result<EwmaVolatilityOutput, EwmaVolatilityError> {
298    ewma_volatility_with_kernel(input, Kernel::Auto)
299}
300
301pub fn ewma_volatility_with_kernel(
302    input: &EwmaVolatilityInput,
303    kernel: Kernel,
304) -> Result<EwmaVolatilityOutput, EwmaVolatilityError> {
305    let (data, period, alpha, _chosen, prep) = ewma_prepare(input, kernel)?;
306    let seed_idx = prep
307        .valid_indices
308        .get(period.saturating_sub(1))
309        .copied()
310        .ok_or(EwmaVolatilityError::NotEnoughValidData {
311            needed: period,
312            valid: prep.valid_values.len(),
313        })?;
314    let mut out = alloc_with_nan_prefix(data.len(), seed_idx);
315    fill_row_from_precomputed(&prep, period, alpha, &mut out)?;
316    Ok(EwmaVolatilityOutput { values: out })
317}
318
319#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
320pub fn ewma_volatility_into(
321    input: &EwmaVolatilityInput,
322    out: &mut [f64],
323) -> Result<(), EwmaVolatilityError> {
324    ewma_volatility_into_slice(out, input, Kernel::Auto)
325}
326
327#[inline]
328pub fn ewma_volatility_into_slice(
329    dst: &mut [f64],
330    input: &EwmaVolatilityInput,
331    kernel: Kernel,
332) -> Result<(), EwmaVolatilityError> {
333    let (data, period, alpha, _chosen, prep) = ewma_prepare(input, kernel)?;
334    if dst.len() != data.len() {
335        return Err(EwmaVolatilityError::OutputLengthMismatch {
336            expected: data.len(),
337            got: dst.len(),
338        });
339    }
340    dst.fill(f64::NAN);
341    fill_row_from_precomputed(&prep, period, alpha, dst)?;
342    Ok(())
343}
344
345#[derive(Debug, Clone)]
346pub struct EwmaVolatilityStream {
347    period: usize,
348    alpha: f64,
349    prev_close: f64,
350    seed_window: Vec<f64>,
351    seed_sum: f64,
352    ema: f64,
353    seeded: bool,
354}
355
356impl EwmaVolatilityStream {
357    #[inline(always)]
358    pub fn try_new(params: EwmaVolatilityParams) -> Result<Self, EwmaVolatilityError> {
359        let lambda = params.lambda.unwrap_or(0.94);
360        let period = period_from_lambda(lambda)?;
361        Ok(Self {
362            period,
363            alpha: alpha_from_period(period),
364            prev_close: f64::NAN,
365            seed_window: Vec::with_capacity(period),
366            seed_sum: 0.0,
367            ema: f64::NAN,
368            seeded: false,
369        })
370    }
371
372    #[inline(always)]
373    pub fn update(&mut self, close: f64) -> Option<f64> {
374        let ret_valid = self.prev_close.is_finite()
375            && self.prev_close > 0.0
376            && close.is_finite()
377            && close > 0.0;
378
379        if ret_valid {
380            let ret = (close / self.prev_close).ln();
381            let sq = ret * ret;
382            if !self.seeded {
383                self.seed_window.push(sq);
384                self.seed_sum += sq;
385                if self.seed_window.len() == self.period {
386                    self.ema = self.seed_sum / self.period as f64;
387                    self.seeded = true;
388                }
389            } else {
390                self.ema = (1.0 - self.alpha).mul_add(self.ema, self.alpha * sq);
391            }
392        }
393
394        self.prev_close = close;
395
396        if self.seeded {
397            Some(self.ema.max(0.0).sqrt() * EWMA_SCALE)
398        } else {
399            None
400        }
401    }
402
403    #[inline(always)]
404    pub fn get_warmup_period(&self) -> usize {
405        self.period
406    }
407}
408
409#[derive(Clone, Debug)]
410pub struct EwmaVolatilityBatchRange {
411    pub lambda: (f64, f64, f64),
412}
413
414impl Default for EwmaVolatilityBatchRange {
415    fn default() -> Self {
416        Self {
417            lambda: (0.94, 0.94, 0.0),
418        }
419    }
420}
421
422#[derive(Clone, Debug, Default)]
423pub struct EwmaVolatilityBatchBuilder {
424    range: EwmaVolatilityBatchRange,
425    kernel: Kernel,
426}
427
428impl EwmaVolatilityBatchBuilder {
429    pub fn new() -> Self {
430        Self::default()
431    }
432
433    pub fn kernel(mut self, kernel: Kernel) -> Self {
434        self.kernel = kernel;
435        self
436    }
437
438    pub fn lambda_range(mut self, start: f64, end: f64, step: f64) -> Self {
439        self.range.lambda = (start, end, step);
440        self
441    }
442
443    pub fn lambda_static(mut self, value: f64) -> Self {
444        self.range.lambda = (value, value, 0.0);
445        self
446    }
447
448    pub fn apply_slice(
449        self,
450        data: &[f64],
451    ) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
452        ewma_volatility_batch_with_kernel(data, &self.range, self.kernel)
453    }
454
455    pub fn apply_candles(
456        self,
457        candles: &Candles,
458    ) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
459        self.apply_slice(&candles.close)
460    }
461}
462
463#[derive(Clone, Debug)]
464pub struct EwmaVolatilityBatchOutput {
465    pub values: Vec<f64>,
466    pub combos: Vec<EwmaVolatilityParams>,
467    pub rows: usize,
468    pub cols: usize,
469}
470
471impl EwmaVolatilityBatchOutput {
472    pub fn row_for_params(&self, params: &EwmaVolatilityParams) -> Option<usize> {
473        let lambda = params.lambda.unwrap_or(0.94);
474        self.combos
475            .iter()
476            .position(|combo| (combo.lambda.unwrap_or(0.94) - lambda).abs() <= 1e-12)
477    }
478
479    pub fn values_for(&self, params: &EwmaVolatilityParams) -> Option<&[f64]> {
480        self.row_for_params(params).and_then(|row| {
481            let start = row * self.cols;
482            self.values.get(start..start + self.cols)
483        })
484    }
485}
486
487#[inline(always)]
488fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, EwmaVolatilityError> {
489    if !start.is_finite() || !end.is_finite() || !step.is_finite() {
490        return Err(EwmaVolatilityError::InvalidRange {
491            start: start.to_string(),
492            end: end.to_string(),
493            step: step.to_string(),
494        });
495    }
496    if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
497        return Ok(vec![start]);
498    }
499    let step_abs = step.abs();
500    let mut out = Vec::new();
501    if start < end {
502        let mut x = start;
503        while x <= end + 1e-12 {
504            out.push(x);
505            x += step_abs;
506        }
507    } else {
508        let mut x = start;
509        while x >= end - 1e-12 {
510            out.push(x);
511            x -= step_abs;
512        }
513    }
514    if out.is_empty() {
515        return Err(EwmaVolatilityError::InvalidRange {
516            start: start.to_string(),
517            end: end.to_string(),
518            step: step.to_string(),
519        });
520    }
521    Ok(out)
522}
523
524#[inline(always)]
525pub fn expand_grid(
526    range: &EwmaVolatilityBatchRange,
527) -> Result<Vec<EwmaVolatilityParams>, EwmaVolatilityError> {
528    Ok(axis_f64(range.lambda)?
529        .into_iter()
530        .map(|lambda| EwmaVolatilityParams {
531            lambda: Some(lambda),
532        })
533        .collect())
534}
535
536pub fn ewma_volatility_batch_with_kernel(
537    data: &[f64],
538    sweep: &EwmaVolatilityBatchRange,
539    kernel: Kernel,
540) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
541    let batch_kernel = match kernel {
542        Kernel::Auto => detect_best_batch_kernel(),
543        other if other.is_batch() => other,
544        _ => return Err(EwmaVolatilityError::InvalidKernelForBatch(kernel)),
545    };
546    ewma_volatility_batch_par_slice(data, sweep, batch_kernel.to_non_batch())
547}
548
549#[inline(always)]
550pub fn ewma_volatility_batch_slice(
551    data: &[f64],
552    sweep: &EwmaVolatilityBatchRange,
553    kernel: Kernel,
554) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
555    ewma_volatility_batch_inner(data, sweep, kernel, false)
556}
557
558#[inline(always)]
559pub fn ewma_volatility_batch_par_slice(
560    data: &[f64],
561    sweep: &EwmaVolatilityBatchRange,
562    kernel: Kernel,
563) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
564    ewma_volatility_batch_inner(data, sweep, kernel, true)
565}
566
567#[inline(always)]
568fn ewma_volatility_batch_inner(
569    data: &[f64],
570    sweep: &EwmaVolatilityBatchRange,
571    kernel: Kernel,
572    parallel: bool,
573) -> Result<EwmaVolatilityBatchOutput, EwmaVolatilityError> {
574    let combos = expand_grid(sweep)?;
575    if data.is_empty() {
576        return Err(EwmaVolatilityError::EmptyInputData);
577    }
578    let prep = prepare_returns(data)?;
579    let rows = combos.len();
580    let cols = data.len();
581
582    let warmups: Vec<usize> = combos
583        .iter()
584        .map(|combo| {
585            let period = period_from_lambda(combo.lambda.unwrap_or(0.94))?;
586            prep.valid_indices
587                .get(period.saturating_sub(1))
588                .copied()
589                .ok_or(EwmaVolatilityError::NotEnoughValidData {
590                    needed: period,
591                    valid: prep.valid_values.len(),
592                })
593        })
594        .collect::<Result<_, _>>()?;
595
596    let mut buf_mu = make_uninit_matrix(rows, cols);
597    init_matrix_prefixes(&mut buf_mu, cols, &warmups);
598    let mut buf_guard = ManuallyDrop::new(buf_mu);
599    let out: &mut [f64] = unsafe {
600        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
601    };
602
603    ewma_volatility_batch_inner_into(data, sweep, kernel, parallel, out)?;
604
605    let values = unsafe {
606        Vec::from_raw_parts(
607            buf_guard.as_mut_ptr() as *mut f64,
608            buf_guard.len(),
609            buf_guard.capacity(),
610        )
611    };
612
613    Ok(EwmaVolatilityBatchOutput {
614        values,
615        combos,
616        rows,
617        cols,
618    })
619}
620
621pub fn ewma_volatility_batch_into_slice(
622    dst: &mut [f64],
623    data: &[f64],
624    sweep: &EwmaVolatilityBatchRange,
625    kernel: Kernel,
626) -> Result<(), EwmaVolatilityError> {
627    ewma_volatility_batch_inner_into(data, sweep, kernel, false, dst)?;
628    Ok(())
629}
630
631fn ewma_volatility_batch_inner_into(
632    data: &[f64],
633    sweep: &EwmaVolatilityBatchRange,
634    kernel: Kernel,
635    parallel: bool,
636    out: &mut [f64],
637) -> Result<Vec<EwmaVolatilityParams>, EwmaVolatilityError> {
638    let combos = expand_grid(sweep)?;
639    if data.is_empty() {
640        return Err(EwmaVolatilityError::EmptyInputData);
641    }
642    let prep = prepare_returns(data)?;
643    let rows = combos.len();
644    let cols = data.len();
645    let expected = rows
646        .checked_mul(cols)
647        .ok_or_else(|| EwmaVolatilityError::InvalidRange {
648            start: rows.to_string(),
649            end: cols.to_string(),
650            step: "rows*cols".to_string(),
651        })?;
652    if out.len() != expected {
653        return Err(EwmaVolatilityError::OutputLengthMismatch {
654            expected,
655            got: out.len(),
656        });
657    }
658
659    let chosen = match kernel {
660        Kernel::Auto => Kernel::Scalar,
661        other => other,
662    };
663    let _ = chosen;
664
665    let periods: Vec<usize> = combos
666        .iter()
667        .map(|combo| period_from_lambda(combo.lambda.unwrap_or(0.94)))
668        .collect::<Result<_, _>>()?;
669
670    let warmups: Vec<usize> = periods
671        .iter()
672        .map(|&period| {
673            prep.valid_indices
674                .get(period.saturating_sub(1))
675                .copied()
676                .ok_or(EwmaVolatilityError::NotEnoughValidData {
677                    needed: period,
678                    valid: prep.valid_values.len(),
679                })
680        })
681        .collect::<Result<_, _>>()?;
682
683    for (row, &warm) in warmups.iter().enumerate() {
684        let row_start = row * cols;
685        out[row_start..row_start + warm.min(cols)].fill(f64::NAN);
686    }
687
688    let do_row = |row: usize, dst: &mut [f64]| -> Result<(), EwmaVolatilityError> {
689        let period = periods[row];
690        let alpha = alpha_from_period(period);
691        fill_row_from_precomputed(&prep, period, alpha, dst)?;
692        Ok(())
693    };
694
695    if parallel {
696        #[cfg(not(target_arch = "wasm32"))]
697        {
698            out.par_chunks_mut(cols)
699                .enumerate()
700                .try_for_each(|(row, dst)| do_row(row, dst))?;
701        }
702        #[cfg(target_arch = "wasm32")]
703        {
704            for (row, dst) in out.chunks_mut(cols).enumerate() {
705                do_row(row, dst)?;
706            }
707        }
708    } else {
709        for (row, dst) in out.chunks_mut(cols).enumerate() {
710            do_row(row, dst)?;
711        }
712    }
713
714    Ok(combos)
715}
716
717#[cfg(feature = "python")]
718#[pyfunction(name = "ewma_volatility")]
719#[pyo3(signature = (data, lambda_=0.94, kernel=None))]
720pub fn ewma_volatility_py<'py>(
721    py: Python<'py>,
722    data: PyReadonlyArray1<'py, f64>,
723    lambda_: f64,
724    kernel: Option<&str>,
725) -> PyResult<Bound<'py, PyArray1<f64>>> {
726    let slice = data.as_slice()?;
727    let kernel = validate_kernel(kernel, false)?;
728    let input = EwmaVolatilityInput::from_slice(
729        slice,
730        EwmaVolatilityParams {
731            lambda: Some(lambda_),
732        },
733    );
734    let out = py
735        .allow_threads(|| ewma_volatility_with_kernel(&input, kernel))
736        .map_err(|e| PyValueError::new_err(e.to_string()))?;
737    Ok(out.values.into_pyarray(py))
738}
739
740#[cfg(feature = "python")]
741#[pyclass(name = "EwmaVolatilityStream")]
742pub struct EwmaVolatilityStreamPy {
743    stream: EwmaVolatilityStream,
744}
745
746#[cfg(feature = "python")]
747#[pymethods]
748impl EwmaVolatilityStreamPy {
749    #[new]
750    fn new(lambda_: Option<f64>) -> PyResult<Self> {
751        let stream = EwmaVolatilityStream::try_new(EwmaVolatilityParams { lambda: lambda_ })
752            .map_err(|e| PyValueError::new_err(e.to_string()))?;
753        Ok(Self { stream })
754    }
755
756    fn update(&mut self, value: f64) -> Option<f64> {
757        self.stream.update(value)
758    }
759}
760
761#[cfg(feature = "python")]
762#[pyfunction(name = "ewma_volatility_batch")]
763#[pyo3(signature = (data, lambda_range=(0.94, 0.94, 0.0), kernel=None))]
764pub fn ewma_volatility_batch_py<'py>(
765    py: Python<'py>,
766    data: PyReadonlyArray1<'py, f64>,
767    lambda_range: (f64, f64, f64),
768    kernel: Option<&str>,
769) -> PyResult<Bound<'py, PyDict>> {
770    let slice = data.as_slice()?;
771    let kernel = validate_kernel(kernel, true)?;
772    let sweep = EwmaVolatilityBatchRange {
773        lambda: lambda_range,
774    };
775    let out = py
776        .allow_threads(|| ewma_volatility_batch_with_kernel(slice, &sweep, kernel))
777        .map_err(|e| PyValueError::new_err(e.to_string()))?;
778
779    let dict = PyDict::new(py);
780    dict.set_item(
781        "values",
782        out.values
783            .clone()
784            .into_pyarray(py)
785            .reshape((out.rows, out.cols))?,
786    )?;
787    dict.set_item(
788        "lambdas",
789        out.combos
790            .iter()
791            .map(|combo| combo.lambda.unwrap_or(0.94))
792            .collect::<Vec<_>>()
793            .into_pyarray(py),
794    )?;
795    dict.set_item("rows", out.rows)?;
796    dict.set_item("cols", out.cols)?;
797    Ok(dict)
798}
799
800#[cfg(feature = "python")]
801pub fn register_ewma_volatility_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
802    m.add_function(wrap_pyfunction!(ewma_volatility_py, m)?)?;
803    m.add_function(wrap_pyfunction!(ewma_volatility_batch_py, m)?)?;
804    m.add_class::<EwmaVolatilityStreamPy>()?;
805    Ok(())
806}
807
808#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
809#[wasm_bindgen(js_name = "ewma_volatility_js")]
810pub fn ewma_volatility_js(data: &[f64], lambda_: f64) -> Result<Vec<f64>, JsValue> {
811    let input = EwmaVolatilityInput::from_slice(
812        data,
813        EwmaVolatilityParams {
814            lambda: Some(lambda_),
815        },
816    );
817    let mut out = vec![0.0; data.len()];
818    ewma_volatility_into_slice(&mut out, &input, Kernel::Auto)
819        .map_err(|e| JsValue::from_str(&e.to_string()))?;
820    Ok(out)
821}
822
823#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
824#[derive(Serialize, Deserialize)]
825pub struct EwmaVolatilityBatchConfig {
826    pub lambda_range: Vec<f64>,
827}
828
829#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
830#[derive(Serialize, Deserialize)]
831pub struct EwmaVolatilityBatchJsOutput {
832    pub values: Vec<f64>,
833    pub combos: Vec<EwmaVolatilityParams>,
834    pub rows: usize,
835    pub cols: usize,
836}
837
838#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
839#[wasm_bindgen(js_name = "ewma_volatility_batch_js")]
840pub fn ewma_volatility_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
841    let config: EwmaVolatilityBatchConfig = serde_wasm_bindgen::from_value(config)
842        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
843    if config.lambda_range.len() != 3 {
844        return Err(JsValue::from_str(
845            "Invalid config: lambda_range must have exactly 3 elements [start, end, step]",
846        ));
847    }
848    let out = ewma_volatility_batch_with_kernel(
849        data,
850        &EwmaVolatilityBatchRange {
851            lambda: (
852                config.lambda_range[0],
853                config.lambda_range[1],
854                config.lambda_range[2],
855            ),
856        },
857        Kernel::Auto,
858    )
859    .map_err(|e| JsValue::from_str(&e.to_string()))?;
860    serde_wasm_bindgen::to_value(&EwmaVolatilityBatchJsOutput {
861        values: out.values,
862        combos: out.combos,
863        rows: out.rows,
864        cols: out.cols,
865    })
866    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
867}
868
869#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
870#[wasm_bindgen]
871pub fn ewma_volatility_alloc(len: usize) -> *mut f64 {
872    let mut vec = Vec::<f64>::with_capacity(len);
873    let ptr = vec.as_mut_ptr();
874    std::mem::forget(vec);
875    ptr
876}
877
878#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
879#[wasm_bindgen]
880pub fn ewma_volatility_free(ptr: *mut f64, len: usize) {
881    if !ptr.is_null() {
882        unsafe {
883            let _ = Vec::from_raw_parts(ptr, len, len);
884        }
885    }
886}
887
888#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
889#[wasm_bindgen]
890pub fn ewma_volatility_into(
891    in_ptr: *const f64,
892    out_ptr: *mut f64,
893    len: usize,
894    lambda_: f64,
895) -> Result<(), JsValue> {
896    if in_ptr.is_null() || out_ptr.is_null() {
897        return Err(JsValue::from_str("Null pointer provided"));
898    }
899    unsafe {
900        let data = std::slice::from_raw_parts(in_ptr, len);
901        let out = std::slice::from_raw_parts_mut(out_ptr, len);
902        let input = EwmaVolatilityInput::from_slice(
903            data,
904            EwmaVolatilityParams {
905                lambda: Some(lambda_),
906            },
907        );
908        ewma_volatility_into_slice(out, &input, Kernel::Auto)
909            .map_err(|e| JsValue::from_str(&e.to_string()))
910    }
911}
912
913#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
914#[wasm_bindgen]
915pub fn ewma_volatility_batch_into(
916    in_ptr: *const f64,
917    out_ptr: *mut f64,
918    len: usize,
919    lambda_start: f64,
920    lambda_end: f64,
921    lambda_step: f64,
922) -> Result<usize, JsValue> {
923    if in_ptr.is_null() || out_ptr.is_null() {
924        return Err(JsValue::from_str(
925            "null pointer passed to ewma_volatility_batch_into",
926        ));
927    }
928    unsafe {
929        let data = std::slice::from_raw_parts(in_ptr, len);
930        let sweep = EwmaVolatilityBatchRange {
931            lambda: (lambda_start, lambda_end, lambda_step),
932        };
933        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
934        let rows = combos.len();
935        let out = std::slice::from_raw_parts_mut(out_ptr, rows * len);
936        ewma_volatility_batch_into_slice(out, data, &sweep, Kernel::Auto)
937            .map_err(|e| JsValue::from_str(&e.to_string()))?;
938        Ok(rows)
939    }
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945
946    fn geometric_series(len: usize, start: f64, ratio: f64) -> Vec<f64> {
947        let mut out = Vec::with_capacity(len);
948        let mut v = start;
949        for _ in 0..len {
950            out.push(v);
951            v *= ratio;
952        }
953        out
954    }
955
956    fn assert_close_series(lhs: &[f64], rhs: &[f64], tol: f64) {
957        assert_eq!(lhs.len(), rhs.len());
958        for i in 0..lhs.len() {
959            let a = lhs[i];
960            let b = rhs[i];
961            assert!(
962                (a.is_nan() && b.is_nan()) || (a - b).abs() <= tol,
963                "mismatch at {i}: {a} vs {b}"
964            );
965        }
966    }
967
968    #[test]
969    fn ewma_volatility_constant_return_series_converges_exactly() {
970        let data = geometric_series(128, 100.0, 1.01);
971        let input =
972            EwmaVolatilityInput::from_slice(&data, EwmaVolatilityParams { lambda: Some(0.94) });
973        let out = ewma_volatility(&input).unwrap();
974        let expected = 1.01f64.ln().abs() * 100.0;
975        let period = period_from_lambda(0.94).unwrap();
976        for i in 0..period {
977            assert!(out.values[i].is_nan(), "expected warmup NaN at {i}");
978        }
979        for v in &out.values[period..] {
980            assert!((*v - expected).abs() <= 1e-12, "unexpected value {v}");
981        }
982    }
983
984    #[test]
985    fn ewma_volatility_stream_matches_batch() {
986        let data = geometric_series(96, 50.0, 1.005);
987        let batch = EwmaVolatilityBuilder::new()
988            .lambda(0.90)
989            .apply_slice(&data)
990            .unwrap();
991        let mut stream = EwmaVolatilityBuilder::new()
992            .lambda(0.90)
993            .into_stream()
994            .unwrap();
995        let stream_values: Vec<f64> = data
996            .iter()
997            .map(|&v| stream.update(v).unwrap_or(f64::NAN))
998            .collect();
999        assert_close_series(&batch.values, &stream_values, 1e-12);
1000    }
1001
1002    #[test]
1003    fn ewma_volatility_batch_rows_match_single() {
1004        let data = geometric_series(128, 100.0, 1.002);
1005        let sweep = EwmaVolatilityBatchRange {
1006            lambda: (0.90, 0.94, 0.02),
1007        };
1008        let batch = ewma_volatility_batch_with_kernel(&data, &sweep, Kernel::Auto).unwrap();
1009        assert_eq!(batch.rows, 3);
1010        assert_eq!(batch.cols, data.len());
1011
1012        for (row, &lambda) in [0.90, 0.92, 0.94].iter().enumerate() {
1013            let single = EwmaVolatilityBuilder::new()
1014                .lambda(lambda)
1015                .apply_slice(&data)
1016                .unwrap();
1017            let start = row * data.len();
1018            assert_close_series(
1019                &batch.values[start..start + data.len()],
1020                &single.values,
1021                1e-12,
1022            );
1023        }
1024    }
1025
1026    #[test]
1027    fn ewma_volatility_into_slice_matches_single() {
1028        let data = geometric_series(80, 25.0, 1.003);
1029        let input =
1030            EwmaVolatilityInput::from_slice(&data, EwmaVolatilityParams { lambda: Some(0.94) });
1031        let direct = ewma_volatility(&input).unwrap();
1032        let mut out = vec![0.0; data.len()];
1033        ewma_volatility_into_slice(&mut out, &input, Kernel::Auto).unwrap();
1034        assert_close_series(&direct.values, &out, 1e-12);
1035    }
1036
1037    #[test]
1038    fn ewma_volatility_invalid_lambda_errors() {
1039        let data = geometric_series(40, 10.0, 1.01);
1040        let err = EwmaVolatilityBuilder::new()
1041            .lambda(1.0)
1042            .apply_slice(&data)
1043            .unwrap_err();
1044        assert!(matches!(err, EwmaVolatilityError::InvalidLambda { .. }));
1045    }
1046}