trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::{IntoPyArray, PyArray2};
7use pyo3::{
8    exceptions::PyTypeError,
9    prelude::*,
10    types::{PyBytes, PyList},
11    IntoPyObjectExt,
12};
13use tracing::trace;
14
15impl DataFrame {
16    fn select_data(
17        &self,
18        keys: Option<Vec<String>>,
19        transposed: Option<bool>,
20    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
21        let keys = keys
22            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
23            .unwrap_or(self.keys());
24        if transposed.unwrap_or(false) {
25            self.select(Some(keys.as_slice()))
26        } else {
27            self.select_transposed(Some(keys.as_slice()))
28        }
29    }
30}
31
32enum DfOrDict {
33    DataFrame(DataFrame),
34    Dict(HashMap<String, DataValue>),
35}
36
37impl DfOrDict {
38    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
39        if let Ok(df) = object.extract::<DataFrame>() {
40            Ok(DfOrDict::DataFrame(df))
41        } else {
42            let dict: HashMap<String, DataValue> = object.extract()?;
43            Ok(DfOrDict::Dict(dict))
44        }
45    }
46}
47
48#[pymethods]
49impl DataFrame {
50    /// Create a new empty DataFrame.
51    #[new]
52    pub fn init() -> Self {
53        Self::default()
54    }
55
56    /// Create a DataFrame from a polars dataframe in python.
57    /// ```text
58    /// import polars as pl
59    /// df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
60    /// tei_df = tdf.DataFrame.from_polars(df)
61    /// ```
62    #[cfg(feature = "polars-df")]
63    #[staticmethod]
64    pub fn from_polars(df: pyo3_polars::PyDataFrame) -> Self {
65        df.0.into()
66    }
67
68    /// Create a DataFrame from a dictionary.
69    /// ```text
70    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
71    /// ```
72    #[staticmethod]
73    pub fn from_dict(df: HashMap<String, Vec<DataValue>>) -> Self {
74        let mut result_df: Vec<(Key, Vec<DataValue>)> = Vec::new();
75        for (key, value) in df.into_iter() {
76            result_df.push((key.as_str().into(), value));
77        }
78        result_df.into()
79    }
80
81    /// Returns the keys of the DataFrame.
82    pub fn keys(&self) -> Vec<Key> {
83        self.dataframe.keys().to_vec()
84    }
85
86    /// Convert the DataFrame to polars DataFrame.
87    /// This requires the `polars-df` feature to be enabled.
88    /// ```text
89    /// import polars as pl
90    /// original_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
91    /// df = tdf.DataFrame.from_polars(original_df)
92    /// polars_df = df.as_polars()
93    /// assert polars_df.frame_equal(original_df)
94    /// ```  
95    #[cfg(feature = "polars-df")]
96    #[pyo3(name = "as_polars")]
97    pub fn py_as_polars(&self) -> PyResult<pyo3_polars::PyDataFrame> {
98        let df = self
99            .as_polars()
100            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
101        Ok(pyo3_polars::PyDataFrame(df))
102    }
103
104    /// Apply a function to the DataFrame.
105    /// The function should accept a DataFrame and return a DataFrame.
106    /// ```text
107    /// def my_function(df):
108    ///    # Perform some operations on the DataFrame
109    ///   return df
110    /// /// df = tdf.DataFrame.init()
111    /// df.apply(my_function)
112    /// ```
113    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
114        let df: DataFrame = pyo3::Python::attach(|py| {
115            let self_ = self
116                .clone()
117                .into_pyobject(py)
118                .expect("BUG: cannot convert to PyObject");
119            let result = function.call1((self_,)).expect("BUG: cannot call function");
120            result
121                .extract::<Bound<DataFrame>>()
122                .expect("BUG: cannot extract data frame")
123                .unbind()
124                .extract(py)
125                .expect("BUG: cannot extract data frame")
126        });
127        self.dataframe = df.dataframe;
128        Ok(())
129    }
130
131    /// Returns slice from dataframe as numpy.array of uint32 of the given keys.
132    /// If `transposed` is true, the keys will be transposed.
133    /// If `keys` is None, all keys will be used.
134    /// ```text
135    /// import numpy as np
136    /// df = tdf.DataFrame.init()
137    /// df.push({"key1": 1, "key2": 2})
138    /// df.push({"key1": 11, "key2": 21})
139    /// a_np = df.as_numpy_u32(['key1', 'key2'])
140    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint32))
141    /// ```
142    #[pyo3(signature = (keys=None, transposed=None))]
143    pub fn as_numpy_u32<'py>(
144        &self,
145        keys: Option<Vec<String>>,
146        transposed: Option<bool>,
147        py: Python<'py>,
148    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
149        let data = self
150            .select_data(keys, transposed)
151            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
152        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
153    }
154
155    /// Returns slice from dataframe as numpy.array of uint64 of the given keys.
156    /// If `transposed` is true, the keys will be transposed.
157    /// If `keys` is None, all keys will be used.
158    /// ```text
159    /// import numpy as np
160    /// df = tdf.DataFrame.init()
161    /// df.push({"key1": 1, "key2": 2})
162    /// df.push({"key1": 11, "key2": 21})
163    /// a_np = df.as_numpy_u64(['key1', 'key2'])
164    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint64))
165    /// ```
166    #[pyo3(signature = (keys=None, transposed=None))]
167    pub fn as_numpy_u64<'py>(
168        &self,
169        keys: Option<Vec<String>>,
170        transposed: Option<bool>,
171        py: Python<'py>,
172    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
173        let data = self
174            .select_data(keys, transposed)
175            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
176        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
177    }
178
179    /// Returns slice from dataframe as numpy.array of int32 of the given keys.
180    /// If `transposed` is true, the keys will be transposed.
181    /// If `keys` is None, all keys will be used.
182    /// ```text
183    /// import numpy as np
184    /// df = tdf.DataFrame.init()
185    /// df.push({"key1": 1, "key2": 2})
186    /// df.push({"key1": 11, "key2": 21})
187    /// a_np = df.as_numpy_i32(['key1', 'key2'])
188    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int32))
189    /// ```
190    #[pyo3(signature = (keys=None, transposed=None))]
191    pub fn as_numpy_i32<'py>(
192        &self,
193        keys: Option<Vec<String>>,
194        transposed: Option<bool>,
195        py: Python<'py>,
196    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
197        let data = self
198            .select_data(keys, transposed)
199            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
200        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
201    }
202
203    /// Returns slice from dataframe as numpy.array of int64 of the given keys.
204    /// If `transposed` is true, the keys will be transposed.
205    /// If `keys` is None, all keys will be used.
206    /// ```text
207    /// import numpy as np
208    /// df = tdf.DataFrame.init()
209    /// df.push({"key1": 1, "key2": 2})
210    /// df.push({"key1": 11, "key2": 21})
211    /// a_np = df.as_numpy_i64(['key1', 'key2'])
212    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int64))
213    /// ```
214    #[pyo3(signature = (keys=None, transposed=None))]
215    pub fn as_numpy_i64<'py>(
216        &self,
217        keys: Option<Vec<String>>,
218        transposed: Option<bool>,
219        py: Python<'py>,
220    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
221        let data = self
222            .select_data(keys, transposed)
223            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
224        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
225    }
226
227    /// Returns slice from dataframe as numpy.array of float32 of the given keys.
228    /// If `transposed` is true, the keys will be transposed.
229    /// If `keys` is None, all keys will be used.
230    /// ```text
231    /// import numpy as np
232    /// df = tdf.DataFrame.init()
233    /// df.push({"key1": 1, "key2": 2})
234    /// df.push({"key1": 11, "key2": 21})
235    /// a_np = df.as_numpy_f32(['key1', 'key2'])
236    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float32))
237    /// ```
238    #[pyo3(signature = (keys=None, transposed=None))]
239    pub fn as_numpy_f32<'py>(
240        &self,
241        keys: Option<Vec<String>>,
242        transposed: Option<bool>,
243        py: Python<'py>,
244    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
245        let data = self
246            .select_data(keys, transposed)
247            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
248        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
249    }
250
251    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
252    /// If `transposed` is true, the keys will be transposed.
253    /// If `keys` is None, all keys will be used.
254    /// ```text
255    /// import numpy as np
256    /// df = tdf.DataFrame.init()
257    /// df.push({"key1": 1, "key2": 2})
258    /// df.push({"key1": 11, "key2": 21})
259    /// a_np = df.as_numpy_f64(['key1', 'key2'])
260    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float64))
261    /// ```
262    #[pyo3(signature = (keys=None, transposed=None))]
263    pub fn as_numpy_f64<'py>(
264        &self,
265        keys: Option<Vec<String>>,
266        transposed: Option<bool>,
267        py: Python<'py>,
268    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
269        let data = self
270            .select_data(keys, transposed)
271            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
272        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
273    }
274
275    /// Returns slice from dataframe as numpy.array of objects of the given keys.
276    /// If `transposed` is true, the keys will be transposed.
277    /// If `keys` is None, all keys will be used.
278    /// ```text
279    /// import numpy as np
280    /// df = tdf.DataFrame.init()
281    /// df.push({"key1": "a", "key2": "b"})
282    /// df.push({"key1": "d", "key2": "c"})
283    /// a_np = df.as_numpy_f64(['key1', 'key2'])
284    /// assert np.array_equal(a_np, np.array([["a", d""], ["b", "c"]], dtype=np.object))
285    /// ```
286    #[pyo3(signature = (keys=None, transposed=None))]
287    pub fn as_numpy_str<'py>(
288        &self,
289        keys: Option<Vec<String>>,
290        transposed: Option<bool>,
291        py: Python<'py>,
292    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
293        self.as_numpy(keys, transposed, py)
294    }
295
296    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
297    /// If `transposed` is true, the keys will be transposed.
298    /// If `keys` is None, all keys will be used.
299    /// ```text
300    /// import numpy as np
301    /// df = tdf.DataFrame.init()
302    /// df.push({"key1": "a", "key2": 1})
303    /// df.push({"key1": "d", "key2": 2})
304    /// a_np = df.as_numpy_f64(['key1', 'key2'])
305    /// assert np.array_equal(a_np, np.array([["a", d""], [1, 2]], dtype=np.object))
306    /// ```
307    #[pyo3(signature = (keys=None, transposed=None))]
308    pub fn as_numpy<'py>(
309        &self,
310        keys: Option<Vec<String>>,
311        transposed: Option<bool>,
312        py: Python<'py>,
313    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
314        let data = self
315            .select_data(keys, transposed)
316            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
317        let data = data.mapv(|x| {
318            String::extract(&x)
319                .into_py_any(py)
320                .expect("cannot convert string to py object")
321        });
322        Ok(data.into_pyarray(py))
323    }
324
325    #[pyo3(name = "shrink")]
326    pub fn py_shrink(&mut self) {
327        self.dataframe.shrink();
328    }
329
330    #[pyo3(name = "add_metadata")]
331    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
332        self.metadata.insert(key, value);
333    }
334
335    #[pyo3(name = "get_metadata")]
336    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
337        self.metadata.get(key).cloned()
338    }
339
340    #[pyo3(name = "rename_key")]
341    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
342        // fixme this may have a problem when the type is different and checked
343        self.dataframe
344            .rename_key(key, new_name.into())
345            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
346    }
347
348    #[pyo3(name = "add_alias")]
349    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
350        self.dataframe
351            .add_alias(key, new_name)
352            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
353    }
354
355    /// Selects data from the DataFrame.
356    /// If `keys` is None, all keys will be used.
357    /// If `keys` is provided, only the specified keys will be selected.
358    /// Returns a list of lists, where each inner list represents a row of data.
359    /// ```text
360    /// import trs_dataframe as tdf
361    /// df = tdf.DataFrame.init()
362    /// df.push({"key1": 1, "key2": 2})
363    /// df.push({"key1": 11, "key2": 21})
364    /// # selected = df.select(["key1", "key2"])
365    /// # assert selected == [[1, 2], [11, 21]]
366    /// # selected = df.select()
367    #[pyo3(name = "select", signature = (keys=None, transposed=None))]
368    pub fn py_select<'py>(
369        &self,
370        py: Python<'py>,
371        keys: Option<Vec<String>>,
372        transposed: Option<bool>,
373    ) -> Result<Bound<'py, PyList>, PyErr> {
374        let keys = keys
375            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
376            .unwrap_or(self.keys());
377
378        let selected = if transposed.unwrap_or_default() {
379            self.select_transposed(Some(keys.as_slice()))
380                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
381        } else {
382            self.select(Some(keys.as_slice()))
383                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
384        };
385
386        let list = PyList::empty(py);
387        for rows in selected.rows() {
388            let row = PyList::empty(py);
389            for value in rows.iter() {
390                row.append(value.clone())
391                    .expect("BUG: cannot append to list");
392            }
393            list.append(row).expect("BUG: cannot append to list");
394        }
395        Ok(list)
396    }
397
398    /// Selects a column from the DataFrame.
399    /// If the column does not exist, it will raise a TypeError.
400    /// Returns a list of values in the selected column.
401    /// ```text
402    /// import trs_dataframe as tdf
403    /// df = tdf.DataFrame.init()
404    /// df.push({"key1": 1, "key2": 2})
405    /// df.push({"key1": 11, "key2": 21})
406    /// # selected = df.select_column("key1")
407    /// # assert selected == [1, 11]
408    /// # selected = df.select_column("key2")
409    /// # assert selected == [2, 21]
410    /// # selected = df.select_column("non_existing_key")  # Raises TypeError
411    /// ```
412    #[pyo3(name = "select_column")]
413    pub fn py_select_column<'py>(
414        &self,
415        py: Python<'py>,
416        key: String,
417    ) -> Result<Bound<'py, PyList>, PyErr> {
418        let selected = self
419            .select_column(Key::from(key))
420            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
421
422        let list = PyList::empty(py);
423        for x in selected.to_vec().into_iter() {
424            list.append(x)?;
425        }
426
427        Ok(list)
428    }
429
430    /// Joins the current DataFrame with another DataFrame.
431    /// The join type is specified by the `join_type` parameter.
432    /// see [`JoinRelation`] for available join types.
433    /// ```text
434    /// import trs_dataframe as tdf
435    /// df1 = tdf.DataFrame.init()
436    /// df1.push({"key1": 1, "key2": 2})
437    /// df1.push({"key1": 11, "key2": 21})
438    /// df2 = tdf.DataFrame.init()
439    /// df2.push({"key1": 1, "key2": 3})
440    /// df2.push({"key1": 11, "key2": 23})
441    /// df1.join(df2, tei.JoinRelation.extend())
442    /// assert df1.select(["key1", "key2"]) == [[1, 2], [11, 21], [1, 3], [11, 23]]
443    /// ```
444    #[pyo3(name = "join")]
445    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
446        self.dataframe
447            .join(other.dataframe, &join_type)
448            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
449
450        Ok(())
451    }
452
453    /// Pushes a new row of data into the DataFrame.
454    /// The data should be provided as a dictionary where keys are column names and values are the corresponding data values.
455    /// ```text
456    /// import trs_dataframe as tdf
457    /// df = tdf.DataFrame.init()
458    /// df.push({"key1": 1, "key2": 2})
459    /// df.push({"key1": 11, "key2": 21})
460    /// ```
461    #[pyo3(name = "push")]
462    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
463        self.dataframe
464            .push(data)
465            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
466        Ok(())
467    }
468
469    /// Adds a new column to the DataFrame.
470    /// The column is specified by a key and a vector of data values.
471    /// If the length of the data vector does not match the number of rows in the DataFrame, it will raise a TypeError.
472    /// ```text
473    /// import trs_dataframe as tdf
474    /// df = tdf.DataFrame.init()
475    /// df.push({"key1": 1, "key2": 2})
476    /// df.push({"key1": 11, "key2": 21})
477    /// df.add_column("key3", [3, 4])
478    /// assert df.select(["key1", "key2", "key3"]) == [[1, 2, 3], [11, 21, 4]]
479    /// ```
480    #[pyo3(name = "add_column")]
481    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
482        self.dataframe
483            .add_single_column(key, Array1::from_vec(data))
484            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
485        Ok(())
486    }
487
488    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
489        self.constants.insert(key, feature);
490        Ok(())
491    }
492
493    /// Filters the DataFrame by a given expression.
494    /// The expression should be a string that can be parsed by the DataFrame's filter method
495    ///
496    /// ```text
497    /// import trs_dataframe as tdf
498    /// df = tdf.DataFrame.init()
499    /// df.push({"key1": 1, "key2": 2})
500    /// df.push({"key1": 11, "key2": 21})
501    /// df.filter_by_expression("key1 > 5")
502    /// assert df.select(["key1", "key2"]) == [[11, 21 ]]
503    /// ```
504    pub fn filter_by_expression(&mut self, expression: String) -> Result<Self, PyErr> {
505        let filter = crate::filter::FilterRules::try_from(expression.as_str())
506            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot parse expression: {e}")))?;
507        self.filter(&filter)
508            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot filter data: {e}")))
509    }
510
511    fn __repr__(&self) -> String {
512        self.to_string()
513    }
514
515    fn __str__(&self) -> String {
516        self.to_string()
517    }
518
519    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
520        trace!("{object:?}");
521        let df_or_dict = DfOrDict::new(object)?;
522        match df_or_dict {
523            DfOrDict::DataFrame(df) => {
524                self.dataframe += df.dataframe;
525            }
526            DfOrDict::Dict(dict) => {
527                self.dataframe += dict;
528            }
529        }
530        Ok(())
531    }
532
533    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
534        trace!("{object:?}");
535
536        let df_or_dict = DfOrDict::new(object)?;
537        match df_or_dict {
538            DfOrDict::DataFrame(df) => {
539                self.dataframe -= df.dataframe;
540            }
541            DfOrDict::Dict(dict) => {
542                self.dataframe -= dict;
543            }
544        }
545        Ok(())
546    }
547
548    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
549        trace!("{object:?}");
550        let df_or_dict = DfOrDict::new(object)?;
551        match df_or_dict {
552            DfOrDict::DataFrame(df) => {
553                self.dataframe *= df.dataframe;
554            }
555            DfOrDict::Dict(dict) => {
556                self.dataframe *= dict;
557            }
558        }
559        Ok(())
560    }
561
562    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
563        trace!("{object:?}");
564        let df_or_dict = DfOrDict::new(object)?;
565        match df_or_dict {
566            DfOrDict::DataFrame(df) => {
567                self.dataframe /= df.dataframe;
568            }
569            DfOrDict::Dict(dict) => {
570                self.dataframe /= dict;
571            }
572        }
573        Ok(())
574    }
575
576    pub fn __len__(&mut self) -> Result<usize, PyErr> {
577        Ok(self.dataframe.len())
578    }
579
580    pub fn serialize_to_json_string(&self) -> String {
581        serde_json::to_string(self).expect("Cannot serialize to strinng")
582    }
583
584    #[staticmethod]
585    pub fn deserialize_from_json_string(json_df: String) -> Self {
586        let mut df: DataFrame =
587            serde_json::from_str(json_df.as_str()).expect("Cannot deserialize from str");
588        let _ = df.dataframe.try_fix_dtype();
589
590        df
591    }
592
593    // derive Serialize and Deserialize
594    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
595        let s: DataFrame = rmp_serde::decode::from_slice(state.as_bytes()).map_err(|e| {
596            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
597                "Cannot deserialize object {e}"
598            ))
599        })?;
600        *self = s;
601        self.dataframe.try_fix_dtype().map_err(|e| {
602            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
603                "Cannot deserialize object {e}"
604            ))
605        })?;
606        Ok(())
607    }
608    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
609        let buf = rmp_serde::encode::to_vec(self).map_err(|e| {
610            // let buf = serde_json::to_string(self).map_err(|e| {
611            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
612                "Cannot deserialize object {e}"
613            ))
614        })?;
615        Ok(PyBytes::new(py, &buf))
616    }
617
618    pub fn __del__(&mut self) {
619        self.dataframe = Default::default();
620    }
621
622    // pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(Py<PyAny>, Py<PyAny>)> {
623    //     let cls = py.get_type::<Self>();
624    //     Ok((
625    //         cls.into(),
626    //         pyo3::types::PyTuple::new(py, &[self.__getstate__(py)?])?
627    //             .into_any()
628    //             .unbind(),
629    //     ))
630    // }
631}
632
633#[cfg(test)]
634mod test {
635
636    use super::*;
637    use crate::DataType;
638    use data_value::{stdhashmap, DataValue};
639    use halfbrown::hashmap;
640    use pyo3::ffi::c_str;
641    use rstest::*;
642    use tracing_test::traced_test;
643
644    #[fixture]
645    fn df() -> DataFrame {
646        let mut df = DataFrame::init();
647        assert!(df
648            .push(hashmap! {
649                Key::new("key1", DataType::U32) => DataValue::U32(1),
650                Key::new("key2", DataType::U32) => DataValue::U32(2),
651            })
652            .is_ok());
653        assert!(df
654            .push(hashmap! {
655                Key::from("key1") => DataValue::U32(11),
656                Key::from("key2") => DataValue::U32(21),
657            })
658            .is_ok());
659        df
660    }
661
662    #[fixture]
663    fn hm() -> HashMap<String, DataValue> {
664        stdhashmap!(
665            "key1".to_string() => DataValue::U32(2),
666            "key2".to_string() => DataValue::U32(3),
667        )
668    }
669
670    #[rstest]
671    fn serde_py(df: DataFrame) {
672        let str_df = df.serialize_to_json_string();
673        assert!(!str_df.is_empty());
674
675        let loaded = DataFrame::deserialize_from_json_string(str_df);
676
677        assert_eq!(loaded, df);
678    }
679    #[cfg(feature = "python")]
680    #[rstest]
681    fn pickle_py(df: DataFrame) {
682        pyo3::Python::attach(|py| {
683            let bytes = df.__getstate__(py);
684            assert!(bytes.is_ok());
685
686            let mut deser = DataFrame::default();
687            assert!(deser.__setstate__(bytes.unwrap().into()).is_ok());
688            assert_eq!(deser, df);
689        });
690    }
691    #[rstest]
692    fn test_select_data(df: DataFrame) {
693        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
694        assert!(data.is_ok());
695        assert_eq!(
696            data.unwrap(),
697            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
698        );
699
700        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
701        assert!(data.is_ok());
702        assert_eq!(
703            data.unwrap(),
704            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
705        );
706    }
707
708    #[cfg(feature = "python")]
709    #[rstest]
710    fn test_from_create() {
711        pyo3::Python::attach(|_py| {
712            let mut hm: HashMap<String, Vec<DataValue>> = Default::default();
713            let value: Vec<DataValue> = vec![1i32.into(), 22i32.into()];
714            hm.insert("a".into(), value);
715
716            let df = DataFrame::from_dict(hm);
717            assert_eq!(
718                df.select(Some(&["a".into()])),
719                Ok(ndarray::array![
720                    [DataValue::from(1i32)],
721                    [DataValue::from(22i32)]
722                ]),
723            );
724        });
725        #[cfg(feature = "polars-df")]
726        {
727            let pdf = polars::df!(
728                "a" => [1u64, 2u64, 3u64],
729                "b" => [4f64, 5f64, 6f64],
730                "c" => [7i64, 8i64, 9i64]
731            )
732            .expect("BUG: should be ok");
733            let df = DataFrame::from_polars(pyo3_polars::PyDataFrame(pdf));
734            assert_eq!(
735                df.select(Some(&["a".into(), "b".into(), "c".into()])),
736                crate::df! {
737                    "a" => [1u64, 2u64, 3u64],
738                    "b" => [4f64, 5f64, 6f64],
739                    "c" => [7i64, 8i64, 9i64]
740                }
741                .select(Some(&["a".into(), "b".into(), "c".into()])),
742            );
743        }
744    }
745
746    #[rstest]
747    #[traced_test]
748    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
749        let mut df_expect = df.clone();
750        let df2 = df.clone();
751        let exec = Python::attach(|py| -> PyResult<()> {
752            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
753            df_expect.dataframe += df2.dataframe;
754            tracing::trace!("{} vs {}", df, df_expect);
755            assert_eq!(df.dataframe, df_expect.dataframe);
756
757            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
758            df_expect.dataframe += hm;
759            tracing::trace!("{} vs {}", df, df_expect);
760            assert_eq!(df.dataframe, df_expect.dataframe);
761
762            Ok(())
763        });
764
765        assert!(exec.is_ok(), "{:?}", exec);
766    }
767
768    #[rstest]
769    #[traced_test]
770    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
771        let mut df_expect = df.clone();
772        let df2 = df.clone();
773        let exec = Python::attach(|py| -> PyResult<()> {
774            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
775            df_expect.dataframe -= df2.dataframe;
776            tracing::trace!("{} vs {}", df, df_expect);
777            assert_eq!(df.dataframe, df_expect.dataframe);
778
779            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
780            df_expect.dataframe -= hm;
781            tracing::trace!("{} vs {}", df, df_expect);
782            assert_eq!(df.dataframe, df_expect.dataframe);
783
784            Ok(())
785        });
786
787        assert!(exec.is_ok(), "{:?}", exec);
788    }
789
790    #[rstest]
791    #[traced_test]
792    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
793        let mut df_expect = df.clone();
794        let df2 = df.clone();
795        let exec = Python::attach(|py| -> PyResult<()> {
796            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
797            df_expect.dataframe *= df2.dataframe;
798            tracing::trace!("{} vs {}", df, df_expect);
799            assert_eq!(df.dataframe, df_expect.dataframe);
800
801            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
802            df_expect.dataframe *= hm;
803            tracing::trace!("{} vs {}", df, df_expect);
804            assert_eq!(df.dataframe, df_expect.dataframe);
805            Ok(())
806        });
807
808        assert!(exec.is_ok(), "{:?}", exec);
809    }
810
811    #[rstest]
812    #[traced_test]
813    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
814        let mut df_expect = df.clone();
815        let df2 = df.clone();
816        let exec = Python::attach(|py| -> PyResult<()> {
817            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
818            df_expect.dataframe /= df2.dataframe;
819            tracing::trace!("{} vs {}", df, df_expect);
820            assert_eq!(df.dataframe, df_expect.dataframe);
821
822            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
823            df_expect.dataframe /= hm;
824            tracing::trace!("{} vs {}", df, df_expect);
825            assert_eq!(df.dataframe, df_expect.dataframe);
826            Ok(())
827        });
828
829        assert!(exec.is_ok(), "{:?}", exec);
830    }
831
832    #[rstest]
833    #[traced_test]
834    #[rstest]
835    fn test_numpy(mut df: DataFrame) {
836        let exec = Python::attach(|py| -> PyResult<()> {
837            let code = c_str!(
838                r#"
839def example(df):
840    import numpy as np
841    a_np = df.as_numpy_f32(['key1', 'key2'])
842    print(a_np)
843    b_np = df.as_numpy_u32(['key1', 'key'])
844    print(b_np)
845    b_np = df.as_numpy_i32(['key1', 'key'])
846    print(b_np)
847    b_np = df.as_numpy_i64(['key1', 'key'])
848    print(b_np)
849    b_np = df.as_numpy_u64(['key1', 'key'])
850    print(b_np)
851    b_np = df.as_numpy_f64(['key1', 'key'])
852    print(b_np)
853    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
854    print(b_np)
855    b_np = df.as_numpy(['key1', 'key'], transposed=True)
856    print(b_np)
857    b_np = df.as_numpy_str(['key1', 'key'], transposed=True)
858    print(b_np)
859    return df
860            "#
861            );
862            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
863                .getattr("example")?
864                .into();
865            let result = fun.call1(py, (df.clone(),));
866            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
867            // user may not have installed polars, we need to get an error in that
868            // case
869            if py.import("numpy").is_ok() {
870                assert!(result.is_ok(), "{:?}", result);
871            } else {
872                assert!(result.is_err(), "{:?}", result);
873            }
874            Ok(())
875        });
876        assert!(exec.is_ok(), "{:?}", exec);
877    }
878
879    #[rstest]
880    #[traced_test]
881    #[rstest]
882    fn test_fill_from_python(df: DataFrame) {
883        let exec = Python::attach(|_py| -> PyResult<()> {
884            let hm = stdhashmap!(
885                Key::from("key1") => DataValue::U32(1),
886                Key::from("key2") => DataValue::U32(2),
887            );
888            let mut df2 = DataFrame::init();
889            assert!(df2.py_push(hm).is_ok());
890            assert!(df2
891                .py_push(stdhashmap!(
892                    Key::from("key1") => DataValue::U32(11),
893                    Key::from("key2") => DataValue::U32(21),
894                ))
895                .is_ok());
896
897            assert_eq!(df, df2);
898
899            let mut df2 = DataFrame::init();
900            assert!(df2
901                .py_add_column(
902                    Key::from("key1"),
903                    vec![DataValue::U32(1), DataValue::U32(11)]
904                )
905                .is_ok());
906            assert!(df2
907                .py_add_column(
908                    Key::from("key2"),
909                    vec![DataValue::U32(2), DataValue::U32(21)]
910                )
911                .is_ok());
912
913            assert_eq!(df, df2);
914            Ok(())
915        });
916        assert!(exec.is_ok(), "{:?}", exec);
917    }
918
919    #[rstest]
920    fn basic_python_dataframe(mut df: DataFrame) {
921        let exec = Python::attach(|py| -> PyResult<()> {
922            let fun: Py<PyAny> = PyModule::from_code(
923                py,
924                c_str!(
925                    "
926def example(df):
927    print(df)
928    df.shrink()
929    assert len(df) == 2
930    df.add_alias('key1', 'key1-alias')
931    a = df.select(['key1', 'key2'])
932    print(a)
933    b = df.select(['key1-alias', 'key2'])
934    print(b)
935    df.rename_key('key1', 'key1new')
936    df.rename_key('key1new', 'key1')
937    assert a == [[1, 2], [11, 21]]
938    assert a == b
939    df.add_metadata('test', 1)
940    m = df.get_metadata('test')
941    assert m == 1
942    b = df.select_transposed(['key1', 'key2'])
943    print(b)
944    assert b == [[1, 11], [2, 21]]
945    c = df.select_column('key1')
946    print(c)
947    assert c == [1, 11]
948
949    a += b
950    print(a)
951    assert a == [[2, 13], [4, 23]]
952    a -= b
953    print(a)
954    assert e == a
955    f = e * b
956    print(f)
957    assert f == [[1, 22], [44, 441]]
958    g = f / b
959    print(g)
960    assert g == e
961
962                "
963                ),
964                c_str!(""),
965                c_str!(""),
966            )?
967            .getattr("example")?
968            .into();
969            let _ = fun.call1(py, (df.clone(),));
970            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
971            Ok(())
972        });
973        assert!(exec.is_ok(), "{:?}", exec);
974    }
975
976    #[rstest]
977    fn dummy_test_apply(mut df: DataFrame) {
978        let exec = Python::attach(|py| -> PyResult<()> {
979            let fun: Py<PyAny> = PyModule::from_code(
980                py,
981                c_str!(
982                    r#"
983def multiply_by_ten(x):
984    print(x)
985    x *= {"key1": 10}
986    print(x)
987    return x
988
989def example(df):
990    print(df)
991    df.apply(multiply_by_ten)
992                "#
993                ),
994                c_str!(""),
995                c_str!(""),
996            )?
997            .getattr("example")?
998            .into();
999            let _ = fun.call1(py, (df.clone(),));
1000            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
1001            Ok(())
1002        });
1003        assert!(exec.is_ok(), "{:?}", exec);
1004    }
1005}