vector_ta/indicators/moving_averages/
linreg.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaLinreg, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7    make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11use aligned_vec::{AVec, CACHELINE_ALIGN};
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(all(feature = "python", feature = "cuda"))]
15use cust::context::Context;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::prelude::*;
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22use std::convert::AsRef;
23use std::error::Error;
24use std::mem::MaybeUninit;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use std::sync::Arc;
27use thiserror::Error;
28
29#[derive(Debug, Clone)]
30pub enum LinRegData<'a> {
31    Candles {
32        candles: &'a Candles,
33        source: &'a str,
34    },
35    Slice(&'a [f64]),
36}
37
38#[derive(Debug, Clone)]
39pub struct LinRegOutput {
40    pub values: Vec<f64>,
41}
42
43#[derive(Debug, Clone)]
44#[cfg_attr(
45    all(target_arch = "wasm32", feature = "wasm"),
46    derive(Serialize, Deserialize)
47)]
48pub struct LinRegParams {
49    pub period: Option<usize>,
50}
51
52impl Default for LinRegParams {
53    fn default() -> Self {
54        Self { period: Some(14) }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct LinRegInput<'a> {
60    pub data: LinRegData<'a>,
61    pub params: LinRegParams,
62}
63
64impl<'a> AsRef<[f64]> for LinRegInput<'a> {
65    #[inline(always)]
66    fn as_ref(&self) -> &[f64] {
67        match &self.data {
68            LinRegData::Slice(slice) => slice,
69            LinRegData::Candles { candles, source } => source_type(candles, source),
70        }
71    }
72}
73
74impl<'a> LinRegInput<'a> {
75    #[inline]
76    pub fn from_candles(c: &'a Candles, s: &'a str, p: LinRegParams) -> Self {
77        Self {
78            data: LinRegData::Candles {
79                candles: c,
80                source: s,
81            },
82            params: p,
83        }
84    }
85    #[inline]
86    pub fn from_slice(sl: &'a [f64], p: LinRegParams) -> Self {
87        Self {
88            data: LinRegData::Slice(sl),
89            params: p,
90        }
91    }
92    #[inline]
93    pub fn with_default_candles(c: &'a Candles) -> Self {
94        Self::from_candles(c, "close", LinRegParams::default())
95    }
96    #[inline]
97    pub fn get_period(&self) -> usize {
98        self.params.period.unwrap_or(14)
99    }
100}
101
102#[derive(Copy, Clone, Debug)]
103pub struct LinRegBuilder {
104    period: Option<usize>,
105    kernel: Kernel,
106}
107
108impl Default for LinRegBuilder {
109    fn default() -> Self {
110        Self {
111            period: None,
112            kernel: Kernel::Auto,
113        }
114    }
115}
116
117impl LinRegBuilder {
118    #[inline(always)]
119    pub fn new() -> Self {
120        Self::default()
121    }
122    #[inline(always)]
123    pub fn period(mut self, n: usize) -> Self {
124        self.period = Some(n);
125        self
126    }
127    #[inline(always)]
128    pub fn kernel(mut self, k: Kernel) -> Self {
129        self.kernel = k;
130        self
131    }
132    #[inline(always)]
133    pub fn apply(self, c: &Candles) -> Result<LinRegOutput, LinRegError> {
134        let p = LinRegParams {
135            period: self.period,
136        };
137        let i = LinRegInput::from_candles(c, "close", p);
138        linreg_with_kernel(&i, self.kernel)
139    }
140    #[inline(always)]
141    pub fn apply_slice(self, d: &[f64]) -> Result<LinRegOutput, LinRegError> {
142        let p = LinRegParams {
143            period: self.period,
144        };
145        let i = LinRegInput::from_slice(d, p);
146        linreg_with_kernel(&i, self.kernel)
147    }
148    #[inline(always)]
149    pub fn into_stream(self) -> Result<LinRegStream, LinRegError> {
150        let p = LinRegParams {
151            period: self.period,
152        };
153        LinRegStream::try_new(p)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum LinRegError {
159    #[error("linreg: No data provided (All values are NaN).")]
160    EmptyInputData,
161    #[error("linreg: All values are NaN.")]
162    AllValuesNaN,
163    #[error("linreg: Invalid period: period = {period}, data length = {data_len}")]
164    InvalidPeriod { period: usize, data_len: usize },
165    #[error("linreg: Not enough valid data: needed = {needed}, valid = {valid}")]
166    NotEnoughValidData { needed: usize, valid: usize },
167    #[error("linreg: Output length mismatch: expected = {expected}, got = {got}")]
168    OutputLengthMismatch { expected: usize, got: usize },
169    #[error("linreg: Invalid range: start = {start}, end = {end}, step = {step}")]
170    InvalidRange {
171        start: usize,
172        end: usize,
173        step: usize,
174    },
175    #[error("linreg: Invalid kernel for batch API: {0:?}")]
176    InvalidKernelForBatch(Kernel),
177    #[error("linreg: arithmetic overflow when computing {what}")]
178    ArithmeticOverflow { what: &'static str },
179}
180
181#[inline]
182pub fn linreg(input: &LinRegInput) -> Result<LinRegOutput, LinRegError> {
183    linreg_with_kernel(input, Kernel::Auto)
184}
185
186#[inline(always)]
187fn linreg_prepare<'a>(
188    input: &'a LinRegInput,
189    kernel: Kernel,
190) -> Result<(&'a [f64], usize, usize, Kernel), LinRegError> {
191    let data: &[f64] = input.as_ref();
192    if data.is_empty() {
193        return Err(LinRegError::EmptyInputData);
194    }
195    let first = data
196        .iter()
197        .position(|x| !x.is_nan())
198        .ok_or(LinRegError::AllValuesNaN)?;
199    let len = data.len();
200    let period = input.get_period();
201
202    if period == 0 || period > len {
203        return Err(LinRegError::InvalidPeriod {
204            period,
205            data_len: len,
206        });
207    }
208    if (len - first) < period {
209        return Err(LinRegError::NotEnoughValidData {
210            needed: period,
211            valid: len - first,
212        });
213    }
214
215    let chosen = match kernel {
216        Kernel::Auto => Kernel::Scalar,
217        other => other,
218    };
219
220    Ok((data, period, first, chosen))
221}
222
223pub fn linreg_with_kernel(
224    input: &LinRegInput,
225    kernel: Kernel,
226) -> Result<LinRegOutput, LinRegError> {
227    let (data, period, first, chosen) = linreg_prepare(input, kernel)?;
228
229    let warm = first + period;
230    let mut out = alloc_with_nan_prefix(data.len(), warm);
231
232    unsafe {
233        match chosen {
234            Kernel::Scalar | Kernel::ScalarBatch => linreg_scalar(data, period, first, &mut out),
235            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
236            Kernel::Avx2 | Kernel::Avx2Batch => linreg_avx2(data, period, first, &mut out),
237            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238            Kernel::Avx512 | Kernel::Avx512Batch => linreg_avx512(data, period, first, &mut out),
239            _ => unreachable!(),
240        }
241    }
242
243    Ok(LinRegOutput { values: out })
244}
245
246#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
247#[inline]
248pub fn linreg_into(input: &LinRegInput, out: &mut [f64]) -> Result<(), LinRegError> {
249    linreg_compute_into(input, Kernel::Scalar, out)
250}
251
252pub fn linreg_compute_into(
253    input: &LinRegInput,
254    kernel: Kernel,
255    out: &mut [f64],
256) -> Result<(), LinRegError> {
257    let data: &[f64] = input.as_ref();
258    if data.is_empty() {
259        return Err(LinRegError::EmptyInputData);
260    }
261    let first = data
262        .iter()
263        .position(|x| !x.is_nan())
264        .ok_or(LinRegError::AllValuesNaN)?;
265    let len = data.len();
266    let period = input.get_period();
267
268    if period == 0 || period > len {
269        return Err(LinRegError::InvalidPeriod {
270            period,
271            data_len: len,
272        });
273    }
274    if (len - first) < period {
275        return Err(LinRegError::NotEnoughValidData {
276            needed: period,
277            valid: len - first,
278        });
279    }
280    if out.len() != len {
281        return Err(LinRegError::OutputLengthMismatch {
282            expected: len,
283            got: out.len(),
284        });
285    }
286
287    let chosen = match kernel {
288        Kernel::Auto => Kernel::Scalar,
289        other => other,
290    };
291
292    let warm = first + period;
293
294    out[..warm].fill(f64::NAN);
295
296    unsafe {
297        match chosen {
298            Kernel::Scalar | Kernel::ScalarBatch => linreg_scalar(data, period, first, out),
299            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
300            Kernel::Avx2 | Kernel::Avx2Batch => linreg_avx2(data, period, first, out),
301            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
302            Kernel::Avx512 | Kernel::Avx512Batch => linreg_avx512(data, period, first, out),
303            _ => unreachable!(),
304        }
305    }
306
307    Ok(())
308}
309
310#[inline(always)]
311fn linreg_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
312    let period_f = period as f64;
313    let x_sum = ((period * (period + 1)) / 2) as f64;
314    let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
315    let denom_inv = 1.0 / (period_f * x2_sum - x_sum * x_sum);
316    let inv_period = 1.0 / period_f;
317
318    let mut y_sum = 0.0;
319    let mut xy_sum = 0.0;
320    let init_slice = &data[first..first + period - 1];
321    let mut k = 1usize;
322    for &v in init_slice.iter() {
323        y_sum += v;
324        xy_sum += (k as f64) * v;
325        k += 1;
326    }
327
328    let len = data.len();
329    let mut idx = first + period - 1;
330    let mut old_idx = first;
331    unsafe {
332        while idx < len {
333            let new_val = *data.get_unchecked(idx);
334            y_sum += new_val;
335            xy_sum += new_val * period_f;
336
337            let b = (period_f * xy_sum - x_sum * y_sum) * denom_inv;
338            let a = (y_sum - b * x_sum) * inv_period;
339            *out.get_unchecked_mut(idx) = a + b * period_f;
340
341            xy_sum -= y_sum;
342            y_sum -= *data.get_unchecked(old_idx);
343
344            idx += 1;
345            old_idx += 1;
346        }
347    }
348}
349
350#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
351#[target_feature(enable = "avx2,fma")]
352pub unsafe fn linreg_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
353    use core::arch::x86_64::*;
354
355    let pf = period as f64;
356    let x_sum = ((period * (period + 1)) / 2) as f64;
357    let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
358    let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
359    let inv_pf = 1.0 / pf;
360
361    let mut y_sum = 0.0f64;
362    let mut xy_sum = 0.0f64;
363
364    let init_len = period.saturating_sub(1);
365    let mut p = data.as_ptr().add(first);
366
367    let vec_blocks = init_len / 4;
368    if vec_blocks > 0 {
369        let base = _mm256_setr_pd(1.0, 2.0, 3.0, 4.0);
370        let mut off = 0.0f64;
371        let mut y_acc = _mm256_set1_pd(0.0);
372        let mut xy_acc = _mm256_set1_pd(0.0);
373
374        for _ in 0..vec_blocks {
375            let y = _mm256_loadu_pd(p);
376            let x = _mm256_add_pd(base, _mm256_set1_pd(off));
377            y_acc = _mm256_add_pd(y_acc, y);
378            xy_acc = _mm256_fmadd_pd(y, x, xy_acc);
379            p = p.add(4);
380            off += 4.0;
381        }
382
383        let mut buf = [0.0f64; 4];
384        _mm256_storeu_pd(buf.as_mut_ptr(), y_acc);
385        y_sum += buf.iter().sum::<f64>();
386        _mm256_storeu_pd(buf.as_mut_ptr(), xy_acc);
387        xy_sum += buf.iter().sum::<f64>();
388    }
389
390    let tail = init_len & 3;
391    let mut k_off = (vec_blocks * 4 + 1) as f64;
392    for _ in 0..tail {
393        let v = *p;
394        y_sum += v;
395        xy_sum += k_off * v;
396        k_off += 1.0;
397        p = p.add(1);
398    }
399
400    let len = data.len();
401    let mut idx = first + period - 1;
402    let mut old_idx = first;
403    while idx < len {
404        let new_v = *data.get_unchecked(idx);
405        y_sum += new_v;
406        xy_sum = f64::mul_add(pf, new_v, xy_sum);
407
408        let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
409        let a = (y_sum - b * x_sum) * inv_pf;
410        *out.get_unchecked_mut(idx) = a + b * pf;
411
412        xy_sum -= y_sum;
413        y_sum -= *data.get_unchecked(old_idx);
414        idx += 1;
415        old_idx += 1;
416    }
417}
418
419#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
420#[target_feature(enable = "avx512f,avx512dq,fma")]
421pub unsafe fn linreg_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
422    use core::arch::x86_64::*;
423
424    let pf = period as f64;
425    let x_sum = ((period * (period + 1)) / 2) as f64;
426    let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
427    let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
428    let inv_pf = 1.0 / pf;
429
430    let mut y_sum = 0.0f64;
431    let mut xy_sum = 0.0f64;
432
433    let init_len = period.saturating_sub(1);
434    let mut p = data.as_ptr().add(first);
435
436    let vec_blocks = init_len / 8;
437    if vec_blocks > 0 {
438        let base = _mm512_setr_pd(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
439        let mut off = 0.0f64;
440        let mut y_acc = _mm512_set1_pd(0.0);
441        let mut xy_acc = _mm512_set1_pd(0.0);
442
443        for _ in 0..vec_blocks {
444            let y = _mm512_loadu_pd(p);
445            let x = _mm512_add_pd(base, _mm512_set1_pd(off));
446            y_acc = _mm512_add_pd(y_acc, y);
447            xy_acc = _mm512_fmadd_pd(y, x, xy_acc);
448            p = p.add(8);
449            off += 8.0;
450        }
451
452        let mut buf = [0.0f64; 8];
453        _mm512_storeu_pd(buf.as_mut_ptr(), y_acc);
454        y_sum += buf.iter().sum::<f64>();
455        _mm512_storeu_pd(buf.as_mut_ptr(), xy_acc);
456        xy_sum += buf.iter().sum::<f64>();
457    }
458
459    let tail = init_len & 7;
460    let mut k_off = (vec_blocks * 8 + 1) as f64;
461    for _ in 0..tail {
462        let v = *p;
463        y_sum += v;
464        xy_sum += k_off * v;
465        k_off += 1.0;
466        p = p.add(1);
467    }
468
469    let len = data.len();
470    let mut idx = first + period - 1;
471    let mut old_idx = first;
472    while idx < len {
473        let new_v = *data.get_unchecked(idx);
474        y_sum += new_v;
475        xy_sum = f64::mul_add(pf, new_v, xy_sum);
476
477        let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
478        let a = (y_sum - b * x_sum) * inv_pf;
479        *out.get_unchecked_mut(idx) = a + b * pf;
480
481        xy_sum -= y_sum;
482        y_sum -= *data.get_unchecked(old_idx);
483        idx += 1;
484        old_idx += 1;
485    }
486}
487
488#[derive(Clone, Debug)]
489pub struct LinRegBatchRange {
490    pub period: (usize, usize, usize),
491}
492
493impl Default for LinRegBatchRange {
494    fn default() -> Self {
495        Self {
496            period: (14, 263, 1),
497        }
498    }
499}
500
501#[derive(Clone, Debug, Default)]
502pub struct LinRegBatchBuilder {
503    range: LinRegBatchRange,
504    kernel: Kernel,
505}
506
507impl LinRegBatchBuilder {
508    pub fn new() -> Self {
509        Self::default()
510    }
511    pub fn kernel(mut self, k: Kernel) -> Self {
512        self.kernel = k;
513        self
514    }
515    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
516        self.range.period = (start, end, step);
517        self
518    }
519    pub fn period_static(mut self, p: usize) -> Self {
520        self.range.period = (p, p, 0);
521        self
522    }
523    pub fn apply_slice(self, data: &[f64]) -> Result<LinRegBatchOutput, LinRegError> {
524        linreg_batch_with_kernel(data, &self.range, self.kernel)
525    }
526    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<LinRegBatchOutput, LinRegError> {
527        LinRegBatchBuilder::new().kernel(k).apply_slice(data)
528    }
529    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<LinRegBatchOutput, LinRegError> {
530        let slice = source_type(c, src);
531        self.apply_slice(slice)
532    }
533    pub fn with_default_candles(c: &Candles) -> Result<LinRegBatchOutput, LinRegError> {
534        LinRegBatchBuilder::new()
535            .kernel(Kernel::Auto)
536            .apply_candles(c, "close")
537    }
538}
539
540#[derive(Clone, Debug)]
541#[cfg_attr(
542    all(target_arch = "wasm32", feature = "wasm"),
543    derive(Serialize, Deserialize)
544)]
545pub struct LinRegBatchOutput {
546    pub values: Vec<f64>,
547    pub combos: Vec<LinRegParams>,
548    pub rows: usize,
549    pub cols: usize,
550}
551
552impl LinRegBatchOutput {
553    pub fn row_for_params(&self, p: &LinRegParams) -> Option<usize> {
554        self.combos
555            .iter()
556            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
557    }
558    pub fn values_for(&self, p: &LinRegParams) -> Option<&[f64]> {
559        self.row_for_params(p).map(|row| {
560            let start = row * self.cols;
561            &self.values[start..start + self.cols]
562        })
563    }
564}
565
566pub fn linreg_batch_with_kernel(
567    data: &[f64],
568    sweep: &LinRegBatchRange,
569    k: Kernel,
570) -> Result<LinRegBatchOutput, LinRegError> {
571    let kernel = match k {
572        Kernel::Auto => Kernel::ScalarBatch,
573        other if other.is_batch() => other,
574        _ => return Err(LinRegError::InvalidKernelForBatch(k)),
575    };
576    let simd = match kernel {
577        Kernel::Avx512Batch => Kernel::Avx512,
578        Kernel::Avx2Batch => Kernel::Avx2,
579        Kernel::ScalarBatch => Kernel::Scalar,
580        _ => unreachable!(),
581    };
582    linreg_batch_par_slice(data, sweep, simd)
583}
584
585#[inline(always)]
586pub fn linreg_batch_slice(
587    data: &[f64],
588    sweep: &LinRegBatchRange,
589    kern: Kernel,
590) -> Result<LinRegBatchOutput, LinRegError> {
591    linreg_batch_inner(data, sweep, kern, false)
592}
593
594#[inline(always)]
595pub fn linreg_batch_par_slice(
596    data: &[f64],
597    sweep: &LinRegBatchRange,
598    kern: Kernel,
599) -> Result<LinRegBatchOutput, LinRegError> {
600    linreg_batch_inner(data, sweep, kern, true)
601}
602
603#[inline(always)]
604fn linreg_batch_inner(
605    data: &[f64],
606    sweep: &LinRegBatchRange,
607    kern: Kernel,
608    parallel: bool,
609) -> Result<LinRegBatchOutput, LinRegError> {
610    let combos = expand_grid_linreg(sweep);
611    if combos.is_empty() {
612        let (s, e, t) = sweep.period;
613        return Err(LinRegError::InvalidRange {
614            start: s,
615            end: e,
616            step: t,
617        });
618    }
619    let first = data
620        .iter()
621        .position(|x| !x.is_nan())
622        .ok_or(LinRegError::AllValuesNaN)?;
623    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
624    if data.len() - first < max_p {
625        return Err(LinRegError::NotEnoughValidData {
626            needed: max_p,
627            valid: data.len() - first,
628        });
629    }
630
631    let rows = combos.len();
632    let cols = data.len();
633    let _ = rows
634        .checked_mul(cols)
635        .ok_or(LinRegError::ArithmeticOverflow { what: "rows*cols" })?;
636
637    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
638
639    let mut raw = make_uninit_matrix(rows, cols);
640    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
641
642    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
643        let period = combos[row].period.unwrap();
644
645        let out_row =
646            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
647
648        match kern {
649            Kernel::Scalar => linreg_row_scalar(data, first, period, out_row),
650            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
651            Kernel::Avx2 => linreg_row_avx2(data, first, period, out_row),
652            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
653            Kernel::Avx512 => linreg_row_avx512(data, first, period, out_row),
654            _ => unreachable!(),
655        }
656    };
657
658    if parallel {
659        #[cfg(not(target_arch = "wasm32"))]
660        {
661            raw.par_chunks_mut(cols)
662                .enumerate()
663                .for_each(|(row, slice)| do_row(row, slice));
664        }
665
666        #[cfg(target_arch = "wasm32")]
667        {
668            for (row, slice) in raw.chunks_mut(cols).enumerate() {
669                do_row(row, slice);
670            }
671        }
672    } else {
673        for (row, slice) in raw.chunks_mut(cols).enumerate() {
674            do_row(row, slice);
675        }
676    }
677
678    let values: Vec<f64> = unsafe { std::mem::transmute(raw) };
679
680    Ok(LinRegBatchOutput {
681        values,
682        combos,
683        rows,
684        cols,
685    })
686}
687
688pub fn linreg_batch_inner_into(
689    data: &[f64],
690    sweep: &LinRegBatchRange,
691    kern: Kernel,
692    parallel: bool,
693    out: &mut [f64],
694) -> Result<Vec<LinRegParams>, LinRegError> {
695    let combos = expand_grid_linreg(sweep);
696    if combos.is_empty() {
697        let (s, e, t) = sweep.period;
698        return Err(LinRegError::InvalidRange {
699            start: s,
700            end: e,
701            step: t,
702        });
703    }
704    let first = data
705        .iter()
706        .position(|x| !x.is_nan())
707        .ok_or(LinRegError::AllValuesNaN)?;
708    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
709    if data.len() - first < max_p {
710        return Err(LinRegError::NotEnoughValidData {
711            needed: max_p,
712            valid: data.len() - first,
713        });
714    }
715
716    let rows = combos.len();
717    let cols = data.len();
718    let expected = rows
719        .checked_mul(cols)
720        .ok_or(LinRegError::ArithmeticOverflow { what: "rows*cols" })?;
721
722    if out.len() != expected {
723        return Err(LinRegError::OutputLengthMismatch {
724            expected,
725            got: out.len(),
726        });
727    }
728
729    let out_uninit = unsafe {
730        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
731    };
732
733    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
734
735    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
736
737    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
738        let period = combos[row].period.unwrap();
739
740        let out_row =
741            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
742
743        match kern {
744            Kernel::Scalar => linreg_row_scalar(data, first, period, out_row),
745            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746            Kernel::Avx2 => linreg_row_avx2(data, first, period, out_row),
747            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
748            Kernel::Avx512 => linreg_row_avx512(data, first, period, out_row),
749            _ => unreachable!(),
750        }
751    };
752
753    if parallel {
754        #[cfg(not(target_arch = "wasm32"))]
755        {
756            out_uninit
757                .par_chunks_mut(cols)
758                .enumerate()
759                .for_each(|(row, slice)| do_row(row, slice));
760        }
761
762        #[cfg(target_arch = "wasm32")]
763        {
764            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
765                do_row(row, slice);
766            }
767        }
768    } else {
769        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
770            do_row(row, slice);
771        }
772    }
773
774    Ok(combos)
775}
776
777#[inline(always)]
778unsafe fn linreg_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
779    linreg_scalar(data, period, first, out)
780}
781
782#[inline(always)]
783unsafe fn linreg_row_prefix_sums_scalar(
784    data: &[f64],
785    first: usize,
786    period: usize,
787    out: &mut [f64],
788    s: &[f64],
789    sp: &[f64],
790) {
791    let len = data.len();
792    let pf = period as f64;
793    let x_sum = ((period * (period + 1)) / 2) as f64;
794    let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
795    let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
796    let inv_pf = 1.0 / pf;
797
798    let mut idx = first + period - 1;
799    while idx < len {
800        let pos = idx - first + 1;
801        let y_sum = s.get_unchecked(pos) - s.get_unchecked(pos - period);
802
803        let xy_sum = (sp.get_unchecked(pos) - sp.get_unchecked(pos - period))
804            - ((pos - period) as f64) * y_sum;
805
806        let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
807        let a = (y_sum - b * x_sum) * inv_pf;
808        *out.get_unchecked_mut(idx) = a + b * pf;
809
810        idx += 1;
811    }
812}
813
814#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815#[inline(always)]
816unsafe fn linreg_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
817    linreg_avx2(data, period, first, out)
818}
819
820#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
821#[inline(always)]
822unsafe fn linreg_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
823    linreg_avx512(data, period, first, out)
824}
825
826#[derive(Debug, Clone)]
827pub struct LinRegStream {
828    period: usize,
829    buffer: Vec<f64>,
830    head: usize,
831    filled: bool,
832    x_sum: f64,
833    x2_sum: f64,
834}
835
836impl LinRegStream {
837    pub fn try_new(params: LinRegParams) -> Result<Self, LinRegError> {
838        let period = params.period.unwrap_or(14);
839        if period == 0 {
840            return Err(LinRegError::InvalidPeriod {
841                period,
842                data_len: 0,
843            });
844        }
845        let mut x_sum = 0.0;
846        let mut x2_sum = 0.0;
847        for i in 1..=period {
848            let xi = i as f64;
849            x_sum += xi;
850            x2_sum += xi * xi;
851        }
852        Ok(Self {
853            period,
854            buffer: vec![f64::NAN; period],
855            head: 0,
856            filled: false,
857            x_sum,
858            x2_sum,
859        })
860    }
861
862    #[inline(always)]
863    pub fn update(&mut self, value: f64) -> Option<f64> {
864        self.buffer[self.head] = value;
865        self.head = (self.head + 1) % self.period;
866        if !self.filled && self.head == 0 {
867            self.filled = true;
868        }
869        if !self.filled {
870            return None;
871        }
872        Some(self.dot_ring())
873    }
874
875    #[inline(always)]
876    fn dot_ring(&self) -> f64 {
877        let mut y_sum = 0.0;
878        let mut xy_sum = 0.0;
879        for (i, &y) in
880            (1..=self.period).zip(self.buffer.iter().cycle().skip(self.head).take(self.period))
881        {
882            y_sum += y;
883            xy_sum += y * (i as f64);
884        }
885        let pf = self.period as f64;
886        let bd = 1.0 / (pf * self.x2_sum - self.x_sum * self.x_sum);
887        let b = (pf * xy_sum - self.x_sum * y_sum) * bd;
888        let a = (y_sum - b * self.x_sum) / pf;
889        a + b * pf
890    }
891}
892
893#[inline(always)]
894fn round_up8(x: usize) -> usize {
895    (x + 7) & !7
896}
897
898#[inline(always)]
899pub fn expand_grid_linreg(r: &LinRegBatchRange) -> Vec<LinRegParams> {
900    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
901        if step == 0 || start == end {
902            return vec![start];
903        }
904        let (lo, hi) = if start <= end {
905            (start, end)
906        } else {
907            (end, start)
908        };
909        let mut v = Vec::new();
910        let mut x = lo;
911        while x <= hi {
912            v.push(x);
913            match x.checked_add(step) {
914                Some(nx) => x = nx,
915                None => break,
916            }
917        }
918        v
919    }
920    let periods = axis_usize(r.period);
921    let mut out = Vec::with_capacity(periods.len());
922    for &p in &periods {
923        out.push(LinRegParams { period: Some(p) });
924    }
925    out
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use crate::skip_if_unsupported;
932    use crate::utilities::data_loader::read_candles_from_csv;
933
934    fn check_linreg_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
935        skip_if_unsupported!(kernel, test_name);
936        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
937        let candles = read_candles_from_csv(file_path)?;
938        let close_prices = candles.select_candle_field("close")?;
939        let params = LinRegParams { period: Some(14) };
940        let input = LinRegInput::from_candles(&candles, "close", params);
941        let linreg_result = linreg_with_kernel(&input, kernel)?;
942        let expected_last_five = [
943            58929.37142857143,
944            58899.42857142857,
945            58918.857142857145,
946            59100.6,
947            58987.94285714286,
948        ];
949        assert!(linreg_result.values.len() >= 5);
950        assert_eq!(linreg_result.values.len(), close_prices.len());
951        let start_index = linreg_result.values.len() - 5;
952        let result_last_five = &linreg_result.values[start_index..];
953        for (i, &value) in result_last_five.iter().enumerate() {
954            let expected_value = expected_last_five[i];
955            assert!(
956                (value - expected_value).abs() < 1e-1,
957                "Mismatch at index {}: expected {}, got {}",
958                i,
959                expected_value,
960                value
961            );
962        }
963        Ok(())
964    }
965
966    fn check_linreg_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
967        skip_if_unsupported!(kernel, test_name);
968        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
969        let candles = read_candles_from_csv(file_path)?;
970        let default_params = LinRegParams { period: None };
971        let input = LinRegInput::from_candles(&candles, "close", default_params);
972        let output = linreg_with_kernel(&input, kernel)?;
973        assert_eq!(output.values.len(), candles.close.len());
974        Ok(())
975    }
976
977    fn check_linreg_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
978        skip_if_unsupported!(kernel, test_name);
979        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
980        let candles = read_candles_from_csv(file_path)?;
981        let input = LinRegInput::with_default_candles(&candles);
982        match input.data {
983            LinRegData::Candles { source, .. } => assert_eq!(source, "close"),
984            _ => panic!("Expected LinRegData::Candles"),
985        }
986        let output = linreg_with_kernel(&input, kernel)?;
987        assert_eq!(output.values.len(), candles.close.len());
988        Ok(())
989    }
990
991    #[test]
992    fn test_linreg_into_matches_api() -> Result<(), Box<dyn Error>> {
993        let mut data = Vec::with_capacity(5 + 256);
994        for _ in 0..5 {
995            data.push(f64::NAN);
996        }
997        for i in 0..256u32 {
998            let x = i as f64;
999            let v = (x * 0.137).sin() * 3.0 + x * 0.25;
1000            data.push(v);
1001        }
1002
1003        let params = LinRegParams { period: Some(14) };
1004        let input = LinRegInput::from_slice(&data, params);
1005
1006        let baseline = linreg(&input)?.values;
1007
1008        let mut out = vec![0.0; data.len()];
1009        linreg_into(&input, &mut out)?;
1010
1011        assert_eq!(out.len(), baseline.len());
1012        for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1013            if a.is_nan() || b.is_nan() {
1014                assert!(
1015                    a.is_nan() && b.is_nan(),
1016                    "NaN parity mismatch at index {}",
1017                    i
1018                );
1019            } else {
1020                let diff = (a - b).abs();
1021                assert!(
1022                    diff <= 1e-12,
1023                    "Value mismatch at index {}: {} vs {} (diff={})",
1024                    i,
1025                    a,
1026                    b,
1027                    diff
1028                );
1029            }
1030        }
1031        Ok(())
1032    }
1033
1034    fn check_linreg_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1035        skip_if_unsupported!(kernel, test_name);
1036        let input_data = [10.0, 20.0, 30.0];
1037        let params = LinRegParams { period: Some(0) };
1038        let input = LinRegInput::from_slice(&input_data, params);
1039        let res = linreg_with_kernel(&input, kernel);
1040        assert!(
1041            res.is_err(),
1042            "[{}] LINREG should fail with zero period",
1043            test_name
1044        );
1045        Ok(())
1046    }
1047
1048    fn check_linreg_period_exceeds_length(
1049        test_name: &str,
1050        kernel: Kernel,
1051    ) -> Result<(), Box<dyn Error>> {
1052        skip_if_unsupported!(kernel, test_name);
1053        let data_small = [10.0, 20.0, 30.0];
1054        let params = LinRegParams { period: Some(10) };
1055        let input = LinRegInput::from_slice(&data_small, params);
1056        let res = linreg_with_kernel(&input, kernel);
1057        assert!(
1058            res.is_err(),
1059            "[{}] LINREG should fail with period exceeding length",
1060            test_name
1061        );
1062        Ok(())
1063    }
1064
1065    fn check_linreg_very_small_dataset(
1066        test_name: &str,
1067        kernel: Kernel,
1068    ) -> Result<(), Box<dyn Error>> {
1069        skip_if_unsupported!(kernel, test_name);
1070        let single_point = [42.0];
1071        let params = LinRegParams { period: Some(14) };
1072        let input = LinRegInput::from_slice(&single_point, params);
1073        let res = linreg_with_kernel(&input, kernel);
1074        assert!(
1075            res.is_err(),
1076            "[{}] LINREG should fail with insufficient data",
1077            test_name
1078        );
1079        Ok(())
1080    }
1081
1082    fn check_linreg_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1083        skip_if_unsupported!(kernel, test_name);
1084        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1085        let candles = read_candles_from_csv(file_path)?;
1086        let first_params = LinRegParams { period: Some(14) };
1087        let first_input = LinRegInput::from_candles(&candles, "close", first_params);
1088        let first_result = linreg_with_kernel(&first_input, kernel)?;
1089        let second_params = LinRegParams { period: Some(10) };
1090        let second_input = LinRegInput::from_slice(&first_result.values, second_params);
1091        let second_result = linreg_with_kernel(&second_input, kernel)?;
1092        assert_eq!(second_result.values.len(), first_result.values.len());
1093        Ok(())
1094    }
1095
1096    fn check_linreg_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1097        skip_if_unsupported!(kernel, test_name);
1098        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1099        let candles = read_candles_from_csv(file_path)?;
1100        let input = LinRegInput::from_candles(&candles, "close", LinRegParams { period: Some(14) });
1101        let res = linreg_with_kernel(&input, kernel)?;
1102        assert_eq!(res.values.len(), candles.close.len());
1103        if res.values.len() > 240 {
1104            for (i, &val) in res.values[240..].iter().enumerate() {
1105                assert!(
1106                    !val.is_nan(),
1107                    "[{}] Found unexpected NaN at out-index {}",
1108                    test_name,
1109                    240 + i
1110                );
1111            }
1112        }
1113        Ok(())
1114    }
1115
1116    fn check_linreg_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1117        skip_if_unsupported!(kernel, test_name);
1118        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1119        let candles = read_candles_from_csv(file_path)?;
1120        let period = 14;
1121        let input = LinRegInput::from_candles(
1122            &candles,
1123            "close",
1124            LinRegParams {
1125                period: Some(period),
1126            },
1127        );
1128        let batch_output = linreg_with_kernel(&input, kernel)?.values;
1129        let mut stream = LinRegStream::try_new(LinRegParams {
1130            period: Some(period),
1131        })?;
1132        let mut stream_values = Vec::with_capacity(candles.close.len());
1133        for &price in &candles.close {
1134            match stream.update(price) {
1135                Some(val) => stream_values.push(val),
1136                None => stream_values.push(f64::NAN),
1137            }
1138        }
1139        assert_eq!(batch_output.len(), stream_values.len());
1140        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1141            if b.is_nan() && s.is_nan() {
1142                continue;
1143            }
1144            let diff = (b - s).abs();
1145            assert!(
1146                diff < 1e-6,
1147                "[{}] LINREG streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1148                test_name,
1149                i,
1150                b,
1151                s,
1152                diff
1153            );
1154        }
1155        Ok(())
1156    }
1157
1158    macro_rules! generate_all_linreg_tests {
1159        ($($test_fn:ident),*) => {
1160            paste::paste! {
1161                $(#[test] fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); })*
1162                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1163                $(
1164                    #[test] fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); }
1165                    #[test] fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); }
1166                )*
1167            }
1168        }
1169    }
1170
1171    #[cfg(debug_assertions)]
1172    fn check_linreg_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1173        skip_if_unsupported!(kernel, test_name);
1174
1175        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1176        let candles = read_candles_from_csv(file_path)?;
1177
1178        let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1179        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1180
1181        for period in &test_periods {
1182            for source in &test_sources {
1183                let input = LinRegInput::from_candles(
1184                    &candles,
1185                    source,
1186                    LinRegParams {
1187                        period: Some(*period),
1188                    },
1189                );
1190                let output = linreg_with_kernel(&input, kernel)?;
1191
1192                for (i, &val) in output.values.iter().enumerate() {
1193                    if val.is_nan() {
1194                        continue;
1195                    }
1196
1197                    let bits = val.to_bits();
1198
1199                    if bits == 0x11111111_11111111 {
1200                        panic!(
1201                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1202                            test_name, val, bits, i, period, source
1203                        );
1204                    }
1205
1206                    if bits == 0x22222222_22222222 {
1207                        panic!(
1208                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1209                            test_name, val, bits, i, period, source
1210                        );
1211                    }
1212
1213                    if bits == 0x33333333_33333333 {
1214                        panic!(
1215                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1216                            test_name, val, bits, i, period, source
1217                        );
1218                    }
1219                }
1220            }
1221        }
1222
1223        Ok(())
1224    }
1225
1226    #[cfg(not(debug_assertions))]
1227    fn check_linreg_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1228        Ok(())
1229    }
1230
1231    #[cfg(feature = "proptest")]
1232    fn check_linreg_property(
1233        test_name: &str,
1234        kernel: Kernel,
1235    ) -> Result<(), Box<dyn std::error::Error>> {
1236        use proptest::prelude::*;
1237        skip_if_unsupported!(kernel, test_name);
1238
1239        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240        let candles = read_candles_from_csv(file_path)?;
1241        let close_data = &candles.close;
1242
1243        let strat = (
1244            2usize..=50,
1245            0usize..close_data.len().saturating_sub(200),
1246            100usize..=200,
1247        );
1248
1249        proptest::test_runner::TestRunner::default()
1250            .run(&strat, |(period, start_idx, slice_len)| {
1251                let end_idx = (start_idx + slice_len).min(close_data.len());
1252                if end_idx <= start_idx || end_idx - start_idx < period + 10 {
1253                    return Ok(());
1254                }
1255
1256                let data_slice = &close_data[start_idx..end_idx];
1257                let params = LinRegParams {
1258                    period: Some(period),
1259                };
1260                let input = LinRegInput::from_slice(data_slice, params.clone());
1261
1262                let result = linreg_with_kernel(&input, kernel);
1263
1264                let scalar_result = linreg_with_kernel(&input, Kernel::Scalar);
1265
1266                match (result, scalar_result) {
1267                    (Ok(LinRegOutput { values: out }), Ok(LinRegOutput { values: ref_out })) => {
1268                        prop_assert_eq!(out.len(), data_slice.len());
1269                        prop_assert_eq!(ref_out.len(), data_slice.len());
1270
1271                        let first = data_slice.iter().position(|x| !x.is_nan()).unwrap_or(0);
1272                        let expected_warmup = first + period;
1273
1274                        let first_valid = out.iter().position(|x| !x.is_nan());
1275                        if let Some(first_idx) = first_valid {
1276                            prop_assert_eq!(
1277                                first_idx,
1278                                expected_warmup,
1279                                "First valid at {} but expected warmup is {}",
1280                                first_idx,
1281                                expected_warmup
1282                            );
1283
1284                            for i in 0..first_idx {
1285                                prop_assert!(
1286                                    out[i].is_nan(),
1287                                    "Expected NaN at index {} during warmup, got {}",
1288                                    i,
1289                                    out[i]
1290                                );
1291                            }
1292                        }
1293
1294                        for i in 0..out.len() {
1295                            let y = out[i];
1296                            let r = ref_out[i];
1297
1298                            if y.is_nan() {
1299                                prop_assert!(
1300                                    r.is_nan(),
1301                                    "Kernel mismatch at {}: {} vs {}",
1302                                    i,
1303                                    y,
1304                                    r
1305                                );
1306                                continue;
1307                            }
1308
1309                            prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
1310
1311                            let ulps_diff = if y == r {
1312                                0
1313                            } else {
1314                                let y_bits = y.to_bits();
1315                                let r_bits = r.to_bits();
1316                                ((y_bits as i64) - (r_bits as i64)).unsigned_abs()
1317                            };
1318
1319                            prop_assert!(
1320                                ulps_diff <= 3 || (y - r).abs() < 1e-9,
1321                                "Kernel mismatch at {}: {} vs {} (diff: {}, ulps: {})",
1322                                i,
1323                                y,
1324                                r,
1325                                (y - r).abs(),
1326                                ulps_diff
1327                            );
1328                        }
1329
1330                        if first_valid.is_some() {
1331                            let mut linear_data = vec![0.0; period + 5];
1332                            for i in 0..linear_data.len() {
1333                                linear_data[i] = 100.0 + i as f64 * 2.0;
1334                            }
1335                            let linear_input =
1336                                LinRegInput::from_slice(&linear_data, params.clone());
1337                            if let Ok(LinRegOutput { values: linear_out }) =
1338                                linreg_with_kernel(&linear_input, kernel)
1339                            {
1340                                for i in period..linear_data.len() {
1341                                    if !linear_out[i].is_nan() {
1342                                        let expected = 100.0 + (i + 1) as f64 * 2.0;
1343                                        prop_assert!(
1344                                            (linear_out[i] - expected).abs() < 1e-6,
1345                                            "Linear prediction error at {}: got {} expected {}",
1346                                            i,
1347                                            linear_out[i],
1348                                            expected
1349                                        );
1350                                    }
1351                                }
1352                            }
1353
1354                            let constant_val = 42.0;
1355                            let constant_data = vec![constant_val; period + 5];
1356                            let const_input = LinRegInput::from_slice(&constant_data, params);
1357                            if let Ok(LinRegOutput { values: const_out }) =
1358                                linreg_with_kernel(&const_input, kernel)
1359                            {
1360                                for i in period..constant_data.len() {
1361                                    if !const_out[i].is_nan() {
1362                                        prop_assert!(
1363                                            (const_out[i] - constant_val).abs() < 1e-9,
1364                                            "Constant prediction error at {}: got {} expected {}",
1365                                            i,
1366                                            const_out[i],
1367                                            constant_val
1368                                        );
1369                                    }
1370                                }
1371                            }
1372
1373                            for i in expected_warmup..out.len() {
1374                                if !out[i].is_nan() {
1375                                    let window_start = i + 1 - period;
1376                                    let window_end = i + 1;
1377                                    let window = &data_slice[window_start..window_end];
1378
1379                                    let min_val =
1380                                        window.iter().copied().fold(f64::INFINITY, f64::min);
1381                                    let max_val =
1382                                        window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1383
1384                                    let range = max_val - min_val;
1385                                    let lower_bound = min_val - range;
1386                                    let upper_bound = max_val + range;
1387
1388                                    prop_assert!(
1389                                        out[i] >= lower_bound && out[i] <= upper_bound,
1390                                        "Output {} at index {} outside reasonable bounds [{}, {}]",
1391                                        out[i],
1392                                        i,
1393                                        lower_bound,
1394                                        upper_bound
1395                                    );
1396                                }
1397                            }
1398                        }
1399
1400                        Ok(())
1401                    }
1402                    (Err(e1), Err(e2)) => {
1403                        prop_assert_eq!(
1404                            std::mem::discriminant(&e1),
1405                            std::mem::discriminant(&e2),
1406                            "Different error types: {:?} vs {:?}",
1407                            e1,
1408                            e2
1409                        );
1410                        Ok(())
1411                    }
1412                    _ => {
1413                        prop_assert!(
1414                            false,
1415                            "Kernel consistency failed - one succeeded, one failed"
1416                        );
1417                        Ok(())
1418                    }
1419                }
1420            })
1421            .map_err(|e| e.into())
1422    }
1423
1424    generate_all_linreg_tests!(
1425        check_linreg_accuracy,
1426        check_linreg_partial_params,
1427        check_linreg_default_candles,
1428        check_linreg_zero_period,
1429        check_linreg_period_exceeds_length,
1430        check_linreg_very_small_dataset,
1431        check_linreg_reinput,
1432        check_linreg_nan_handling,
1433        check_linreg_streaming,
1434        check_linreg_no_poison
1435    );
1436
1437    #[cfg(feature = "proptest")]
1438    generate_all_linreg_tests!(check_linreg_property);
1439
1440    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441        skip_if_unsupported!(kernel, test);
1442        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443        let c = read_candles_from_csv(file)?;
1444        let output = LinRegBatchBuilder::new()
1445            .kernel(kernel)
1446            .apply_candles(&c, "close")?;
1447        let def = LinRegParams::default();
1448        let row = output.values_for(&def).expect("default row missing");
1449        assert_eq!(row.len(), c.close.len());
1450        Ok(())
1451    }
1452
1453    macro_rules! gen_batch_tests {
1454        ($fn_name:ident) => {
1455            paste::paste! {
1456                #[test] fn [<$fn_name _scalar>]()      {
1457                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1458                }
1459                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1460                #[test] fn [<$fn_name _avx2>]()        {
1461                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1462                }
1463                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1464                #[test] fn [<$fn_name _avx512>]()      {
1465                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1466                }
1467                #[test] fn [<$fn_name _auto_detect>]() {
1468                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1469                }
1470            }
1471        };
1472    }
1473
1474    #[cfg(debug_assertions)]
1475    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1476        skip_if_unsupported!(kernel, test);
1477
1478        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479        let c = read_candles_from_csv(file)?;
1480
1481        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1482
1483        for source in &test_sources {
1484            let output = LinRegBatchBuilder::new()
1485                .kernel(kernel)
1486                .period_range(2, 200, 3)
1487                .apply_candles(&c, source)?;
1488
1489            for (idx, &val) in output.values.iter().enumerate() {
1490                if val.is_nan() {
1491                    continue;
1492                }
1493
1494                let bits = val.to_bits();
1495                let row = idx / output.cols;
1496                let col = idx % output.cols;
1497
1498                if bits == 0x11111111_11111111 {
1499                    panic!(
1500                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1501                        test, val, bits, row, col, idx, source
1502                    );
1503                }
1504
1505                if bits == 0x22222222_22222222 {
1506                    panic!(
1507                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1508                        test, val, bits, row, col, idx, source
1509                    );
1510                }
1511
1512                if bits == 0x33333333_33333333 {
1513                    panic!(
1514                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1515                        test, val, bits, row, col, idx, source
1516                    );
1517                }
1518            }
1519        }
1520
1521        let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1522        for (start, end, step) in edge_case_ranges {
1523            let output = LinRegBatchBuilder::new()
1524                .kernel(kernel)
1525                .period_range(start, end, step)
1526                .apply_candles(&c, "close")?;
1527
1528            for (idx, &val) in output.values.iter().enumerate() {
1529                if val.is_nan() {
1530                    continue;
1531                }
1532
1533                let bits = val.to_bits();
1534                let row = idx / output.cols;
1535                let col = idx % output.cols;
1536
1537                if bits == 0x11111111_11111111
1538                    || bits == 0x22222222_22222222
1539                    || bits == 0x33333333_33333333
1540                {
1541                    panic!(
1542						"[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1543						test, val, bits, row, col, start, end, step
1544					);
1545                }
1546            }
1547        }
1548
1549        Ok(())
1550    }
1551
1552    #[cfg(not(debug_assertions))]
1553    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1554        Ok(())
1555    }
1556
1557    gen_batch_tests!(check_batch_default_row);
1558    gen_batch_tests!(check_batch_no_poison);
1559}
1560
1561#[cfg(feature = "python")]
1562use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
1563#[cfg(feature = "python")]
1564use pyo3::exceptions::PyValueError;
1565#[cfg(feature = "python")]
1566use pyo3::prelude::*;
1567
1568#[cfg(feature = "python")]
1569use numpy::IntoPyArray;
1570#[cfg(feature = "python")]
1571use pyo3::types::PyDict;
1572
1573#[cfg(feature = "python")]
1574#[pyfunction]
1575#[pyo3(name = "linreg", signature = (data, period, kernel=None))]
1576pub fn linreg_py<'py>(
1577    py: Python<'py>,
1578    data: PyReadonlyArray1<'py, f64>,
1579    period: usize,
1580    kernel: Option<&str>,
1581) -> PyResult<Bound<'py, PyArray1<f64>>> {
1582    use numpy::{IntoPyArray, PyArrayMethods};
1583
1584    let slice_in = data.as_slice()?;
1585    let kern = validate_kernel(kernel, false)?;
1586    let params = LinRegParams {
1587        period: Some(period),
1588    };
1589    let input = LinRegInput::from_slice(slice_in, params);
1590
1591    let result_vec: Vec<f64> = py
1592        .allow_threads(|| linreg_with_kernel(&input, kern).map(|o| o.values))
1593        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1594
1595    Ok(result_vec.into_pyarray(py))
1596}
1597
1598#[cfg(feature = "python")]
1599#[pyfunction]
1600#[pyo3(name = "linreg_batch", signature = (data, period_range, kernel=None))]
1601pub fn linreg_batch_py<'py>(
1602    py: Python<'py>,
1603    data: PyReadonlyArray1<'py, f64>,
1604    period_range: (usize, usize, usize),
1605    kernel: Option<&str>,
1606) -> PyResult<Bound<'py, PyDict>> {
1607    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1608    use pyo3::types::PyDict;
1609
1610    let slice_in = data.as_slice()?;
1611    let kern = validate_kernel(kernel, true)?;
1612    let sweep = LinRegBatchRange {
1613        period: period_range,
1614    };
1615
1616    let combos = expand_grid_linreg(&sweep);
1617    let rows = combos.len();
1618    let cols = slice_in.len();
1619
1620    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1621    let slice_out = unsafe { out_arr.as_slice_mut()? };
1622
1623    let combos = py
1624        .allow_threads(|| {
1625            let kernel = match kern {
1626                Kernel::Auto => detect_best_batch_kernel(),
1627                k => k,
1628            };
1629            let simd = match kernel {
1630                Kernel::Avx512Batch => Kernel::Avx512,
1631                Kernel::Avx2Batch => Kernel::Avx2,
1632                Kernel::ScalarBatch => Kernel::Scalar,
1633                _ => kernel,
1634            };
1635
1636            linreg_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1637        })
1638        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1639
1640    let dict = PyDict::new(py);
1641    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1642    dict.set_item(
1643        "periods",
1644        combos
1645            .iter()
1646            .map(|p| p.period.unwrap() as u64)
1647            .collect::<Vec<_>>()
1648            .into_pyarray(py),
1649    )?;
1650
1651    Ok(dict)
1652}
1653
1654#[cfg(all(feature = "python", feature = "cuda"))]
1655#[pyfunction(name = "linreg_cuda_batch_dev")]
1656#[pyo3(signature = (data_f32, period_range, device_id=0))]
1657pub fn linreg_cuda_batch_dev_py<'py>(
1658    py: Python<'py>,
1659    data_f32: PyReadonlyArray1<'py, f32>,
1660    period_range: (usize, usize, usize),
1661    device_id: usize,
1662) -> PyResult<(DeviceArrayF32LinregPy, Bound<'py, PyDict>)> {
1663    use crate::cuda::cuda_available;
1664    use numpy::IntoPyArray;
1665    use pyo3::types::PyDict;
1666
1667    if !cuda_available() {
1668        return Err(PyValueError::new_err("CUDA not available"));
1669    }
1670
1671    let slice_in = data_f32.as_slice()?;
1672    let sweep = LinRegBatchRange {
1673        period: period_range,
1674    };
1675
1676    let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
1677        let cuda = CudaLinreg::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1678        let ctx = cuda.ctx();
1679        let dev_id = cuda.device_id();
1680        let (dev_arr, cmb) = cuda
1681            .linreg_batch_dev(slice_in, &sweep)
1682            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1683        Ok::<_, PyErr>((dev_arr, cmb, ctx, dev_id))
1684    })?;
1685
1686    let dict = PyDict::new(py);
1687    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1688    dict.set_item("periods", periods.into_pyarray(py))?;
1689    Ok((DeviceArrayF32LinregPy::new(inner, ctx, dev_id), dict))
1690}
1691
1692#[cfg(all(feature = "python", feature = "cuda"))]
1693#[pyfunction(name = "linreg_cuda_many_series_one_param_dev")]
1694#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1695pub fn linreg_cuda_many_series_one_param_dev_py(
1696    py: Python<'_>,
1697    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1698    period: usize,
1699    device_id: usize,
1700) -> PyResult<DeviceArrayF32LinregPy> {
1701    use crate::cuda::cuda_available;
1702    use numpy::PyUntypedArrayMethods;
1703
1704    if !cuda_available() {
1705        return Err(PyValueError::new_err("CUDA not available"));
1706    }
1707
1708    let flat_in = data_tm_f32.as_slice()?;
1709    let rows = data_tm_f32.shape()[0];
1710    let cols = data_tm_f32.shape()[1];
1711    let params = LinRegParams {
1712        period: Some(period),
1713    };
1714
1715    let (inner, ctx, dev_id) = py.allow_threads(|| {
1716        let cuda = CudaLinreg::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1717        let ctx = cuda.ctx();
1718        let dev_id = cuda.device_id();
1719        let arr = cuda
1720            .linreg_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1721            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1722        Ok::<_, PyErr>((arr, ctx, dev_id))
1723    })?;
1724
1725    Ok(DeviceArrayF32LinregPy::new(inner, ctx, dev_id))
1726}
1727
1728#[cfg(all(feature = "python", feature = "cuda"))]
1729#[pyclass(
1730    module = "ta_indicators.cuda",
1731    name = "DeviceArrayF32Linreg",
1732    unsendable
1733)]
1734pub struct DeviceArrayF32LinregPy {
1735    pub(crate) inner: DeviceArrayF32,
1736    _ctx_guard: Arc<Context>,
1737    _device_id: u32,
1738}
1739
1740#[cfg(all(feature = "python", feature = "cuda"))]
1741#[pymethods]
1742impl DeviceArrayF32LinregPy {
1743    #[new]
1744    fn py_new() -> PyResult<Self> {
1745        Err(pyo3::exceptions::PyTypeError::new_err(
1746            "use factory methods from CUDA functions",
1747        ))
1748    }
1749
1750    #[getter]
1751    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1752        let d = PyDict::new(py);
1753        let itemsize = std::mem::size_of::<f32>();
1754        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1755        d.set_item("typestr", "<f4")?;
1756        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1757        let size = self.inner.rows.saturating_mul(self.inner.cols);
1758        let ptr_val: usize = if size == 0 {
1759            0
1760        } else {
1761            self.inner.buf.as_device_ptr().as_raw() as usize
1762        };
1763        d.set_item("data", (ptr_val, false))?;
1764        d.set_item("version", 3)?;
1765        Ok(d)
1766    }
1767
1768    fn __dlpack_device__(&self) -> (i32, i32) {
1769        (2, self._device_id as i32)
1770    }
1771
1772    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1773    fn __dlpack__<'py>(
1774        &mut self,
1775        py: Python<'py>,
1776        stream: Option<PyObject>,
1777        max_version: Option<PyObject>,
1778        dl_device: Option<PyObject>,
1779        copy: Option<PyObject>,
1780    ) -> PyResult<PyObject> {
1781        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1782
1783        let (kdl, alloc_dev) = self.__dlpack_device__();
1784        if let Some(dev_obj) = dl_device.as_ref() {
1785            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1786                if dev_ty != kdl || dev_id != alloc_dev {
1787                    let wants_copy = copy
1788                        .as_ref()
1789                        .and_then(|c| c.extract::<bool>(py).ok())
1790                        .unwrap_or(false);
1791                    if wants_copy {
1792                        return Err(PyValueError::new_err(
1793                            "device copy not implemented for __dlpack__",
1794                        ));
1795                    } else {
1796                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1797                    }
1798                }
1799            }
1800        }
1801        let _ = stream;
1802
1803        let dummy =
1804            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1805        let inner = std::mem::replace(
1806            &mut self.inner,
1807            DeviceArrayF32 {
1808                buf: dummy,
1809                rows: 0,
1810                cols: 0,
1811            },
1812        );
1813
1814        let rows = inner.rows;
1815        let cols = inner.cols;
1816        let buf = inner.buf;
1817
1818        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1819
1820        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1821    }
1822}
1823
1824#[cfg(all(feature = "python", feature = "cuda"))]
1825impl DeviceArrayF32LinregPy {
1826    pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1827        Self {
1828            inner,
1829            _ctx_guard: ctx_guard,
1830            _device_id: device_id,
1831        }
1832    }
1833}
1834
1835#[cfg(feature = "python")]
1836#[pyclass(name = "LinRegStream")]
1837pub struct LinRegStreamPy {
1838    inner: LinRegStream,
1839}
1840
1841#[cfg(feature = "python")]
1842#[pymethods]
1843impl LinRegStreamPy {
1844    #[new]
1845    pub fn new(period: usize) -> PyResult<Self> {
1846        let params = LinRegParams {
1847            period: Some(period),
1848        };
1849        match LinRegStream::try_new(params) {
1850            Ok(stream) => Ok(Self { inner: stream }),
1851            Err(e) => Err(PyValueError::new_err(format!("LinRegStream error: {}", e))),
1852        }
1853    }
1854
1855    pub fn update(&mut self, value: f64) -> Option<f64> {
1856        self.inner.update(value)
1857    }
1858}
1859
1860#[inline]
1861pub fn linreg_into_slice(
1862    dst: &mut [f64],
1863    input: &LinRegInput,
1864    kern: Kernel,
1865) -> Result<(), LinRegError> {
1866    let data: &[f64] = input.as_ref();
1867
1868    if dst.len() != data.len() {
1869        return Err(LinRegError::OutputLengthMismatch {
1870            expected: data.len(),
1871            got: dst.len(),
1872        });
1873    }
1874
1875    linreg_compute_into(input, kern, dst)
1876}
1877
1878#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1879use wasm_bindgen::prelude::*;
1880
1881#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1882#[wasm_bindgen]
1883pub fn linreg_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1884    let params = LinRegParams {
1885        period: Some(period),
1886    };
1887    let input = LinRegInput::from_slice(data, params);
1888
1889    let mut output = vec![0.0; data.len()];
1890
1891    linreg_into_slice(&mut output, &input, Kernel::Scalar)
1892        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1893
1894    Ok(output)
1895}
1896
1897#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1898#[derive(Serialize, Deserialize)]
1899pub struct LinRegBatchConfig {
1900    pub period_range: (usize, usize, usize),
1901}
1902
1903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1904#[wasm_bindgen(js_name = linreg_batch)]
1905pub fn linreg_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1906    let config: LinRegBatchConfig = serde_wasm_bindgen::from_value(config)
1907        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1908
1909    let sweep = LinRegBatchRange {
1910        period: config.period_range,
1911    };
1912
1913    let output = linreg_batch_slice(data, &sweep, Kernel::Scalar)
1914        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1915
1916    serde_wasm_bindgen::to_value(&output)
1917        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen]
1922pub fn linreg_alloc(len: usize) -> *mut f64 {
1923    let mut vec = Vec::<f64>::with_capacity(len);
1924    let ptr = vec.as_mut_ptr();
1925    std::mem::forget(vec);
1926    ptr
1927}
1928
1929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1930#[wasm_bindgen]
1931pub fn linreg_free(ptr: *mut f64, len: usize) {
1932    if !ptr.is_null() {
1933        unsafe {
1934            let _ = Vec::from_raw_parts(ptr, len, len);
1935        }
1936    }
1937}
1938
1939#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1940#[wasm_bindgen]
1941pub fn linreg_into(
1942    in_ptr: *const f64,
1943    out_ptr: *mut f64,
1944    len: usize,
1945    period: usize,
1946) -> Result<(), JsValue> {
1947    if in_ptr.is_null() || out_ptr.is_null() {
1948        return Err(JsValue::from_str("null pointer passed to linreg_into"));
1949    }
1950
1951    unsafe {
1952        let data = std::slice::from_raw_parts(in_ptr, len);
1953
1954        if period == 0 || period > len {
1955            return Err(JsValue::from_str("Invalid period"));
1956        }
1957
1958        let params = LinRegParams {
1959            period: Some(period),
1960        };
1961        let input = LinRegInput::from_slice(data, params);
1962
1963        if in_ptr == out_ptr {
1964            let mut temp = vec![0.0; len];
1965            linreg_into_slice(&mut temp, &input, Kernel::Scalar)
1966                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1967
1968            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1969            out.copy_from_slice(&temp);
1970        } else {
1971            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1972            linreg_into_slice(out, &input, Kernel::Scalar)
1973                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1974        }
1975
1976        Ok(())
1977    }
1978}
1979
1980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1981#[wasm_bindgen]
1982pub fn linreg_batch_into(
1983    in_ptr: *const f64,
1984    out_ptr: *mut f64,
1985    len: usize,
1986    period_start: usize,
1987    period_end: usize,
1988    period_step: usize,
1989) -> Result<usize, JsValue> {
1990    if in_ptr.is_null() || out_ptr.is_null() {
1991        return Err(JsValue::from_str(
1992            "null pointer passed to linreg_batch_into",
1993        ));
1994    }
1995
1996    unsafe {
1997        let data = std::slice::from_raw_parts(in_ptr, len);
1998
1999        let sweep = LinRegBatchRange {
2000            period: (period_start, period_end, period_step),
2001        };
2002
2003        let combos = expand_grid_linreg(&sweep);
2004        let rows = combos.len();
2005        let cols = len;
2006
2007        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2008
2009        linreg_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
2010            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2011
2012        Ok(rows)
2013    }
2014}