Skip to main content

statoxide/
lib.rs

1use ndarray::Array1;
2use pyo3::exceptions::{PyRuntimeError, PyValueError};
3use pyo3::prelude::*;
4use pyo3::types::{PyDict, PyList};
5use std::collections::HashMap;
6
7use so_core::data::{DataFrame, Series};
8use so_core::formula::Formula;
9
10// Import models
11use so_models::glm::{Family, GLM as RustGLM, GLMModelBuilder, GLMResults, Link};
12use so_models::regression::OLS;
13
14// Import time series
15use so_tsa::TimeSeries;
16use so_tsa::arima::ARIMAResults;
17use so_tsa::garch::GARCHResults;
18
19// Import statistical tests
20use so_stats::tests::{
21    Alternative, TestResult, anova_one_way as anova_one_way_rs,
22    chi_square_test_independence as chi_square_test_independence_rs,
23    shapiro_wilk_test as shapiro_wilk_test_rs, t_test_one_sample as t_test_one_sample_rs,
24    t_test_paired as t_test_paired_rs, t_test_two_sample as t_test_two_sample_rs,
25};
26
27/// Python wrapper for StatOxide Series
28#[pyclass(name = "Series")]
29struct PySeries {
30    inner: Series,
31}
32
33#[pymethods]
34impl PySeries {
35    /// Create a new Series from Python list or array
36    #[new]
37    fn new(name: String, data: Vec<f64>) -> PyResult<Self> {
38        let array = Array1::from_vec(data);
39        Ok(PySeries {
40            inner: Series::new(name, array),
41        })
42    }
43
44    /// Get the name of the series
45    #[getter]
46    fn name(&self) -> String {
47        self.inner.name().to_string()
48    }
49
50    /// Get the length of the series
51    #[getter]
52    fn len(&self) -> usize {
53        self.inner.len()
54    }
55
56    /// Check if series is empty
57    fn is_empty(&self) -> bool {
58        self.inner.is_empty()
59    }
60
61    /// Compute mean of series
62    fn mean(&self) -> Option<f64> {
63        self.inner.mean()
64    }
65
66    /// Compute standard deviation
67    fn std(&self, ddof: f64) -> Option<f64> {
68        self.inner.std(ddof)
69    }
70
71    /// Compute variance
72    fn var(&self, ddof: f64) -> Option<f64> {
73        self.inner.var(ddof)
74    }
75
76    /// Get minimum value
77    fn min(&self) -> Option<f64> {
78        self.inner.min()
79    }
80
81    /// Get maximum value
82    fn max(&self) -> Option<f64> {
83        self.inner.max()
84    }
85
86    /// Compute quantile
87    fn quantile(&self, q: f64) -> Option<f64> {
88        self.inner.quantile(q)
89    }
90
91    /// Convert to Python list
92    fn to_list(&self) -> Vec<f64> {
93        self.inner.data().to_vec()
94    }
95
96    /// String representation
97    fn __repr__(&self) -> String {
98        format!(
99            "Series(name='{}', len={})",
100            self.inner.name(),
101            self.inner.len()
102        )
103    }
104}
105
106/// Python wrapper for StatOxide DataFrame
107#[pyclass(name = "DataFrame")]
108struct PyDataFrame {
109    inner: DataFrame,
110}
111
112#[pymethods]
113impl PyDataFrame {
114    /// Create a new DataFrame from a dictionary of columns
115    #[new]
116    fn new(data: HashMap<String, Vec<f64>>) -> PyResult<Self> {
117        let mut columns = HashMap::new();
118
119        for (name, values) in data {
120            let array = Array1::from_vec(values);
121            columns.insert(name.clone(), Series::new(name, array));
122        }
123
124        match DataFrame::from_series(columns) {
125            Ok(df) => Ok(PyDataFrame { inner: df }),
126            Err(e) => Err(PyValueError::new_err(format!(
127                "Error creating DataFrame: {:?}",
128                e
129            ))),
130        }
131    }
132
133    /// Get number of rows
134    #[getter]
135    fn n_rows(&self) -> usize {
136        self.inner.n_rows()
137    }
138
139    /// Get number of columns
140    #[getter]
141    fn n_cols(&self) -> usize {
142        self.inner.n_cols()
143    }
144
145    /// Get column names
146    fn columns(&self) -> Vec<String> {
147        self.inner.column_names()
148    }
149
150    /// Get a column by name
151    fn get_column(&self, name: &str) -> PyResult<PySeries> {
152        match self.inner.column(name) {
153            Some(series) => Ok(PySeries {
154                inner: series.clone(),
155            }),
156            None => Err(PyValueError::new_err(format!(
157                "Column '{}' not found",
158                name
159            ))),
160        }
161    }
162
163    /// Add a column to the DataFrame
164    fn with_column(&mut self, series: &PySeries) -> PyResult<()> {
165        let df = self
166            .inner
167            .clone()
168            .with_column(series.inner.clone())
169            .map_err(|e| PyValueError::new_err(format!("Error adding column: {:?}", e)))?;
170        self.inner = df;
171        Ok(())
172    }
173
174    /// String representation
175    fn __repr__(&self) -> String {
176        format!(
177            "DataFrame(rows={}, cols={})",
178            self.inner.n_rows(),
179            self.inner.n_cols()
180        )
181    }
182}
183
184/// Python wrapper for StatOxide Formula
185#[pyclass(name = "Formula")]
186struct PyFormula {
187    inner: Formula,
188}
189
190#[pymethods]
191impl PyFormula {
192    /// Parse a formula string
193    #[new]
194    fn new(formula: String) -> PyResult<Self> {
195        match Formula::parse(&formula) {
196            Ok(f) => Ok(PyFormula { inner: f }),
197            Err(e) => Err(PyValueError::new_err(format!(
198                "Error parsing formula: {:?}",
199                e
200            ))),
201        }
202    }
203
204    /// Get all variable names in the formula
205    fn variables(&self) -> Vec<String> {
206        self.inner.variables().into_iter().collect()
207    }
208
209    /// String representation
210    fn __repr__(&self) -> String {
211        format!("Formula({:?})", self.inner)
212    }
213}
214
215/// Python wrapper for GLM Family
216#[pyclass(name = "Family")]
217#[derive(Clone)]
218struct PyFamily {
219    inner: Family,
220}
221
222#[pymethods]
223impl PyFamily {
224    /// Create Gaussian family
225    #[staticmethod]
226    fn gaussian() -> Self {
227        PyFamily {
228            inner: Family::Gaussian,
229        }
230    }
231
232    /// Create Binomial family
233    #[staticmethod]
234    fn binomial() -> Self {
235        PyFamily {
236            inner: Family::Binomial,
237        }
238    }
239
240    /// Create Poisson family
241    #[staticmethod]
242    fn poisson() -> Self {
243        PyFamily {
244            inner: Family::Poisson,
245        }
246    }
247
248    /// Create Gamma family
249    #[staticmethod]
250    fn gamma() -> Self {
251        PyFamily {
252            inner: Family::Gamma,
253        }
254    }
255
256    /// Create Inverse Gaussian family
257    #[staticmethod]
258    fn inverse_gaussian() -> Self {
259        PyFamily {
260            inner: Family::InverseGaussian,
261        }
262    }
263
264    /// Get family name
265    fn name(&self) -> String {
266        self.inner.name().to_string()
267    }
268
269    /// String representation
270    fn __repr__(&self) -> String {
271        format!("Family({})", self.name())
272    }
273}
274
275/// Python wrapper for GLM Link function
276#[pyclass(name = "Link")]
277#[derive(Clone)]
278struct PyLink {
279    inner: Link,
280}
281
282#[pymethods]
283impl PyLink {
284    /// Identity link: η = μ
285    #[staticmethod]
286    fn identity() -> Self {
287        PyLink {
288            inner: Link::Identity,
289        }
290    }
291
292    /// Logit link: η = log(μ/(1-μ))
293    #[staticmethod]
294    fn logit() -> Self {
295        PyLink { inner: Link::Logit }
296    }
297
298    /// Probit link: η = Φ⁻¹(μ)
299    #[staticmethod]
300    fn probit() -> Self {
301        PyLink {
302            inner: Link::Probit,
303        }
304    }
305
306    /// Log link: η = log(μ)
307    #[staticmethod]
308    fn log() -> Self {
309        PyLink { inner: Link::Log }
310    }
311
312    /// Inverse link: η = 1/μ
313    #[staticmethod]
314    fn inverse() -> Self {
315        PyLink {
316            inner: Link::Inverse,
317        }
318    }
319
320    /// String representation
321    fn __repr__(&self) -> String {
322        match self.inner {
323            Link::Identity => "Link(identity)".to_string(),
324            Link::Logit => "Link(logit)".to_string(),
325            Link::Probit => "Link(probit)".to_string(),
326            Link::Cloglog => "Link(cloglog)".to_string(),
327            Link::Log => "Link(log)".to_string(),
328            Link::Inverse => "Link(inverse)".to_string(),
329            Link::InverseSquare => "Link(inverse-square)".to_string(),
330            Link::Sqrt => "Link(sqrt)".to_string(),
331        }
332    }
333}
334
335/// Python wrapper for GLM model builder
336#[pyclass(name = "GLMBuilder")]
337struct PyGLMBuilder {
338    inner: Option<GLMModelBuilder>,
339}
340
341#[pymethods]
342impl PyGLMBuilder {
343    /// Create a new GLM builder
344    #[new]
345    fn new() -> Self {
346        PyGLMBuilder {
347            inner: Some(GLMModelBuilder::new()),
348        }
349    }
350
351    /// Set the distribution family
352    fn family(&mut self, family: &PyFamily) -> PyResult<()> {
353        if let Some(inner) = self.inner.take() {
354            self.inner = Some(inner.family(family.inner));
355        }
356        Ok(())
357    }
358
359    /// Set the link function
360    fn link(&mut self, link: &PyLink) -> PyResult<()> {
361        if let Some(inner) = self.inner.take() {
362            self.inner = Some(inner.link(link.inner));
363        }
364        Ok(())
365    }
366
367    /// Set whether to include intercept
368    fn intercept(&mut self, intercept: bool) -> PyResult<()> {
369        if let Some(inner) = self.inner.take() {
370            self.inner = Some(inner.intercept(intercept));
371        }
372        Ok(())
373    }
374
375    /// Set maximum iterations
376    fn max_iter(&mut self, max_iter: usize) -> PyResult<()> {
377        if let Some(inner) = self.inner.take() {
378            self.inner = Some(inner.max_iter(max_iter));
379        }
380        Ok(())
381    }
382
383    /// Set convergence tolerance
384    fn tol(&mut self, tol: f64) -> PyResult<()> {
385        if let Some(inner) = self.inner.take() {
386            self.inner = Some(inner.tol(tol));
387        }
388        Ok(())
389    }
390
391    /// Set fixed scale parameter
392    fn scale(&mut self, scale: f64) -> PyResult<()> {
393        if let Some(inner) = self.inner.take() {
394            self.inner = Some(inner.scale(scale));
395        }
396        Ok(())
397    }
398
399    /// Build the GLM model
400    fn build(&mut self) -> PyResult<PyGLM> {
401        if let Some(inner) = self.inner.take() {
402            Ok(PyGLM {
403                inner: inner.build(),
404            })
405        } else {
406            Err(PyRuntimeError::new_err("GLM builder not available"))
407        }
408    }
409}
410
411/// Python wrapper for GLM model
412#[pyclass(name = "GLM")]
413struct PyGLM {
414    inner: RustGLM,
415}
416
417#[pymethods]
418impl PyGLM {
419    /// Create a new GLM builder
420    #[staticmethod]
421    fn new() -> PyGLMBuilder {
422        PyGLMBuilder::new()
423    }
424
425    /// Fit the GLM using formula and DataFrame
426    fn fit(&self, formula: &str, data: &PyDataFrame) -> PyResult<PyGLMResults> {
427        match self.inner.fit(formula, &data.inner) {
428            Ok(results) => Ok(PyGLMResults { inner: results }),
429            Err(e) => Err(PyRuntimeError::new_err(format!(
430                "GLM fitting failed: {:?}",
431                e
432            ))),
433        }
434    }
435
436    /// Fit the GLM with design matrix X and response y
437    fn fit_matrix(&self, x: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<PyGLMResults> {
438        // Convert to DataFrame
439        let n_rows = x.len();
440        if n_rows == 0 {
441            return Err(PyValueError::new_err("X must have at least one row"));
442        }
443        if n_rows != y.len() {
444            return Err(PyValueError::new_err(
445                "X and y must have same number of rows",
446            ));
447        }
448
449        let n_cols = x[0].len();
450
451        // Check all rows have same number of columns
452        for (i, row) in x.iter().enumerate() {
453            if row.len() != n_cols {
454                return Err(PyValueError::new_err(format!(
455                    "Row {} has {} columns, expected {}",
456                    i,
457                    row.len(),
458                    n_cols
459                )));
460            }
461        }
462
463        // Create column names
464        let mut col_names = Vec::new();
465        for i in 0..n_cols {
466            col_names.push(format!("x{}", i));
467        }
468        col_names.push("y".to_string());
469
470        // Create Series for each column
471        let mut columns = HashMap::new();
472
473        // Create X columns
474        for i in 0..n_cols {
475            let mut col_data = Vec::with_capacity(n_rows);
476            for row in &x {
477                col_data.push(row[i]);
478            }
479            let series = Series::new(format!("x{}", i), ndarray::Array1::from_vec(col_data));
480            columns.insert(format!("x{}", i), series);
481        }
482
483        // Create y column
484        let y_series = Series::new("y".to_string(), ndarray::Array1::from_vec(y.clone()));
485        columns.insert("y".to_string(), y_series);
486
487        // Create DataFrame from Series
488        let df = DataFrame::from_series(columns)
489            .map_err(|e| PyValueError::new_err(format!("Failed to create DataFrame: {}", e)))?;
490
491        // Create formula: y ~ x0 + x1 + ... + x{n-1}
492        let formula_str = if n_cols == 0 {
493            "y ~ 1".to_string()
494        } else {
495            let mut formula = "y ~ ".to_string();
496            for i in 0..n_cols {
497                formula.push_str(&format!("x{}", i));
498                if i < n_cols - 1 {
499                    formula.push_str(" + ");
500                }
501            }
502            formula
503        };
504
505        // Use the GLM's fit method (takes formula string and DataFrame)
506        match self.inner.fit(&formula_str, &df) {
507            Ok(results) => Ok(PyGLMResults { inner: results }),
508            Err(e) => Err(PyRuntimeError::new_err(format!(
509                "GLM fitting failed: {:?}",
510                e
511            ))),
512        }
513    }
514}
515
516/// Python wrapper for GLM results
517#[pyclass(name = "GLMResults")]
518struct PyGLMResults {
519    inner: GLMResults,
520}
521
522#[pymethods]
523impl PyGLMResults {
524    /// Get coefficients
525    #[getter]
526    fn coefficients(&self) -> Vec<f64> {
527        self.inner.coefficients.to_vec()
528    }
529
530    /// Get standard errors
531    #[getter]
532    fn std_errors(&self) -> Vec<f64> {
533        self.inner.std_errors.to_vec()
534    }
535
536    /// Get z-values (Wald test statistics)
537    #[getter]
538    fn z_values(&self) -> Vec<f64> {
539        self.inner.z_values.to_vec()
540    }
541
542    /// Get p-values
543    #[getter]
544    fn p_values(&self) -> Vec<f64> {
545        self.inner.p_values.to_vec()
546    }
547
548    /// Get deviance
549    #[getter]
550    fn deviance(&self) -> f64 {
551        self.inner.deviance
552    }
553
554    /// Get null deviance
555    #[getter]
556    fn null_deviance(&self) -> f64 {
557        self.inner.null_deviance
558    }
559
560    /// Get AIC
561    #[getter]
562    fn aic(&self) -> f64 {
563        self.inner.aic
564    }
565
566    /// Get BIC
567    #[getter]
568    fn bic(&self) -> f64 {
569        self.inner.bic
570    }
571
572    /// Get degrees of freedom
573    #[getter]
574    fn df_residual(&self) -> usize {
575        self.inner.df_residual
576    }
577
578    /// Get degrees of freedom for null model
579    #[getter]
580    fn df_null(&self) -> usize {
581        self.inner.df_null
582    }
583
584    /// Get scale parameter
585    #[getter]
586    fn scale(&self) -> f64 {
587        self.inner.scale
588    }
589
590    /// Get fitted values
591    #[getter]
592    fn fitted_values(&self) -> Vec<f64> {
593        self.inner.fitted_values.to_vec()
594    }
595
596    /// Get Pearson residuals
597    #[getter]
598    fn pearson_residuals(&self) -> Vec<f64> {
599        self.inner.pearson_residuals.to_vec()
600    }
601
602    /// Get raw residuals (response scale)
603    #[getter]
604    fn residuals(&self) -> Vec<f64> {
605        self.inner.residuals.to_vec()
606    }
607
608    /// Get diagonal of hat matrix (leverage values)
609    #[getter]
610    fn hat_matrix_diag(&self) -> Vec<f64> {
611        self.inner.hat_matrix_diag.to_vec()
612    }
613
614    /// Get number of iterations
615    #[getter]
616    fn iterations(&self) -> usize {
617        self.inner.iterations
618    }
619
620    /// Check if model converged
621    #[getter]
622    fn converged(&self) -> bool {
623        self.inner.converged
624    }
625
626    /// Predict using the fitted model
627    fn predict(&self, x: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
628        let n_rows = x.len();
629        if n_rows == 0 {
630            return Ok(Vec::new());
631        }
632        let n_cols = x[0].len();
633
634        // Check dimensions
635        if n_cols != self.inner.coefficients.len() {
636            return Err(PyValueError::new_err(format!(
637                "X has {} columns but model has {} coefficients",
638                n_cols,
639                self.inner.coefficients.len()
640            )));
641        }
642
643        // Simple linear prediction: y = Xβ
644        let mut predictions = Vec::with_capacity(n_rows);
645        for row in x {
646            if row.len() != n_cols {
647                return Err(PyValueError::new_err(
648                    "All rows must have same number of columns",
649                ));
650            }
651            let mut pred = 0.0;
652            for (i, &xi) in row.iter().enumerate() {
653                pred += xi * self.inner.coefficients[i];
654            }
655            // TODO: Apply inverse link function based on family and link
656            predictions.push(pred);
657        }
658
659        Ok(predictions)
660    }
661
662    /// Get summary string
663    fn summary(&self) -> String {
664        // Simple summary for now
665        format!(
666            "GLM Results:\n  Coefficients: {:?}\n  AIC: {:.2}\n  BIC: {:.2}\n  Deviance: {:.2}\n  Scale: {:.2}",
667            self.coefficients(),
668            self.aic(),
669            self.bic(),
670            self.deviance(),
671            self.scale()
672        )
673    }
674}
675
676/// Python wrapper for TimeSeries
677#[pyclass(name = "TimeSeries")]
678struct PyTimeSeries {
679    inner: TimeSeries,
680}
681
682#[pymethods]
683impl PyTimeSeries {
684    /// Create a TimeSeries from a DataFrame
685    #[staticmethod]
686    fn from_dataframe(df: &PyDataFrame, value_col: &str, date_col: &str) -> PyResult<Self> {
687        match TimeSeries::from_dataframe(&df.inner, value_col, date_col) {
688            Ok(ts) => Ok(PyTimeSeries { inner: ts }),
689            Err(e) => Err(PyRuntimeError::new_err(format!(
690                "Failed to create TimeSeries: {:?}",
691                e
692            ))),
693        }
694    }
695
696    /// Create a TimeSeries from vectors
697    #[staticmethod]
698    fn from_vectors(values: Vec<f64>, _dates: Vec<String>) -> PyResult<Self> {
699        // Simple implementation - use index as timestamps
700        // In practice, would parse dates string to timestamps
701        let timestamps: Vec<i64> = (0..values.len() as i64).collect();
702        let values_array = ndarray::Array1::from_vec(values);
703
704        match TimeSeries::new("series", timestamps, values_array, None) {
705            Ok(ts) => Ok(PyTimeSeries { inner: ts }),
706            Err(e) => Err(PyRuntimeError::new_err(format!(
707                "Failed to create TimeSeries: {:?}",
708                e
709            ))),
710        }
711    }
712
713    /// Get values
714    #[getter]
715    fn values(&self) -> Vec<f64> {
716        self.inner.values().to_vec()
717    }
718
719    /// Get length
720    #[getter]
721    fn len(&self) -> usize {
722        self.inner.len()
723    }
724
725    /// Check if empty
726    fn is_empty(&self) -> bool {
727        self.inner.is_empty()
728    }
729
730    /// Compute mean
731    fn mean(&self) -> Option<f64> {
732        Some(self.inner.stats().mean)
733    }
734
735    /// Compute standard deviation
736    fn std(&self, _ddof: f64) -> Option<f64> {
737        // Note: ddof is ignored for now, uses population std
738        Some(self.inner.stats().std)
739    }
740
741    /// Compute variance
742    fn var(&self, _ddof: f64) -> Option<f64> {
743        // Note: ddof is ignored for now, uses population variance
744        Some(self.inner.stats().variance)
745    }
746
747    /// String representation
748    fn __repr__(&self) -> String {
749        format!("TimeSeries(len={})", self.len())
750    }
751}
752
753/// Python wrapper for ARIMA builder
754#[pyclass(name = "ARIMA")]
755struct PyARIMA {
756    builder: Option<so_tsa::arima::ARIMABuilder>,
757}
758
759#[pymethods]
760impl PyARIMA {
761    /// Create a new ARIMA model
762    #[new]
763    fn new(p: usize, d: usize, q: usize) -> Self {
764        use so_tsa::arima::ARIMABuilder;
765        PyARIMA {
766            builder: Some(ARIMABuilder::new(p, d, q)),
767        }
768    }
769
770    /// Set seasonal parameters
771    fn seasonal(&mut self, _p: usize, _d: usize, _q: usize, _s: usize) -> PyResult<()> {
772        // SARIMA not fully implemented in Python bindings yet
773        // For now, just do nothing
774        // TODO: Implement proper SARIMA support
775        Ok(())
776    }
777
778    /// Include constant term
779    fn with_constant(&mut self, include: bool) -> PyResult<()> {
780        if let Some(builder) = self.builder.take() {
781            self.builder = Some(builder.with_constant(include));
782        }
783        Ok(())
784    }
785
786    /// Set estimation method
787    fn method(&mut self, method: String) -> PyResult<()> {
788        use so_tsa::arima::EstimationMethod;
789        if let Some(builder) = self.builder.take() {
790            let est_method = match method.to_lowercase().as_str() {
791                "css" => EstimationMethod::CSS,
792                "ml" => EstimationMethod::ML,
793                "exactml" => EstimationMethod::ExactML,
794                _ => EstimationMethod::CSS,
795            };
796            self.builder = Some(builder.method(est_method));
797        }
798        Ok(())
799    }
800
801    /// Set maximum iterations
802    fn max_iter(&mut self, max_iter: usize) -> PyResult<()> {
803        if let Some(builder) = self.builder.take() {
804            self.builder = Some(builder.max_iter(max_iter));
805        }
806        Ok(())
807    }
808
809    /// Set convergence tolerance
810    fn tol(&mut self, tol: f64) -> PyResult<()> {
811        if let Some(builder) = self.builder.take() {
812            self.builder = Some(builder.tol(tol));
813        }
814        Ok(())
815    }
816
817    /// Fit the ARIMA model
818    ///
819    /// Accepts multiple input types:
820    /// - PyTimeSeries object
821    /// - List of floats (Vec<f64>)
822    /// - Any object convertible to a list of floats
823    fn fit(&mut self, py: Python, data: Py<PyAny>) -> PyResult<PyARIMAResults> {
824        // Get reference to Python object
825        let data_ref = data.bind(py);
826
827        // Try to convert input to TimeSeries
828        let timeseries = if let Ok(ts) = data_ref.extract::<PyRef<PyTimeSeries>>() {
829            // Already a TimeSeries
830            ts.inner.clone()
831        } else if let Ok(vec) = data_ref.extract::<Vec<f64>>() {
832            // Vector of floats - create TimeSeries with index as timestamps
833            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
834            let values_array = ndarray::Array1::from_vec(vec);
835            TimeSeries::new("series", timestamps, values_array, None).map_err(|e| {
836                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
837            })?
838        } else if let Ok(list) = data_ref.cast::<PyList>() {
839            // Python list - extract as floats
840            let mut vec = Vec::with_capacity(list.len());
841            for i in 0..list.len() {
842                let item = list.get_item(i)?;
843                let val: f64 = item
844                    .extract()
845                    .map_err(|_| PyValueError::new_err("List must contain only numeric values"))?;
846                vec.push(val);
847            }
848            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
849            let values_array = ndarray::Array1::from_vec(vec);
850            TimeSeries::new("series", timestamps, values_array, None).map_err(|e| {
851                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
852            })?
853        } else {
854            return Err(PyValueError::new_err(
855                "Input must be a TimeSeries, list of floats, or convertible to list of floats",
856            ));
857        };
858
859        if let Some(builder) = self.builder.take() {
860            match builder.fit(&timeseries) {
861                Ok(results) => Ok(PyARIMAResults { inner: results }),
862                Err(e) => Err(PyRuntimeError::new_err(format!(
863                    "ARIMA fitting failed: {:?}",
864                    e
865                ))),
866            }
867        } else {
868            Err(PyRuntimeError::new_err("ARIMA builder not available"))
869        }
870    }
871}
872
873/// Python wrapper for ARIMA results
874#[pyclass(name = "ARIMAResults")]
875struct PyARIMAResults {
876    inner: ARIMAResults,
877}
878
879#[pymethods]
880impl PyARIMAResults {
881    /// Get AR coefficients
882    #[getter]
883    fn ar_coef(&self) -> Option<Vec<f64>> {
884        self.inner.ar_coef.as_ref().map(|coef| coef.to_vec())
885    }
886
887    /// Get MA coefficients
888    #[getter]
889    fn ma_coef(&self) -> Option<Vec<f64>> {
890        self.inner.ma_coef.as_ref().map(|coef| coef.to_vec())
891    }
892
893    /// Get constant term
894    #[getter]
895    fn constant(&self) -> Option<f64> {
896        self.inner.constant
897    }
898
899    /// Get AIC
900    #[getter]
901    fn aic(&self) -> f64 {
902        self.inner.aic
903    }
904
905    /// Get BIC
906    #[getter]
907    fn bic(&self) -> f64 {
908        self.inner.bic
909    }
910
911    /// Get log-likelihood
912    #[getter]
913    fn log_likelihood(&self) -> f64 {
914        self.inner.log_likelihood
915    }
916
917    /// Get sigma2 (innovation variance)
918    #[getter]
919    fn sigma2(&self) -> f64 {
920        self.inner.sigma2
921    }
922
923    /// Get number of observations
924    #[getter]
925    fn n_obs(&self) -> usize {
926        self.inner.n_obs
927    }
928
929    /// Forecast future values
930    fn forecast(&self, steps: usize) -> Vec<f64> {
931        // Simple forecast - in practice would use proper forecasting method
932        // For now, return last value repeated
933        let last_value = if self.inner.fitted.len() > 0 {
934            self.inner.fitted[self.inner.fitted.len() - 1]
935        } else {
936            0.0
937        };
938        vec![last_value; steps]
939    }
940
941    /// Get fitted values
942    #[getter]
943    fn fitted(&self) -> Vec<f64> {
944        self.inner.fitted.to_vec()
945    }
946
947    /// Get residuals
948    #[getter]
949    fn residuals(&self) -> Vec<f64> {
950        self.inner.residuals.to_vec()
951    }
952
953    /// Get summary string
954    fn summary(&self) -> String {
955        format!(
956            "ARIMA Results:\n  AIC: {:.2}\n  BIC: {:.2}\n  Log-Likelihood: {:.2}\n  Sigma2: {:.4}",
957            self.aic(),
958            self.bic(),
959            self.log_likelihood(),
960            self.sigma2()
961        )
962    }
963}
964
965/// Python wrapper for GARCH model
966#[pyclass(name = "GARCH")]
967struct PyGARCH {
968    builder: Option<so_tsa::garch::GARCHBuilder>,
969}
970
971#[pymethods]
972impl PyGARCH {
973    /// Create a new GARCH model
974    #[new]
975    fn new(p: usize, q: usize) -> Self {
976        use so_tsa::garch::GARCHBuilder;
977        PyGARCH {
978            builder: Some(GARCHBuilder::new(p, q)),
979        }
980    }
981
982    /// Create an ARCH model (GARCH with p=0)
983    #[staticmethod]
984    fn arch(q: usize) -> Self {
985        use so_tsa::garch::GARCHBuilder;
986        PyGARCH {
987            builder: Some(GARCHBuilder::arch(q)),
988        }
989    }
990
991    /// Set distribution for innovations
992    fn distribution(&mut self, distribution: String) -> PyResult<()> {
993        use so_tsa::garch::GARCHDistribution;
994        if let Some(builder) = self.builder.take() {
995            let dist = match distribution.to_lowercase().as_str() {
996                "normal" => GARCHDistribution::Normal,
997                "t" | "studentst" => GARCHDistribution::StudentsT(5.0), // Default df=5.0
998                "ged" => GARCHDistribution::GED(1.5),                   // Default shape=1.5
999                _ => GARCHDistribution::Normal,
1000            };
1001            self.builder = Some(builder.distribution(dist));
1002        }
1003        Ok(())
1004    }
1005
1006    /// Fit the GARCH model to residuals
1007    ///
1008    /// Accepts multiple input types:
1009    /// - List of floats (Vec<f64>) - residuals
1010    /// - PyTimeSeries object
1011    /// - Any object convertible to a list of floats
1012    fn fit(&mut self, py: Python, data: Py<PyAny>) -> PyResult<PyGARCHResults> {
1013        // Get reference to Python object
1014        let data_ref = data.bind(py);
1015
1016        // Try to convert input to TimeSeries
1017        let timeseries = if let Ok(ts) = data_ref.extract::<PyRef<PyTimeSeries>>() {
1018            // Already a TimeSeries
1019            ts.inner.clone()
1020        } else if let Ok(vec) = data_ref.extract::<Vec<f64>>() {
1021            // Vector of floats - create TimeSeries with index as timestamps
1022            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
1023            let values_array = ndarray::Array1::from_vec(vec);
1024            TimeSeries::new("residuals", timestamps, values_array, None).map_err(|e| {
1025                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
1026            })?
1027        } else if let Ok(list) = data_ref.cast::<PyList>() {
1028            // Python list - extract as floats
1029            let mut vec = Vec::with_capacity(list.len());
1030            for i in 0..list.len() {
1031                let item = list.get_item(i)?;
1032                let val: f64 = item
1033                    .extract()
1034                    .map_err(|_| PyValueError::new_err("List must contain only numeric values"))?;
1035                vec.push(val);
1036            }
1037            let timestamps: Vec<i64> = (0..vec.len() as i64).collect();
1038            let values_array = ndarray::Array1::from_vec(vec);
1039            TimeSeries::new("residuals", timestamps, values_array, None).map_err(|e| {
1040                PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e))
1041            })?
1042        } else {
1043            return Err(PyValueError::new_err(
1044                "Input must be a TimeSeries, list of floats, or convertible to list of floats",
1045            ));
1046        };
1047
1048        if let Some(builder) = self.builder.take() {
1049            match builder.fit(&timeseries) {
1050                Ok(results) => Ok(PyGARCHResults { inner: results }),
1051                Err(e) => Err(PyRuntimeError::new_err(format!(
1052                    "GARCH fitting failed: {:?}",
1053                    e
1054                ))),
1055            }
1056        } else {
1057            Err(PyRuntimeError::new_err("GARCH builder not available"))
1058        }
1059    }
1060}
1061
1062/// Python wrapper for GARCH results
1063#[pyclass(name = "GARCHResults")]
1064struct PyGARCHResults {
1065    inner: GARCHResults,
1066}
1067
1068#[pymethods]
1069impl PyGARCHResults {
1070    /// Get omega (constant in variance equation)
1071    #[getter]
1072    fn omega(&self) -> f64 {
1073        self.inner.omega
1074    }
1075
1076    /// Get ARCH coefficients (α₁, ..., α_q)
1077    #[getter]
1078    fn arch_coef(&self) -> Vec<f64> {
1079        self.inner.arch_coef.to_vec()
1080    }
1081
1082    /// Get GARCH coefficients (β₁, ..., β_p)
1083    #[getter]
1084    fn garch_coef(&self) -> Vec<f64> {
1085        self.inner.garch_coef.to_vec()
1086    }
1087
1088    /// Get mu (constant in mean equation, if included)
1089    #[getter]
1090    fn mu(&self) -> Option<f64> {
1091        self.inner.mu
1092    }
1093
1094    /// Get degrees of freedom (for t/GED distributions)
1095    #[getter]
1096    fn df(&self) -> Option<f64> {
1097        self.inner.df
1098    }
1099
1100    /// Get AIC
1101    #[getter]
1102    fn aic(&self) -> f64 {
1103        self.inner.aic
1104    }
1105
1106    /// Get BIC
1107    #[getter]
1108    fn bic(&self) -> f64 {
1109        self.inner.bic
1110    }
1111
1112    /// Get log-likelihood
1113    #[getter]
1114    fn log_likelihood(&self) -> f64 {
1115        self.inner.log_likelihood
1116    }
1117
1118    /// Get number of observations
1119    #[getter]
1120    fn n_obs(&self) -> usize {
1121        self.inner.n_obs
1122    }
1123
1124    /// Get residuals (εₜ)
1125    #[getter]
1126    fn residuals(&self) -> Vec<f64> {
1127        self.inner.residuals.to_vec()
1128    }
1129
1130    /// Get conditional variances (σₜ²)
1131    #[getter]
1132    fn conditional_variances(&self) -> Vec<f64> {
1133        self.inner.conditional_variances.to_vec()
1134    }
1135
1136    /// Get standardized residuals (zₜ = εₜ/σₜ)
1137    #[getter]
1138    fn standardized_residuals(&self) -> Vec<f64> {
1139        self.inner.standardized_residuals.to_vec()
1140    }
1141
1142    /// Get summary string
1143    fn summary(&self) -> String {
1144        format!(
1145            "GARCH Results:\n  AIC: {:.2}\n  BIC: {:.2}\n  Log-Likelihood: {:.2}",
1146            self.aic(),
1147            self.bic(),
1148            self.log_likelihood()
1149        )
1150    }
1151}
1152
1153/// StatOxide Python module
1154#[pymodule]
1155#[pyo3(name = "statoxide")]
1156fn statoxide(m: &Bound<'_, PyModule>) -> PyResult<()> {
1157    // Register core classes
1158    m.add_class::<PySeries>()?;
1159    m.add_class::<PyDataFrame>()?;
1160    m.add_class::<PyFormula>()?;
1161
1162    // Register GLM classes
1163    m.add_class::<PyFamily>()?;
1164    m.add_class::<PyLink>()?;
1165    m.add_class::<PyGLMBuilder>()?;
1166    m.add_class::<PyGLM>()?;
1167    m.add_class::<PyGLMResults>()?;
1168
1169    // Register TSA classes
1170    m.add_class::<PyTimeSeries>()?;
1171    m.add_class::<PyARIMA>()?;
1172    m.add_class::<PyARIMAResults>()?;
1173    m.add_class::<PyGARCH>()?;
1174    m.add_class::<PyGARCHResults>()?;
1175
1176    // Basic functions module
1177    let stats_module = PyModule::new(m.py(), "stats")?;
1178    stats_module.add_function(wrap_pyfunction!(mean, &stats_module)?)?;
1179    stats_module.add_function(wrap_pyfunction!(std_dev, &stats_module)?)?;
1180    stats_module.add_function(wrap_pyfunction!(correlation, &stats_module)?)?;
1181    stats_module.add_function(wrap_pyfunction!(descriptive_summary, &stats_module)?)?;
1182    // Statistical tests
1183    stats_module.add_function(wrap_pyfunction!(t_test_one_sample, &stats_module)?)?;
1184    stats_module.add_function(wrap_pyfunction!(t_test_two_sample, &stats_module)?)?;
1185    stats_module.add_function(wrap_pyfunction!(t_test_paired, &stats_module)?)?;
1186    stats_module.add_function(wrap_pyfunction!(
1187        chi_square_test_independence,
1188        &stats_module
1189    )?)?;
1190    stats_module.add_function(wrap_pyfunction!(anova_one_way, &stats_module)?)?;
1191    stats_module.add_function(wrap_pyfunction!(shapiro_wilk_test, &stats_module)?)?;
1192    m.add_submodule(&stats_module)?;
1193
1194    // Models module
1195    let models_module = PyModule::new(m.py(), "models")?;
1196    models_module.add_function(wrap_pyfunction!(linear_regression, &models_module)?)?;
1197    models_module.add_function(wrap_pyfunction!(mixed_effects, &models_module)?)?;
1198    models_module.add_class::<PyFamily>()?;
1199    models_module.add_class::<PyLink>()?;
1200    models_module.add_class::<PyGLMBuilder>()?;
1201    models_module.add_class::<PyGLM>()?;
1202    models_module.add_class::<PyGLMResults>()?;
1203    m.add_submodule(&models_module)?;
1204
1205    // TSA module
1206    let tsa_module = PyModule::new(m.py(), "tsa")?;
1207    tsa_module.add_function(wrap_pyfunction!(fit_arima, &tsa_module)?)?;
1208    tsa_module.add_class::<PyTimeSeries>()?;
1209    tsa_module.add_class::<PyARIMA>()?;
1210    tsa_module.add_class::<PyARIMAResults>()?;
1211    tsa_module.add_class::<PyGARCH>()?;
1212    tsa_module.add_class::<PyGARCHResults>()?;
1213    m.add_submodule(&tsa_module)?;
1214
1215    // Utilities module
1216    let utils_module = PyModule::new(m.py(), "utils")?;
1217    utils_module.add_function(wrap_pyfunction!(train_test_split, &utils_module)?)?;
1218    m.add_submodule(&utils_module)?;
1219
1220    // Top-level functions
1221    m.add_function(wrap_pyfunction!(version, m)?)?;
1222
1223    // Add commonly used functions to top level for convenience
1224    m.add_function(wrap_pyfunction!(mean, m)?)?;
1225    m.add_function(wrap_pyfunction!(std_dev, m)?)?;
1226    m.add_function(wrap_pyfunction!(correlation, m)?)?;
1227    m.add_function(wrap_pyfunction!(descriptive_summary, m)?)?;
1228    m.add_function(wrap_pyfunction!(train_test_split, m)?)?;
1229
1230    Ok(())
1231}
1232
1233/// Compute mean of data
1234#[pyfunction]
1235fn mean(data: Vec<f64>) -> PyResult<f64> {
1236    if data.is_empty() {
1237        return Ok(f64::NAN);
1238    }
1239    Ok(data.iter().sum::<f64>() / data.len() as f64)
1240}
1241
1242/// Compute standard deviation of data
1243#[pyfunction]
1244fn std_dev(data: Vec<f64>) -> PyResult<f64> {
1245    if data.len() < 2 {
1246        return Ok(f64::NAN);
1247    }
1248    let mean_val = mean(data.clone())?;
1249    let variance =
1250        data.iter().map(|&x| (x - mean_val).powi(2)).sum::<f64>() / (data.len() as f64 - 1.0);
1251    Ok(variance.sqrt())
1252}
1253
1254/// Compute correlation between two variables
1255#[pyfunction]
1256fn correlation(x: Vec<f64>, y: Vec<f64>) -> PyResult<f64> {
1257    if x.len() != y.len() || x.len() < 2 {
1258        return Ok(f64::NAN);
1259    }
1260
1261    let x_mean = x.iter().sum::<f64>() / x.len() as f64;
1262    let y_mean = y.iter().sum::<f64>() / y.len() as f64;
1263
1264    let covariance = x
1265        .iter()
1266        .zip(y.iter())
1267        .map(|(&xi, &yi)| (xi - x_mean) * (yi - y_mean))
1268        .sum::<f64>()
1269        / (x.len() as f64 - 1.0);
1270
1271    let x_std = std_dev(x.clone())?;
1272    let y_std = std_dev(y.clone())?;
1273
1274    if x_std == 0.0 || y_std == 0.0 {
1275        Ok(0.0)
1276    } else {
1277        Ok(covariance / (x_std * y_std))
1278    }
1279}
1280
1281/// Compute descriptive statistics summary
1282#[pyfunction]
1283fn descriptive_summary(py: Python, data: Vec<f64>) -> PyResult<Py<PyDict>> {
1284    let dict = PyDict::new(py);
1285    dict.set_item("count", data.len())?;
1286
1287    if !data.is_empty() {
1288        let mean_val = data.iter().sum::<f64>() / data.len() as f64;
1289        dict.set_item("mean", mean_val)?;
1290
1291        if data.len() >= 2 {
1292            let variance = data.iter().map(|&x| (x - mean_val).powi(2)).sum::<f64>()
1293                / (data.len() as f64 - 1.0);
1294            dict.set_item("std", variance.sqrt())?;
1295            dict.set_item("variance", variance)?;
1296
1297            // Min and max
1298            if let (Some(min), Some(max)) = (
1299                data.iter().min_by(|a, b| a.partial_cmp(b).unwrap()),
1300                data.iter().max_by(|a, b| a.partial_cmp(b).unwrap()),
1301            ) {
1302                dict.set_item("min", *min)?;
1303                dict.set_item("max", *max)?;
1304            }
1305        }
1306    }
1307
1308    Ok(dict.into())
1309}
1310
1311/// Convert alternative hypothesis string to Rust enum
1312fn parse_alternative(alternative: &str) -> PyResult<Alternative> {
1313    match alternative.to_lowercase().as_str() {
1314        "two-sided" | "two_sided" | "two.sided" => Ok(Alternative::TwoSided),
1315        "less" | "smaller" => Ok(Alternative::Less),
1316        "greater" | "larger" => Ok(Alternative::Greater),
1317        _ => Err(PyValueError::new_err(
1318            "alternative must be 'two-sided', 'less', or 'greater'",
1319        )),
1320    }
1321}
1322
1323/// Convert TestResult to Python dictionary
1324fn test_result_to_dict(py: Python, result: &TestResult) -> PyResult<Py<PyDict>> {
1325    let dict = PyDict::new(py);
1326    dict.set_item("statistic", result.statistic)?;
1327    dict.set_item("p_value", result.p_value)?;
1328    dict.set_item("df", result.df)?;
1329    dict.set_item(
1330        "alternative",
1331        match result.alternative {
1332            Alternative::TwoSided => "two-sided",
1333            Alternative::Less => "less",
1334            Alternative::Greater => "greater",
1335        },
1336    )?;
1337    dict.set_item("null_value", result.null_value)?;
1338    Ok(dict.into())
1339}
1340
1341/// One-sample t-test
1342#[pyfunction]
1343fn t_test_one_sample(
1344    py: Python,
1345    data: Vec<f64>,
1346    mu: f64,
1347    alternative: String,
1348) -> PyResult<Py<PyDict>> {
1349    let data_array = ndarray::Array1::from_vec(data);
1350    let alt = parse_alternative(&alternative)?;
1351    let result = t_test_one_sample_rs(&data_array, mu, alt)
1352        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {:?}", e)))?;
1353    test_result_to_dict(py, &result)
1354}
1355
1356/// Two-sample t-test (independent samples, equal variance assumed)
1357#[pyfunction]
1358fn t_test_two_sample(
1359    py: Python,
1360    x: Vec<f64>,
1361    y: Vec<f64>,
1362    alternative: String,
1363) -> PyResult<Py<PyDict>> {
1364    let x_array = ndarray::Array1::from_vec(x);
1365    let y_array = ndarray::Array1::from_vec(y);
1366    let alt = parse_alternative(&alternative)?;
1367    let result = t_test_two_sample_rs(&x_array, &y_array, alt)
1368        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {:?}", e)))?;
1369    test_result_to_dict(py, &result)
1370}
1371
1372/// Paired t-test
1373#[pyfunction]
1374fn t_test_paired(
1375    py: Python,
1376    x: Vec<f64>,
1377    y: Vec<f64>,
1378    alternative: String,
1379) -> PyResult<Py<PyDict>> {
1380    let x_array = ndarray::Array1::from_vec(x);
1381    let y_array = ndarray::Array1::from_vec(y);
1382    let alt = parse_alternative(&alternative)?;
1383    let result = t_test_paired_rs(&x_array, &y_array, alt)
1384        .map_err(|e| PyRuntimeError::new_err(format!("paired t-test failed: {:?}", e)))?;
1385    test_result_to_dict(py, &result)
1386}
1387
1388/// Chi-square test of independence
1389#[pyfunction]
1390fn chi_square_test_independence(py: Python, observed: Vec<Vec<f64>>) -> PyResult<Py<PyDict>> {
1391    // Convert to ndarray matrix
1392    let n_rows = observed.len();
1393    if n_rows == 0 {
1394        return Err(PyValueError::new_err("observed must have at least one row"));
1395    }
1396    let n_cols = observed[0].len();
1397    let mut flat = Vec::new();
1398    for row in observed {
1399        if row.len() != n_cols {
1400            return Err(PyValueError::new_err("All rows must have same length"));
1401        }
1402        flat.extend(row);
1403    }
1404    let matrix = ndarray::Array2::from_shape_vec((n_rows, n_cols), flat)
1405        .map_err(|e| PyValueError::new_err(format!("Failed to create matrix: {}", e)))?;
1406
1407    let result = chi_square_test_independence_rs(&matrix)
1408        .map_err(|e| PyRuntimeError::new_err(format!("chi-square test failed: {:?}", e)))?;
1409    test_result_to_dict(py, &result)
1410}
1411
1412/// One-way ANOVA
1413#[pyfunction]
1414fn anova_one_way(py: Python, groups: Vec<Vec<f64>>) -> PyResult<Py<PyDict>> {
1415    let arrays: Vec<_> = groups
1416        .into_iter()
1417        .map(|g| ndarray::Array1::from_vec(g))
1418        .collect();
1419    let result = anova_one_way_rs(&arrays)
1420        .map_err(|e| PyRuntimeError::new_err(format!("ANOVA failed: {:?}", e)))?;
1421    test_result_to_dict(py, &result)
1422}
1423
1424/// Shapiro-Wilk test for normality
1425#[pyfunction]
1426fn shapiro_wilk_test(py: Python, data: Vec<f64>) -> PyResult<Py<PyDict>> {
1427    let data_array = ndarray::Array1::from_vec(data);
1428    let result = shapiro_wilk_test_rs(&data_array)
1429        .map_err(|e| PyRuntimeError::new_err(format!("Shapiro-Wilk test failed: {:?}", e)))?;
1430    test_result_to_dict(py, &result)
1431}
1432
1433/// Fit linear regression model
1434#[pyfunction]
1435fn linear_regression(py: Python, x: Vec<Vec<f64>>, y: Vec<f64>) -> PyResult<Py<PyDict>> {
1436    // Convert to ndarray
1437    let n_rows = x.len();
1438    if n_rows == 0 {
1439        return Err(PyValueError::new_err("X must have at least one row"));
1440    }
1441    if n_rows != y.len() {
1442        return Err(PyValueError::new_err(
1443            "X and y must have same number of rows",
1444        ));
1445    }
1446
1447    let n_cols = x[0].len();
1448    let x_array =
1449        ndarray::Array2::from_shape_vec((n_rows, n_cols), x.into_iter().flatten().collect())
1450            .map_err(|e| PyValueError::new_err(format!("Failed to create X matrix: {}", e)))?;
1451
1452    let y_array = ndarray::Array1::from_vec(y);
1453
1454    // Fit OLS model
1455    let model = OLS::new();
1456    match model.fit(&x_array, &y_array) {
1457        Ok(results) => {
1458            let dict = PyDict::new(py);
1459            dict.set_item("coefficients", results.coefficients.to_vec())?;
1460            dict.set_item("r_squared", results.r_squared)?;
1461            dict.set_item("r_squared_adj", results.r_squared_adj)?;
1462            dict.set_item("sigma", results.sigma)?;
1463            dict.set_item("df_residual", results.df_residual)?;
1464            dict.set_item("df_model", results.df_model)?;
1465
1466            if let Some(std_errors) = &results.std_errors {
1467                dict.set_item("std_errors", std_errors.to_vec())?;
1468            }
1469
1470            if let Some(t_values) = &results.t_values {
1471                dict.set_item("t_values", t_values.to_vec())?;
1472            }
1473
1474            if let Some(p_values) = &results.p_values {
1475                dict.set_item("p_values", p_values.to_vec())?;
1476            }
1477
1478            if let Some(f_statistic) = &results.f_statistic {
1479                dict.set_item("f_statistic", f_statistic)?;
1480            }
1481
1482            if let Some(f_p_value) = &results.f_p_value {
1483                dict.set_item("f_p_value", f_p_value)?;
1484            }
1485
1486            Ok(dict.into())
1487        }
1488        Err(e) => Err(PyRuntimeError::new_err(format!(
1489            "Linear regression failed: {:?}",
1490            e
1491        ))),
1492    }
1493}
1494
1495/// Fit mixed effects model (linear mixed model)
1496#[pyfunction]
1497fn mixed_effects(py: Python, data: &PyDataFrame, formula: String) -> PyResult<Py<PyDict>> {
1498    // TODO: Implement actual mixed effects model using so-models
1499    let dict = PyDict::new(py);
1500
1501    // For now, return a placeholder result
1502    dict.set_item("fixed_effects", Vec::<f64>::new())?;
1503    dict.set_item("random_variances", Vec::<f64>::new())?;
1504    dict.set_item("residual_variance", 0.0)?;
1505    dict.set_item("log_likelihood", 0.0)?;
1506    dict.set_item("aic", 0.0)?;
1507    dict.set_item("bic", 0.0)?;
1508    dict.set_item(
1509        "message",
1510        "Mixed effects model placeholder - implement using so-models::mixed",
1511    )?;
1512    dict.set_item("formula", formula)?;
1513    dict.set_item("n_obs", data.n_rows())?;
1514
1515    Ok(dict.into())
1516}
1517
1518/// Fit ARIMA model
1519#[pyfunction]
1520fn fit_arima(py: Python, data: Vec<f64>, p: usize, d: usize, q: usize) -> PyResult<Py<PyDict>> {
1521    // Create TimeSeries from data with index as timestamps
1522    let values = ndarray::Array1::from_vec(data.clone());
1523    let timestamps: Vec<i64> = (0..data.len() as i64).collect();
1524    let ts = TimeSeries::new("series", timestamps, values, None)
1525        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create TimeSeries: {:?}", e)))?;
1526
1527    // Create and fit ARIMA model using builder
1528    use so_tsa::arima::ARIMABuilder;
1529    let builder = ARIMABuilder::new(p, d, q);
1530
1531    match builder.fit(&ts) {
1532        Ok(results) => {
1533            let dict = PyDict::new(py);
1534            dict.set_item("order", (p, d, q))?;
1535
1536            // Collect coefficients
1537            let mut coefficients = Vec::new();
1538            if let Some(ar_coef) = &results.ar_coef {
1539                coefficients.extend_from_slice(&ar_coef.to_vec());
1540            }
1541            if let Some(ma_coef) = &results.ma_coef {
1542                coefficients.extend_from_slice(&ma_coef.to_vec());
1543            }
1544            if let Some(constant) = results.constant {
1545                coefficients.push(constant);
1546            }
1547            dict.set_item("coefficients", coefficients)?;
1548
1549            dict.set_item("aic", results.aic)?;
1550            dict.set_item("bic", results.bic)?;
1551            dict.set_item("log_likelihood", results.log_likelihood)?;
1552            dict.set_item("sigma2", results.sigma2)?;
1553
1554            // Add fitted values and residuals
1555            dict.set_item("fitted", results.fitted.to_vec())?;
1556            dict.set_item("residuals", results.residuals.to_vec())?;
1557            dict.set_item("n_obs", results.n_obs)?;
1558
1559            Ok(dict.into())
1560        }
1561        Err(e) => Err(PyRuntimeError::new_err(format!(
1562            "ARIMA fitting failed: {:?}",
1563            e
1564        ))),
1565    }
1566}
1567
1568/// Split data into training and test sets
1569#[pyfunction]
1570fn train_test_split(data: Vec<f64>, test_size: f64) -> PyResult<(Vec<f64>, Vec<f64>)> {
1571    if test_size <= 0.0 || test_size >= 1.0 {
1572        return Err(PyValueError::new_err("test_size must be between 0 and 1"));
1573    }
1574
1575    let split_idx = (data.len() as f64 * (1.0 - test_size)) as usize;
1576    let train = data[..split_idx].to_vec();
1577    let test = data[split_idx..].to_vec();
1578
1579    Ok((train, test))
1580}
1581
1582/// Get library version
1583#[pyfunction]
1584fn version() -> PyResult<String> {
1585    Ok("0.2.0".to_string())
1586}