Skip to main content

vector_ta/indicators/
linearreg_intercept.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(feature = "python")]
13use pyo3::exceptions::PyValueError;
14#[cfg(feature = "python")]
15use pyo3::prelude::*;
16#[cfg(not(target_arch = "wasm32"))]
17use rayon::prelude::*;
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use serde::{Deserialize, Serialize};
20use std::convert::AsRef;
21use std::error::Error;
22use thiserror::Error;
23#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
24use wasm_bindgen::prelude::*;
25
26impl<'a> AsRef<[f64]> for LinearRegInterceptInput<'a> {
27    #[inline(always)]
28    fn as_ref(&self) -> &[f64] {
29        match &self.data {
30            LinearRegInterceptData::Slice(slice) => slice,
31            LinearRegInterceptData::Candles { candles, source } => source_type(candles, source),
32        }
33    }
34}
35
36#[derive(Debug, Clone)]
37pub enum LinearRegInterceptData<'a> {
38    Candles {
39        candles: &'a Candles,
40        source: &'a str,
41    },
42    Slice(&'a [f64]),
43}
44
45#[derive(Debug, Clone)]
46pub struct LinearRegInterceptOutput {
47    pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52    all(target_arch = "wasm32", feature = "wasm"),
53    derive(serde::Serialize, serde::Deserialize)
54)]
55pub struct LinearRegInterceptParams {
56    pub period: Option<usize>,
57}
58
59impl Default for LinearRegInterceptParams {
60    fn default() -> Self {
61        Self { period: Some(14) }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct LinearRegInterceptInput<'a> {
67    pub data: LinearRegInterceptData<'a>,
68    pub params: LinearRegInterceptParams,
69}
70
71impl<'a> LinearRegInterceptInput<'a> {
72    #[inline]
73    pub fn from_candles(c: &'a Candles, s: &'a str, p: LinearRegInterceptParams) -> Self {
74        Self {
75            data: LinearRegInterceptData::Candles {
76                candles: c,
77                source: s,
78            },
79            params: p,
80        }
81    }
82    #[inline]
83    pub fn from_slice(sl: &'a [f64], p: LinearRegInterceptParams) -> Self {
84        Self {
85            data: LinearRegInterceptData::Slice(sl),
86            params: p,
87        }
88    }
89    #[inline]
90    pub fn with_default_candles(c: &'a Candles) -> Self {
91        Self::from_candles(c, "close", LinearRegInterceptParams::default())
92    }
93    #[inline]
94    pub fn get_period(&self) -> usize {
95        self.params.period.unwrap_or(14)
96    }
97}
98
99#[derive(Copy, Clone, Debug)]
100pub struct LinearRegInterceptBuilder {
101    period: Option<usize>,
102    kernel: Kernel,
103}
104
105impl Default for LinearRegInterceptBuilder {
106    fn default() -> Self {
107        Self {
108            period: None,
109            kernel: Kernel::Auto,
110        }
111    }
112}
113
114impl LinearRegInterceptBuilder {
115    #[inline(always)]
116    pub fn new() -> Self {
117        Self::default()
118    }
119    #[inline(always)]
120    pub fn period(mut self, n: usize) -> Self {
121        self.period = Some(n);
122        self
123    }
124    #[inline(always)]
125    pub fn kernel(mut self, k: Kernel) -> Self {
126        self.kernel = k;
127        self
128    }
129    #[inline(always)]
130    pub fn apply(self, c: &Candles) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
131        let p = LinearRegInterceptParams {
132            period: self.period,
133        };
134        let i = LinearRegInterceptInput::from_candles(c, "close", p);
135        linearreg_intercept_with_kernel(&i, self.kernel)
136    }
137    #[inline(always)]
138    pub fn apply_slice(
139        self,
140        d: &[f64],
141    ) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
142        let p = LinearRegInterceptParams {
143            period: self.period,
144        };
145        let i = LinearRegInterceptInput::from_slice(d, p);
146        linearreg_intercept_with_kernel(&i, self.kernel)
147    }
148    #[inline(always)]
149    pub fn into_stream(self) -> Result<LinearRegInterceptStream, LinearRegInterceptError> {
150        let p = LinearRegInterceptParams {
151            period: self.period,
152        };
153        LinearRegInterceptStream::try_new(p)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum LinearRegInterceptError {
159    #[error("linearreg_intercept: Input data slice is empty.")]
160    EmptyInputData,
161    #[error("linearreg_intercept: All values are NaN.")]
162    AllValuesNaN,
163    #[error("linearreg_intercept: Invalid period: period = {period}, data length = {data_len}")]
164    InvalidPeriod { period: usize, data_len: usize },
165    #[error("linearreg_intercept: Not enough valid data: needed = {needed}, valid = {valid}")]
166    NotEnoughValidData { needed: usize, valid: usize },
167    #[error("linearreg_intercept: Output length mismatch: expected {expected}, got {got}")]
168    OutputLengthMismatch { expected: usize, got: usize },
169    #[error("linearreg_intercept: Invalid range: start={start}, end={end}, step={step}")]
170    InvalidRange {
171        start: String,
172        end: String,
173        step: String,
174    },
175    #[error("linearreg_intercept: Invalid kernel for batch: {0:?}")]
176    InvalidKernelForBatch(Kernel),
177}
178
179#[inline]
180pub fn linearreg_intercept(
181    input: &LinearRegInterceptInput,
182) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
183    linearreg_intercept_with_kernel(input, Kernel::Auto)
184}
185
186pub fn linearreg_intercept_with_kernel(
187    input: &LinearRegInterceptInput,
188    kernel: Kernel,
189) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
190    let data: &[f64] = input.as_ref();
191
192    if data.is_empty() {
193        return Err(LinearRegInterceptError::EmptyInputData);
194    }
195
196    let first = data
197        .iter()
198        .position(|x| !x.is_nan())
199        .ok_or(LinearRegInterceptError::AllValuesNaN)?;
200    let len = data.len();
201    let period = input.get_period();
202
203    if period == 0 || period > len {
204        return Err(LinearRegInterceptError::InvalidPeriod {
205            period,
206            data_len: len,
207        });
208    }
209    if (len - first) < period {
210        return Err(LinearRegInterceptError::NotEnoughValidData {
211            needed: period,
212            valid: len - first,
213        });
214    }
215
216    let mut out = alloc_with_nan_prefix(len, first + period - 1);
217
218    let chosen = match kernel {
219        Kernel::Auto => Kernel::Scalar,
220        other => other,
221    };
222
223    unsafe {
224        match chosen {
225            Kernel::Scalar | Kernel::ScalarBatch => {
226                linearreg_intercept_scalar(data, period, first, &mut out)
227            }
228            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
229            Kernel::Avx2 | Kernel::Avx2Batch => {
230                linearreg_intercept_avx2(data, period, first, &mut out)
231            }
232            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233            Kernel::Avx512 | Kernel::Avx512Batch => {
234                linearreg_intercept_avx512(data, period, first, &mut out)
235            }
236            _ => unreachable!(),
237        }
238    }
239
240    Ok(LinearRegInterceptOutput { values: out })
241}
242
243#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
244#[inline]
245pub fn linearreg_intercept_into(
246    input: &LinearRegInterceptInput,
247    dst: &mut [f64],
248) -> Result<(), LinearRegInterceptError> {
249    let data: &[f64] = input.as_ref();
250
251    if data.is_empty() {
252        return Err(LinearRegInterceptError::EmptyInputData);
253    }
254
255    let first = data
256        .iter()
257        .position(|x| !x.is_nan())
258        .ok_or(LinearRegInterceptError::AllValuesNaN)?;
259    let len = data.len();
260    let period = input.get_period();
261
262    if period == 0 || period > len {
263        return Err(LinearRegInterceptError::InvalidPeriod {
264            period,
265            data_len: len,
266        });
267    }
268    if (len - first) < period {
269        return Err(LinearRegInterceptError::NotEnoughValidData {
270            needed: period,
271            valid: len - first,
272        });
273    }
274
275    if dst.len() != data.len() {
276        return Err(LinearRegInterceptError::OutputLengthMismatch {
277            expected: data.len(),
278            got: dst.len(),
279        });
280    }
281
282    let warmup_end = first + period - 1;
283    for v in &mut dst[..warmup_end] {
284        *v = f64::from_bits(0x7ff8_0000_0000_0000);
285    }
286
287    let chosen = detect_best_kernel();
288
289    unsafe {
290        match chosen {
291            Kernel::Scalar | Kernel::ScalarBatch => {
292                linearreg_intercept_scalar(data, period, first, dst)
293            }
294            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295            Kernel::Avx2 | Kernel::Avx2Batch => linearreg_intercept_avx2(data, period, first, dst),
296            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297            Kernel::Avx512 | Kernel::Avx512Batch => {
298                linearreg_intercept_avx512(data, period, first, dst)
299            }
300            _ => unreachable!(),
301        }
302    }
303
304    Ok(())
305}
306
307#[inline]
308pub fn linearreg_intercept_into_slice(
309    dst: &mut [f64],
310    input: &LinearRegInterceptInput,
311    kern: Kernel,
312) -> Result<(), LinearRegInterceptError> {
313    let data: &[f64] = input.as_ref();
314
315    if data.is_empty() {
316        return Err(LinearRegInterceptError::EmptyInputData);
317    }
318
319    let first = data
320        .iter()
321        .position(|x| !x.is_nan())
322        .ok_or(LinearRegInterceptError::AllValuesNaN)?;
323    let len = data.len();
324    let period = input.get_period();
325
326    if period == 0 || period > len {
327        return Err(LinearRegInterceptError::InvalidPeriod {
328            period,
329            data_len: len,
330        });
331    }
332    if (len - first) < period {
333        return Err(LinearRegInterceptError::NotEnoughValidData {
334            needed: period,
335            valid: len - first,
336        });
337    }
338
339    if dst.len() != data.len() {
340        return Err(LinearRegInterceptError::OutputLengthMismatch {
341            expected: data.len(),
342            got: dst.len(),
343        });
344    }
345
346    let chosen = match kern {
347        Kernel::Auto => Kernel::Scalar,
348        other => other,
349    };
350
351    unsafe {
352        match chosen {
353            Kernel::Scalar | Kernel::ScalarBatch => {
354                linearreg_intercept_scalar(data, period, first, dst)
355            }
356            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357            Kernel::Avx2 | Kernel::Avx2Batch => linearreg_intercept_avx2(data, period, first, dst),
358            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359            Kernel::Avx512 | Kernel::Avx512Batch => {
360                linearreg_intercept_avx512(data, period, first, dst)
361            }
362            _ => unreachable!(),
363        }
364    }
365
366    let warmup_end = first + period - 1;
367    for v in &mut dst[..warmup_end] {
368        *v = f64::NAN;
369    }
370
371    Ok(())
372}
373
374#[inline]
375pub fn linearreg_intercept_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
376    if period == 1 {
377        for i in first_val..data.len() {
378            out[i] = data[i];
379        }
380        return;
381    }
382
383    let n = period as f64;
384    let inv_n = 1.0 / n;
385
386    let sum_x = 0.5_f64 * n * (n + 1.0);
387    let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
388    let denom = n * sum_x2 - sum_x * sum_x;
389    let bd = 1.0 / denom;
390    let k = 1.0 - sum_x * inv_n;
391
392    let start = first_val;
393    let end = data.len();
394    if end == 0 || end < start + period {
395        return;
396    }
397
398    let mut sum_y = 0.0f64;
399    let mut sum_xy = 0.0f64;
400    for j in 0..period {
401        let y = data[start + j];
402        let x = (j as f64) + 1.0;
403        sum_y += y;
404        sum_xy += y * x;
405    }
406
407    let mut i = start + period - 1;
408    out[i] = ((n * sum_xy - sum_x * sum_y) * bd) * k + sum_y * inv_n;
409
410    while i + 1 < end {
411        let y_in = data[i + 1];
412        let y_out = data[i + 1 - period];
413
414        let prev_sum_y = sum_y;
415        sum_y = prev_sum_y + y_in - y_out;
416        sum_xy = (sum_xy - prev_sum_y) + n * y_in;
417
418        i += 1;
419        out[i] = ((n * sum_xy - sum_x * sum_y) * bd) * k + sum_y * inv_n;
420    }
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn linearreg_intercept_avx512(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
426    if period <= 32 {
427        unsafe { linearreg_intercept_avx512_short(data, period, first_val, out) }
428    } else {
429        unsafe { linearreg_intercept_avx512_long(data, period, first_val, out) }
430    }
431}
432
433#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
434#[inline]
435pub unsafe fn linearreg_intercept_avx2(
436    data: &[f64],
437    period: usize,
438    first_val: usize,
439    out: &mut [f64],
440) {
441    if period == 1 {
442        let mut i = first_val;
443        let end = data.len();
444        let src = data.as_ptr();
445        let dst = out.as_mut_ptr();
446        while i < end {
447            *dst.add(i) = *src.add(i);
448            i += 1;
449        }
450        return;
451    }
452
453    let n = period as f64;
454    let inv_n = 1.0 / n;
455    let sum_x = 0.5_f64 * n * (n + 1.0);
456    let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
457    let denom = n.mul_add(sum_x2, -sum_x * sum_x);
458    let bd = 1.0 / denom;
459    let k = 1.0 - sum_x * inv_n;
460
461    let start = first_val;
462    let end = data.len();
463    if end == 0 || end < start + period {
464        return;
465    }
466
467    let mut sum_y = 0.0f64;
468    let mut sum_xy = 0.0f64;
469    let base = data.as_ptr().add(start);
470    let mut j = 0usize;
471    let mut x = 1.0f64;
472    while j < period {
473        let y = *base.add(j);
474        sum_y += y;
475        sum_xy = y.mul_add(x, sum_xy);
476        x += 1.0;
477        j += 1;
478    }
479
480    let mut i = start + period - 1;
481    let outp = out.as_mut_ptr();
482    let mut b = n.mul_add(sum_xy, -sum_x * sum_y) * bd;
483    *outp.add(i) = b.mul_add(k, sum_y * inv_n);
484
485    let dptr = data.as_ptr();
486    while i + 1 < end {
487        let y_in = *dptr.add(i + 1);
488        let y_out = *dptr.add(i + 1 - period);
489
490        let prev_sum_y = sum_y;
491        sum_y = prev_sum_y + y_in - y_out;
492        sum_xy = (sum_xy - prev_sum_y) + n * y_in;
493
494        i += 1;
495        b = n.mul_add(sum_xy, -sum_x * sum_y) * bd;
496        *outp.add(i) = b.mul_add(k, sum_y * inv_n);
497    }
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[inline]
502pub unsafe fn linearreg_intercept_avx512_short(
503    data: &[f64],
504    period: usize,
505    first_val: usize,
506    out: &mut [f64],
507) {
508    linearreg_intercept_avx2(data, period, first_val, out)
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[inline]
513pub unsafe fn linearreg_intercept_avx512_long(
514    data: &[f64],
515    period: usize,
516    first_val: usize,
517    out: &mut [f64],
518) {
519    linearreg_intercept_avx2(data, period, first_val, out)
520}
521
522#[derive(Debug, Clone)]
523pub struct LinearRegInterceptStream {
524    period: usize,
525    buffer: Vec<f64>,
526    head: usize,
527    filled: bool,
528    sum_x: f64,
529    sum_x2: f64,
530    n: f64,
531    bd: f64,
532    sum_y: f64,
533    sum_xy: f64,
534}
535
536impl LinearRegInterceptStream {
537    #[inline]
538    pub fn try_new(params: LinearRegInterceptParams) -> Result<Self, LinearRegInterceptError> {
539        let period = params.period.unwrap_or(14);
540        if period == 0 {
541            return Err(LinearRegInterceptError::InvalidPeriod {
542                period,
543                data_len: 0,
544            });
545        }
546
547        let n = period as f64;
548        let sum_x = 0.5_f64 * n * (n + 1.0);
549        let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
550        let denom = n * sum_x2 - sum_x * sum_x;
551        let bd = if period == 1 { 0.0 } else { 1.0 / denom };
552
553        Ok(Self {
554            period,
555            buffer: vec![f64::NAN; period],
556            head: 0,
557            filled: false,
558            sum_x,
559            sum_x2,
560            n,
561            bd,
562            sum_y: 0.0,
563            sum_xy: 0.0,
564        })
565    }
566
567    #[inline(always)]
568    pub fn update(&mut self, value: f64) -> Option<f64> {
569        if self.period == 1 {
570            return Some(value);
571        }
572
573        let tail = self.head;
574        let y_out = self.buffer[tail];
575
576        self.buffer[tail] = value;
577        self.head = if self.head + 1 == self.period {
578            0
579        } else {
580            self.head + 1
581        };
582
583        if !self.filled {
584            let x = (tail as f64) + 1.0;
585            self.sum_y += value;
586            self.sum_xy = value.mul_add(x, self.sum_xy);
587
588            if self.head == 0 {
589                self.filled = true;
590            } else {
591                return None;
592            }
593        } else {
594            let sum_y_old = self.sum_y;
595            self.sum_y = sum_y_old + value - y_out;
596
597            self.sum_xy = (self.sum_xy - sum_y_old) + self.n * value;
598        }
599
600        let inv_n = 1.0 / self.n;
601        let k = 1.0 - self.sum_x * inv_n;
602
603        let t = self.n.mul_add(self.sum_xy, -(self.sum_x * self.sum_y));
604        let b = t * self.bd;
605        let y = self.sum_y.mul_add(inv_n, b * k);
606        Some(y)
607    }
608}
609
610#[derive(Clone, Debug)]
611pub struct LinearRegInterceptBatchRange {
612    pub period: (usize, usize, usize),
613}
614
615impl Default for LinearRegInterceptBatchRange {
616    fn default() -> Self {
617        Self {
618            period: (14, 263, 1),
619        }
620    }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct LinearRegInterceptBatchBuilder {
625    range: LinearRegInterceptBatchRange,
626    kernel: Kernel,
627}
628
629impl LinearRegInterceptBatchBuilder {
630    pub fn new() -> Self {
631        Self::default()
632    }
633    pub fn kernel(mut self, k: Kernel) -> Self {
634        self.kernel = k;
635        self
636    }
637    #[inline]
638    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
639        self.range.period = (start, end, step);
640        self
641    }
642    #[inline]
643    pub fn period_static(mut self, p: usize) -> Self {
644        self.range.period = (p, p, 0);
645        self
646    }
647    pub fn apply_slice(
648        self,
649        data: &[f64],
650    ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
651        linearreg_intercept_batch_with_kernel(data, &self.range, self.kernel)
652    }
653    pub fn with_default_slice(
654        data: &[f64],
655        k: Kernel,
656    ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
657        LinearRegInterceptBatchBuilder::new()
658            .kernel(k)
659            .apply_slice(data)
660    }
661    pub fn apply_candles(
662        self,
663        c: &Candles,
664        src: &str,
665    ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
666        let slice = source_type(c, src);
667        self.apply_slice(slice)
668    }
669    pub fn with_default_candles(
670        c: &Candles,
671    ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
672        LinearRegInterceptBatchBuilder::new()
673            .kernel(Kernel::Auto)
674            .apply_candles(c, "close")
675    }
676}
677
678pub fn linearreg_intercept_batch_with_kernel(
679    data: &[f64],
680    sweep: &LinearRegInterceptBatchRange,
681    k: Kernel,
682) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
683    let kernel = match k {
684        Kernel::Auto => detect_best_batch_kernel(),
685        other if other.is_batch() => other,
686        other => return Err(LinearRegInterceptError::InvalidKernelForBatch(other)),
687    };
688
689    let simd = match kernel {
690        Kernel::Avx512Batch => Kernel::Avx512,
691        Kernel::Avx2Batch => Kernel::Avx2,
692        Kernel::ScalarBatch => Kernel::Scalar,
693        _ => unreachable!(),
694    };
695    linearreg_intercept_batch_par_slice(data, sweep, simd)
696}
697
698#[derive(Clone, Debug)]
699pub struct LinearRegInterceptBatchOutput {
700    pub values: Vec<f64>,
701    pub combos: Vec<LinearRegInterceptParams>,
702    pub rows: usize,
703    pub cols: usize,
704}
705impl LinearRegInterceptBatchOutput {
706    pub fn row_for_params(&self, p: &LinearRegInterceptParams) -> Option<usize> {
707        self.combos
708            .iter()
709            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
710    }
711    pub fn values_for(&self, p: &LinearRegInterceptParams) -> Option<&[f64]> {
712        self.row_for_params(p).map(|row| {
713            let start = row * self.cols;
714            &self.values[start..start + self.cols]
715        })
716    }
717}
718
719#[inline(always)]
720fn expand_grid(
721    r: &LinearRegInterceptBatchRange,
722) -> Result<Vec<LinearRegInterceptParams>, LinearRegInterceptError> {
723    fn axis_usize(
724        (start, end, step): (usize, usize, usize),
725    ) -> Result<Vec<usize>, LinearRegInterceptError> {
726        if step == 0 || start == end {
727            return Ok(vec![start]);
728        }
729
730        let mut values = Vec::new();
731        let step_u = step;
732
733        if start <= end {
734            let mut v = start;
735            loop {
736                if v > end {
737                    break;
738                }
739                values.push(v);
740                match v.checked_add(step_u) {
741                    Some(next) => v = next,
742                    None => break,
743                }
744            }
745        } else {
746            let mut v = start;
747            loop {
748                if v < end {
749                    break;
750                }
751                values.push(v);
752                match v.checked_sub(step_u) {
753                    Some(next) => v = next,
754                    None => break,
755                }
756            }
757        }
758
759        if values.is_empty() {
760            return Err(LinearRegInterceptError::InvalidRange {
761                start: start.to_string(),
762                end: end.to_string(),
763                step: step.to_string(),
764            });
765        }
766
767        Ok(values)
768    }
769
770    let periods = axis_usize(r.period)?;
771
772    let mut out = Vec::with_capacity(periods.len());
773    for p in periods {
774        out.push(LinearRegInterceptParams { period: Some(p) });
775    }
776    Ok(out)
777}
778
779#[inline(always)]
780pub fn linearreg_intercept_batch_slice(
781    data: &[f64],
782    sweep: &LinearRegInterceptBatchRange,
783    kern: Kernel,
784) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
785    linearreg_intercept_batch_inner(data, sweep, kern, false)
786}
787
788#[inline(always)]
789pub fn linearreg_intercept_batch_par_slice(
790    data: &[f64],
791    sweep: &LinearRegInterceptBatchRange,
792    kern: Kernel,
793) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
794    linearreg_intercept_batch_inner(data, sweep, kern, true)
795}
796
797#[inline(always)]
798fn linearreg_intercept_batch_inner_into(
799    data: &[f64],
800    sweep: &LinearRegInterceptBatchRange,
801    kern: Kernel,
802    parallel: bool,
803    out: &mut [f64],
804) -> Result<Vec<LinearRegInterceptParams>, LinearRegInterceptError> {
805    if data.is_empty() {
806        return Err(LinearRegInterceptError::EmptyInputData);
807    }
808
809    let combos = expand_grid(sweep)?;
810
811    let first = data
812        .iter()
813        .position(|x| !x.is_nan())
814        .ok_or(LinearRegInterceptError::AllValuesNaN)?;
815    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
816    if data.len() - first < max_p {
817        return Err(LinearRegInterceptError::NotEnoughValidData {
818            needed: max_p,
819            valid: data.len() - first,
820        });
821    }
822
823    let cols = data.len();
824
825    let chosen = match kern {
826        Kernel::Auto => Kernel::Scalar,
827        other => other,
828    };
829
830    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
831        let period = combos[row].period.unwrap();
832        match chosen {
833            Kernel::Scalar | Kernel::ScalarBatch => {
834                linearreg_intercept_row_scalar(data, first, period, out_row)
835            }
836            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
837            Kernel::Avx2 | Kernel::Avx2Batch => {
838                linearreg_intercept_row_avx2(data, first, period, out_row)
839            }
840            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
841            Kernel::Avx512 | Kernel::Avx512Batch => {
842                linearreg_intercept_row_avx512(data, first, period, out_row)
843            }
844            _ => unreachable!(),
845        }
846    };
847
848    if parallel {
849        #[cfg(not(target_arch = "wasm32"))]
850        {
851            out.par_chunks_mut(cols)
852                .enumerate()
853                .for_each(|(row, slice)| do_row(row, slice));
854        }
855
856        #[cfg(target_arch = "wasm32")]
857        {
858            for (row, slice) in out.chunks_mut(cols).enumerate() {
859                do_row(row, slice);
860            }
861        }
862    } else {
863        for (row, slice) in out.chunks_mut(cols).enumerate() {
864            do_row(row, slice);
865        }
866    }
867
868    Ok(combos)
869}
870
871#[inline(always)]
872fn linearreg_intercept_batch_inner(
873    data: &[f64],
874    sweep: &LinearRegInterceptBatchRange,
875    kern: Kernel,
876    parallel: bool,
877) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
878    if data.is_empty() {
879        return Err(LinearRegInterceptError::EmptyInputData);
880    }
881
882    let combos = expand_grid(sweep)?;
883
884    let first = data
885        .iter()
886        .position(|x| !x.is_nan())
887        .ok_or(LinearRegInterceptError::AllValuesNaN)?;
888    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
889    if data.len() - first < max_p {
890        return Err(LinearRegInterceptError::NotEnoughValidData {
891            needed: max_p,
892            valid: data.len() - first,
893        });
894    }
895
896    let rows = combos.len();
897    let cols = data.len();
898
899    let total = rows
900        .checked_mul(cols)
901        .ok_or_else(|| LinearRegInterceptError::InvalidRange {
902            start: sweep.period.0.to_string(),
903            end: sweep.period.1.to_string(),
904            step: sweep.period.2.to_string(),
905        })?;
906
907    let mut buf_mu = make_uninit_matrix(rows, cols);
908
909    let warm: Vec<usize> = combos
910        .iter()
911        .map(|c| first + c.period.unwrap() - 1)
912        .collect();
913    init_matrix_prefixes(&mut buf_mu, cols, &warm);
914
915    let mut values = unsafe {
916        let ptr = buf_mu.as_mut_ptr() as *mut f64;
917        std::mem::forget(buf_mu);
918        Vec::from_raw_parts(ptr, total, total)
919    };
920
921    let chosen = match kern {
922        Kernel::Auto => Kernel::Scalar,
923        other => other,
924    };
925
926    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
927        let period = combos[row].period.unwrap();
928        match chosen {
929            Kernel::Scalar | Kernel::ScalarBatch => {
930                linearreg_intercept_row_scalar(data, first, period, out_row)
931            }
932            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
933            Kernel::Avx2 | Kernel::Avx2Batch => {
934                linearreg_intercept_row_avx2(data, first, period, out_row)
935            }
936            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937            Kernel::Avx512 | Kernel::Avx512Batch => {
938                linearreg_intercept_row_avx512(data, first, period, out_row)
939            }
940            _ => unreachable!(),
941        }
942    };
943
944    if parallel {
945        #[cfg(not(target_arch = "wasm32"))]
946        {
947            values
948                .par_chunks_mut(cols)
949                .enumerate()
950                .for_each(|(row, slice)| do_row(row, slice));
951        }
952
953        #[cfg(target_arch = "wasm32")]
954        {
955            for (row, slice) in values.chunks_mut(cols).enumerate() {
956                do_row(row, slice);
957            }
958        }
959    } else {
960        for (row, slice) in values.chunks_mut(cols).enumerate() {
961            do_row(row, slice);
962        }
963    }
964
965    Ok(LinearRegInterceptBatchOutput {
966        values,
967        combos,
968        rows,
969        cols,
970    })
971}
972
973#[inline(always)]
974unsafe fn linearreg_intercept_row_scalar(
975    data: &[f64],
976    first: usize,
977    period: usize,
978    out: &mut [f64],
979) {
980    linearreg_intercept_scalar(data, period, first, out)
981}
982
983#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
984#[inline(always)]
985unsafe fn linearreg_intercept_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
986    linearreg_intercept_avx2(data, period, first, out)
987}
988
989#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
990#[inline(always)]
991pub unsafe fn linearreg_intercept_row_avx512(
992    data: &[f64],
993    first: usize,
994    period: usize,
995    out: &mut [f64],
996) {
997    linearreg_intercept_avx512(data, period, first, out)
998}
999
1000#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1001#[inline(always)]
1002pub unsafe fn linearreg_intercept_row_avx512_short(
1003    data: &[f64],
1004    first: usize,
1005    period: usize,
1006    out: &mut [f64],
1007) {
1008    linearreg_intercept_avx512_short(data, period, first, out)
1009}
1010
1011#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1012#[inline(always)]
1013pub unsafe fn linearreg_intercept_row_avx512_long(
1014    data: &[f64],
1015    first: usize,
1016    period: usize,
1017    out: &mut [f64],
1018) {
1019    linearreg_intercept_avx512_long(data, period, first, out)
1020}
1021
1022#[inline(always)]
1023fn expand_grid_reg(r: &LinearRegInterceptBatchRange) -> Vec<LinearRegInterceptParams> {
1024    expand_grid(r).unwrap_or_else(|_| Vec::new())
1025}
1026
1027#[cfg(all(feature = "python", feature = "cuda"))]
1028use crate::cuda::moving_averages::DeviceArrayF32;
1029#[cfg(all(feature = "python", feature = "cuda"))]
1030use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1031#[cfg(all(feature = "python", feature = "cuda"))]
1032use cust::context::Context;
1033#[cfg(all(feature = "python", feature = "cuda"))]
1034use cust::memory::DeviceBuffer;
1035#[cfg(all(feature = "python", feature = "cuda"))]
1036use std::sync::Arc;
1037
1038#[cfg(all(feature = "python", feature = "cuda"))]
1039#[pyclass(module = "ta_indicators.cuda", unsendable)]
1040pub struct LinearRegInterceptDeviceArrayF32Py {
1041    pub(crate) buf: Option<DeviceBuffer<f32>>,
1042    pub(crate) rows: usize,
1043    pub(crate) cols: usize,
1044    pub(crate) ctx: Arc<Context>,
1045    pub(crate) device_id: u32,
1046}
1047
1048#[cfg(all(feature = "python", feature = "cuda"))]
1049#[pymethods]
1050impl LinearRegInterceptDeviceArrayF32Py {
1051    #[getter]
1052    fn __cuda_array_interface__<'py>(
1053        &self,
1054        py: Python<'py>,
1055    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1056        let d = pyo3::types::PyDict::new(py);
1057        d.set_item("shape", (self.rows, self.cols))?;
1058        d.set_item("typestr", "<f4")?;
1059        d.set_item(
1060            "strides",
1061            (
1062                self.cols * std::mem::size_of::<f32>(),
1063                std::mem::size_of::<f32>(),
1064            ),
1065        )?;
1066        let ptr = self
1067            .buf
1068            .as_ref()
1069            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1070            .as_device_ptr()
1071            .as_raw() as usize;
1072        d.set_item("data", (ptr, false))?;
1073
1074        d.set_item("version", 3)?;
1075        Ok(d)
1076    }
1077
1078    fn __dlpack_device__(&self) -> (i32, i32) {
1079        (2, self.device_id as i32)
1080    }
1081
1082    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1083    fn __dlpack__<'py>(
1084        &mut self,
1085        py: Python<'py>,
1086        stream: Option<pyo3::PyObject>,
1087        max_version: Option<pyo3::PyObject>,
1088        dl_device: Option<pyo3::PyObject>,
1089        copy: Option<pyo3::PyObject>,
1090    ) -> PyResult<PyObject> {
1091        let (kdl, alloc_dev) = self.__dlpack_device__();
1092        if let Some(dev_obj) = dl_device.as_ref() {
1093            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1094                if dev_ty != kdl || dev_id != alloc_dev {
1095                    let wants_copy = copy
1096                        .as_ref()
1097                        .and_then(|c| c.extract::<bool>(py).ok())
1098                        .unwrap_or(false);
1099                    if wants_copy {
1100                        return Err(PyValueError::new_err(
1101                            "device copy not implemented for __dlpack__",
1102                        ));
1103                    } else {
1104                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1105                    }
1106                }
1107            }
1108        }
1109        let _ = stream;
1110
1111        let buf = self
1112            .buf
1113            .take()
1114            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1115
1116        let rows = self.rows;
1117        let cols = self.cols;
1118
1119        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1120
1121        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1122    }
1123}
1124
1125#[cfg(all(feature = "python", feature = "cuda"))]
1126#[pyfunction(name = "linearreg_intercept_cuda_batch_dev")]
1127#[pyo3(signature = (data, period_range, device_id=0))]
1128pub fn linearreg_intercept_cuda_batch_dev_py(
1129    py: Python<'_>,
1130    data: numpy::PyReadonlyArray1<'_, f32>,
1131    period_range: (usize, usize, usize),
1132    device_id: usize,
1133) -> PyResult<LinearRegInterceptDeviceArrayF32Py> {
1134    use crate::cuda::cuda_available;
1135    use crate::cuda::CudaLinregIntercept;
1136    if !cuda_available() {
1137        return Err(PyValueError::new_err("CUDA not available"));
1138    }
1139    let slice = data.as_slice()?;
1140    let sweep = LinearRegInterceptBatchRange {
1141        period: period_range,
1142    };
1143    let (dev, ctx, dev_id) = py.allow_threads(|| {
1144        let cuda = CudaLinregIntercept::new(device_id)
1145            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1146        let (dev, _combos) = cuda
1147            .linearreg_intercept_batch_dev(slice, &sweep)
1148            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1149        let ctx = cuda.context_arc();
1150        Ok::<_, pyo3::PyErr>((dev, ctx, cuda.device_id()))
1151    })?;
1152    let DeviceArrayF32 { buf, rows, cols } = dev;
1153    Ok(LinearRegInterceptDeviceArrayF32Py {
1154        buf: Some(buf),
1155        rows,
1156        cols,
1157        ctx,
1158        device_id: dev_id,
1159    })
1160}
1161
1162#[cfg(all(feature = "python", feature = "cuda"))]
1163#[pyfunction(name = "linearreg_intercept_cuda_many_series_one_param_dev")]
1164#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1165pub fn linearreg_intercept_cuda_many_series_one_param_dev_py(
1166    py: Python<'_>,
1167    data_tm: numpy::PyReadonlyArray1<'_, f32>,
1168    cols: usize,
1169    rows: usize,
1170    period: usize,
1171    device_id: usize,
1172) -> PyResult<LinearRegInterceptDeviceArrayF32Py> {
1173    use crate::cuda::cuda_available;
1174    use crate::cuda::CudaLinregIntercept;
1175    if !cuda_available() {
1176        return Err(PyValueError::new_err("CUDA not available"));
1177    }
1178    let slice = data_tm.as_slice()?;
1179    let expected = cols
1180        .checked_mul(rows)
1181        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1182    if slice.len() != expected {
1183        return Err(PyValueError::new_err("time-major input length mismatch"));
1184    }
1185    let params = LinearRegInterceptParams {
1186        period: Some(period),
1187    };
1188    let (dev, ctx, dev_id) = py.allow_threads(|| {
1189        let cuda = CudaLinregIntercept::new(device_id)
1190            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1191        let dev = cuda
1192            .linearreg_intercept_many_series_one_param_time_major_dev(slice, cols, rows, &params)
1193            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1194        let ctx = cuda.context_arc();
1195        Ok::<_, pyo3::PyErr>((dev, ctx, cuda.device_id()))
1196    })?;
1197    Ok(LinearRegInterceptDeviceArrayF32Py {
1198        buf: Some(dev.buf),
1199        rows: dev.rows,
1200        cols: dev.cols,
1201        ctx,
1202        device_id: dev_id,
1203    })
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208    use super::*;
1209    use crate::skip_if_unsupported;
1210    use crate::utilities::data_loader::read_candles_from_csv;
1211    #[cfg(feature = "proptest")]
1212    use proptest::prelude::*;
1213
1214    fn check_linreg_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1215        skip_if_unsupported!(kernel, test_name);
1216        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1217        let candles = read_candles_from_csv(file_path)?;
1218        let default_params = LinearRegInterceptParams { period: None };
1219        let input = LinearRegInterceptInput::from_candles(&candles, "close", default_params);
1220        let output = linearreg_intercept_with_kernel(&input, kernel)?;
1221        assert_eq!(output.values.len(), candles.close.len());
1222        Ok(())
1223    }
1224
1225    fn check_linreg_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1226        skip_if_unsupported!(kernel, test_name);
1227        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1228        let candles = read_candles_from_csv(file_path)?;
1229        let input = LinearRegInterceptInput::from_candles(
1230            &candles,
1231            "close",
1232            LinearRegInterceptParams::default(),
1233        );
1234        let result = linearreg_intercept_with_kernel(&input, kernel)?;
1235        let expected_last_five = [
1236            60000.91428571429,
1237            59947.142857142855,
1238            59754.57142857143,
1239            59318.4,
1240            59321.91428571429,
1241        ];
1242        let start = result.values.len().saturating_sub(5);
1243        for (i, &val) in result.values[start..].iter().enumerate() {
1244            let diff = (val - expected_last_five[i]).abs();
1245            assert!(
1246                diff < 1e-1,
1247                "[{}] LinReg {:?} mismatch at idx {}: got {}, expected {}",
1248                test_name,
1249                kernel,
1250                i,
1251                val,
1252                expected_last_five[i]
1253            );
1254        }
1255        Ok(())
1256    }
1257
1258    fn check_linreg_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1259        skip_if_unsupported!(kernel, test_name);
1260        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1261        let candles = read_candles_from_csv(file_path)?;
1262        let input = LinearRegInterceptInput::with_default_candles(&candles);
1263        match input.data {
1264            LinearRegInterceptData::Candles { source, .. } => assert_eq!(source, "close"),
1265            _ => panic!("Expected LinearRegInterceptData::Candles"),
1266        }
1267        let output = linearreg_intercept_with_kernel(&input, kernel)?;
1268        assert_eq!(output.values.len(), candles.close.len());
1269        Ok(())
1270    }
1271
1272    fn check_linreg_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1273        skip_if_unsupported!(kernel, test_name);
1274        let input_data = [10.0, 20.0, 30.0];
1275        let params = LinearRegInterceptParams { period: Some(0) };
1276        let input = LinearRegInterceptInput::from_slice(&input_data, params);
1277        let res = linearreg_intercept_with_kernel(&input, kernel);
1278        assert!(
1279            res.is_err(),
1280            "[{}] LinReg should fail with zero period",
1281            test_name
1282        );
1283        Ok(())
1284    }
1285
1286    fn check_linreg_period_exceeds_length(
1287        test_name: &str,
1288        kernel: Kernel,
1289    ) -> Result<(), Box<dyn Error>> {
1290        skip_if_unsupported!(kernel, test_name);
1291        let data_small = [10.0, 20.0, 30.0];
1292        let params = LinearRegInterceptParams { period: Some(10) };
1293        let input = LinearRegInterceptInput::from_slice(&data_small, params);
1294        let res = linearreg_intercept_with_kernel(&input, kernel);
1295        assert!(
1296            res.is_err(),
1297            "[{}] LinReg should fail with period exceeding length",
1298            test_name
1299        );
1300        Ok(())
1301    }
1302
1303    fn check_linreg_very_small_dataset(
1304        test_name: &str,
1305        kernel: Kernel,
1306    ) -> Result<(), Box<dyn Error>> {
1307        skip_if_unsupported!(kernel, test_name);
1308        let single_point = [42.0];
1309        let params = LinearRegInterceptParams { period: Some(14) };
1310        let input = LinearRegInterceptInput::from_slice(&single_point, params);
1311        let res = linearreg_intercept_with_kernel(&input, kernel);
1312        assert!(
1313            res.is_err(),
1314            "[{}] LinReg should fail with insufficient data",
1315            test_name
1316        );
1317        Ok(())
1318    }
1319
1320    fn check_linreg_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1321        skip_if_unsupported!(kernel, test_name);
1322        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1323        let candles = read_candles_from_csv(file_path)?;
1324        let first_params = LinearRegInterceptParams { period: Some(14) };
1325        let first_input = LinearRegInterceptInput::from_candles(&candles, "close", first_params);
1326        let first_result = linearreg_intercept_with_kernel(&first_input, kernel)?;
1327        let second_params = LinearRegInterceptParams { period: Some(14) };
1328        let second_input = LinearRegInterceptInput::from_slice(&first_result.values, second_params);
1329        let second_result = linearreg_intercept_with_kernel(&second_input, kernel)?;
1330        assert_eq!(second_result.values.len(), first_result.values.len());
1331
1332        let start = second_result
1333            .values
1334            .iter()
1335            .position(|v| !v.is_nan())
1336            .unwrap_or(second_result.values.len());
1337
1338        for (i, v) in second_result.values[start..].iter().enumerate() {
1339            assert!(
1340                !v.is_nan(),
1341                "[{}] Unexpected NaN at index {} after reinput",
1342                test_name,
1343                start + i
1344            );
1345        }
1346        Ok(())
1347    }
1348
1349    fn check_linreg_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1350        skip_if_unsupported!(kernel, test_name);
1351        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1352        let candles = read_candles_from_csv(file_path)?;
1353        let input = LinearRegInterceptInput::from_candles(
1354            &candles,
1355            "close",
1356            LinearRegInterceptParams { period: Some(14) },
1357        );
1358        let res = linearreg_intercept_with_kernel(&input, kernel)?;
1359        assert_eq!(res.values.len(), candles.close.len());
1360        if res.values.len() > 40 {
1361            for (i, &val) in res.values[40..].iter().enumerate() {
1362                assert!(
1363                    !val.is_nan(),
1364                    "[{}] Found unexpected NaN at out-index {}",
1365                    test_name,
1366                    40 + i
1367                );
1368            }
1369        }
1370        Ok(())
1371    }
1372
1373    #[cfg(debug_assertions)]
1374    fn check_linreg_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1375        skip_if_unsupported!(kernel, test_name);
1376
1377        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1378        let candles = read_candles_from_csv(file_path)?;
1379
1380        let test_params = vec![
1381            LinearRegInterceptParams::default(),
1382            LinearRegInterceptParams { period: Some(2) },
1383            LinearRegInterceptParams { period: Some(5) },
1384            LinearRegInterceptParams { period: Some(7) },
1385            LinearRegInterceptParams { period: Some(10) },
1386            LinearRegInterceptParams { period: Some(20) },
1387            LinearRegInterceptParams { period: Some(30) },
1388            LinearRegInterceptParams { period: Some(50) },
1389            LinearRegInterceptParams { period: Some(100) },
1390            LinearRegInterceptParams { period: Some(150) },
1391            LinearRegInterceptParams { period: Some(200) },
1392        ];
1393
1394        for (param_idx, params) in test_params.iter().enumerate() {
1395            let input = LinearRegInterceptInput::from_candles(&candles, "close", params.clone());
1396            let output = linearreg_intercept_with_kernel(&input, kernel)?;
1397
1398            for (i, &val) in output.values.iter().enumerate() {
1399                if val.is_nan() {
1400                    continue;
1401                }
1402
1403                let bits = val.to_bits();
1404
1405                if bits == 0x11111111_11111111 {
1406                    panic!(
1407                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1408						 with params: period={} (param set {})",
1409                        test_name,
1410                        val,
1411                        bits,
1412                        i,
1413                        params.period.unwrap_or(14),
1414                        param_idx
1415                    );
1416                }
1417
1418                if bits == 0x22222222_22222222 {
1419                    panic!(
1420                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1421						 with params: period={} (param set {})",
1422                        test_name,
1423                        val,
1424                        bits,
1425                        i,
1426                        params.period.unwrap_or(14),
1427                        param_idx
1428                    );
1429                }
1430
1431                if bits == 0x33333333_33333333 {
1432                    panic!(
1433                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1434						 with params: period={} (param set {})",
1435                        test_name,
1436                        val,
1437                        bits,
1438                        i,
1439                        params.period.unwrap_or(14),
1440                        param_idx
1441                    );
1442                }
1443            }
1444        }
1445
1446        Ok(())
1447    }
1448
1449    #[cfg(not(debug_assertions))]
1450    fn check_linreg_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1451        Ok(())
1452    }
1453
1454    #[cfg(feature = "proptest")]
1455    #[allow(clippy::float_cmp)]
1456    fn check_linearreg_intercept_property(
1457        test_name: &str,
1458        kernel: Kernel,
1459    ) -> Result<(), Box<dyn std::error::Error>> {
1460        skip_if_unsupported!(kernel, test_name);
1461
1462        fn calculate_expected_linreg_intercept(
1463            window_start_idx: usize,
1464            period: usize,
1465            data_slope: f64,
1466            data_intercept: f64,
1467        ) -> f64 {
1468            data_slope * window_start_idx as f64 + data_intercept
1469        }
1470
1471        let strat = (1usize..=100, 50usize..500, 0usize..5, any::<u64>()).prop_map(
1472            |(period, len, scenario, seed)| {
1473                let mut rng_state = seed.wrapping_mul(1664525).wrapping_add(1013904223);
1474                let mut data = Vec::with_capacity(len);
1475
1476                match scenario {
1477                    0 => {
1478                        for _ in 0..len {
1479                            rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
1480                            let val = (rng_state as f64 / u64::MAX as f64) * 200.0 - 100.0;
1481                            data.push(val);
1482                        }
1483                    }
1484                    1 => {
1485                        let constant = 42.0;
1486                        data.resize(len, constant);
1487                    }
1488                    2 => {
1489                        for i in 0..len {
1490                            data.push(2.0 * i as f64 + 10.0);
1491                        }
1492                    }
1493                    3 => {
1494                        for i in 0..len {
1495                            data.push(-1.5 * i as f64 + 100.0);
1496                        }
1497                    }
1498                    _ => {
1499                        for i in 0..len {
1500                            rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
1501                            let noise = ((rng_state as f64 / u64::MAX as f64) - 0.5) * 10.0;
1502                            data.push(0.5 * i as f64 + 50.0 + noise);
1503                        }
1504                    }
1505                }
1506
1507                (data, period, scenario)
1508            },
1509        );
1510
1511        proptest::test_runner::TestRunner::default().run(&strat, |(data, period, scenario)| {
1512            let params = LinearRegInterceptParams {
1513                period: Some(period),
1514            };
1515            let input = LinearRegInterceptInput::from_slice(&data, params);
1516
1517            let output = linearreg_intercept_with_kernel(&input, kernel)?;
1518
1519            let ref_output = linearreg_intercept_with_kernel(&input, Kernel::Scalar)?;
1520
1521            prop_assert_eq!(
1522                output.values.len(),
1523                data.len(),
1524                "[{}] Output length mismatch",
1525                test_name
1526            );
1527
1528            if period == 1 {
1529                for i in 0..data.len() {
1530                    let expected = data[i];
1531                    let actual = output.values[i];
1532                    prop_assert!(
1533                        (actual - expected).abs() < 1e-9,
1534                        "[{}] Period=1: expected {}, got {} at index {}",
1535                        test_name,
1536                        expected,
1537                        actual,
1538                        i
1539                    );
1540                }
1541            } else {
1542                for i in 0..(period - 1) {
1543                    prop_assert!(
1544                        output.values[i].is_nan(),
1545                        "[{}] Expected NaN during warmup at index {}",
1546                        test_name,
1547                        i
1548                    );
1549                }
1550
1551                if period <= data.len() {
1552                    prop_assert!(
1553                        !output.values[period - 1].is_nan(),
1554                        "[{}] Expected valid value at index {}",
1555                        test_name,
1556                        period - 1
1557                    );
1558                }
1559            }
1560
1561            if scenario == 1 && period < data.len() {
1562                for i in (period - 1)..data.len() {
1563                    let intercept = output.values[i];
1564                    if !intercept.is_nan() {
1565                        prop_assert!(
1566                            (intercept - 42.0).abs() < 1e-9,
1567                            "[{}] Constant data: expected 42.0, got {} at index {}",
1568                            test_name,
1569                            intercept,
1570                            i
1571                        );
1572                    }
1573                }
1574            }
1575
1576            if (scenario == 2 || scenario == 3) && period > 1 && period < data.len() {
1577                let (data_slope, data_intercept) = match scenario {
1578                    2 => (2.0, 10.0),
1579                    3 => (-1.5, 100.0),
1580                    _ => unreachable!(),
1581                };
1582
1583                for i in (period - 1)..data.len().min(period * 5) {
1584                    let actual = output.values[i];
1585                    if !actual.is_nan() {
1586                        let window_start = i + 1 - period;
1587                        let expected = calculate_expected_linreg_intercept(
1588                            window_start,
1589                            period,
1590                            data_slope,
1591                            data_intercept,
1592                        );
1593
1594                        prop_assert!((actual - expected).abs() < 1e-9,
1595								"[{}] Linear trend (scenario {}): expected {:.6}, got {:.6} at index {} (window start {})",
1596								test_name, scenario, expected, actual, i, window_start);
1597                    }
1598                }
1599            }
1600
1601            for i in 0..output.values.len() {
1602                let y = output.values[i];
1603                let r = ref_output.values[i];
1604
1605                let bits = y.to_bits();
1606                prop_assert!(
1607                    bits != 0x11111111_11111111
1608                        && bits != 0x22222222_22222222
1609                        && bits != 0x33333333_33333333,
1610                    "[{}] Found poison value at index {}: 0x{:016X}",
1611                    test_name,
1612                    i,
1613                    bits
1614                );
1615
1616                if y.is_nan() && r.is_nan() {
1617                    continue;
1618                }
1619
1620                if y.is_finite() && r.is_finite() {
1621                    prop_assert!(
1622                        (y - r).abs() <= 1e-9,
1623                        "[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1624                        test_name,
1625                        i,
1626                        y,
1627                        r,
1628                        (y - r).abs()
1629                    );
1630                } else {
1631                    prop_assert_eq!(
1632                        y.is_nan(),
1633                        r.is_nan(),
1634                        "[{}] NaN mismatch at index {}: {} vs {}",
1635                        test_name,
1636                        i,
1637                        y,
1638                        r
1639                    );
1640                }
1641            }
1642
1643            if period > 1 {
1644                for i in (period - 1)..output.values.len() {
1645                    let val = output.values[i];
1646                    prop_assert!(
1647                        val.is_finite(),
1648                        "[{}] Non-finite value {} at index {}",
1649                        test_name,
1650                        val,
1651                        i
1652                    );
1653                }
1654            }
1655
1656            Ok(())
1657        })?;
1658
1659        Ok(())
1660    }
1661
1662    macro_rules! generate_all_linreg_tests {
1663        ($($test_fn:ident),*) => {
1664            paste::paste! {
1665                $(
1666                    #[test]
1667                    fn [<$test_fn _scalar_f64>]() {
1668                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1669                    }
1670                )*
1671                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1672                $(
1673                    #[test]
1674                    fn [<$test_fn _avx2_f64>]() {
1675                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1676                    }
1677                    #[test]
1678                    fn [<$test_fn _avx512_f64>]() {
1679                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1680                    }
1681                )*
1682            }
1683        }
1684    }
1685
1686    generate_all_linreg_tests!(
1687        check_linreg_partial_params,
1688        check_linreg_accuracy,
1689        check_linreg_default_candles,
1690        check_linreg_zero_period,
1691        check_linreg_period_exceeds_length,
1692        check_linreg_very_small_dataset,
1693        check_linreg_reinput,
1694        check_linreg_nan_handling,
1695        check_linreg_no_poison
1696    );
1697
1698    #[cfg(feature = "proptest")]
1699    generate_all_linreg_tests!(check_linearreg_intercept_property);
1700
1701    #[test]
1702    fn test_linearreg_intercept_into_matches_api() -> Result<(), Box<dyn Error>> {
1703        let len = 256usize;
1704        let mut data = Vec::with_capacity(len);
1705        for i in 0..len {
1706            if i < 10 {
1707                data.push(f64::NAN);
1708            } else {
1709                let x = i as f64;
1710                data.push((0.1 * x).sin() * 3.0 + 0.05 * x + 2.0);
1711            }
1712        }
1713
1714        let input = LinearRegInterceptInput::from_slice(&data, LinearRegInterceptParams::default());
1715
1716        let baseline = linearreg_intercept(&input)?.values;
1717
1718        let mut out = vec![0.0; data.len()];
1719        #[allow(unused_variables)]
1720        {
1721            #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1722            {
1723                linearreg_intercept_into(&input, &mut out)?;
1724            }
1725            #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1726            {
1727                linearreg_intercept_into_slice(&mut out, &input, Kernel::Auto)?;
1728            }
1729        }
1730
1731        assert_eq!(baseline.len(), out.len());
1732
1733        fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1734            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1735        }
1736
1737        for i in 0..out.len() {
1738            assert!(
1739                eq_or_both_nan_eps(baseline[i], out[i]),
1740                "mismatch at index {}: baseline={} out={}",
1741                i,
1742                baseline[i],
1743                out[i]
1744            );
1745        }
1746
1747        Ok(())
1748    }
1749
1750    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1751        skip_if_unsupported!(kernel, test);
1752
1753        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1754        let c = read_candles_from_csv(file)?;
1755
1756        let output = LinearRegInterceptBatchBuilder::new()
1757            .kernel(kernel)
1758            .apply_candles(&c, "close")?;
1759
1760        let def = LinearRegInterceptParams::default();
1761        let row = output.values_for(&def).expect("default row missing");
1762
1763        assert_eq!(row.len(), c.close.len());
1764
1765        let expected = [
1766            60000.91428571429,
1767            59947.142857142855,
1768            59754.57142857143,
1769            59318.4,
1770            59321.91428571429,
1771        ];
1772        let start = row.len() - 5;
1773        for (i, &v) in row[start..].iter().enumerate() {
1774            assert!(
1775                (v - expected[i]).abs() < 1e-1,
1776                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1777            );
1778        }
1779        Ok(())
1780    }
1781
1782    #[cfg(debug_assertions)]
1783    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784        skip_if_unsupported!(kernel, test);
1785
1786        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787        let c = read_candles_from_csv(file)?;
1788
1789        let test_configs = vec![
1790            (2, 10, 2),
1791            (5, 25, 5),
1792            (10, 10, 0),
1793            (2, 5, 1),
1794            (30, 60, 15),
1795            (2, 14, 3),
1796            (50, 100, 25),
1797            (100, 200, 50),
1798        ];
1799
1800        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1801            let output = LinearRegInterceptBatchBuilder::new()
1802                .kernel(kernel)
1803                .period_range(p_start, p_end, p_step)
1804                .apply_candles(&c, "close")?;
1805
1806            for (idx, &val) in output.values.iter().enumerate() {
1807                if val.is_nan() {
1808                    continue;
1809                }
1810
1811                let bits = val.to_bits();
1812                let row = idx / output.cols;
1813                let col = idx % output.cols;
1814                let combo = &output.combos[row];
1815
1816                if bits == 0x11111111_11111111 {
1817                    panic!(
1818                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1819						 at row {} col {} (flat index {}) with params: period={}",
1820                        test,
1821                        cfg_idx,
1822                        val,
1823                        bits,
1824                        row,
1825                        col,
1826                        idx,
1827                        combo.period.unwrap_or(14)
1828                    );
1829                }
1830
1831                if bits == 0x22222222_22222222 {
1832                    panic!(
1833                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1834						 at row {} col {} (flat index {}) with params: period={}",
1835                        test,
1836                        cfg_idx,
1837                        val,
1838                        bits,
1839                        row,
1840                        col,
1841                        idx,
1842                        combo.period.unwrap_or(14)
1843                    );
1844                }
1845
1846                if bits == 0x33333333_33333333 {
1847                    panic!(
1848                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1849						 at row {} col {} (flat index {}) with params: period={}",
1850                        test,
1851                        cfg_idx,
1852                        val,
1853                        bits,
1854                        row,
1855                        col,
1856                        idx,
1857                        combo.period.unwrap_or(14)
1858                    );
1859                }
1860            }
1861        }
1862
1863        Ok(())
1864    }
1865
1866    #[cfg(not(debug_assertions))]
1867    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1868        Ok(())
1869    }
1870
1871    macro_rules! gen_batch_tests {
1872        ($fn_name:ident) => {
1873            paste::paste! {
1874                #[test] fn [<$fn_name _scalar>]()      {
1875                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1876                }
1877                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1878                #[test] fn [<$fn_name _avx2>]()        {
1879                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1880                }
1881                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1882                #[test] fn [<$fn_name _avx512>]()      {
1883                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1884                }
1885                #[test] fn [<$fn_name _auto_detect>]() {
1886                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1887                }
1888            }
1889        };
1890    }
1891    gen_batch_tests!(check_batch_default_row);
1892    gen_batch_tests!(check_batch_no_poison);
1893}
1894
1895#[cfg(feature = "python")]
1896#[pyfunction(name = "linearreg_intercept")]
1897#[pyo3(signature = (data, period, kernel=None))]
1898pub fn linearreg_intercept_py<'py>(
1899    py: Python<'py>,
1900    data: numpy::PyReadonlyArray1<'py, f64>,
1901    period: usize,
1902    kernel: Option<&str>,
1903) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1904    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1905
1906    let slice_in = data.as_slice()?;
1907    let kern = validate_kernel(kernel, false)?;
1908    let params = LinearRegInterceptParams {
1909        period: Some(period),
1910    };
1911    let input = LinearRegInterceptInput::from_slice(slice_in, params);
1912
1913    let result_vec: Vec<f64> = py
1914        .allow_threads(|| linearreg_intercept_with_kernel(&input, kern).map(|o| o.values))
1915        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1916
1917    Ok(result_vec.into_pyarray(py))
1918}
1919
1920#[cfg(feature = "python")]
1921#[pyclass(name = "LinearRegInterceptStream")]
1922pub struct LinearRegInterceptStreamPy {
1923    stream: LinearRegInterceptStream,
1924}
1925
1926#[cfg(feature = "python")]
1927#[pymethods]
1928impl LinearRegInterceptStreamPy {
1929    #[new]
1930    fn new(period: usize) -> PyResult<Self> {
1931        let params = LinearRegInterceptParams {
1932            period: Some(period),
1933        };
1934        let stream = LinearRegInterceptStream::try_new(params)
1935            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1936        Ok(LinearRegInterceptStreamPy { stream })
1937    }
1938
1939    fn update(&mut self, value: f64) -> Option<f64> {
1940        self.stream.update(value)
1941    }
1942}
1943
1944#[cfg(feature = "python")]
1945#[pyfunction(name = "linearreg_intercept_batch")]
1946#[pyo3(signature = (data, period_range, kernel=None))]
1947pub fn linearreg_intercept_batch_py<'py>(
1948    py: Python<'py>,
1949    data: numpy::PyReadonlyArray1<'py, f64>,
1950    period_range: (usize, usize, usize),
1951    kernel: Option<&str>,
1952) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1953    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1954    use pyo3::types::PyDict;
1955
1956    let slice_in = data.as_slice()?;
1957    let sweep = LinearRegInterceptBatchRange {
1958        period: period_range,
1959    };
1960
1961    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1962    let rows = combos.len();
1963    let cols = slice_in.len();
1964
1965    let total = rows
1966        .checked_mul(cols)
1967        .ok_or_else(|| PyValueError::new_err("linearreg_intercept_batch: rows*cols overflow"))?;
1968
1969    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1970    let slice_out = unsafe { out_arr.as_slice_mut()? };
1971
1972    if !combos.is_empty() && cols > 0 {
1973        let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1974        for (r, prm) in combos.iter().enumerate() {
1975            let warm = (first + prm.period.unwrap() - 1).min(cols);
1976            for v in &mut slice_out[r * cols..r * cols + warm] {
1977                *v = f64::NAN;
1978            }
1979        }
1980    }
1981
1982    let kern = validate_kernel(kernel, true)?;
1983    let combos = py
1984        .allow_threads(|| {
1985            let kernel = match kern {
1986                Kernel::Auto => detect_best_batch_kernel(),
1987                k => k,
1988            };
1989            let simd = match kernel {
1990                Kernel::Avx512Batch => Kernel::Avx512,
1991                Kernel::Avx2Batch => Kernel::Avx2,
1992                Kernel::ScalarBatch => Kernel::Scalar,
1993                _ => unreachable!(),
1994            };
1995            linearreg_intercept_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1996        })
1997        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1998
1999    let dict = PyDict::new(py);
2000    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2001    dict.set_item(
2002        "periods",
2003        combos
2004            .iter()
2005            .map(|p| p.period.unwrap() as u64)
2006            .collect::<Vec<_>>()
2007            .into_pyarray(py),
2008    )?;
2009    Ok(dict)
2010}
2011
2012#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2013#[wasm_bindgen]
2014pub fn linearreg_intercept_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2015    let params = LinearRegInterceptParams {
2016        period: Some(period),
2017    };
2018    let input = LinearRegInterceptInput::from_slice(data, params);
2019
2020    let mut output = vec![0.0; data.len()];
2021    linearreg_intercept_into_slice(&mut output, &input, Kernel::Auto)
2022        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2023
2024    Ok(output)
2025}
2026
2027#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2028#[wasm_bindgen]
2029pub fn linearreg_intercept_alloc(len: usize) -> *mut f64 {
2030    let mut vec = Vec::<f64>::with_capacity(len);
2031    let ptr = vec.as_mut_ptr();
2032    std::mem::forget(vec);
2033    ptr
2034}
2035
2036#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2037#[wasm_bindgen]
2038pub fn linearreg_intercept_free(ptr: *mut f64, len: usize) {
2039    if !ptr.is_null() {
2040        unsafe {
2041            let _ = Vec::from_raw_parts(ptr, len, len);
2042        }
2043    }
2044}
2045
2046#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2047#[wasm_bindgen]
2048pub fn linearreg_intercept_into(
2049    in_ptr: *const f64,
2050    out_ptr: *mut f64,
2051    len: usize,
2052    period: usize,
2053) -> Result<(), JsValue> {
2054    if in_ptr.is_null() || out_ptr.is_null() {
2055        return Err(JsValue::from_str("Null pointer provided"));
2056    }
2057
2058    unsafe {
2059        let data = std::slice::from_raw_parts(in_ptr, len);
2060        let params = LinearRegInterceptParams {
2061            period: Some(period),
2062        };
2063        let input = LinearRegInterceptInput::from_slice(data, params);
2064
2065        if in_ptr == out_ptr {
2066            let mut temp = vec![0.0; len];
2067            linearreg_intercept_into_slice(&mut temp, &input, Kernel::Auto)
2068                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2069            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2070            out.copy_from_slice(&temp);
2071        } else {
2072            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2073            linearreg_intercept_into_slice(out, &input, Kernel::Auto)
2074                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2075        }
2076        Ok(())
2077    }
2078}
2079
2080#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2081#[derive(Serialize, Deserialize)]
2082pub struct LinearRegInterceptBatchConfig {
2083    pub period_range: (usize, usize, usize),
2084}
2085
2086#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2087#[derive(Serialize, Deserialize)]
2088pub struct LinearRegInterceptBatchJsOutput {
2089    pub values: Vec<f64>,
2090    pub combos: Vec<LinearRegInterceptParams>,
2091    pub rows: usize,
2092    pub cols: usize,
2093}
2094
2095#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2096#[wasm_bindgen(js_name = linearreg_intercept_batch)]
2097pub fn linearreg_intercept_batch_unified_js(
2098    data: &[f64],
2099    config: JsValue,
2100) -> Result<JsValue, JsValue> {
2101    let config: LinearRegInterceptBatchConfig = serde_wasm_bindgen::from_value(config)
2102        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2103
2104    let sweep = LinearRegInterceptBatchRange {
2105        period: config.period_range,
2106    };
2107
2108    let batch_output = linearreg_intercept_batch_with_kernel(data, &sweep, Kernel::Auto)
2109        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2110
2111    let rows = batch_output.values.len() / data.len();
2112    let result = LinearRegInterceptBatchJsOutput {
2113        values: batch_output.values,
2114        combos: batch_output.combos,
2115        rows,
2116        cols: data.len(),
2117    };
2118
2119    serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
2120}
2121
2122#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2123#[wasm_bindgen]
2124pub fn linearreg_intercept_batch_into(
2125    in_ptr: *const f64,
2126    out_ptr: *mut f64,
2127    len: usize,
2128    period_start: usize,
2129    period_end: usize,
2130    period_step: usize,
2131) -> Result<usize, JsValue> {
2132    if in_ptr.is_null() || out_ptr.is_null() {
2133        return Err(JsValue::from_str("Null pointer provided"));
2134    }
2135
2136    unsafe {
2137        let data = std::slice::from_raw_parts(in_ptr, len);
2138        let sweep = LinearRegInterceptBatchRange {
2139            period: (period_start, period_end, period_step),
2140        };
2141        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2142        let rows = combos.len();
2143        let cols = len;
2144        let total = rows.checked_mul(cols).ok_or_else(|| {
2145            JsValue::from_str("linearreg_intercept_batch_into: rows*cols overflow")
2146        })?;
2147
2148        let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2149
2150        if in_ptr == out_ptr {
2151            let mut temp = vec![0.0; total];
2152
2153            for (r, prm) in combos.iter().enumerate() {
2154                let warm = (first + prm.period.unwrap() - 1).min(cols);
2155                temp[r * cols..r * cols + warm].fill(f64::NAN);
2156            }
2157
2158            linearreg_intercept_batch_inner_into(data, &sweep, Kernel::Auto, true, &mut temp)
2159                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2160            let out = std::slice::from_raw_parts_mut(out_ptr, total);
2161            out.copy_from_slice(&temp);
2162        } else {
2163            let out = std::slice::from_raw_parts_mut(out_ptr, total);
2164
2165            for (r, prm) in combos.iter().enumerate() {
2166                let warm = (first + prm.period.unwrap() - 1).min(cols);
2167                out[r * cols..r * cols + warm].fill(f64::NAN);
2168            }
2169
2170            linearreg_intercept_batch_inner_into(data, &sweep, Kernel::Auto, true, out)
2171                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2172        }
2173        Ok(rows)
2174    }
2175}