Skip to main content

vector_ta/indicators/moving_averages/
sgf.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(not(target_arch = "wasm32"))]
9use rayon::prelude::*;
10use std::convert::AsRef;
11use std::mem::MaybeUninit;
12use thiserror::Error;
13
14#[cfg(all(feature = "python", feature = "cuda"))]
15use crate::cuda::cuda_available;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::moving_averages::CudaSgf;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(feature = "python")]
23use numpy::{IntoPyArray, PyArray1};
24#[cfg(feature = "python")]
25use pyo3::exceptions::PyValueError;
26#[cfg(feature = "python")]
27use pyo3::prelude::*;
28#[cfg(feature = "python")]
29use pyo3::types::PyDict;
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use serde::{Deserialize, Serialize};
32#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
33use wasm_bindgen::prelude::*;
34
35impl<'a> AsRef<[f64]> for SgfInput<'a> {
36    #[inline(always)]
37    fn as_ref(&self) -> &[f64] {
38        match &self.data {
39            SgfData::Slice(slice) => slice,
40            SgfData::Candles { candles, source } => source_type(candles, source),
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
46pub enum SgfData<'a> {
47    Candles {
48        candles: &'a Candles,
49        source: &'a str,
50    },
51    Slice(&'a [f64]),
52}
53
54#[derive(Debug, Clone)]
55pub struct SgfOutput {
56    pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    all(target_arch = "wasm32", feature = "wasm"),
62    derive(Serialize, Deserialize)
63)]
64pub struct SgfParams {
65    pub period: Option<usize>,
66    pub poly_order: Option<usize>,
67}
68
69impl Default for SgfParams {
70    fn default() -> Self {
71        Self {
72            period: Some(21),
73            poly_order: Some(2),
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct SgfInput<'a> {
80    pub data: SgfData<'a>,
81    pub params: SgfParams,
82}
83
84impl<'a> SgfInput<'a> {
85    #[inline]
86    pub fn from_candles(candles: &'a Candles, source: &'a str, params: SgfParams) -> Self {
87        Self {
88            data: SgfData::Candles { candles, source },
89            params,
90        }
91    }
92
93    #[inline]
94    pub fn from_slice(slice: &'a [f64], params: SgfParams) -> Self {
95        Self {
96            data: SgfData::Slice(slice),
97            params,
98        }
99    }
100
101    #[inline]
102    pub fn with_default_candles(candles: &'a Candles) -> Self {
103        Self::from_candles(candles, "close", SgfParams::default())
104    }
105
106    #[inline]
107    pub fn get_period(&self) -> usize {
108        self.params.period.unwrap_or(21)
109    }
110
111    #[inline]
112    pub fn get_poly_order(&self) -> usize {
113        self.params.poly_order.unwrap_or(2)
114    }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct SgfBuilder {
119    period: Option<usize>,
120    poly_order: Option<usize>,
121    kernel: Kernel,
122}
123
124impl Default for SgfBuilder {
125    fn default() -> Self {
126        Self {
127            period: None,
128            poly_order: None,
129            kernel: Kernel::Auto,
130        }
131    }
132}
133
134impl SgfBuilder {
135    #[inline(always)]
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    #[inline(always)]
141    pub fn period(mut self, n: usize) -> Self {
142        self.period = Some(n);
143        self
144    }
145
146    #[inline(always)]
147    pub fn poly_order(mut self, n: usize) -> Self {
148        self.poly_order = Some(n);
149        self
150    }
151
152    #[inline(always)]
153    pub fn kernel(mut self, kernel: Kernel) -> Self {
154        self.kernel = kernel;
155        self
156    }
157
158    #[inline(always)]
159    pub fn apply(self, candles: &Candles) -> Result<SgfOutput, SgfError> {
160        let input = SgfInput::from_candles(
161            candles,
162            "close",
163            SgfParams {
164                period: self.period,
165                poly_order: self.poly_order,
166            },
167        );
168        sgf_with_kernel(&input, self.kernel)
169    }
170
171    #[inline(always)]
172    pub fn apply_slice(self, data: &[f64]) -> Result<SgfOutput, SgfError> {
173        let input = SgfInput::from_slice(
174            data,
175            SgfParams {
176                period: self.period,
177                poly_order: self.poly_order,
178            },
179        );
180        sgf_with_kernel(&input, self.kernel)
181    }
182
183    #[inline(always)]
184    pub fn into_stream(self) -> Result<SgfStream, SgfError> {
185        SgfStream::try_new(SgfParams {
186            period: self.period,
187            poly_order: self.poly_order,
188        })
189    }
190}
191
192#[derive(Debug, Error)]
193pub enum SgfError {
194    #[error("sgf: input data slice is empty.")]
195    EmptyInputData,
196    #[error("sgf: all values are NaN.")]
197    AllValuesNaN,
198    #[error(
199        "sgf: invalid period: period = {period}, effective_period = {effective_period}, data length = {data_len}"
200    )]
201    InvalidPeriod {
202        period: usize,
203        effective_period: usize,
204        data_len: usize,
205    },
206    #[error(
207        "sgf: invalid polynomial order: poly_order = {poly_order}, effective_period = {effective_period}"
208    )]
209    InvalidPolyOrder {
210        poly_order: usize,
211        effective_period: usize,
212    },
213    #[error("sgf: not enough valid data: needed = {needed}, valid = {valid}")]
214    NotEnoughValidData { needed: usize, valid: usize },
215    #[error("sgf: output length mismatch: expected {expected}, got {got}")]
216    OutputLengthMismatch { expected: usize, got: usize },
217    #[error("sgf: invalid range expansion: start={start}, end={end}, step={step}")]
218    InvalidRange {
219        start: usize,
220        end: usize,
221        step: usize,
222    },
223    #[error("sgf: invalid poly-order range expansion: start={start}, end={end}, step={step}")]
224    InvalidPolyOrderRange {
225        start: usize,
226        end: usize,
227        step: usize,
228    },
229    #[error("sgf: invalid kernel passed to batch path: {0:?}")]
230    InvalidKernelForBatch(Kernel),
231}
232
233#[inline(always)]
234pub(crate) fn effective_period(period: usize) -> usize {
235    if period <= 1 {
236        period
237    } else if (period & 1) == 0 {
238        period - 1
239    } else {
240        period
241    }
242}
243
244#[inline]
245pub(crate) fn validate_period_and_order(
246    period: usize,
247    poly_order: usize,
248    len: usize,
249) -> Result<usize, SgfError> {
250    let effective = effective_period(period);
251    if period < 3 || effective < 3 || effective > len {
252        return Err(SgfError::InvalidPeriod {
253            period,
254            effective_period: effective,
255            data_len: len,
256        });
257    }
258    if poly_order >= effective {
259        return Err(SgfError::InvalidPolyOrder {
260            poly_order,
261            effective_period: effective,
262        });
263    }
264    Ok(effective)
265}
266
267fn solve_linear_system(mut a: Vec<f64>, mut b: Vec<f64>, n: usize) -> Result<Vec<f64>, SgfError> {
268    for pivot in 0..n {
269        let mut best_row = pivot;
270        let mut best_abs = a[pivot * n + pivot].abs();
271        for row in (pivot + 1)..n {
272            let cand = a[row * n + pivot].abs();
273            if cand > best_abs {
274                best_abs = cand;
275                best_row = row;
276            }
277        }
278
279        if best_abs <= 1e-15 {
280            return Err(SgfError::InvalidPolyOrder {
281                poly_order: n - 1,
282                effective_period: 0,
283            });
284        }
285
286        if best_row != pivot {
287            for col in pivot..n {
288                a.swap(pivot * n + col, best_row * n + col);
289            }
290            b.swap(pivot, best_row);
291        }
292
293        let pivot_val = a[pivot * n + pivot];
294        for col in pivot..n {
295            a[pivot * n + col] /= pivot_val;
296        }
297        b[pivot] /= pivot_val;
298
299        for row in 0..n {
300            if row == pivot {
301                continue;
302            }
303            let factor = a[row * n + pivot];
304            if factor == 0.0 {
305                continue;
306            }
307            for col in pivot..n {
308                a[row * n + col] -= factor * a[pivot * n + col];
309            }
310            b[row] -= factor * b[pivot];
311        }
312    }
313
314    Ok(b)
315}
316
317pub(crate) fn build_endpoint_sgf_weights(
318    period: usize,
319    poly_order: usize,
320) -> Result<AVec<f64>, SgfError> {
321    let effective = validate_period_and_order(period, poly_order, period)?;
322    let order = poly_order + 1;
323    let mut gram = vec![0.0f64; order * order];
324
325    for i in 0..effective {
326        let x = (i as f64) - ((effective - 1) as f64);
327        let mut powers = vec![1.0f64; order];
328        for k in 1..order {
329            powers[k] = powers[k - 1] * x;
330        }
331        for row in 0..order {
332            for col in 0..order {
333                gram[row * order + col] += powers[row] * powers[col];
334            }
335        }
336    }
337
338    let mut rhs = vec![0.0f64; order];
339    rhs[0] = 1.0;
340    let coeffs = solve_linear_system(gram, rhs, order)?;
341
342    let mut weights = AVec::<f64>::with_capacity(CACHELINE_ALIGN, effective);
343    let mut sum = 0.0f64;
344    for i in 0..effective {
345        let x = (i as f64) - ((effective - 1) as f64);
346        let mut power = 1.0f64;
347        let mut weight = 0.0f64;
348        for &coef in &coeffs {
349            weight += coef * power;
350            power *= x;
351        }
352        weights.push(weight);
353        sum += weight;
354    }
355
356    if sum != 0.0 {
357        for weight in weights.iter_mut() {
358            *weight /= sum;
359        }
360    }
361
362    Ok(weights)
363}
364
365#[derive(Clone)]
366struct SgfPrepared<'a> {
367    data: &'a [f64],
368    weights: AVec<f64>,
369    period: usize,
370    poly_order: usize,
371    first: usize,
372    kernel: Kernel,
373}
374
375#[inline]
376pub fn sgf(input: &SgfInput) -> Result<SgfOutput, SgfError> {
377    sgf_with_kernel(input, Kernel::Auto)
378}
379
380#[inline]
381fn sgf_prepare<'a>(input: &'a SgfInput, kernel: Kernel) -> Result<SgfPrepared<'a>, SgfError> {
382    let data = input.as_ref();
383    let len = data.len();
384    if len == 0 {
385        return Err(SgfError::EmptyInputData);
386    }
387
388    let first = data
389        .iter()
390        .position(|x| !x.is_nan())
391        .ok_or(SgfError::AllValuesNaN)?;
392    let requested_period = input.get_period();
393    let poly_order = input.get_poly_order();
394    let period = validate_period_and_order(requested_period, poly_order, len)?;
395
396    if len - first < period {
397        return Err(SgfError::NotEnoughValidData {
398            needed: period,
399            valid: len - first,
400        });
401    }
402
403    let weights = build_endpoint_sgf_weights(requested_period, poly_order)?;
404    let kernel = match kernel {
405        Kernel::Auto => detect_best_kernel(),
406        other => other,
407    };
408
409    Ok(SgfPrepared {
410        data,
411        weights,
412        period,
413        poly_order,
414        first,
415        kernel,
416    })
417}
418
419#[inline(always)]
420fn sgf_dot(window: &[f64], weights: &[f64]) -> f64 {
421    let mut acc0 = 0.0f64;
422    let mut acc1 = 0.0f64;
423    let mut acc2 = 0.0f64;
424    let mut acc3 = 0.0f64;
425    let mut idx = 0usize;
426    let len = weights.len();
427
428    while idx + 3 < len {
429        acc0 += window[idx] * weights[idx];
430        acc1 += window[idx + 1] * weights[idx + 1];
431        acc2 += window[idx + 2] * weights[idx + 2];
432        acc3 += window[idx + 3] * weights[idx + 3];
433        idx += 4;
434    }
435    while idx < len {
436        acc0 += window[idx] * weights[idx];
437        idx += 1;
438    }
439
440    (acc0 + acc1) + (acc2 + acc3)
441}
442
443#[inline(always)]
444fn sgf_compute_into(
445    data: &[f64],
446    weights: &[f64],
447    period: usize,
448    first: usize,
449    _kernel: Kernel,
450    out: &mut [f64],
451) {
452    let start = first + period - 1;
453    for idx in start..data.len() {
454        let from = idx + 1 - period;
455        out[idx] = sgf_dot(&data[from..(idx + 1)], weights);
456    }
457}
458
459pub fn sgf_with_kernel(input: &SgfInput, kernel: Kernel) -> Result<SgfOutput, SgfError> {
460    let prepared = sgf_prepare(input, kernel)?;
461    let warm = prepared.first + prepared.period - 1;
462    let mut out = alloc_with_nan_prefix(prepared.data.len(), warm);
463    sgf_compute_into(
464        prepared.data,
465        &prepared.weights,
466        prepared.period,
467        prepared.first,
468        prepared.kernel,
469        &mut out,
470    );
471    Ok(SgfOutput { values: out })
472}
473
474#[inline]
475pub fn sgf_into_slice(dst: &mut [f64], input: &SgfInput, kernel: Kernel) -> Result<(), SgfError> {
476    let prepared = sgf_prepare(input, kernel)?;
477    if dst.len() != prepared.data.len() {
478        return Err(SgfError::OutputLengthMismatch {
479            expected: prepared.data.len(),
480            got: dst.len(),
481        });
482    }
483
484    let warm = prepared.first + prepared.period - 1;
485    for value in &mut dst[..warm] {
486        *value = f64::from_bits(0x7ff8_0000_0000_0000);
487    }
488    sgf_compute_into(
489        prepared.data,
490        &prepared.weights,
491        prepared.period,
492        prepared.first,
493        prepared.kernel,
494        dst,
495    );
496    Ok(())
497}
498
499#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
500#[inline]
501pub fn sgf_into(input: &SgfInput, out: &mut [f64]) -> Result<(), SgfError> {
502    sgf_into_slice(out, input, Kernel::Auto)
503}
504
505#[derive(Debug, Clone)]
506pub struct SgfStream {
507    period: usize,
508    weights: AVec<f64>,
509    ring: AVec<f64>,
510    next: usize,
511    count: usize,
512}
513
514impl SgfStream {
515    pub fn try_new(params: SgfParams) -> Result<Self, SgfError> {
516        let requested_period = params.period.unwrap_or(21);
517        let poly_order = params.poly_order.unwrap_or(2);
518        let period = validate_period_and_order(requested_period, poly_order, requested_period)?;
519        let weights = build_endpoint_sgf_weights(requested_period, poly_order)?;
520        let mut ring = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
521        ring.resize(period, 0.0);
522        Ok(Self {
523            period,
524            weights,
525            ring,
526            next: 0,
527            count: 0,
528        })
529    }
530
531    #[inline(always)]
532    pub fn update(&mut self, value: f64) -> Option<f64> {
533        self.ring[self.next] = value;
534        self.next += 1;
535        if self.next == self.period {
536            self.next = 0;
537        }
538        if self.count < self.period {
539            self.count += 1;
540        }
541        if self.count < self.period {
542            return None;
543        }
544
545        let mut acc = 0.0f64;
546        for idx in 0..self.period {
547            let ring_idx = (self.next + idx) % self.period;
548            acc += self.ring[ring_idx] * self.weights[idx];
549        }
550        Some(acc)
551    }
552}
553
554#[derive(Clone, Debug)]
555pub struct SgfBatchRange {
556    pub period: (usize, usize, usize),
557    pub poly_order: (usize, usize, usize),
558}
559
560impl Default for SgfBatchRange {
561    fn default() -> Self {
562        Self {
563            period: (21, 81, 2),
564            poly_order: (2, 2, 0),
565        }
566    }
567}
568
569#[derive(Clone, Debug, Default)]
570pub struct SgfBatchBuilder {
571    range: SgfBatchRange,
572    kernel: Kernel,
573}
574
575impl SgfBatchBuilder {
576    pub fn new() -> Self {
577        Self::default()
578    }
579
580    pub fn kernel(mut self, kernel: Kernel) -> Self {
581        self.kernel = kernel;
582        self
583    }
584
585    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
586        self.range.period = (start, end, step);
587        self
588    }
589
590    pub fn period_static(mut self, period: usize) -> Self {
591        self.range.period = (period, period, 0);
592        self
593    }
594
595    pub fn poly_order_range(mut self, start: usize, end: usize, step: usize) -> Self {
596        self.range.poly_order = (start, end, step);
597        self
598    }
599
600    pub fn poly_order_static(mut self, poly_order: usize) -> Self {
601        self.range.poly_order = (poly_order, poly_order, 0);
602        self
603    }
604
605    pub fn apply_slice(self, data: &[f64]) -> Result<SgfBatchOutput, SgfError> {
606        sgf_batch_with_kernel(data, &self.range, self.kernel)
607    }
608
609    pub fn apply_candles(
610        self,
611        candles: &Candles,
612        source: &str,
613    ) -> Result<SgfBatchOutput, SgfError> {
614        self.apply_slice(source_type(candles, source))
615    }
616}
617
618#[derive(Clone, Debug)]
619pub struct SgfBatchOutput {
620    pub values: Vec<f64>,
621    pub combos: Vec<SgfParams>,
622    pub rows: usize,
623    pub cols: usize,
624}
625
626impl SgfBatchOutput {
627    pub fn row_for_params(&self, params: &SgfParams) -> Option<usize> {
628        self.combos.iter().position(|combo| {
629            combo.period == params.period && combo.poly_order == params.poly_order
630        })
631    }
632
633    pub fn values_for(&self, params: &SgfParams) -> Option<&[f64]> {
634        self.row_for_params(params).map(|row| {
635            let start = row * self.cols;
636            &self.values[start..start + self.cols]
637        })
638    }
639}
640
641#[inline(always)]
642fn expand_axis(range: (usize, usize, usize), is_poly_order: bool) -> Result<Vec<usize>, SgfError> {
643    let (start, end, step) = range;
644    let values = if step == 0 || start == end {
645        vec![start]
646    } else if start < end {
647        (start..=end).step_by(step.max(1)).collect()
648    } else {
649        let mut out = Vec::new();
650        let mut cur = start;
651        loop {
652            out.push(cur);
653            if cur <= end {
654                break;
655            }
656            match cur.checked_sub(step.max(1)) {
657                Some(next) => {
658                    cur = next;
659                    if cur < end {
660                        break;
661                    }
662                }
663                None => break,
664            }
665        }
666        out
667    };
668
669    if values.is_empty() {
670        if is_poly_order {
671            return Err(SgfError::InvalidPolyOrderRange { start, end, step });
672        }
673        return Err(SgfError::InvalidRange { start, end, step });
674    }
675    Ok(values)
676}
677
678#[inline(always)]
679pub fn expand_grid(range: &SgfBatchRange) -> Result<Vec<SgfParams>, SgfError> {
680    let periods = expand_axis(range.period, false)?;
681    let poly_orders = expand_axis(range.poly_order, true)?;
682    let mut out = Vec::with_capacity(periods.len() * poly_orders.len());
683    for &period in &periods {
684        for &poly_order in &poly_orders {
685            out.push(SgfParams {
686                period: Some(period),
687                poly_order: Some(poly_order),
688            });
689        }
690    }
691    Ok(out)
692}
693
694pub fn sgf_batch_with_kernel(
695    data: &[f64],
696    sweep: &SgfBatchRange,
697    kernel: Kernel,
698) -> Result<SgfBatchOutput, SgfError> {
699    let kernel = match kernel {
700        Kernel::Auto => detect_best_batch_kernel(),
701        other if other.is_batch() => other,
702        other => return Err(SgfError::InvalidKernelForBatch(other)),
703    };
704    let single_kernel = match kernel {
705        Kernel::ScalarBatch => Kernel::Scalar,
706        Kernel::Avx2Batch => Kernel::Avx2,
707        Kernel::Avx512Batch => Kernel::Avx512,
708        other => other,
709    };
710    sgf_batch_inner(data, sweep, single_kernel, true)
711}
712
713#[inline(always)]
714pub fn sgf_batch_slice(
715    data: &[f64],
716    sweep: &SgfBatchRange,
717    kernel: Kernel,
718) -> Result<SgfBatchOutput, SgfError> {
719    sgf_batch_inner(data, sweep, kernel, false)
720}
721
722#[inline(always)]
723pub fn sgf_batch_par_slice(
724    data: &[f64],
725    sweep: &SgfBatchRange,
726    kernel: Kernel,
727) -> Result<SgfBatchOutput, SgfError> {
728    sgf_batch_inner(data, sweep, kernel, true)
729}
730
731pub fn sgf_batch_into_slice(
732    dst: &mut [f64],
733    data: &[f64],
734    sweep: &SgfBatchRange,
735    kernel: Kernel,
736) -> Result<Vec<SgfParams>, SgfError> {
737    sgf_batch_inner_into(data, sweep, kernel, true, dst)
738}
739
740fn sgf_batch_inner(
741    data: &[f64],
742    sweep: &SgfBatchRange,
743    kernel: Kernel,
744    parallel: bool,
745) -> Result<SgfBatchOutput, SgfError> {
746    let combos = expand_grid(sweep)?;
747    if data.is_empty() {
748        return Err(SgfError::EmptyInputData);
749    }
750    let first = data
751        .iter()
752        .position(|x| !x.is_nan())
753        .ok_or(SgfError::AllValuesNaN)?;
754
755    let rows = combos.len();
756    let cols = data.len();
757    let max_period = combos
758        .iter()
759        .map(|combo| combo.period.unwrap_or(21))
760        .map(effective_period)
761        .max()
762        .unwrap_or(0);
763
764    if max_period == 0 || max_period > cols {
765        return Err(SgfError::InvalidPeriod {
766            period: max_period,
767            effective_period: max_period,
768            data_len: cols,
769        });
770    }
771    if cols - first < max_period {
772        return Err(SgfError::NotEnoughValidData {
773            needed: max_period,
774            valid: cols - first,
775        });
776    }
777
778    let mut weights_flat = AVec::<f64>::with_capacity(
779        CACHELINE_ALIGN,
780        rows.checked_mul(max_period).ok_or(SgfError::InvalidRange {
781            start: sweep.period.0,
782            end: sweep.period.1,
783            step: sweep.period.2,
784        })?,
785    );
786    weights_flat.resize(rows * max_period, 0.0);
787
788    let mut periods = Vec::with_capacity(rows);
789    let mut warm = Vec::with_capacity(rows);
790    for (row, combo) in combos.iter().enumerate() {
791        let requested_period = combo.period.unwrap_or(21);
792        let poly_order = combo.poly_order.unwrap_or(2);
793        let period = validate_period_and_order(requested_period, poly_order, cols)?;
794        if cols - first < period {
795            return Err(SgfError::NotEnoughValidData {
796                needed: period,
797                valid: cols - first,
798            });
799        }
800        let weights = build_endpoint_sgf_weights(requested_period, poly_order)?;
801        let row_offset = row * max_period;
802        weights_flat[row_offset..row_offset + period].copy_from_slice(&weights);
803        periods.push(period);
804        warm.push(first + period - 1);
805    }
806
807    let mut buf_mu = make_uninit_matrix(rows, cols);
808    init_matrix_prefixes(&mut buf_mu, cols, &warm);
809    let row_fn = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
810        let period = periods[row];
811        let row_offset = row * max_period;
812        let out_row =
813            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
814        sgf_compute_into(
815            data,
816            &weights_flat[row_offset..row_offset + period],
817            period,
818            first,
819            kernel,
820            out_row,
821        );
822    };
823
824    #[cfg(not(target_arch = "wasm32"))]
825    if parallel {
826        buf_mu
827            .par_chunks_mut(cols)
828            .enumerate()
829            .for_each(|(row, slice)| row_fn(row, slice));
830    } else {
831        for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
832            row_fn(row, slice);
833        }
834    }
835
836    #[cfg(target_arch = "wasm32")]
837    for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
838        row_fn(row, slice);
839    }
840
841    use core::mem::ManuallyDrop;
842    let mut guard = ManuallyDrop::new(buf_mu);
843    let values = unsafe {
844        Vec::from_raw_parts(
845            guard.as_mut_ptr() as *mut f64,
846            guard.len(),
847            guard.capacity(),
848        )
849    };
850
851    Ok(SgfBatchOutput {
852        values,
853        combos,
854        rows,
855        cols,
856    })
857}
858
859fn sgf_batch_inner_into(
860    data: &[f64],
861    sweep: &SgfBatchRange,
862    kernel: Kernel,
863    parallel: bool,
864    out: &mut [f64],
865) -> Result<Vec<SgfParams>, SgfError> {
866    let result = sgf_batch_inner(data, sweep, kernel, parallel)?;
867    let expected = result.values.len();
868    if out.len() != expected {
869        return Err(SgfError::OutputLengthMismatch {
870            expected,
871            got: out.len(),
872        });
873    }
874    out.copy_from_slice(&result.values);
875    Ok(result.combos)
876}
877
878#[cfg(feature = "python")]
879#[pyfunction(name = "sgf")]
880#[pyo3(signature = (data, period=21, poly_order=2, kernel=None))]
881pub fn sgf_py<'py>(
882    py: Python<'py>,
883    data: numpy::PyReadonlyArray1<'py, f64>,
884    period: usize,
885    poly_order: usize,
886    kernel: Option<&str>,
887) -> PyResult<Bound<'py, PyArray1<f64>>> {
888    use numpy::PyArrayMethods;
889
890    let slice = data.as_slice()?;
891    let kernel = validate_kernel(kernel, false)?;
892    let input = SgfInput::from_slice(
893        slice,
894        SgfParams {
895            period: Some(period),
896            poly_order: Some(poly_order),
897        },
898    );
899    let values = py
900        .allow_threads(|| sgf_with_kernel(&input, kernel).map(|out| out.values))
901        .map_err(|e| PyValueError::new_err(e.to_string()))?;
902    Ok(values.into_pyarray(py))
903}
904
905#[cfg(feature = "python")]
906#[pyclass(name = "SgfStream")]
907pub struct SgfStreamPy {
908    stream: SgfStream,
909}
910
911#[cfg(feature = "python")]
912#[pymethods]
913impl SgfStreamPy {
914    #[new]
915    fn new(period: usize, poly_order: Option<usize>) -> PyResult<Self> {
916        let stream = SgfStream::try_new(SgfParams {
917            period: Some(period),
918            poly_order,
919        })
920        .map_err(|e| PyValueError::new_err(e.to_string()))?;
921        Ok(Self { stream })
922    }
923
924    fn update(&mut self, value: f64) -> Option<f64> {
925        self.stream.update(value)
926    }
927}
928
929#[cfg(feature = "python")]
930#[pyfunction(name = "sgf_batch")]
931#[pyo3(signature = (data, period_range, poly_order_range=(2, 2, 0), kernel=None))]
932pub fn sgf_batch_py<'py>(
933    py: Python<'py>,
934    data: numpy::PyReadonlyArray1<'py, f64>,
935    period_range: (usize, usize, usize),
936    poly_order_range: (usize, usize, usize),
937    kernel: Option<&str>,
938) -> PyResult<Bound<'py, PyDict>> {
939    use numpy::PyArrayMethods;
940
941    let slice = data.as_slice()?;
942    let kernel = validate_kernel(kernel, true)?;
943    let sweep = SgfBatchRange {
944        period: period_range,
945        poly_order: poly_order_range,
946    };
947    let out = py
948        .allow_threads(|| sgf_batch_with_kernel(slice, &sweep, kernel))
949        .map_err(|e| PyValueError::new_err(e.to_string()))?;
950
951    let dict = PyDict::new(py);
952    dict.set_item(
953        "values",
954        out.values
955            .clone()
956            .into_pyarray(py)
957            .reshape((out.rows, out.cols))?,
958    )?;
959    dict.set_item(
960        "periods",
961        out.combos
962            .iter()
963            .map(|c| c.period.unwrap_or(21))
964            .collect::<Vec<_>>()
965            .into_pyarray(py),
966    )?;
967    dict.set_item(
968        "poly_orders",
969        out.combos
970            .iter()
971            .map(|c| c.poly_order.unwrap_or(2))
972            .collect::<Vec<_>>()
973            .into_pyarray(py),
974    )?;
975    Ok(dict)
976}
977
978#[cfg(all(feature = "python", feature = "cuda"))]
979#[pyfunction(name = "sgf_cuda_batch_dev")]
980#[pyo3(signature = (data, period_range, poly_order_range=(2, 2, 0), device_id=0))]
981pub fn sgf_cuda_batch_dev_py(
982    py: Python<'_>,
983    data: numpy::PyReadonlyArray1<'_, f64>,
984    period_range: (usize, usize, usize),
985    poly_order_range: (usize, usize, usize),
986    device_id: usize,
987) -> PyResult<DeviceArrayF32SgfPy> {
988    use numpy::PyArrayMethods;
989
990    if !cuda_available() {
991        return Err(PyValueError::new_err("CUDA not available"));
992    }
993
994    let slice = data.as_slice()?;
995    let data_f32: Vec<f32> = slice.iter().map(|&v| v as f32).collect();
996    let sweep = SgfBatchRange {
997        period: period_range,
998        poly_order: poly_order_range,
999    };
1000
1001    let (inner, ctx, dev_id) = py.allow_threads(|| {
1002        let cuda = CudaSgf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1003        let ctx = cuda.context_arc();
1004        let dev_id = cuda.device_id();
1005        cuda.sgf_batch_dev(&data_f32, &sweep)
1006            .map(|inner| (inner, ctx, dev_id))
1007            .map_err(|e| PyValueError::new_err(e.to_string()))
1008    })?;
1009
1010    Ok(DeviceArrayF32SgfPy {
1011        inner: Some(DeviceArrayF32Py {
1012            inner,
1013            _ctx: Some(ctx),
1014            device_id: Some(dev_id),
1015        }),
1016    })
1017}
1018
1019#[cfg(all(feature = "python", feature = "cuda"))]
1020#[pyfunction(name = "sgf_cuda_many_series_one_param_dev")]
1021#[pyo3(signature = (data_tm_f32, period, poly_order=2, device_id=0))]
1022pub fn sgf_cuda_many_series_one_param_dev_py(
1023    py: Python<'_>,
1024    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1025    period: usize,
1026    poly_order: usize,
1027    device_id: usize,
1028) -> PyResult<DeviceArrayF32SgfPy> {
1029    use numpy::{PyArrayMethods, PyUntypedArrayMethods};
1030
1031    if !cuda_available() {
1032        return Err(PyValueError::new_err("CUDA not available"));
1033    }
1034
1035    let flat = data_tm_f32.as_slice()?;
1036    let cols = data_tm_f32.shape()[1];
1037    let rows = data_tm_f32.shape()[0];
1038
1039    let (inner, ctx, dev_id) = py.allow_threads(|| {
1040        let cuda = CudaSgf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1041        let ctx = cuda.context_arc();
1042        let dev_id = cuda.device_id();
1043        cuda.sgf_multi_series_one_param_time_major_dev(
1044            flat,
1045            cols,
1046            rows,
1047            &SgfParams {
1048                period: Some(period),
1049                poly_order: Some(poly_order),
1050            },
1051        )
1052        .map(|inner| (inner, ctx, dev_id))
1053        .map_err(|e| PyValueError::new_err(e.to_string()))
1054    })?;
1055
1056    Ok(DeviceArrayF32SgfPy {
1057        inner: Some(DeviceArrayF32Py {
1058            inner,
1059            _ctx: Some(ctx),
1060            device_id: Some(dev_id),
1061        }),
1062    })
1063}
1064
1065#[cfg(all(feature = "python", feature = "cuda"))]
1066#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Sgf", unsendable)]
1067pub struct DeviceArrayF32SgfPy {
1068    pub(crate) inner: Option<DeviceArrayF32Py>,
1069}
1070
1071#[cfg(all(feature = "python", feature = "cuda"))]
1072#[pymethods]
1073impl DeviceArrayF32SgfPy {
1074    #[getter]
1075    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1076        self.inner
1077            .as_ref()
1078            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1079            .__cuda_array_interface__(py)
1080    }
1081
1082    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1083        self.inner
1084            .as_ref()
1085            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1086            .__dlpack_device__()
1087    }
1088
1089    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1090    fn __dlpack__<'py>(
1091        &mut self,
1092        py: Python<'py>,
1093        stream: Option<PyObject>,
1094        max_version: Option<PyObject>,
1095        dl_device: Option<PyObject>,
1096        copy: Option<PyObject>,
1097    ) -> PyResult<PyObject> {
1098        self.inner
1099            .take()
1100            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1101            .__dlpack__(py, stream, max_version, dl_device, copy)
1102    }
1103}
1104
1105#[cfg(feature = "python")]
1106pub fn register_sgf_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1107    m.add_function(wrap_pyfunction!(sgf_py, m)?)?;
1108    m.add_function(wrap_pyfunction!(sgf_batch_py, m)?)?;
1109    m.add_class::<SgfStreamPy>()?;
1110    #[cfg(feature = "cuda")]
1111    {
1112        m.add_function(wrap_pyfunction!(sgf_cuda_batch_dev_py, m)?)?;
1113        m.add_function(wrap_pyfunction!(sgf_cuda_many_series_one_param_dev_py, m)?)?;
1114    }
1115    Ok(())
1116}
1117
1118#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1119#[wasm_bindgen]
1120pub fn sgf_js(data: &[f64], period: usize, poly_order: usize) -> Result<Vec<f64>, JsValue> {
1121    let input = SgfInput::from_slice(
1122        data,
1123        SgfParams {
1124            period: Some(period),
1125            poly_order: Some(poly_order),
1126        },
1127    );
1128    let mut out = vec![0.0; data.len()];
1129    sgf_into_slice(&mut out, &input, Kernel::Auto)
1130        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1131    Ok(out)
1132}
1133
1134#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1135#[derive(Serialize, Deserialize)]
1136pub struct SgfBatchConfig {
1137    pub period_range: (usize, usize, usize),
1138    pub poly_order_range: (usize, usize, usize),
1139}
1140
1141#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1142#[derive(Serialize, Deserialize)]
1143pub struct SgfBatchJsOutput {
1144    pub values: Vec<f64>,
1145    pub combos: Vec<SgfParams>,
1146    pub rows: usize,
1147    pub cols: usize,
1148}
1149
1150#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1151#[wasm_bindgen(js_name = sgf_batch)]
1152pub fn sgf_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1153    let config: SgfBatchConfig = serde_wasm_bindgen::from_value(config)
1154        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1155    let out = sgf_batch_with_kernel(
1156        data,
1157        &SgfBatchRange {
1158            period: config.period_range,
1159            poly_order: config.poly_order_range,
1160        },
1161        Kernel::Auto,
1162    )
1163    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1164    serde_wasm_bindgen::to_value(&SgfBatchJsOutput {
1165        values: out.values,
1166        combos: out.combos,
1167        rows: out.rows,
1168        cols: out.cols,
1169    })
1170    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1171}
1172
1173#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1174#[wasm_bindgen]
1175pub fn sgf_alloc(len: usize) -> *mut f64 {
1176    let mut vec = Vec::<f64>::with_capacity(len);
1177    let ptr = vec.as_mut_ptr();
1178    std::mem::forget(vec);
1179    ptr
1180}
1181
1182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1183#[wasm_bindgen]
1184pub fn sgf_free(ptr: *mut f64, len: usize) {
1185    if !ptr.is_null() {
1186        unsafe {
1187            let _ = Vec::from_raw_parts(ptr, len, len);
1188        }
1189    }
1190}
1191
1192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1193#[wasm_bindgen]
1194pub fn sgf_into(
1195    in_ptr: *const f64,
1196    out_ptr: *mut f64,
1197    len: usize,
1198    period: usize,
1199    poly_order: usize,
1200) -> Result<(), JsValue> {
1201    if in_ptr.is_null() || out_ptr.is_null() {
1202        return Err(JsValue::from_str("Null pointer provided"));
1203    }
1204    unsafe {
1205        let data = std::slice::from_raw_parts(in_ptr, len);
1206        let out = std::slice::from_raw_parts_mut(out_ptr, len);
1207        let input = SgfInput::from_slice(
1208            data,
1209            SgfParams {
1210                period: Some(period),
1211                poly_order: Some(poly_order),
1212            },
1213        );
1214        sgf_into_slice(out, &input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))
1215    }
1216}
1217
1218#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1219#[wasm_bindgen]
1220pub fn sgf_batch_into(
1221    in_ptr: *const f64,
1222    out_ptr: *mut f64,
1223    len: usize,
1224    period_start: usize,
1225    period_end: usize,
1226    period_step: usize,
1227    poly_order_start: usize,
1228    poly_order_end: usize,
1229    poly_order_step: usize,
1230) -> Result<usize, JsValue> {
1231    if in_ptr.is_null() || out_ptr.is_null() {
1232        return Err(JsValue::from_str("null pointer passed to sgf_batch_into"));
1233    }
1234    unsafe {
1235        let data = std::slice::from_raw_parts(in_ptr, len);
1236        let combos = expand_grid(&SgfBatchRange {
1237            period: (period_start, period_end, period_step),
1238            poly_order: (poly_order_start, poly_order_end, poly_order_step),
1239        })
1240        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1241        let rows = combos.len();
1242        let out = std::slice::from_raw_parts_mut(out_ptr, rows * len);
1243        sgf_batch_into_slice(
1244            out,
1245            data,
1246            &SgfBatchRange {
1247                period: (period_start, period_end, period_step),
1248                poly_order: (poly_order_start, poly_order_end, poly_order_step),
1249            },
1250            Kernel::Auto,
1251        )
1252        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1253        Ok(rows)
1254    }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259    use super::*;
1260
1261    fn polynomial_series(len: usize, coeffs: &[f64], warm_prefix: usize) -> Vec<f64> {
1262        let mut data = vec![f64::NAN; len];
1263        for (idx, slot) in data.iter_mut().enumerate().skip(warm_prefix) {
1264            let x = idx as f64;
1265            let mut pow = 1.0;
1266            let mut y = 0.0;
1267            for &coef in coeffs {
1268                y += coef * pow;
1269                pow *= x;
1270            }
1271            *slot = y;
1272        }
1273        data
1274    }
1275
1276    #[test]
1277    fn sgf_reproduces_quadratic_endpoint() {
1278        let data = polynomial_series(64, &[3.0, -0.25, 0.75], 3);
1279        let out = SgfBuilder::new()
1280            .period(9)
1281            .poly_order(2)
1282            .apply_slice(&data)
1283            .unwrap();
1284        for idx in 11..data.len() {
1285            assert!((out.values[idx] - data[idx]).abs() < 1e-10);
1286        }
1287    }
1288
1289    #[test]
1290    fn sgf_reproduces_quartic_endpoint() {
1291        let data = polynomial_series(96, &[1.0, -2.0, 0.5, 0.125, -0.01], 4);
1292        let out = SgfBuilder::new()
1293            .period(11)
1294            .poly_order(4)
1295            .apply_slice(&data)
1296            .unwrap();
1297        for idx in 14..data.len() {
1298            assert!((out.values[idx] - data[idx]).abs() < 1e-7);
1299        }
1300    }
1301
1302    #[test]
1303    fn sgf_stream_matches_batch() {
1304        let data = polynomial_series(80, &[2.0, 0.1, -0.03], 2);
1305        let batch = SgfBuilder::new()
1306            .period(9)
1307            .poly_order(2)
1308            .apply_slice(&data)
1309            .unwrap();
1310        let mut stream = SgfBuilder::new()
1311            .period(9)
1312            .poly_order(2)
1313            .into_stream()
1314            .unwrap();
1315        let mut streamed = vec![f64::NAN; data.len()];
1316        for (idx, &value) in data.iter().enumerate() {
1317            if value.is_nan() {
1318                continue;
1319            }
1320            if let Some(out) = stream.update(value) {
1321                streamed[idx] = out;
1322            }
1323        }
1324        assert_eq!(batch.values.len(), streamed.len());
1325        for idx in 0..streamed.len() {
1326            assert!(
1327                (batch.values[idx].is_nan() && streamed[idx].is_nan())
1328                    || (batch.values[idx] - streamed[idx]).abs() < 1e-10
1329            );
1330        }
1331    }
1332
1333    #[test]
1334    fn sgf_batch_rows_match_single() {
1335        let data = polynomial_series(72, &[0.5, -0.2, 0.03], 1);
1336        let sweep = SgfBatchRange {
1337            period: (7, 11, 2),
1338            poly_order: (2, 2, 0),
1339        };
1340        let batch = sgf_batch_with_kernel(&data, &sweep, Kernel::ScalarBatch).unwrap();
1341        for period in [7usize, 9, 11] {
1342            let params = SgfParams {
1343                period: Some(period),
1344                poly_order: Some(2),
1345            };
1346            let row = batch.values_for(&params).unwrap();
1347            let single = sgf(&SgfInput::from_slice(&data, params.clone())).unwrap();
1348            for idx in 0..data.len() {
1349                assert!(
1350                    (row[idx].is_nan() && single.values[idx].is_nan())
1351                        || (row[idx] - single.values[idx]).abs() < 1e-10
1352                );
1353            }
1354        }
1355    }
1356
1357    #[test]
1358    fn sgf_rejects_invalid_poly_order() {
1359        let data = polynomial_series(32, &[1.0, 2.0], 0);
1360        let err = SgfBuilder::new()
1361            .period(5)
1362            .poly_order(5)
1363            .apply_slice(&data)
1364            .unwrap_err()
1365            .to_string();
1366        assert!(err.contains("invalid polynomial order"));
1367    }
1368}