Skip to main content

statoxide/
lib.rs

1use ndarray::Array1;
2use pyo3::exceptions::{PyRuntimeError, PyValueError};
3use pyo3::prelude::*;
4use pyo3::types::{PyDict, PyList};
5use std::collections::HashMap;
6
7use so_core::data::{DataFrame, Series};
8use so_core::formula::Formula;
9
10// Import models
11use so_models::glm::{Family, GLM as RustGLM, GLMModelBuilder, GLMResults, Link};
12use so_models::regression::OLS;
13use so_models::robust::{LeastTrimmedSquares, MEstimator as RustMEstimator, RobustRegressionResults as RustRobustResults, LossFunction, ScaleEstimator};
14use so_models::nonparametric::{Kernel, KernelRegression, KernelRegressionResults, LocalRegression, LocalRegressionResults, SmoothingSpline, SmoothingSplineResults, BandwidthMethod};
15
16// Import time series
17use so_tsa::TimeSeries;
18use so_tsa::arima::ARIMAResults;
19use so_tsa::garch::GARCHResults;
20
21// Import statistical tests
22use so_stats::tests::{
23    Alternative, TestResult, anova_one_way as anova_one_way_rs,
24    chi_square_test_independence as chi_square_test_independence_rs,
25    shapiro_wilk_test as shapiro_wilk_test_rs, t_test_one_sample as t_test_one_sample_rs,
26    t_test_paired as t_test_paired_rs, t_test_two_sample as t_test_two_sample_rs,
27};
28
29/// Python wrapper for StatOxide Series
30#[pyclass(name = "Series")]
31struct PySeries {
32    inner: Series,
33}
34
35#[pymethods]
36impl PySeries {
37    /// Create a new Series from Python list or array
38    #[new]
39    fn new(name: String, data: Vec<f64>) -> PyResult<Self> {
40        let array = Array1::from_vec(data);
41        Ok(PySeries {
42            inner: Series::new(name, array),
43        })
44    }
45
46    /// Get the name of the series
47    #[getter]
48    fn name(&self) -> String {
49        self.inner.name().to_string()
50    }
51
52    /// Get the length of the series
53    #[getter]
54    fn len(&self) -> usize {
55        self.inner.len()
56    }
57
58    /// Check if series is empty
59    fn is_empty(&self) -> bool {
60        self.inner.is_empty()
61    }
62
63    /// Compute mean of series
64    fn mean(&self) -> Option<f64> {
65        self.inner.mean()
66    }
67
68    /// Compute standard deviation
69    fn std(&self, ddof: f64) -> Option<f64> {
70        self.inner.std(ddof)
71    }
72
73    /// Compute variance
74    fn var(&self, ddof: f64) -> Option<f64> {
75        self.inner.var(ddof)
76    }
77
78    /// Get minimum value
79    fn min(&self) -> Option<f64> {
80        self.inner.min()
81    }
82
83    /// Get maximum value
84    fn max(&self) -> Option<f64> {
85        self.inner.max()
86    }
87
88    /// Compute quantile
89    fn quantile(&self, q: f64) -> Option<f64> {
90        self.inner.quantile(q)
91    }
92
93    /// Convert to Python list
94    fn to_list(&self) -> Vec<f64> {
95        self.inner.data().to_vec()
96    }
97
98    /// String representation
99    fn __repr__(&self) -> String {
100        format!(
101            "Series(name='{}', len={})",
102            self.inner.name(),
103            self.inner.len()
104        )
105    }
106}
107
108/// Python wrapper for StatOxide DataFrame
109#[pyclass(name = "DataFrame")]
110struct PyDataFrame {
111    inner: DataFrame,
112}
113
114#[pymethods]
115impl PyDataFrame {
116    /// Create a new DataFrame from a dictionary of columns
117    #[new]
118    fn new(data: HashMap<String, Vec<f64>>) -> PyResult<Self> {
119        let mut columns = HashMap::new();
120
121        for (name, values) in data {
122            let array = Array1::from_vec(values);
123            columns.insert(name.clone(), Series::new(name, array));
124        }
125
126        match DataFrame::from_series(columns) {
127            Ok(df) => Ok(PyDataFrame { inner: df }),
128            Err(e) => Err(PyValueError::new_err(format!(
129                "Error creating DataFrame: {:?}",
130                e
131            ))),
132        }
133    }
134
135    /// Get number of rows
136    #[getter]
137    fn n_rows(&self) -> usize {
138        self.inner.n_rows()
139    }
140
141    /// Get number of columns
142    #[getter]
143    fn n_cols(&self) -> usize {
144        self.inner.n_cols()
145    }
146
147    /// Get column names
148    fn columns(&self) -> Vec<String> {
149        self.inner.column_names()
150    }
151
152    /// Get a column by name
153    fn get_column(&self, name: &str) -> PyResult<PySeries> {
154        match self.inner.column(name) {
155            Some(series) => Ok(PySeries {
156                inner: series.clone(),
157            }),
158            None => Err(PyValueError::new_err(format!(
159                "Column '{}' not found",
160                name
161            ))),
162        }
163    }
164
165    /// Add a column to the DataFrame
166    fn with_column(&mut self, series: &PySeries) -> PyResult<()> {
167        let df = self
168            .inner
169            .clone()
170            .with_column(series.inner.clone())
171            .map_err(|e| PyValueError::new_err(format!("Error adding column: {:?}", e)))?;
172        self.inner = df;
173        Ok(())
174    }
175
176    /// String representation
177    fn __repr__(&self) -> String {
178        format!(
179            "DataFrame(rows={}, cols={})",
180            self.inner.n_rows(),
181            self.inner.n_cols()
182        )
183    }
184}
185
186/// Python wrapper for StatOxide Formula
187#[pyclass(name = "Formula")]
188struct PyFormula {
189    inner: Formula,
190}
191
192#[pymethods]
193impl PyFormula {
194    /// Parse a formula string
195    #[new]
196    fn new(formula: String) -> PyResult<Self> {
197        match Formula::parse(&formula) {
198            Ok(f) => Ok(PyFormula { inner: f }),
199            Err(e) => Err(PyValueError::new_err(format!(
200                "Error parsing formula: {:?}",
201                e
202            ))),
203        }
204    }
205
206    /// Get all variable names in the formula
207    fn variables(&self) -> Vec<String> {
208        self.inner.variables().into_iter().collect()
209    }
210
211    /// String representation
212    fn __repr__(&self) -> String {
213        format!("Formula({:?})", self.inner)
214    }
215}
216
217/// Python wrapper for GLM Family
218#[pyclass(name = "Family")]
219#[derive(Clone)]
220struct PyFamily {
221    inner: Family,
222}
223
224#[pymethods]
225impl PyFamily {
226    /// Create Gaussian family
227    #[staticmethod]
228    fn gaussian() -> Self {
229        PyFamily {
230            inner: Family::Gaussian,
231        }
232    }
233
234    /// Create Binomial family
235    #[staticmethod]
236    fn binomial() -> Self {
237        PyFamily {
238            inner: Family::Binomial,
239        }
240    }
241
242    /// Create Poisson family
243    #[staticmethod]
244    fn poisson() -> Self {
245        PyFamily {
246            inner: Family::Poisson,
247        }
248    }
249
250    /// Create Gamma family
251    #[staticmethod]
252    fn gamma() -> Self {
253        PyFamily {
254            inner: Family::Gamma,
255        }
256    }
257
258    /// Create Inverse Gaussian family
259    #[staticmethod]
260    fn inverse_gaussian() -> Self {
261        PyFamily {
262            inner: Family::InverseGaussian,
263        }
264    }
265
266    /// Get family name
267    fn name(&self) -> String {
268        self.inner.name().to_string()
269    }
270
271    /// String representation
272    fn __repr__(&self) -> String {
273        format!("Family({})", self.name())
274    }
275}
276
277/// Python wrapper for GLM Link function
278#[pyclass(name = "Link")]
279#[derive(Clone)]
280struct PyLink {
281    inner: Link,
282}
283
284#[pymethods]
285impl PyLink {
286    /// Identity link: η = μ
287    #[staticmethod]
288    fn identity() -> Self {
289        PyLink {
290            inner: Link::Identity,
291        }
292    }
293
294    /// Logit link: η = log(μ/(1-μ))
295    #[staticmethod]
296    fn logit() -> Self {
297        PyLink { inner: Link::Logit }
298    }
299
300    /// Probit link: η = Φ⁻¹(μ)
301    #[staticmethod]
302    fn probit() -> Self {
303        PyLink {
304            inner: Link::Probit,
305        }
306    }
307
308    /// Log link: η = log(μ)
309    #[staticmethod]
310    fn log() -> Self {
311        PyLink { inner: Link::Log }
312    }
313
314    /// Inverse link: η = 1/μ
315    #[staticmethod]
316    fn inverse() -> Self {
317        PyLink {
318            inner: Link::Inverse,
319        }
320    }
321
322    /// String representation
323    fn __repr__(&self) -> String {
324        match self.inner {
325            Link::Identity => "Link(identity)".to_string(),
326            Link::Logit => "Link(logit)".to_string(),
327            Link::Probit => "Link(probit)".to_string(),
328            Link::Cloglog => "Link(cloglog)".to_string(),
329            Link::Log => "Link(log)".to_string(),
330            Link::Inverse => "Link(inverse)".to_string(),
331            Link::InverseSquare => "Link(inverse-square)".to_string(),
332            Link::Sqrt => "Link(sqrt)".to_string(),
333        }
334    }
335}
336
337/// Python wrapper for GLM model builder
338#[pyclass(name = "GLMBuilder")]
339struct PyGLMBuilder {
340    inner: Option<GLMModelBuilder>,
341}
342
343#[pymethods]
344impl PyGLMBuilder {
345    /// Create a new GLM builder
346    #[new]
347    fn new() -> Self {
348        PyGLMBuilder {
349            inner: Some(GLMModelBuilder::new()),
350        }
351    }
352
353    /// Set the distribution family
354    fn family(&mut self, family: &PyFamily) -> PyResult<()> {
355        if let Some(inner) = self.inner.take() {
356            self.inner = Some(inner.family(family.inner));
357        }
358        Ok(())
359    }
360
361    /// Set the link function
362    fn link(&mut self, link: &PyLink) -> PyResult<()> {
363        if let Some(inner) = self.inner.take() {
364            self.inner = Some(inner.link(link.inner));
365        }
366        Ok(())
367    }
368
369    /// Set whether to include intercept
370    fn intercept(&mut self, intercept: bool) -> PyResult<()> {
371        if let Some(inner) = self.inner.take() {
372            self.inner = Some(inner.intercept(intercept));
373        }
374        Ok(())
375    }
376
377    /// Set maximum iterations
378    fn max_iter(&mut self, max_iter: usize) -> PyResult<()> {
379        if let Some(inner) = self.inner.take() {
380            self.inner = Some(inner.max_iter(max_iter));
381        }
382        Ok(())
383    }
384
385    /// Set convergence tolerance
386    fn tol(&mut self, tol: f64) -> PyResult<()> {
387        if let Some(inner) = self.inner.take() {
388            self.inner = Some(inner.tol(tol));
389        }
390        Ok(())
391    }
392
393    /// Set fixed scale parameter
394    fn scale(&mut self, scale: f64) -> PyResult<()> {
395        if let Some(inner) = self.inner.take() {
396            self.inner = Some(inner.scale(scale));
397        }
398        Ok(())
399    }
400
401    /// Build the GLM model
402    fn build(&mut self) -> PyResult<PyGLM> {
403        if let Some(inner) = self.inner.take() {
404            Ok(PyGLM {
405                inner: inner.build(),
406            })
407        } else {
408            Err(PyRuntimeError::new_err("GLM builder not available"))
409        }
410    }
411}
412
413/// Python wrapper for GLM model
414#[pyclass(name = "GLM")]
415struct PyGLM {
416    inner: RustGLM,
417}
418
419#[pymethods]
420impl PyGLM {
421    /// Create a new GLM builder
422    #[staticmethod]
423    fn new() -> PyGLMBuilder {
424        PyGLMBuilder::new()
425    }
426
427    /// Fit the GLM using formula and DataFrame
428    fn fit(&self, formula: &str, data: &PyDataFrame) -> PyResult<PyGLMResults> {
429        match self.inner.fit(formula, &data.inner) {
430            Ok(results) => Ok(PyGLMResults { inner: results }),
431            Err(e) => Err(PyRuntimeError::new_err(format!(
432                "GLM fitting failed: {:?}",
433                e
434            ))),
435        }
436    }
437
438    /// Fit the GLM with design matrix X and response y
439    fn fit_matrix(&self, x: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<PyGLMResults> {
440        // Convert to DataFrame
441        let n_rows = x.len();
442        if n_rows == 0 {
443            return Err(PyValueError::new_err("X must have at least one row"));
444        }
445        if n_rows != y.len() {
446            return Err(PyValueError::new_err(
447                "X and y must have same number of rows",
448            ));
449        }
450
451        let n_cols = x[0].len();
452
453        // Check all rows have same number of columns
454        for (i, row) in x.iter().enumerate() {
455            if row.len() != n_cols {
456                return Err(PyValueError::new_err(format!(
457                    "Row {} has {} columns, expected {}",
458                    i,
459                    row.len(),
460                    n_cols
461                )));
462            }
463        }
464
465        // Create column names
466        let mut col_names = Vec::new();
467        for i in 0..n_cols {
468            col_names.push(format!("x{}", i));
469        }
470        col_names.push("y".to_string());
471
472        // Create Series for each column
473        let mut columns = HashMap::new();
474
475        // Create X columns
476        for i in 0..n_cols {
477            let mut col_data = Vec::with_capacity(n_rows);
478            for row in &x {
479                col_data.push(row[i]);
480            }
481            let series = Series::new(format!("x{}", i), ndarray::Array1::from_vec(col_data));
482            columns.insert(format!("x{}", i), series);
483        }
484
485        // Create y column
486        let y_series = Series::new("y".to_string(), ndarray::Array1::from_vec(y.clone()));
487        columns.insert("y".to_string(), y_series);
488
489        // Create DataFrame from Series
490        let df = DataFrame::from_series(columns)
491            .map_err(|e| PyValueError::new_err(format!("Failed to create DataFrame: {}", e)))?;
492
493        // Create formula: y ~ x0 + x1 + ... + x{n-1}
494        let formula_str = if n_cols == 0 {
495            "y ~ 1".to_string()
496        } else {
497            let mut formula = "y ~ ".to_string();
498            for i in 0..n_cols {
499                formula.push_str(&format!("x{}", i));
500                if i < n_cols - 1 {
501                    formula.push_str(" + ");
502                }
503            }
504            formula
505        };
506
507        // Use the GLM's fit method (takes formula string and DataFrame)
508        match self.inner.fit(&formula_str, &df) {
509            Ok(results) => Ok(PyGLMResults { inner: results }),
510            Err(e) => Err(PyRuntimeError::new_err(format!(
511                "GLM fitting failed: {:?}",
512                e
513            ))),
514        }
515    }
516}
517
518/// Python wrapper for GLM results
519#[pyclass(name = "GLMResults")]
520struct PyGLMResults {
521    inner: GLMResults,
522}
523
524#[pymethods]
525impl PyGLMResults {
526    /// Get coefficients
527    #[getter]
528    fn coefficients(&self) -> Vec<f64> {
529        self.inner.coefficients.to_vec()
530    }
531
532    /// Get standard errors
533    #[getter]
534    fn std_errors(&self) -> Vec<f64> {
535        self.inner.std_errors.to_vec()
536    }
537
538    /// Get z-values (Wald test statistics)
539    #[getter]
540    fn z_values(&self) -> Vec<f64> {
541        self.inner.z_values.to_vec()
542    }
543
544    /// Get p-values
545    #[getter]
546    fn p_values(&self) -> Vec<f64> {
547        self.inner.p_values.to_vec()
548    }
549
550    /// Get deviance
551    #[getter]
552    fn deviance(&self) -> f64 {
553        self.inner.deviance
554    }
555
556    /// Get null deviance
557    #[getter]
558    fn null_deviance(&self) -> f64 {
559        self.inner.null_deviance
560    }
561
562    /// Get AIC
563    #[getter]
564    fn aic(&self) -> f64 {
565        self.inner.aic
566    }
567
568    /// Get BIC
569    #[getter]
570    fn bic(&self) -> f64 {
571        self.inner.bic
572    }
573
574    /// Get degrees of freedom
575    #[getter]
576    fn df_residual(&self) -> usize {
577        self.inner.df_residual
578    }
579
580    /// Get degrees of freedom for null model
581    #[getter]
582    fn df_null(&self) -> usize {
583        self.inner.df_null
584    }
585
586    /// Get scale parameter
587    #[getter]
588    fn scale(&self) -> f64 {
589        self.inner.scale
590    }
591
592    /// Get fitted values
593    #[getter]
594    fn fitted_values(&self) -> Vec<f64> {
595        self.inner.fitted_values.to_vec()
596    }
597
598    /// Get Pearson residuals
599    #[getter]
600    fn pearson_residuals(&self) -> Vec<f64> {
601        self.inner.pearson_residuals.to_vec()
602    }
603
604    /// Get raw residuals (response scale)
605    #[getter]
606    fn residuals(&self) -> Vec<f64> {
607        self.inner.residuals.to_vec()
608    }
609
610    /// Get diagonal of hat matrix (leverage values)
611    #[getter]
612    fn hat_matrix_diag(&self) -> Vec<f64> {
613        self.inner.hat_matrix_diag.to_vec()
614    }
615
616    /// Get number of iterations
617    #[getter]
618    fn iterations(&self) -> usize {
619        self.inner.iterations
620    }
621
622    /// Check if model converged
623    #[getter]
624    fn converged(&self) -> bool {
625        self.inner.converged
626    }
627
628    /// Predict using the fitted model
629    fn predict(&self, x: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
630        let n_rows = x.len();
631        if n_rows == 0 {
632            return Ok(Vec::new());
633        }
634        let n_cols = x[0].len();
635
636        // Check dimensions
637        if n_cols != self.inner.coefficients.len() {
638            return Err(PyValueError::new_err(format!(
639                "X has {} columns but model has {} coefficients",
640                n_cols,
641                self.inner.coefficients.len()
642            )));
643        }
644
645        // Simple linear prediction: y = Xβ
646        let mut predictions = Vec::with_capacity(n_rows);
647        for row in x {
648            if row.len() != n_cols {
649                return Err(PyValueError::new_err(
650                    "All rows must have same number of columns",
651                ));
652            }
653            let mut pred = 0.0;
654            for (i, &xi) in row.iter().enumerate() {
655                pred += xi * self.inner.coefficients[i];
656            }
657            // TODO: Apply inverse link function based on family and link
658            predictions.push(pred);
659        }
660
661        Ok(predictions)
662    }
663
664    /// Get summary string
665    fn summary(&self) -> String {
666        // Simple summary for now
667        format!(
668            "GLM Results:\n  Coefficients: {:?}\n  AIC: {:.2}\n  BIC: {:.2}\n  Deviance: {:.2}\n  Scale: {:.2}",
669            self.coefficients(),
670            self.aic(),
671            self.bic(),
672            self.deviance(),
673            self.scale()
674        )
675    }
676}
677
678/// Python wrapper for TimeSeries
679#[pyclass(name = "TimeSeries")]
680struct PyTimeSeries {
681    inner: TimeSeries,
682}
683
684#[pymethods]
685impl PyTimeSeries {
686    /// Create a TimeSeries from a DataFrame
687    #[staticmethod]
688    fn from_dataframe(df: &PyDataFrame, value_col: &str, date_col: &str) -> PyResult<Self> {
689        match TimeSeries::from_dataframe(&df.inner, value_col, date_col) {
690            Ok(ts) => Ok(PyTimeSeries { inner: ts }),
691            Err(e) => Err(PyRuntimeError::new_err(format!(
692                "Failed to create TimeSeries: {:?}",
693                e
694            ))),
695        }
696    }
697
698    /// Create a TimeSeries from vectors
699    #[staticmethod]
700    fn from_vectors(values: Vec<f64>, _dates: Vec<String>) -> PyResult<Self> {
701        // Simple implementation - use index as timestamps
702        // In practice, would parse dates string to timestamps
703        let timestamps: Vec<i64> = (0..values.len() as i64).collect();
704        let values_array = ndarray::Array1::from_vec(values);
705
706        match TimeSeries::new("series", timestamps, values_array, None) {
707            Ok(ts) => Ok(PyTimeSeries { inner: ts }),
708            Err(e) => Err(PyRuntimeError::new_err(format!(
709                "Failed to create TimeSeries: {:?}",
710                e
711            ))),
712        }
713    }
714
715    /// Get values
716    #[getter]
717    fn values(&self) -> Vec<f64> {
718        self.inner.values().to_vec()
719    }
720
721    /// Get length
722    #[getter]
723    fn len(&self) -> usize {
724        self.inner.len()
725    }
726
727    /// Check if empty
728    fn is_empty(&self) -> bool {
729        self.inner.is_empty()
730    }
731
732    /// Compute mean
733    fn mean(&self) -> Option<f64> {
734        Some(self.inner.stats().mean)
735    }
736
737    /// Compute standard deviation
738    fn std(&self, _ddof: f64) -> Option<f64> {
739        // Note: ddof is ignored for now, uses population std
740        Some(self.inner.stats().std)
741    }
742
743    /// Compute variance
744    fn var(&self, _ddof: f64) -> Option<f64> {
745        // Note: ddof is ignored for now, uses population variance
746        Some(self.inner.stats().variance)
747    }
748
749    /// String representation
750    fn __repr__(&self) -> String {
751        format!("TimeSeries(len={})", self.len())
752    }
753}
754
755/// Python wrapper for ARIMA builder
756#[pyclass(name = "ARIMA")]
757struct PyARIMA {
758    builder: Option<so_tsa::arima::ARIMABuilder>,
759}
760
761#[pymethods]
762impl PyARIMA {
763    /// Create a new ARIMA model
764    #[new]
765    fn new(p: usize, d: usize, q: usize) -> Self {
766        use so_tsa::arima::ARIMABuilder;
767        PyARIMA {
768            builder: Some(ARIMABuilder::new(p, d, q)),
769        }
770    }
771
772    /// Set seasonal parameters
773    fn seasonal(&mut self, _p: usize, _d: usize, _q: usize, _s: usize) -> PyResult<()> {
774        // SARIMA not fully implemented in Python bindings yet
775        // For now, just do nothing
776        // TODO: Implement proper SARIMA support
777        Ok(())
778    }
779
780    /// Include constant term
781    fn with_constant(&mut self, include: bool) -> PyResult<()> {
782        if let Some(builder) = self.builder.take() {
783            self.builder = Some(builder.with_constant(include));
784        }
785        Ok(())
786    }
787
788    /// Set estimation method
789    fn method(&mut self, method: String) -> PyResult<()> {
790        use so_tsa::arima::EstimationMethod;
791        if let Some(builder) = self.builder.take() {
792            let est_method = match method.to_lowercase().as_str() {
793                "css" => EstimationMethod::CSS,
794                "ml" => EstimationMethod::ML,
795                "exactml" => EstimationMethod::ExactML,
796                _ => EstimationMethod::CSS,
797            };
798            self.builder = Some(builder.method(est_method));
799        }
800        Ok(())
801    }
802
803    /// Set maximum iterations
804    fn max_iter(&mut self, max_iter: usize) -> PyResult<()> {
805        if let Some(builder) = self.builder.take() {
806            self.builder = Some(builder.max_iter(max_iter));
807        }
808        Ok(())
809    }
810
811    /// Set convergence tolerance
812    fn tol(&mut self, tol: f64) -> PyResult<()> {
813        if let Some(builder) = self.builder.take() {
814            self.builder = Some(builder.tol(tol));
815        }
816        Ok(())
817    }
818
819    /// Fit the ARIMA model
820    ///
821    /// Accepts multiple input types:
822    /// - PyTimeSeries object
823    /// - List of floats (Vec<f64>)
824    /// - Any object convertible to a list of floats
825    fn fit(&mut self, py: Python, data: Py<PyAny>) -> PyResult<PyARIMAResults> {
826        // Get reference to Python object
827        let data_ref = data.bind(py);
828
829        // Try to convert input to TimeSeries
830        let timeseries = if let Ok(ts) = data_ref.extract::<PyRef<PyTimeSeries>>() {
831            // Already a TimeSeries
832            ts.inner.clone()
833        } else if let Ok(vec) = data_ref.extract::<Vec<f64>>() {
834            // Vector of floats - create TimeSeries with index as timestamps
835            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
836            let values_array = ndarray::Array1::from_vec(vec);
837            TimeSeries::new("series", timestamps, values_array, None).map_err(|e| {
838                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
839            })?
840        } else if let Ok(list) = data_ref.cast::<PyList>() {
841            // Python list - extract as floats
842            let mut vec = Vec::with_capacity(list.len());
843            for i in 0..list.len() {
844                let item = list.get_item(i)?;
845                let val: f64 = item
846                    .extract()
847                    .map_err(|_| PyValueError::new_err("List must contain only numeric values"))?;
848                vec.push(val);
849            }
850            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
851            let values_array = ndarray::Array1::from_vec(vec);
852            TimeSeries::new("series", timestamps, values_array, None).map_err(|e| {
853                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
854            })?
855        } else {
856            return Err(PyValueError::new_err(
857                "Input must be a TimeSeries, list of floats, or convertible to list of floats",
858            ));
859        };
860
861        if let Some(builder) = self.builder.take() {
862            match builder.fit(&timeseries) {
863                Ok(results) => Ok(PyARIMAResults { inner: results }),
864                Err(e) => Err(PyRuntimeError::new_err(format!(
865                    "ARIMA fitting failed: {:?}",
866                    e
867                ))),
868            }
869        } else {
870            Err(PyRuntimeError::new_err("ARIMA builder not available"))
871        }
872    }
873}
874
875/// Python wrapper for ARIMA results
876#[pyclass(name = "ARIMAResults")]
877struct PyARIMAResults {
878    inner: ARIMAResults,
879}
880
881#[pymethods]
882impl PyARIMAResults {
883    /// Get AR coefficients
884    #[getter]
885    fn ar_coef(&self) -> Option<Vec<f64>> {
886        self.inner.ar_coef.as_ref().map(|coef| coef.to_vec())
887    }
888
889    /// Get MA coefficients
890    #[getter]
891    fn ma_coef(&self) -> Option<Vec<f64>> {
892        self.inner.ma_coef.as_ref().map(|coef| coef.to_vec())
893    }
894
895    /// Get constant term
896    #[getter]
897    fn constant(&self) -> Option<f64> {
898        self.inner.constant
899    }
900
901    /// Get AIC
902    #[getter]
903    fn aic(&self) -> f64 {
904        self.inner.aic
905    }
906
907    /// Get BIC
908    #[getter]
909    fn bic(&self) -> f64 {
910        self.inner.bic
911    }
912
913    /// Get log-likelihood
914    #[getter]
915    fn log_likelihood(&self) -> f64 {
916        self.inner.log_likelihood
917    }
918
919    /// Get sigma2 (innovation variance)
920    #[getter]
921    fn sigma2(&self) -> f64 {
922        self.inner.sigma2
923    }
924
925    /// Get number of observations
926    #[getter]
927    fn n_obs(&self) -> usize {
928        self.inner.n_obs
929    }
930
931    /// Forecast future values
932    fn forecast(&self, steps: usize) -> Vec<f64> {
933        // Simple forecast - in practice would use proper forecasting method
934        // For now, return last value repeated
935        let last_value = if self.inner.fitted.len() > 0 {
936            self.inner.fitted[self.inner.fitted.len() - 1]
937        } else {
938            0.0
939        };
940        vec![last_value; steps]
941    }
942
943    /// Get fitted values
944    #[getter]
945    fn fitted(&self) -> Vec<f64> {
946        self.inner.fitted.to_vec()
947    }
948
949    /// Get residuals
950    #[getter]
951    fn residuals(&self) -> Vec<f64> {
952        self.inner.residuals.to_vec()
953    }
954
955    /// Get summary string
956    fn summary(&self) -> String {
957        format!(
958            "ARIMA Results:\n  AIC: {:.2}\n  BIC: {:.2}\n  Log-Likelihood: {:.2}\n  Sigma2: {:.4}",
959            self.aic(),
960            self.bic(),
961            self.log_likelihood(),
962            self.sigma2()
963        )
964    }
965}
966
967/// Python wrapper for GARCH model
968#[pyclass(name = "GARCH")]
969struct PyGARCH {
970    builder: Option<so_tsa::garch::GARCHBuilder>,
971}
972
973#[pymethods]
974impl PyGARCH {
975    /// Create a new GARCH model
976    #[new]
977    fn new(p: usize, q: usize) -> Self {
978        use so_tsa::garch::GARCHBuilder;
979        PyGARCH {
980            builder: Some(GARCHBuilder::new(p, q)),
981        }
982    }
983
984    /// Create an ARCH model (GARCH with p=0)
985    #[staticmethod]
986    fn arch(q: usize) -> Self {
987        use so_tsa::garch::GARCHBuilder;
988        PyGARCH {
989            builder: Some(GARCHBuilder::arch(q)),
990        }
991    }
992
993    /// Set distribution for innovations
994    fn distribution(&mut self, distribution: String) -> PyResult<()> {
995        use so_tsa::garch::GARCHDistribution;
996        if let Some(builder) = self.builder.take() {
997            let dist = match distribution.to_lowercase().as_str() {
998                "normal" => GARCHDistribution::Normal,
999                "t" | "studentst" => GARCHDistribution::StudentsT(5.0), // Default df=5.0
1000                "ged" => GARCHDistribution::GED(1.5),                   // Default shape=1.5
1001                _ => GARCHDistribution::Normal,
1002            };
1003            self.builder = Some(builder.distribution(dist));
1004        }
1005        Ok(())
1006    }
1007
1008    /// Fit the GARCH model to residuals
1009    ///
1010    /// Accepts multiple input types:
1011    /// - List of floats (Vec<f64>) - residuals
1012    /// - PyTimeSeries object
1013    /// - Any object convertible to a list of floats
1014    fn fit(&mut self, py: Python, data: Py<PyAny>) -> PyResult<PyGARCHResults> {
1015        // Get reference to Python object
1016        let data_ref = data.bind(py);
1017
1018        // Try to convert input to TimeSeries
1019        let timeseries = if let Ok(ts) = data_ref.extract::<PyRef<PyTimeSeries>>() {
1020            // Already a TimeSeries
1021            ts.inner.clone()
1022        } else if let Ok(vec) = data_ref.extract::<Vec<f64>>() {
1023            // Vector of floats - create TimeSeries with index as timestamps
1024            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
1025            let values_array = ndarray::Array1::from_vec(vec);
1026            TimeSeries::new("residuals", timestamps, values_array, None).map_err(|e| {
1027                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
1028            })?
1029        } else if let Ok(list) = data_ref.cast::<PyList>() {
1030            // Python list - extract as floats
1031            let mut vec = Vec::with_capacity(list.len());
1032            for i in 0..list.len() {
1033                let item = list.get_item(i)?;
1034                let val: f64 = item
1035                    .extract()
1036                    .map_err(|_| PyValueError::new_err("List must contain only numeric values"))?;
1037                vec.push(val);
1038            }
1039            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
1040            let values_array = ndarray::Array1::from_vec(vec);
1041            TimeSeries::new("residuals", timestamps, values_array, None).map_err(|e| {
1042                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
1043            })?
1044        } else {
1045            return Err(PyValueError::new_err(
1046                "Input must be a TimeSeries, list of floats, or convertible to list of floats",
1047            ));
1048        };
1049
1050        if let Some(builder) = self.builder.take() {
1051            match builder.fit(&timeseries) {
1052                Ok(results) => Ok(PyGARCHResults { inner: results }),
1053                Err(e) => Err(PyRuntimeError::new_err(format!(
1054                    "GARCH fitting failed: {:?}",
1055                    e
1056                ))),
1057            }
1058        } else {
1059            Err(PyRuntimeError::new_err("GARCH builder not available"))
1060        }
1061    }
1062}
1063
1064/// Python wrapper for GARCH results
1065#[pyclass(name = "GARCHResults")]
1066struct PyGARCHResults {
1067    inner: GARCHResults,
1068}
1069
1070#[pymethods]
1071impl PyGARCHResults {
1072    /// Get omega (constant in variance equation)
1073    #[getter]
1074    fn omega(&self) -> f64 {
1075        self.inner.omega
1076    }
1077
1078    /// Get ARCH coefficients (α₁, ..., α_q)
1079    #[getter]
1080    fn arch_coef(&self) -> Vec<f64> {
1081        self.inner.arch_coef.to_vec()
1082    }
1083
1084    /// Get GARCH coefficients (β₁, ..., β_p)
1085    #[getter]
1086    fn garch_coef(&self) -> Vec<f64> {
1087        self.inner.garch_coef.to_vec()
1088    }
1089
1090    /// Get mu (constant in mean equation, if included)
1091    #[getter]
1092    fn mu(&self) -> Option<f64> {
1093        self.inner.mu
1094    }
1095
1096    /// Get degrees of freedom (for t/GED distributions)
1097    #[getter]
1098    fn df(&self) -> Option<f64> {
1099        self.inner.df
1100    }
1101
1102    /// Get AIC
1103    #[getter]
1104    fn aic(&self) -> f64 {
1105        self.inner.aic
1106    }
1107
1108    /// Get BIC
1109    #[getter]
1110    fn bic(&self) -> f64 {
1111        self.inner.bic
1112    }
1113
1114    /// Get log-likelihood
1115    #[getter]
1116    fn log_likelihood(&self) -> f64 {
1117        self.inner.log_likelihood
1118    }
1119
1120    /// Get number of observations
1121    #[getter]
1122    fn n_obs(&self) -> usize {
1123        self.inner.n_obs
1124    }
1125
1126    /// Get residuals (εₜ)
1127    #[getter]
1128    fn residuals(&self) -> Vec<f64> {
1129        self.inner.residuals.to_vec()
1130    }
1131
1132    /// Get conditional variances (σₜ²)
1133    #[getter]
1134    fn conditional_variances(&self) -> Vec<f64> {
1135        self.inner.conditional_variances.to_vec()
1136    }
1137
1138    /// Get standardized residuals (zₜ = εₜ/σₜ)
1139    #[getter]
1140    fn standardized_residuals(&self) -> Vec<f64> {
1141        self.inner.standardized_residuals.to_vec()
1142    }
1143
1144    /// Get summary string
1145    fn summary(&self) -> String {
1146        format!(
1147            "GARCH Results:\n  AIC: {:.2}\n  BIC: {:.2}\n  Log-Likelihood: {:.2}",
1148            self.aic(),
1149            self.bic(),
1150            self.log_likelihood()
1151        )
1152    }
1153}
1154
1155/// StatOxide Python module
1156#[pymodule]
1157#[pyo3(name = "statoxide")]
1158fn statoxide(m: &Bound<'_, PyModule>) -> PyResult<()> {
1159    // Register core classes
1160    m.add_class::<PySeries>()?;
1161    m.add_class::<PyDataFrame>()?;
1162    m.add_class::<PyFormula>()?;
1163
1164    // Register GLM classes
1165    m.add_class::<PyFamily>()?;
1166    m.add_class::<PyLink>()?;
1167    m.add_class::<PyGLMBuilder>()?;
1168    m.add_class::<PyGLM>()?;
1169    m.add_class::<PyGLMResults>()?;
1170
1171    // Register TSA classes
1172    m.add_class::<PyTimeSeries>()?;
1173    m.add_class::<PyARIMA>()?;
1174    m.add_class::<PyARIMAResults>()?;
1175    m.add_class::<PyGARCH>()?;
1176    m.add_class::<PyGARCHResults>()?;
1177
1178    // Basic functions module
1179    let stats_module = PyModule::new(m.py(), "stats")?;
1180    stats_module.add_function(wrap_pyfunction!(mean, &stats_module)?)?;
1181    stats_module.add_function(wrap_pyfunction!(std_dev, &stats_module)?)?;
1182    stats_module.add_function(wrap_pyfunction!(correlation, &stats_module)?)?;
1183    stats_module.add_function(wrap_pyfunction!(descriptive_summary, &stats_module)?)?;
1184    // Statistical tests
1185    stats_module.add_function(wrap_pyfunction!(t_test_one_sample, &stats_module)?)?;
1186    stats_module.add_function(wrap_pyfunction!(t_test_two_sample, &stats_module)?)?;
1187    stats_module.add_function(wrap_pyfunction!(t_test_paired, &stats_module)?)?;
1188    stats_module.add_function(wrap_pyfunction!(
1189        chi_square_test_independence,
1190        &stats_module
1191    )?)?;
1192    stats_module.add_function(wrap_pyfunction!(anova_one_way, &stats_module)?)?;
1193    stats_module.add_function(wrap_pyfunction!(shapiro_wilk_test, &stats_module)?)?;
1194    m.add_submodule(&stats_module)?;
1195
1196    // Models module
1197    let models_module = PyModule::new(m.py(), "models")?;
1198    models_module.add_function(wrap_pyfunction!(linear_regression, &models_module)?)?;
1199    models_module.add_function(wrap_pyfunction!(mixed_effects, &models_module)?)?;
1200    models_module.add_class::<PyFamily>()?;
1201    models_module.add_class::<PyLink>()?;
1202    models_module.add_class::<PyGLMBuilder>()?;
1203    models_module.add_class::<PyGLM>()?;
1204    models_module.add_class::<PyGLMResults>()?;
1205    // Robust regression classes
1206    models_module.add_class::<PyMEstimator>()?;
1207    models_module.add_class::<PyRobustRegressionResults>()?;
1208    models_module.add_class::<PyLeastTrimmedSquares>()?;
1209    // Nonparametric regression classes
1210    models_module.add_class::<PyKernelRegression>()?;
1211    models_module.add_class::<PyKernelRegressionResults>()?;
1212    models_module.add_class::<PyLocalRegression>()?;
1213    models_module.add_class::<PyLocalRegressionResults>()?;
1214    m.add_submodule(&models_module)?;
1215
1216    // TSA module
1217    let tsa_module = PyModule::new(m.py(), "tsa")?;
1218    tsa_module.add_function(wrap_pyfunction!(fit_arima, &tsa_module)?)?;
1219    tsa_module.add_class::<PyTimeSeries>()?;
1220    tsa_module.add_class::<PyARIMA>()?;
1221    tsa_module.add_class::<PyARIMAResults>()?;
1222    tsa_module.add_class::<PyGARCH>()?;
1223    tsa_module.add_class::<PyGARCHResults>()?;
1224    m.add_submodule(&tsa_module)?;
1225
1226    // Utilities module
1227    let utils_module = PyModule::new(m.py(), "utils")?;
1228    utils_module.add_function(wrap_pyfunction!(train_test_split, &utils_module)?)?;
1229    m.add_submodule(&utils_module)?;
1230
1231    // Top-level functions
1232    m.add_function(wrap_pyfunction!(version, m)?)?;
1233
1234    // Add commonly used functions to top level for convenience
1235    m.add_function(wrap_pyfunction!(mean, m)?)?;
1236    m.add_function(wrap_pyfunction!(std_dev, m)?)?;
1237    m.add_function(wrap_pyfunction!(correlation, m)?)?;
1238    m.add_function(wrap_pyfunction!(descriptive_summary, m)?)?;
1239    m.add_function(wrap_pyfunction!(train_test_split, m)?)?;
1240
1241    Ok(())
1242}
1243
1244/// Compute mean of data
1245#[pyfunction]
1246fn mean(data: Vec<f64>) -> PyResult<f64> {
1247    if data.is_empty() {
1248        return Ok(f64::NAN);
1249    }
1250    Ok(data.iter().sum::<f64>() / data.len() as f64)
1251}
1252
1253/// Compute standard deviation of data
1254#[pyfunction]
1255fn std_dev(data: Vec<f64>) -> PyResult<f64> {
1256    if data.len() < 2 {
1257        return Ok(f64::NAN);
1258    }
1259    let mean_val = mean(data.clone())?;
1260    let variance =
1261        data.iter().map(|&x| (x - mean_val).powi(2)).sum::<f64>() / (data.len() as f64 - 1.0);
1262    Ok(variance.sqrt())
1263}
1264
1265/// Compute correlation between two variables
1266#[pyfunction]
1267fn correlation(x: Vec<f64>, y: Vec<f64>) -> PyResult<f64> {
1268    if x.len() != y.len() || x.len() < 2 {
1269        return Ok(f64::NAN);
1270    }
1271
1272    let x_mean = x.iter().sum::<f64>() / x.len() as f64;
1273    let y_mean = y.iter().sum::<f64>() / y.len() as f64;
1274
1275    let covariance = x
1276        .iter()
1277        .zip(y.iter())
1278        .map(|(&xi, &yi)| (xi - x_mean) * (yi - y_mean))
1279        .sum::<f64>()
1280        / (x.len() as f64 - 1.0);
1281
1282    let x_std = std_dev(x.clone())?;
1283    let y_std = std_dev(y.clone())?;
1284
1285    if x_std == 0.0 || y_std == 0.0 {
1286        Ok(0.0)
1287    } else {
1288        Ok(covariance / (x_std * y_std))
1289    }
1290}
1291
1292/// Compute descriptive statistics summary
1293#[pyfunction]
1294fn descriptive_summary(py: Python, data: Vec<f64>) -> PyResult<Py<PyDict>> {
1295    let dict = PyDict::new(py);
1296    dict.set_item("count", data.len())?;
1297
1298    if !data.is_empty() {
1299        let mean_val = data.iter().sum::<f64>() / data.len() as f64;
1300        dict.set_item("mean", mean_val)?;
1301
1302        if data.len() >= 2 {
1303            let variance = data.iter().map(|&x| (x - mean_val).powi(2)).sum::<f64>()
1304                / (data.len() as f64 - 1.0);
1305            dict.set_item("std", variance.sqrt())?;
1306            dict.set_item("variance", variance)?;
1307
1308            // Min and max
1309            if let (Some(min), Some(max)) = (
1310                data.iter().min_by(|a, b| a.partial_cmp(b).unwrap()),
1311                data.iter().max_by(|a, b| a.partial_cmp(b).unwrap()),
1312            ) {
1313                dict.set_item("min", *min)?;
1314                dict.set_item("max", *max)?;
1315            }
1316        }
1317    }
1318
1319    Ok(dict.into())
1320}
1321
1322/// Convert alternative hypothesis string to Rust enum
1323fn parse_alternative(alternative: &str) -> PyResult<Alternative> {
1324    match alternative.to_lowercase().as_str() {
1325        "two-sided" | "two_sided" | "two.sided" => Ok(Alternative::TwoSided),
1326        "less" | "smaller" => Ok(Alternative::Less),
1327        "greater" | "larger" => Ok(Alternative::Greater),
1328        _ => Err(PyValueError::new_err(
1329            "alternative must be 'two-sided', 'less', or 'greater'",
1330        )),
1331    }
1332}
1333
1334/// Convert TestResult to Python dictionary
1335fn test_result_to_dict(py: Python, result: &TestResult) -> PyResult<Py<PyDict>> {
1336    let dict = PyDict::new(py);
1337    dict.set_item("statistic", result.statistic)?;
1338    dict.set_item("p_value", result.p_value)?;
1339    dict.set_item("df", result.df)?;
1340    dict.set_item(
1341        "alternative",
1342        match result.alternative {
1343            Alternative::TwoSided => "two-sided",
1344            Alternative::Less => "less",
1345            Alternative::Greater => "greater",
1346        },
1347    )?;
1348    dict.set_item("null_value", result.null_value)?;
1349    Ok(dict.into())
1350}
1351
1352/// One-sample t-test
1353#[pyfunction]
1354fn t_test_one_sample(
1355    py: Python,
1356    data: Vec<f64>,
1357    mu: f64,
1358    alternative: String,
1359) -> PyResult<Py<PyDict>> {
1360    let data_array = ndarray::Array1::from_vec(data);
1361    let alt = parse_alternative(&alternative)?;
1362    let result = t_test_one_sample_rs(&data_array, mu, alt)
1363        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {:?}", e)))?;
1364    test_result_to_dict(py, &result)
1365}
1366
1367/// Two-sample t-test (independent samples, equal variance assumed)
1368#[pyfunction]
1369fn t_test_two_sample(
1370    py: Python,
1371    x: Vec<f64>,
1372    y: Vec<f64>,
1373    alternative: String,
1374) -> PyResult<Py<PyDict>> {
1375    let x_array = ndarray::Array1::from_vec(x);
1376    let y_array = ndarray::Array1::from_vec(y);
1377    let alt = parse_alternative(&alternative)?;
1378    let result = t_test_two_sample_rs(&x_array, &y_array, alt)
1379        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {:?}", e)))?;
1380    test_result_to_dict(py, &result)
1381}
1382
1383/// Paired t-test
1384#[pyfunction]
1385fn t_test_paired(
1386    py: Python,
1387    x: Vec<f64>,
1388    y: Vec<f64>,
1389    alternative: String,
1390) -> PyResult<Py<PyDict>> {
1391    let x_array = ndarray::Array1::from_vec(x);
1392    let y_array = ndarray::Array1::from_vec(y);
1393    let alt = parse_alternative(&alternative)?;
1394    let result = t_test_paired_rs(&x_array, &y_array, alt)
1395        .map_err(|e| PyRuntimeError::new_err(format!("paired t-test failed: {:?}", e)))?;
1396    test_result_to_dict(py, &result)
1397}
1398
1399/// Chi-square test of independence
1400#[pyfunction]
1401fn chi_square_test_independence(py: Python, observed: Vec<Vec<f64>>) -> PyResult<Py<PyDict>> {
1402    // Convert to ndarray matrix
1403    let n_rows = observed.len();
1404    if n_rows == 0 {
1405        return Err(PyValueError::new_err("observed must have at least one row"));
1406    }
1407    let n_cols = observed[0].len();
1408    let mut flat = Vec::new();
1409    for row in observed {
1410        if row.len() != n_cols {
1411            return Err(PyValueError::new_err("All rows must have same length"));
1412        }
1413        flat.extend(row);
1414    }
1415    let matrix = ndarray::Array2::from_shape_vec((n_rows, n_cols), flat)
1416        .map_err(|e| PyValueError::new_err(format!("Failed to create matrix: {}", e)))?;
1417
1418    let result = chi_square_test_independence_rs(&matrix)
1419        .map_err(|e| PyRuntimeError::new_err(format!("chi-square test failed: {:?}", e)))?;
1420    test_result_to_dict(py, &result)
1421}
1422
1423/// One-way ANOVA
1424#[pyfunction]
1425fn anova_one_way(py: Python, groups: Vec<Vec<f64>>) -> PyResult<Py<PyDict>> {
1426    let arrays: Vec<_> = groups
1427        .into_iter()
1428        .map(|g| ndarray::Array1::from_vec(g))
1429        .collect();
1430    let result = anova_one_way_rs(&arrays)
1431        .map_err(|e| PyRuntimeError::new_err(format!("ANOVA failed: {:?}", e)))?;
1432    test_result_to_dict(py, &result)
1433}
1434
1435/// Shapiro-Wilk test for normality
1436#[pyfunction]
1437fn shapiro_wilk_test(py: Python, data: Vec<f64>) -> PyResult<Py<PyDict>> {
1438    let data_array = ndarray::Array1::from_vec(data);
1439    let result = shapiro_wilk_test_rs(&data_array)
1440        .map_err(|e| PyRuntimeError::new_err(format!("Shapiro-Wilk test failed: {:?}", e)))?;
1441    test_result_to_dict(py, &result)
1442}
1443
1444/// Fit linear regression model
1445#[pyfunction]
1446fn linear_regression(py: Python, x: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<Py<PyDict>> {
1447    // Convert to ndarray
1448    let n_rows = x.len();
1449    if n_rows == 0 {
1450        return Err(PyValueError::new_err("X must have at least one row"));
1451    }
1452    if n_rows != y.len() {
1453        return Err(PyValueError::new_err(
1454            "X and y must have same number of rows",
1455        ));
1456    }
1457
1458    let n_cols = x[0].len();
1459    let x_array =
1460        ndarray::Array2::from_shape_vec((n_rows, n_cols), x.into_iter().flatten().collect())
1461            .map_err(|e| PyValueError::new_err(format!("Failed to create X matrix: {}", e)))?;
1462
1463    let y_array = ndarray::Array1::from_vec(y);
1464
1465    // Fit OLS model
1466    let model = OLS::new();
1467    match model.fit(&x_array, &y_array) {
1468        Ok(results) => {
1469            let dict = PyDict::new(py);
1470            dict.set_item("coefficients", results.coefficients.to_vec())?;
1471            dict.set_item("r_squared", results.r_squared)?;
1472            dict.set_item("r_squared_adj", results.r_squared_adj)?;
1473            dict.set_item("sigma", results.sigma)?;
1474            dict.set_item("df_residual", results.df_residual)?;
1475            dict.set_item("df_model", results.df_model)?;
1476
1477            if let Some(std_errors) = &results.std_errors {
1478                dict.set_item("std_errors", std_errors.to_vec())?;
1479            }
1480
1481            if let Some(t_values) = &results.t_values {
1482                dict.set_item("t_values", t_values.to_vec())?;
1483            }
1484
1485            if let Some(p_values) = &results.p_values {
1486                dict.set_item("p_values", p_values.to_vec())?;
1487            }
1488
1489            if let Some(f_statistic) = &results.f_statistic {
1490                dict.set_item("f_statistic", f_statistic)?;
1491            }
1492
1493            if let Some(f_p_value) = &results.f_p_value {
1494                dict.set_item("f_p_value", f_p_value)?;
1495            }
1496
1497            Ok(dict.into())
1498        }
1499        Err(e) => Err(PyRuntimeError::new_err(format!(
1500            "Linear regression failed: {:?}",
1501            e
1502        ))),
1503    }
1504}
1505
1506/// Fit mixed effects model (linear mixed model)
1507#[pyfunction]
1508fn mixed_effects(py: Python, data: &PyDataFrame, formula: String) -> PyResult<Py<PyDict>> {
1509    // TODO: Implement actual mixed effects model using so-models
1510    let dict = PyDict::new(py);
1511
1512    // For now, return a placeholder result
1513    dict.set_item("fixed_effects", Vec::<f64>::new())?;
1514    dict.set_item("random_variances", Vec::<f64>::new())?;
1515    dict.set_item("residual_variance", 0.0)?;
1516    dict.set_item("log_likelihood", 0.0)?;
1517    dict.set_item("aic", 0.0)?;
1518    dict.set_item("bic", 0.0)?;
1519    dict.set_item(
1520        "message",
1521        "Mixed effects model placeholder - implement using so-models::mixed",
1522    )?;
1523    dict.set_item("formula", formula)?;
1524    dict.set_item("n_obs", data.n_rows())?;
1525
1526    Ok(dict.into())
1527}
1528
1529/// Fit ARIMA model
1530#[pyfunction]
1531fn fit_arima(py: Python, data: Vec<f64>, p: usize, d: usize, q: usize) -> PyResult<Py<PyDict>> {
1532    // Create TimeSeries from data with index as timestamps
1533    let values = ndarray::Array1::from_vec(data.clone());
1534    let timestamps: Vec<i64> = (0..data.len() as i64).collect();
1535    let ts = TimeSeries::new("series", timestamps, values, None)
1536        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e)))?;
1537
1538    // Create and fit ARIMA model using builder
1539    use so_tsa::arima::ARIMABuilder;
1540    let builder = ARIMABuilder::new(p, d, q);
1541
1542    match builder.fit(&ts) {
1543        Ok(results) => {
1544            let dict = PyDict::new(py);
1545            dict.set_item("order", (p, d, q))?;
1546
1547            // Collect coefficients
1548            let mut coefficients = Vec::new();
1549            if let Some(ar_coef) = &results.ar_coef {
1550                coefficients.extend_from_slice(&ar_coef.to_vec());
1551            }
1552            if let Some(ma_coef) = &results.ma_coef {
1553                coefficients.extend_from_slice(&ma_coef.to_vec());
1554            }
1555            if let Some(constant) = results.constant {
1556                coefficients.push(constant);
1557            }
1558            dict.set_item("coefficients", coefficients)?;
1559
1560            dict.set_item("aic", results.aic)?;
1561            dict.set_item("bic", results.bic)?;
1562            dict.set_item("log_likelihood", results.log_likelihood)?;
1563            dict.set_item("sigma2", results.sigma2)?;
1564
1565            // Add fitted values and residuals
1566            dict.set_item("fitted", results.fitted.to_vec())?;
1567            dict.set_item("residuals", results.residuals.to_vec())?;
1568            dict.set_item("n_obs", results.n_obs)?;
1569
1570            Ok(dict.into())
1571        }
1572        Err(e) => Err(PyRuntimeError::new_err(format!(
1573            "ARIMA fitting failed: {:?}",
1574            e
1575        ))),
1576    }
1577}
1578
1579/// Split data into training and test sets
1580#[pyfunction]
1581fn train_test_split(data: Vec<f64>, test_size: f64) -> PyResult<(Vec<f64>, Vec<f64>)> {
1582    if test_size <= 0.0 || test_size >= 1.0 {
1583        return Err(PyValueError::new_err("test_size must be between 0 and 1"));
1584    }
1585
1586    let split_idx = (data.len() as f64 * (1.0 - test_size)) as usize;
1587    let train = data[..split_idx].to_vec();
1588    let test = data[split_idx..].to_vec();
1589
1590    Ok((train, test))
1591}
1592
1593// ============================================================================
1594// Robust Regression
1595// ============================================================================
1596
1597/// Python wrapper for robust M-estimator
1598#[pyclass(name = "MEstimator")]
1599struct PyMEstimator {
1600    inner: Option<RustMEstimator>,
1601}
1602
1603#[pymethods]
1604impl PyMEstimator {
1605    /// Create a new Huber M-estimator (k=1.345 gives 95% efficiency)
1606    #[staticmethod]
1607    fn huber(k: f64) -> Self {
1608        PyMEstimator {
1609            inner: Some(RustMEstimator::huber(k)),
1610        }
1611    }
1612
1613    /// Create a new Tukey's biweight M-estimator (c=4.685 gives 95% efficiency)
1614    #[staticmethod]
1615    fn tukey(c: f64) -> Self {
1616        PyMEstimator {
1617            inner: Some(RustMEstimator::tukey(c)),
1618        }
1619    }
1620
1621    /// Set maximum iterations
1622    fn max_iterations(&mut self, max_iter: usize) -> PyResult<()> {
1623        if let Some(inner) = self.inner.take() {
1624            self.inner = Some(inner.max_iterations(max_iter));
1625        }
1626        Ok(())
1627    }
1628
1629    /// Set convergence tolerance
1630    fn tolerance(&mut self, tol: f64) -> PyResult<()> {
1631        if let Some(inner) = self.inner.take() {
1632            self.inner = Some(inner.tolerance(tol));
1633        }
1634        Ok(())
1635    }
1636
1637    /// Fit robust regression to data
1638    fn fit(&mut self, X: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<PyRobustRegressionResults> {
1639        if X.is_empty() || X[0].is_empty() {
1640            return Err(PyValueError::new_err("X must be non-empty"));
1641        }
1642        
1643        if X.len() != y.len() {
1644            return Err(PyValueError::new_err(
1645                "X and y must have the same number of rows",
1646            ));
1647        }
1648
1649        // Convert to ndarray
1650        let n_rows = X.len();
1651        let n_cols = X[0].len();
1652        let mut X_array = ndarray::Array2::zeros((n_rows, n_cols));
1653        
1654        for i in 0..n_rows {
1655            if X[i].len() != n_cols {
1656                return Err(PyValueError::new_err(
1657                    "All rows of X must have the same length",
1658                ));
1659            }
1660            for j in 0..n_cols {
1661                X_array[[i, j]] = X[i][j];
1662            }
1663        }
1664        
1665        let y_array = ndarray::Array1::from_vec(y);
1666
1667        if let Some(inner) = self.inner.take() {
1668            match inner.fit(&X_array, &y_array) {
1669                Ok(results) => Ok(PyRobustRegressionResults { inner: results }),
1670                Err(e) => Err(PyRuntimeError::new_err(format!(
1671                    "Robust regression failed: {:?}",
1672                    e
1673                ))),
1674            }
1675        } else {
1676            Err(PyRuntimeError::new_err("MEstimator not available"))
1677        }
1678    }
1679}
1680
1681/// Python wrapper for robust regression results
1682#[pyclass(name = "RobustRegressionResults")]
1683struct PyRobustRegressionResults {
1684    inner: RustRobustResults,
1685}
1686
1687#[pymethods]
1688impl PyRobustRegressionResults {
1689    /// Get robust coefficients
1690    #[getter]
1691    fn coefficients(&self) -> Vec<f64> {
1692        self.inner.coefficients.to_vec()
1693    }
1694
1695    /// Get robust standard errors
1696    #[getter]
1697    fn standard_errors(&self) -> Vec<f64> {
1698        self.inner.standard_errors.to_vec()
1699    }
1700
1701    /// Get robust scale estimate
1702    #[getter]
1703    fn scale(&self) -> f64 {
1704        self.inner.scale
1705    }
1706
1707    /// Get number of iterations
1708    #[getter]
1709    fn iterations(&self) -> usize {
1710        self.inner.iterations
1711    }
1712
1713    /// Get weights (can identify outliers)
1714    #[getter]
1715    fn weights(&self) -> Vec<f64> {
1716        self.inner.weights.to_vec()
1717    }
1718
1719    /// Get breakdown point
1720    #[getter]
1721    fn breakdown_point(&self) -> f64 {
1722        self.inner.breakdown_point
1723    }
1724
1725    /// Get efficiency relative to OLS
1726    #[getter]
1727    fn efficiency(&self) -> f64 {
1728        self.inner.efficiency
1729    }
1730
1731    /// Get summary as dictionary
1732    fn summary(&self) -> PyResult<Py<PyDict>> {
1733        Python::with_gil(|py| {
1734            let dict = PyDict::new(py);
1735            
1736            dict.set_item("coefficients", self.coefficients())?;
1737            dict.set_item("standard_errors", self.standard_errors())?;
1738            dict.set_item("scale", self.scale())?;
1739            dict.set_item("iterations", self.iterations())?;
1740            dict.set_item("breakdown_point", self.breakdown_point())?;
1741            dict.set_item("efficiency", self.efficiency())?;
1742            
1743            // Compute t-values
1744            let coefficients = &self.inner.coefficients;
1745            let standard_errors = &self.inner.standard_errors;
1746            let mut t_values = Vec::new();
1747            
1748            for i in 0..coefficients.len() {
1749                if standard_errors[i] > 0.0 {
1750                    let t = coefficients[i] / standard_errors[i];
1751                    t_values.push(t);
1752                } else {
1753                    t_values.push(f64::NAN);
1754                }
1755            }
1756            
1757            dict.set_item("t_values", t_values)?;
1758            // Note: p-values require t-distribution CDF, omitted for simplicity
1759            // Users can compute p-values using scipy.stats if needed
1760            
1761            Ok(dict.into())
1762        })
1763    }
1764}
1765
1766/// Python wrapper for Least Trimmed Squares (high breakdown)
1767#[pyclass(name = "LeastTrimmedSquares")]
1768struct PyLeastTrimmedSquares {
1769    inner: Option<LeastTrimmedSquares>,
1770}
1771
1772#[pymethods]
1773impl PyLeastTrimmedSquares {
1774    /// Create a new LTS estimator (coverage=0.5)
1775    #[new]
1776    fn new(coverage: Option<f64>) -> Self {
1777        PyLeastTrimmedSquares {
1778            inner: Some(LeastTrimmedSquares::new(coverage.unwrap_or(0.5))),
1779        }
1780    }
1781
1782    /// Fit LTS regression
1783    fn fit(&mut self, X: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<PyRobustRegressionResults> {
1784        if X.is_empty() || X[0].is_empty() {
1785            return Err(PyValueError::new_err("X must be non-empty"));
1786        }
1787        
1788        if X.len() != y.len() {
1789            return Err(PyValueError::new_err(
1790                "X and y must have the same number of rows",
1791            ));
1792        }
1793
1794        let n_rows = X.len();
1795        let n_cols = X[0].len();
1796        let mut X_array = ndarray::Array2::zeros((n_rows, n_cols));
1797        
1798        for i in 0..n_rows {
1799            if X[i].len() != n_cols {
1800                return Err(PyValueError::new_err(
1801                    "All rows of X must have the same length",
1802                ));
1803            }
1804            for j in 0..n_cols {
1805                X_array[[i, j]] = X[i][j];
1806            }
1807        }
1808        
1809        let y_array = ndarray::Array1::from_vec(y);
1810
1811        if let Some(inner) = self.inner.take() {
1812            match inner.fit(&X_array, &y_array) {
1813                Ok(results) => Ok(PyRobustRegressionResults { inner: results }),
1814                Err(e) => Err(PyRuntimeError::new_err(format!(
1815                    "LTS regression failed: {:?}",
1816                    e
1817                ))),
1818            }
1819        } else {
1820            Err(PyRuntimeError::new_err("LTS estimator not available"))
1821        }
1822    }
1823}
1824
1825// ============================================================================
1826// Nonparametric Methods
1827// ============================================================================
1828
1829/// Python wrapper for kernel regression
1830#[pyclass(name = "KernelRegression")]
1831struct PyKernelRegression {
1832    inner: Option<KernelRegression>,
1833}
1834
1835#[pymethods]
1836impl PyKernelRegression {
1837    /// Create a new kernel regression with Gaussian kernel
1838    #[new]
1839    fn new() -> Self {
1840        PyKernelRegression {
1841            inner: Some(KernelRegression::new()),
1842        }
1843    }
1844
1845    /// Set kernel type
1846    fn kernel(&mut self, kernel: &str) -> PyResult<()> {
1847        let kernel_enum = match kernel.to_lowercase().as_str() {
1848            "gaussian" => Kernel::Gaussian,
1849            "epanechnikov" => Kernel::Epanechnikov,
1850            "uniform" => Kernel::Uniform,
1851            "triangular" => Kernel::Triangular,
1852            "biweight" => Kernel::Biweight,
1853            "triweight" => Kernel::Triweight,
1854            "cosine" => Kernel::Cosine,
1855            _ => return Err(PyValueError::new_err(
1856                format!("Unknown kernel: {}. Valid options: gaussian, epanechnikov, uniform, triangular, biweight, triweight, cosine", kernel)
1857            )),
1858        };
1859
1860        if let Some(inner) = self.inner.take() {
1861            self.inner = Some(inner.kernel(kernel_enum));
1862        }
1863        Ok(())
1864    }
1865
1866    /// Set bandwidth directly
1867    fn bandwidth(&mut self, bandwidth: f64) -> PyResult<()> {
1868        if bandwidth <= 0.0 {
1869            return Err(PyValueError::new_err("Bandwidth must be positive"));
1870        }
1871        
1872        if let Some(inner) = self.inner.take() {
1873            self.inner = Some(inner.bandwidth(bandwidth));
1874        }
1875        Ok(())
1876    }
1877
1878    /// Fit kernel regression model
1879    fn fit(&mut self, x: Vec<f64>, y: Vec<f64>) -> PyResult<PyKernelRegressionResults> {
1880        if x.len() != y.len() {
1881            return Err(PyValueError::new_err("x and y must have the same length"));
1882        }
1883        
1884        if x.len() < 3 {
1885            return Err(PyValueError::new_err("Need at least 3 observations for kernel regression"));
1886        }
1887
1888        let x_array = Array1::from_vec(x);
1889        let y_array = Array1::from_vec(y);
1890
1891        if let Some(inner) = self.inner.take() {
1892            match inner.fit(&x_array, &y_array) {
1893                Ok(results) => Ok(PyKernelRegressionResults { inner: results }),
1894                Err(e) => Err(PyRuntimeError::new_err(format!(
1895                    "Kernel regression failed: {:?}",
1896                    e
1897                ))),
1898            }
1899        } else {
1900            Err(PyRuntimeError::new_err("KernelRegression not available"))
1901        }
1902    }
1903}
1904
1905/// Python wrapper for kernel regression results
1906#[pyclass(name = "KernelRegressionResults")]
1907struct PyKernelRegressionResults {
1908    inner: KernelRegressionResults,
1909}
1910
1911#[pymethods]
1912impl PyKernelRegressionResults {
1913    /// Get fitted values
1914    #[getter]
1915    fn fitted_values(&self) -> Vec<f64> {
1916        self.inner.fitted_values.to_vec()
1917    }
1918
1919    /// Get evaluation points (sorted x values)
1920    #[getter]
1921    fn evaluation_points(&self) -> Vec<f64> {
1922        self.inner.evaluation_points.to_vec()
1923    }
1924
1925    /// Get bandwidth used
1926    #[getter]
1927    fn bandwidth(&self) -> f64 {
1928        self.inner.bandwidth
1929    }
1930
1931    /// Get effective degrees of freedom
1932    #[getter]
1933    fn df(&self) -> f64 {
1934        self.inner.df
1935    }
1936
1937    /// Get residual sum of squares
1938    #[getter]
1939    fn rss(&self) -> f64 {
1940        self.inner.rss
1941    }
1942}
1943
1944/// Python wrapper for local regression (LOESS)
1945#[pyclass(name = "LocalRegression")]
1946struct PyLocalRegression {
1947    inner: Option<LocalRegression>,
1948}
1949
1950#[pymethods]
1951impl PyLocalRegression {
1952    /// Create a new local regression
1953    #[new]
1954    fn new() -> Self {
1955        PyLocalRegression {
1956            inner: Some(LocalRegression::new()),
1957        }
1958    }
1959
1960    /// Set polynomial degree (0, 1, or 2)
1961    fn degree(&mut self, degree: usize) -> PyResult<()> {
1962        if degree > 2 {
1963            return Err(PyValueError::new_err("Degree must be 0, 1, or 2"));
1964        }
1965        
1966        if let Some(inner) = self.inner.take() {
1967            self.inner = Some(inner.degree(degree));
1968        }
1969        Ok(())
1970    }
1971
1972    /// Set span (proportion of data used locally, 0.1 to 1.0)
1973    fn span(&mut self, span: f64) -> PyResult<()> {
1974        if span < 0.1 || span > 1.0 {
1975            return Err(PyValueError::new_err("Span must be between 0.1 and 1.0"));
1976        }
1977        
1978        if let Some(inner) = self.inner.take() {
1979            self.inner = Some(inner.span(span));
1980        }
1981        Ok(())
1982    }
1983
1984    /// Enable robust fitting
1985    fn robust(&mut self, robust: bool) -> PyResult<()> {
1986        if let Some(inner) = self.inner.take() {
1987            self.inner = Some(inner.robust(robust));
1988        }
1989        Ok(())
1990    }
1991
1992    /// Fit local regression model
1993    fn fit(&mut self, x: Vec<f64>, y: Vec<f64>) -> PyResult<PyLocalRegressionResults> {
1994        if x.len() != y.len() {
1995            return Err(PyValueError::new_err("x and y must have the same length"));
1996        }
1997        
1998        if x.len() < 3 {
1999            return Err(PyValueError::new_err("Need at least 3 observations for local regression"));
2000        }
2001
2002        let x_array = Array1::from_vec(x);
2003        let y_array = Array1::from_vec(y);
2004
2005        if let Some(inner) = self.inner.take() {
2006            match inner.fit(&x_array, &y_array) {
2007                Ok(results) => Ok(PyLocalRegressionResults { inner: results }),
2008                Err(e) => Err(PyRuntimeError::new_err(format!(
2009                    "Local regression failed: {:?}",
2010                    e
2011                ))),
2012            }
2013        } else {
2014            Err(PyRuntimeError::new_err("LocalRegression not available"))
2015        }
2016    }
2017}
2018
2019/// Python wrapper for local regression results
2020#[pyclass(name = "LocalRegressionResults")]
2021struct PyLocalRegressionResults {
2022    inner: LocalRegressionResults,
2023}
2024
2025#[pymethods]
2026impl PyLocalRegressionResults {
2027    /// Get fitted values
2028    #[getter]
2029    fn fitted_values(&self) -> Vec<f64> {
2030        self.inner.fitted_values.to_vec()
2031    }
2032
2033    /// Get evaluation points (sorted x values)
2034    #[getter]
2035    fn evaluation_points(&self) -> Vec<f64> {
2036        self.inner.evaluation_points.to_vec()
2037    }
2038
2039    /// Get polynomial degree used
2040    #[getter]
2041    fn degree(&self) -> usize {
2042        self.inner.degree
2043    }
2044
2045    /// Get span used
2046    #[getter]
2047    fn span(&self) -> f64 {
2048        self.inner.span
2049    }
2050
2051    /// Get residual sum of squares
2052    #[getter]
2053    fn rss(&self) -> f64 {
2054        self.inner.rss
2055    }
2056}
2057
2058/// Get library version
2059#[pyfunction]
2060fn version() -> PyResult<String> {
2061    Ok("0.3.0".to_string())
2062}