trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::{IntoPyArray, PyArray2};
7use pyo3::{
8    exceptions::PyTypeError,
9    prelude::*,
10    types::{PyBytes, PyList},
11    IntoPyObjectExt,
12};
13use tracing::trace;
14
15impl DataFrame {
16    fn select_data(
17        &self,
18        keys: Option<Vec<String>>,
19        transposed: Option<bool>,
20    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
21        let keys = keys
22            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
23            .unwrap_or(self.keys());
24        if transposed.unwrap_or(false) {
25            self.select(Some(keys.as_slice()))
26        } else {
27            self.select_transposed(Some(keys.as_slice()))
28        }
29    }
30}
31
32enum DfOrDict {
33    DataFrame(DataFrame),
34    Dict(HashMap<String, DataValue>),
35}
36
37impl DfOrDict {
38    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
39        if let Ok(df) = object.extract::<DataFrame>() {
40            Ok(DfOrDict::DataFrame(df))
41        } else {
42            let dict: HashMap<String, DataValue> = object.extract()?;
43            Ok(DfOrDict::Dict(dict))
44        }
45    }
46}
47
48#[pymethods]
49impl DataFrame {
50    /// Create a new empty DataFrame.
51    #[new]
52    pub fn init() -> Self {
53        Self::default()
54    }
55
56    /// Create a DataFrame from a polars dataframe in python.
57    /// ```text
58    /// import polars as pl
59    /// df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
60    /// tei_df = tdf.DataFrame.from_polars(df)
61    /// ```
62    #[cfg(feature = "polars-df")]
63    #[staticmethod]
64    pub fn from_polars(df: pyo3_polars::PyDataFrame) -> Self {
65        df.0.into()
66    }
67
68    /// Create a DataFrame from a dictionary.
69    /// ```text
70    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
71    /// ```
72    #[staticmethod]
73    pub fn from_dict(df: HashMap<String, Vec<DataValue>>) -> Self {
74        let mut result_df: Vec<(Key, Vec<DataValue>)> = Vec::new();
75        for (key, value) in df.into_iter() {
76            let dtype = crate::detect_dtype_arr(&value);
77            let key = Key::new(key.as_str(), dtype);
78            result_df.push((key, value));
79        }
80
81        result_df.into()
82    }
83
84    /// Returns the keys of the DataFrame.
85    pub fn keys(&self) -> Vec<Key> {
86        self.dataframe.keys().to_vec()
87    }
88
89    /// Convert the DataFrame to polars DataFrame.
90    /// ```text
91    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]});
92    /// df.set_dtype_for_column("a", trs.DataType.I32)
93    /// ```  
94    pub fn set_dtype_for_column(&mut self, key: String, dtype: crate::DataType) -> PyResult<()> {
95        self.dataframe
96            .enforce_dtype_for_column(key.as_str(), dtype)
97            .map_err(|e| {
98                PyErr::new::<PyTypeError, _>(format!("Cannot set dtype for columnĀ {key}: {e}"))
99            })
100    }
101
102    /// Convert the DataFrame to polars DataFrame.
103    /// This requires the `polars-df` feature to be enabled.
104    /// ```text
105    /// import polars as pl
106    /// original_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
107    /// df = tdf.DataFrame.from_polars(original_df)
108    /// polars_df = df.as_polars()
109    /// assert polars_df.frame_equal(original_df)
110    /// ```  
111    #[cfg(feature = "polars-df")]
112    #[pyo3(name = "as_polars")]
113    pub fn py_as_polars(&self) -> PyResult<pyo3_polars::PyDataFrame> {
114        let df = self
115            .as_polars()
116            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
117        Ok(pyo3_polars::PyDataFrame(df))
118    }
119
120    /// Apply a function to the DataFrame.
121    /// The function should accept a DataFrame and return a DataFrame.
122    /// ```text
123    /// def my_function(df):
124    ///    # Perform some operations on the DataFrame
125    ///   return df
126    /// /// df = tdf.DataFrame.init()
127    /// df.apply(my_function)
128    /// ```
129    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
130        let df: DataFrame = pyo3::Python::attach(|py| {
131            let self_ = self
132                .clone()
133                .into_pyobject(py)
134                .expect("BUG: cannot convert to PyObject");
135            let result = function.call1((self_,)).expect("BUG: cannot call function");
136            result
137                .extract::<Bound<DataFrame>>()
138                .expect("BUG: cannot extract data frame")
139                .unbind()
140                .extract(py)
141                .expect("BUG: cannot extract data frame")
142        });
143        self.dataframe = df.dataframe;
144        Ok(())
145    }
146
147    /// Returns slice from dataframe as numpy.array of uint32 of the given keys.
148    /// If `transposed` is true, the keys will be transposed.
149    /// If `keys` is None, all keys will be used.
150    /// ```text
151    /// import numpy as np
152    /// df = tdf.DataFrame.init()
153    /// df.push({"key1": 1, "key2": 2})
154    /// df.push({"key1": 11, "key2": 21})
155    /// a_np = df.as_numpy_u32(['key1', 'key2'])
156    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint32))
157    /// ```
158    #[pyo3(signature = (keys=None, transposed=None))]
159    pub fn as_numpy_u32<'py>(
160        &self,
161        keys: Option<Vec<String>>,
162        transposed: Option<bool>,
163        py: Python<'py>,
164    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
165        let data = self
166            .select_data(keys, transposed)
167            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
168        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
169    }
170
171    /// Returns slice from dataframe as numpy.array of uint64 of the given keys.
172    /// If `transposed` is true, the keys will be transposed.
173    /// If `keys` is None, all keys will be used.
174    /// ```text
175    /// import numpy as np
176    /// df = tdf.DataFrame.init()
177    /// df.push({"key1": 1, "key2": 2})
178    /// df.push({"key1": 11, "key2": 21})
179    /// a_np = df.as_numpy_u64(['key1', 'key2'])
180    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint64))
181    /// ```
182    #[pyo3(signature = (keys=None, transposed=None))]
183    pub fn as_numpy_u64<'py>(
184        &self,
185        keys: Option<Vec<String>>,
186        transposed: Option<bool>,
187        py: Python<'py>,
188    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
189        let data = self
190            .select_data(keys, transposed)
191            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
192        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
193    }
194
195    /// Returns slice from dataframe as numpy.array of int32 of the given keys.
196    /// If `transposed` is true, the keys will be transposed.
197    /// If `keys` is None, all keys will be used.
198    /// ```text
199    /// import numpy as np
200    /// df = tdf.DataFrame.init()
201    /// df.push({"key1": 1, "key2": 2})
202    /// df.push({"key1": 11, "key2": 21})
203    /// a_np = df.as_numpy_i32(['key1', 'key2'])
204    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int32))
205    /// ```
206    #[pyo3(signature = (keys=None, transposed=None))]
207    pub fn as_numpy_i32<'py>(
208        &self,
209        keys: Option<Vec<String>>,
210        transposed: Option<bool>,
211        py: Python<'py>,
212    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
213        let data = self
214            .select_data(keys, transposed)
215            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
216        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
217    }
218
219    /// Returns slice from dataframe as numpy.array of int64 of the given keys.
220    /// If `transposed` is true, the keys will be transposed.
221    /// If `keys` is None, all keys will be used.
222    /// ```text
223    /// import numpy as np
224    /// df = tdf.DataFrame.init()
225    /// df.push({"key1": 1, "key2": 2})
226    /// df.push({"key1": 11, "key2": 21})
227    /// a_np = df.as_numpy_i64(['key1', 'key2'])
228    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int64))
229    /// ```
230    #[pyo3(signature = (keys=None, transposed=None))]
231    pub fn as_numpy_i64<'py>(
232        &self,
233        keys: Option<Vec<String>>,
234        transposed: Option<bool>,
235        py: Python<'py>,
236    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
237        let data = self
238            .select_data(keys, transposed)
239            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
240        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
241    }
242
243    /// Returns slice from dataframe as numpy.array of float32 of the given keys.
244    /// If `transposed` is true, the keys will be transposed.
245    /// If `keys` is None, all keys will be used.
246    /// ```text
247    /// import numpy as np
248    /// df = tdf.DataFrame.init()
249    /// df.push({"key1": 1, "key2": 2})
250    /// df.push({"key1": 11, "key2": 21})
251    /// a_np = df.as_numpy_f32(['key1', 'key2'])
252    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float32))
253    /// ```
254    #[pyo3(signature = (keys=None, transposed=None))]
255    pub fn as_numpy_f32<'py>(
256        &self,
257        keys: Option<Vec<String>>,
258        transposed: Option<bool>,
259        py: Python<'py>,
260    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
261        let data = self
262            .select_data(keys, transposed)
263            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
264        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
265    }
266
267    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
268    /// If `transposed` is true, the keys will be transposed.
269    /// If `keys` is None, all keys will be used.
270    /// ```text
271    /// import numpy as np
272    /// df = tdf.DataFrame.init()
273    /// df.push({"key1": 1, "key2": 2})
274    /// df.push({"key1": 11, "key2": 21})
275    /// a_np = df.as_numpy_f64(['key1', 'key2'])
276    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float64))
277    /// ```
278    #[pyo3(signature = (keys=None, transposed=None))]
279    pub fn as_numpy_f64<'py>(
280        &self,
281        keys: Option<Vec<String>>,
282        transposed: Option<bool>,
283        py: Python<'py>,
284    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
285        let data = self
286            .select_data(keys, transposed)
287            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
288        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
289    }
290
291    /// Returns slice from dataframe as numpy.array of objects of the given keys.
292    /// If `transposed` is true, the keys will be transposed.
293    /// If `keys` is None, all keys will be used.
294    /// ```text
295    /// import numpy as np
296    /// df = tdf.DataFrame.init()
297    /// df.push({"key1": "a", "key2": "b"})
298    /// df.push({"key1": "d", "key2": "c"})
299    /// a_np = df.as_numpy_f64(['key1', 'key2'])
300    /// assert np.array_equal(a_np, np.array([["a", d""], ["b", "c"]], dtype=np.object))
301    /// ```
302    #[pyo3(signature = (keys=None, transposed=None))]
303    pub fn as_numpy_str<'py>(
304        &self,
305        keys: Option<Vec<String>>,
306        transposed: Option<bool>,
307        py: Python<'py>,
308    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
309        self.as_numpy(keys, transposed, py)
310    }
311
312    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
313    /// If `transposed` is true, the keys will be transposed.
314    /// If `keys` is None, all keys will be used.
315    /// ```text
316    /// import numpy as np
317    /// df = tdf.DataFrame.init()
318    /// df.push({"key1": "a", "key2": 1})
319    /// df.push({"key1": "d", "key2": 2})
320    /// a_np = df.as_numpy_f64(['key1', 'key2'])
321    /// assert np.array_equal(a_np, np.array([["a", d""], [1, 2]], dtype=np.object))
322    /// ```
323    #[pyo3(signature = (keys=None, transposed=None))]
324    pub fn as_numpy<'py>(
325        &self,
326        keys: Option<Vec<String>>,
327        transposed: Option<bool>,
328        py: Python<'py>,
329    ) -> PyResult<Bound<'py, numpy::PyArray2<Py<PyAny>>>> {
330        let data = self
331            .select_data(keys, transposed)
332            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
333        let data = data.mapv(|x| {
334            String::extract(&x)
335                .into_py_any(py)
336                .expect("cannot convert string to py object")
337        });
338        Ok(data.into_pyarray(py))
339    }
340
341    #[pyo3(name = "shrink")]
342    pub fn py_shrink(&mut self) {
343        self.dataframe.shrink();
344    }
345
346    #[pyo3(name = "add_metadata")]
347    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
348        self.metadata.insert(key, value);
349    }
350
351    #[pyo3(name = "get_metadata")]
352    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
353        self.metadata.get(key).cloned()
354    }
355
356    #[pyo3(name = "rename_key")]
357    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
358        // fixme this may have a problem when the type is different and checked
359        self.dataframe
360            .rename_key(key, new_name.into())
361            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
362    }
363
364    #[pyo3(name = "add_alias")]
365    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
366        self.dataframe
367            .add_alias(key, new_name)
368            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
369    }
370
371    /// Selects data from the DataFrame.
372    /// If `keys` is None, all keys will be used.
373    /// If `keys` is provided, only the specified keys will be selected.
374    /// Returns a list of lists, where each inner list represents a row of data.
375    /// ```text
376    /// import trs_dataframe as tdf
377    /// df = tdf.DataFrame.init()
378    /// df.push({"key1": 1, "key2": 2})
379    /// df.push({"key1": 11, "key2": 21})
380    /// # selected = df.select(["key1", "key2"])
381    /// # assert selected == [[1, 2], [11, 21]]
382    /// # selected = df.select()
383    #[pyo3(name = "select", signature = (keys=None, transposed=None))]
384    pub fn py_select<'py>(
385        &self,
386        py: Python<'py>,
387        keys: Option<Vec<String>>,
388        transposed: Option<bool>,
389    ) -> Result<Bound<'py, PyList>, PyErr> {
390        let keys = keys
391            .map(|x| x.into_iter().map(Key::from).collect::<Vec<Key>>())
392            .unwrap_or(self.keys());
393
394        let selected = if transposed.unwrap_or_default() {
395            self.select_transposed(Some(keys.as_slice()))
396                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
397        } else {
398            self.select(Some(keys.as_slice()))
399                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
400        };
401
402        let list = PyList::empty(py);
403        for rows in selected.rows() {
404            let row = PyList::empty(py);
405            for value in rows.iter() {
406                row.append(value.clone())
407                    .expect("BUG: cannot append to list");
408            }
409            list.append(row).expect("BUG: cannot append to list");
410        }
411        Ok(list)
412    }
413
414    /// Selects a column from the DataFrame.
415    /// If the column does not exist, it will raise a TypeError.
416    /// Returns a list of values in the selected column.
417    /// ```text
418    /// import trs_dataframe as tdf
419    /// df = tdf.DataFrame.init()
420    /// df.push({"key1": 1, "key2": 2})
421    /// df.push({"key1": 11, "key2": 21})
422    /// # selected = df.select_column("key1")
423    /// # assert selected == [1, 11]
424    /// # selected = df.select_column("key2")
425    /// # assert selected == [2, 21]
426    /// # selected = df.select_column("non_existing_key")  # Raises TypeError
427    /// ```
428    #[pyo3(name = "select_column")]
429    pub fn py_select_column<'py>(
430        &self,
431        py: Python<'py>,
432        key: String,
433    ) -> Result<Bound<'py, PyList>, PyErr> {
434        let selected = self
435            .select_column(Key::from(key))
436            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
437
438        let list = PyList::empty(py);
439        for x in selected.to_vec().into_iter() {
440            list.append(x)?;
441        }
442
443        Ok(list)
444    }
445
446    /// Joins the current DataFrame with another DataFrame.
447    /// The join type is specified by the `join_type` parameter.
448    /// see [`JoinRelation`] for available join types.
449    /// ```text
450    /// import trs_dataframe as tdf
451    /// df1 = tdf.DataFrame.init()
452    /// df1.push({"key1": 1, "key2": 2})
453    /// df1.push({"key1": 11, "key2": 21})
454    /// df2 = tdf.DataFrame.init()
455    /// df2.push({"key1": 1, "key2": 3})
456    /// df2.push({"key1": 11, "key2": 23})
457    /// df1.join(df2, tei.JoinRelation.extend())
458    /// assert df1.select(["key1", "key2"]) == [[1, 2], [11, 21], [1, 3], [11, 23]]
459    /// ```
460    #[pyo3(name = "join")]
461    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
462        self.dataframe
463            .join(other.dataframe, &join_type)
464            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
465
466        Ok(())
467    }
468
469    /// Pushes a new row of data into the DataFrame.
470    /// The data should be provided as a dictionary where keys are column names and values are the corresponding data values.
471    /// ```text
472    /// import trs_dataframe as tdf
473    /// df = tdf.DataFrame.init()
474    /// df.push({"key1": 1, "key2": 2})
475    /// df.push({"key1": 11, "key2": 21})
476    /// ```
477    #[pyo3(name = "push")]
478    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
479        self.dataframe
480            .push(data)
481            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
482        Ok(())
483    }
484
485    /// Adds a new column to the DataFrame.
486    /// The column is specified by a key and a vector of data values.
487    /// If the length of the data vector does not match the number of rows in the DataFrame, it will raise a TypeError.
488    /// ```text
489    /// import trs_dataframe as tdf
490    /// df = tdf.DataFrame.init()
491    /// df.push({"key1": 1, "key2": 2})
492    /// df.push({"key1": 11, "key2": 21})
493    /// df.add_column("key3", [3, 4])
494    /// assert df.select(["key1", "key2", "key3"]) == [[1, 2, 3], [11, 21, 4]]
495    /// ```
496    #[pyo3(name = "add_column")]
497    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
498        self.dataframe
499            .add_single_column(key, Array1::from_vec(data))
500            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
501        Ok(())
502    }
503
504    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
505        self.constants.insert(key, feature);
506        Ok(())
507    }
508
509    /// Filters the DataFrame by a given expression.
510    /// The expression should be a string that can be parsed by the DataFrame's filter method
511    ///
512    /// ```text
513    /// import trs_dataframe as tdf
514    /// df = tdf.DataFrame.init()
515    /// df.push({"key1": 1, "key2": 2})
516    /// df.push({"key1": 11, "key2": 21})
517    /// df.filter_by_expression("key1 > 5")
518    /// assert df.select(["key1", "key2"]) == [[11, 21 ]]
519    /// ```
520    pub fn filter_by_expression(&mut self, expression: String) -> Result<Self, PyErr> {
521        let filter = crate::filter::FilterRules::try_from(expression.as_str())
522            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot parse expression: {e}")))?;
523        self.filter(&filter)
524            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot filter data: {e}")))
525    }
526
527    fn __repr__(&self) -> String {
528        self.to_string()
529    }
530
531    fn __str__(&self) -> String {
532        self.to_string()
533    }
534
535    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
536        trace!("{object:?}");
537        let df_or_dict = DfOrDict::new(object)?;
538        match df_or_dict {
539            DfOrDict::DataFrame(df) => {
540                self.dataframe += df.dataframe;
541            }
542            DfOrDict::Dict(dict) => {
543                self.dataframe += dict;
544            }
545        }
546        Ok(())
547    }
548
549    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
550        trace!("{object:?}");
551
552        let df_or_dict = DfOrDict::new(object)?;
553        match df_or_dict {
554            DfOrDict::DataFrame(df) => {
555                self.dataframe -= df.dataframe;
556            }
557            DfOrDict::Dict(dict) => {
558                self.dataframe -= dict;
559            }
560        }
561        Ok(())
562    }
563
564    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
565        trace!("{object:?}");
566        let df_or_dict = DfOrDict::new(object)?;
567        match df_or_dict {
568            DfOrDict::DataFrame(df) => {
569                self.dataframe *= df.dataframe;
570            }
571            DfOrDict::Dict(dict) => {
572                self.dataframe *= dict;
573            }
574        }
575        Ok(())
576    }
577
578    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
579        trace!("{object:?}");
580        let df_or_dict = DfOrDict::new(object)?;
581        match df_or_dict {
582            DfOrDict::DataFrame(df) => {
583                self.dataframe /= df.dataframe;
584            }
585            DfOrDict::Dict(dict) => {
586                self.dataframe /= dict;
587            }
588        }
589        Ok(())
590    }
591
592    pub fn __len__(&mut self) -> Result<usize, PyErr> {
593        Ok(self.dataframe.len())
594    }
595
596    pub fn serialize_to_json_string(&self) -> String {
597        serde_json::to_string(self).expect("Cannot serialize to strinng")
598    }
599
600    #[staticmethod]
601    pub fn deserialize_from_json_string(json_df: String) -> Self {
602        let mut df: DataFrame =
603            serde_json::from_str(json_df.as_str()).expect("Cannot deserialize from str");
604        let _ = df.dataframe.try_fix_dtype();
605
606        df
607    }
608
609    // derive Serialize and Deserialize
610    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
611        let s: DataFrame = rmp_serde::decode::from_slice(state.as_bytes()).map_err(|e| {
612            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
613                "Cannot deserialize object {e}"
614            ))
615        })?;
616        *self = s;
617        self.dataframe.try_fix_dtype().map_err(|e| {
618            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
619                "Cannot deserialize object {e}"
620            ))
621        })?;
622        Ok(())
623    }
624    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
625        let buf = rmp_serde::encode::to_vec(self).map_err(|e| {
626            // let buf = serde_json::to_string(self).map_err(|e| {
627            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
628                "Cannot deserialize object {e}"
629            ))
630        })?;
631        Ok(PyBytes::new(py, &buf))
632    }
633
634    pub fn __del__(&mut self) {
635        self.dataframe = Default::default();
636    }
637
638    // pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(Py<PyAny>, Py<PyAny>)> {
639    //     let cls = py.get_type::<Self>();
640    //     Ok((
641    //         cls.into(),
642    //         pyo3::types::PyTuple::new(py, &[self.__getstate__(py)?])?
643    //             .into_any()
644    //             .unbind(),
645    //     ))
646    // }
647}
648
649#[cfg(test)]
650mod test {
651
652    use super::*;
653    use crate::DataType;
654    use data_value::{stdhashmap, DataValue};
655    use halfbrown::hashmap;
656    use pyo3::ffi::c_str;
657    use rstest::*;
658    use tracing_test::traced_test;
659
660    #[fixture]
661    fn df() -> DataFrame {
662        let mut df = DataFrame::init();
663        assert!(df
664            .push(hashmap! {
665                Key::new("key1", DataType::U32) => DataValue::U32(1),
666                Key::new("key2", DataType::U32) => DataValue::U32(2),
667            })
668            .is_ok());
669        assert!(df
670            .push(hashmap! {
671                Key::from("key1") => DataValue::U32(11),
672                Key::from("key2") => DataValue::U32(21),
673            })
674            .is_ok());
675        df
676    }
677
678    #[fixture]
679    fn hm() -> HashMap<String, DataValue> {
680        stdhashmap!(
681            "key1".to_string() => DataValue::U32(2),
682            "key2".to_string() => DataValue::U32(3),
683        )
684    }
685
686    #[rstest]
687    fn serde_py(df: DataFrame) {
688        let str_df = df.serialize_to_json_string();
689        assert!(!str_df.is_empty());
690
691        let loaded = DataFrame::deserialize_from_json_string(str_df);
692
693        assert_eq!(loaded, df);
694    }
695    #[cfg(feature = "python")]
696    #[rstest]
697    fn pickle_py(df: DataFrame) {
698        pyo3::Python::attach(|py| {
699            let bytes = df.__getstate__(py);
700            assert!(bytes.is_ok());
701
702            let mut deser = DataFrame::default();
703            assert!(deser.__setstate__(bytes.unwrap().into()).is_ok());
704            assert_eq!(deser, df);
705        });
706    }
707    #[rstest]
708    fn test_select_data(df: DataFrame) {
709        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
710        assert!(data.is_ok());
711        assert_eq!(
712            data.unwrap(),
713            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
714        );
715
716        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
717        assert!(data.is_ok());
718        assert_eq!(
719            data.unwrap(),
720            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
721        );
722    }
723
724    #[cfg(feature = "python")]
725    #[rstest]
726    fn test_from_create() {
727        pyo3::Python::attach(|_py| {
728            let mut hm: HashMap<String, Vec<DataValue>> = Default::default();
729            let value: Vec<DataValue> = vec![1i32.into(), 22i32.into()];
730            hm.insert("a".into(), value);
731
732            let mut df = DataFrame::from_dict(hm);
733            assert_eq!(
734                df.select(Some(&["a".into()])),
735                Ok(ndarray::array![
736                    [DataValue::from(1i32)],
737                    [DataValue::from(22i32)]
738                ]),
739            );
740            assert!(df.set_dtype_for_column("a".into(), DataType::U32).is_ok());
741            assert_eq!(
742                df.select(Some(&["a".into()])),
743                Ok(ndarray::array![
744                    [DataValue::from(1u32)],
745                    [DataValue::from(22u32)]
746                ]),
747            );
748        });
749        #[cfg(feature = "polars-df")]
750        {
751            let pdf = polars::df!(
752                "a" => [1u64, 2u64, 3u64],
753                "b" => [4f64, 5f64, 6f64],
754                "c" => [7i64, 8i64, 9i64]
755            )
756            .expect("BUG: should be ok");
757            let df = DataFrame::from_polars(pyo3_polars::PyDataFrame(pdf));
758            assert_eq!(
759                df.select(Some(&["a".into(), "b".into(), "c".into()])),
760                crate::df! {
761                    "a" => [1u64, 2u64, 3u64],
762                    "b" => [4f64, 5f64, 6f64],
763                    "c" => [7i64, 8i64, 9i64]
764                }
765                .select(Some(&["a".into(), "b".into(), "c".into()])),
766            );
767            let keys = df.keys();
768            assert_eq!(
769                keys,
770                vec![
771                    Key::new("a", DataType::U64),
772                    Key::new("b", DataType::F64),
773                    Key::new("c", DataType::I64),
774                ]
775            )
776        }
777    }
778
779    #[rstest]
780    #[traced_test]
781    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
782        let mut df_expect = df.clone();
783        let df2 = df.clone();
784        let exec = Python::attach(|py| -> PyResult<()> {
785            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
786            df_expect.dataframe += df2.dataframe;
787            tracing::trace!("{} vs {}", df, df_expect);
788            assert_eq!(df.dataframe, df_expect.dataframe);
789
790            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
791            df_expect.dataframe += hm;
792            tracing::trace!("{} vs {}", df, df_expect);
793            assert_eq!(df.dataframe, df_expect.dataframe);
794
795            Ok(())
796        });
797
798        assert!(exec.is_ok(), "{:?}", exec);
799    }
800
801    #[rstest]
802    #[traced_test]
803    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
804        let mut df_expect = df.clone();
805        let df2 = df.clone();
806        let exec = Python::attach(|py| -> PyResult<()> {
807            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
808            df_expect.dataframe -= df2.dataframe;
809            tracing::trace!("{} vs {}", df, df_expect);
810            assert_eq!(df.dataframe, df_expect.dataframe);
811
812            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
813            df_expect.dataframe -= hm;
814            tracing::trace!("{} vs {}", df, df_expect);
815            assert_eq!(df.dataframe, df_expect.dataframe);
816
817            Ok(())
818        });
819
820        assert!(exec.is_ok(), "{:?}", exec);
821    }
822
823    #[rstest]
824    #[traced_test]
825    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
826        let mut df_expect = df.clone();
827        let df2 = df.clone();
828        let exec = Python::attach(|py| -> PyResult<()> {
829            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
830            df_expect.dataframe *= df2.dataframe;
831            tracing::trace!("{} vs {}", df, df_expect);
832            assert_eq!(df.dataframe, df_expect.dataframe);
833
834            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
835            df_expect.dataframe *= hm;
836            tracing::trace!("{} vs {}", df, df_expect);
837            assert_eq!(df.dataframe, df_expect.dataframe);
838            Ok(())
839        });
840
841        assert!(exec.is_ok(), "{:?}", exec);
842    }
843
844    #[rstest]
845    #[traced_test]
846    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
847        let mut df_expect = df.clone();
848        let df2 = df.clone();
849        let exec = Python::attach(|py| -> PyResult<()> {
850            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
851            df_expect.dataframe /= df2.dataframe;
852            tracing::trace!("{} vs {}", df, df_expect);
853            assert_eq!(df.dataframe, df_expect.dataframe);
854
855            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
856            df_expect.dataframe /= hm;
857            tracing::trace!("{} vs {}", df, df_expect);
858            assert_eq!(df.dataframe, df_expect.dataframe);
859            Ok(())
860        });
861
862        assert!(exec.is_ok(), "{:?}", exec);
863    }
864
865    #[rstest]
866    #[traced_test]
867    #[rstest]
868    fn test_numpy(mut df: DataFrame) {
869        let exec = Python::attach(|py| -> PyResult<()> {
870            let code = c_str!(
871                r#"
872def example(df):
873    import numpy as np
874    a_np = df.as_numpy_f32(['key1', 'key2'])
875    print(a_np)
876    b_np = df.as_numpy_u32(['key1', 'key'])
877    print(b_np)
878    b_np = df.as_numpy_i32(['key1', 'key'])
879    print(b_np)
880    b_np = df.as_numpy_i64(['key1', 'key'])
881    print(b_np)
882    b_np = df.as_numpy_u64(['key1', 'key'])
883    print(b_np)
884    b_np = df.as_numpy_f64(['key1', 'key'])
885    print(b_np)
886    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
887    print(b_np)
888    b_np = df.as_numpy(['key1', 'key'], transposed=True)
889    print(b_np)
890    b_np = df.as_numpy_str(['key1', 'key'], transposed=True)
891    print(b_np)
892    return df
893            "#
894            );
895            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
896                .getattr("example")?
897                .into();
898            let result = fun.call1(py, (df.clone(),));
899            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
900            // user may not have installed polars, we need to get an error in that
901            // case
902            if py.import("numpy").is_ok() {
903                assert!(result.is_ok(), "{:?}", result);
904            } else {
905                assert!(result.is_err(), "{:?}", result);
906            }
907            Ok(())
908        });
909        assert!(exec.is_ok(), "{:?}", exec);
910    }
911
912    #[rstest]
913    #[traced_test]
914    #[rstest]
915    fn test_fill_from_python(df: DataFrame) {
916        let exec = Python::attach(|_py| -> PyResult<()> {
917            let hm = stdhashmap!(
918                Key::from("key1") => DataValue::U32(1),
919                Key::from("key2") => DataValue::U32(2),
920            );
921            let mut df2 = DataFrame::init();
922            assert!(df2.py_push(hm).is_ok());
923            assert!(df2
924                .py_push(stdhashmap!(
925                    Key::from("key1") => DataValue::U32(11),
926                    Key::from("key2") => DataValue::U32(21),
927                ))
928                .is_ok());
929
930            assert_eq!(df, df2);
931
932            let mut df2 = DataFrame::init();
933            assert!(df2
934                .py_add_column(
935                    Key::from("key1"),
936                    vec![DataValue::U32(1), DataValue::U32(11)]
937                )
938                .is_ok());
939            assert!(df2
940                .py_add_column(
941                    Key::from("key2"),
942                    vec![DataValue::U32(2), DataValue::U32(21)]
943                )
944                .is_ok());
945
946            assert_eq!(df, df2);
947            Ok(())
948        });
949        assert!(exec.is_ok(), "{:?}", exec);
950    }
951
952    #[rstest]
953    fn basic_python_dataframe(mut df: DataFrame) {
954        let exec = Python::attach(|py| -> PyResult<()> {
955            let fun: Py<PyAny> = PyModule::from_code(
956                py,
957                c_str!(
958                    "
959def example(df):
960    print(df)
961    df.shrink()
962    assert len(df) == 2
963    df.add_alias('key1', 'key1-alias')
964    a = df.select(['key1', 'key2'])
965    print(a)
966    b = df.select(['key1-alias', 'key2'])
967    print(b)
968    df.rename_key('key1', 'key1new')
969    df.rename_key('key1new', 'key1')
970    assert a == [[1, 2], [11, 21]]
971    assert a == b
972    df.add_metadata('test', 1)
973    m = df.get_metadata('test')
974    assert m == 1
975    b = df.select_transposed(['key1', 'key2'])
976    print(b)
977    assert b == [[1, 11], [2, 21]]
978    c = df.select_column('key1')
979    print(c)
980    assert c == [1, 11]
981
982    a += b
983    print(a)
984    assert a == [[2, 13], [4, 23]]
985    a -= b
986    print(a)
987    assert e == a
988    f = e * b
989    print(f)
990    assert f == [[1, 22], [44, 441]]
991    g = f / b
992    print(g)
993    assert g == e
994
995                "
996                ),
997                c_str!(""),
998                c_str!(""),
999            )?
1000            .getattr("example")?
1001            .into();
1002            let _ = fun.call1(py, (df.clone(),));
1003            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
1004            Ok(())
1005        });
1006        assert!(exec.is_ok(), "{:?}", exec);
1007    }
1008
1009    #[rstest]
1010    fn dummy_test_apply(mut df: DataFrame) {
1011        let exec = Python::attach(|py| -> PyResult<()> {
1012            let fun: Py<PyAny> = PyModule::from_code(
1013                py,
1014                c_str!(
1015                    r#"
1016def multiply_by_ten(x):
1017    print(x)
1018    x *= {"key1": 10}
1019    print(x)
1020    return x
1021
1022def example(df):
1023    print(df)
1024    df.apply(multiply_by_ten)
1025                "#
1026                ),
1027                c_str!(""),
1028                c_str!(""),
1029            )?
1030            .getattr("example")?
1031            .into();
1032            let _ = fun.call1(py, (df.clone(),));
1033            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
1034            Ok(())
1035        });
1036        assert!(exec.is_ok(), "{:?}", exec);
1037    }
1038}