trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::PyArray2;
7use pyo3::{
8    exceptions::PyTypeError,
9    prelude::*,
10    types::{PyBytes, PyList},
11};
12use tracing::trace;
13
14impl DataFrame {
15    fn select_data(
16        &self,
17        keys: Option<Vec<String>>,
18        transposed: Option<bool>,
19    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
20        let keys = keys
21            .unwrap_or(self.keys())
22            .into_iter()
23            .map(Key::from)
24            .collect::<Vec<Key>>();
25        if transposed.unwrap_or(false) {
26            self.select(Some(keys.as_slice()))
27        } else {
28            self.select_transposed(Some(keys.as_slice()))
29        }
30    }
31}
32
33enum DfOrDict {
34    DataFrame(DataFrame),
35    Dict(HashMap<String, DataValue>),
36}
37
38impl DfOrDict {
39    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
40        if let Ok(df) = object.extract::<DataFrame>() {
41            Ok(DfOrDict::DataFrame(df))
42        } else {
43            let dict: HashMap<String, DataValue> = object.extract()?;
44            Ok(DfOrDict::Dict(dict))
45        }
46    }
47}
48
49#[pymethods]
50impl DataFrame {
51    /// Create a new empty DataFrame.
52    #[new]
53    pub fn init() -> Self {
54        Self::default()
55    }
56
57    /// Create a DataFrame from a polars dataframe in python.
58    /// ```text
59    /// import polars as pl
60    /// df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
61    /// tei_df = tdf.DataFrame.from_polars(df)
62    /// ```
63    #[cfg(feature = "polars-df")]
64    #[staticmethod]
65    pub fn from_polars(df: pyo3_polars::PyDataFrame) -> Self {
66        df.0.into()
67    }
68
69    /// Create a DataFrame from a dictionary.
70    /// ```text
71    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
72    /// ```
73    #[staticmethod]
74    pub fn from_dict(df: HashMap<String, Vec<DataValue>>) -> Self {
75        let mut result_df: Vec<(Key, Vec<DataValue>)> = Vec::new();
76        for (key, value) in df.into_iter() {
77            result_df.push((key.as_str().into(), value));
78        }
79        result_df.into()
80    }
81
82    /// Returns the keys of the DataFrame.
83    pub fn keys(&self) -> Vec<String> {
84        self.dataframe
85            .keys()
86            .iter()
87            .map(|x| x.name().to_string())
88            .collect()
89    }
90
91    /// Convert the DataFrame to polars DataFrame.
92    /// This requires the `polars-df` feature to be enabled.
93    /// ```text
94    /// import polars as pl
95    /// original_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
96    /// df = tdf.DataFrame.from_polars(original_df)
97    /// polars_df = df.as_polars()
98    /// assert polars_df.frame_equal(original_df)
99    /// ```  
100    #[cfg(feature = "polars-df")]
101    #[pyo3(name = "as_polars")]
102    pub fn py_as_polars(&self) -> PyResult<pyo3_polars::PyDataFrame> {
103        let df = self
104            .as_polars()
105            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
106        Ok(pyo3_polars::PyDataFrame(df))
107    }
108
109    /// Apply a function to the DataFrame.
110    /// The function should accept a DataFrame and return a DataFrame.
111    /// ```text
112    /// def my_function(df):
113    ///    # Perform some operations on the DataFrame
114    ///   return df
115    /// /// df = tdf.DataFrame.init()
116    /// df.apply(my_function)
117    /// ```
118    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
119        let df: DataFrame = pyo3::Python::attach(|py| {
120            let self_ = self
121                .clone()
122                .into_pyobject(py)
123                .expect("BUG: cannot convert to PyObject");
124            let result = function.call1((self_,)).expect("BUG: cannot call function");
125            result
126                .extract::<Bound<DataFrame>>()
127                .expect("BUG: cannot extract data frame")
128                .unbind()
129                .extract(py)
130                .expect("BUG: cannot extract data frame")
131        });
132        self.dataframe = df.dataframe;
133        Ok(())
134    }
135
136    /// Returns slice from dataframe as numpy.array of uint32 of the given keys.
137    /// If `transposed` is true, the keys will be transposed.
138    /// If `keys` is None, all keys will be used.
139    /// ```text
140    /// import numpy as np
141    /// df = tdf.DataFrame.init()
142    /// df.push({"key1": 1, "key2": 2})
143    /// df.push({"key1": 11, "key2": 21})
144    /// a_np = df.as_numpy_u32(['key1', 'key2'])
145    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint32))
146    /// ```
147    #[pyo3(signature = (keys=None, transposed=None))]
148    pub fn as_numpy_u32<'py>(
149        &self,
150        keys: Option<Vec<String>>,
151        transposed: Option<bool>,
152        py: Python<'py>,
153    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
154        let data = self
155            .select_data(keys, transposed)
156            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
157        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
158    }
159
160    /// Returns slice from dataframe as numpy.array of uint64 of the given keys.
161    /// If `transposed` is true, the keys will be transposed.
162    /// If `keys` is None, all keys will be used.
163    /// ```text
164    /// import numpy as np
165    /// df = tdf.DataFrame.init()
166    /// df.push({"key1": 1, "key2": 2})
167    /// df.push({"key1": 11, "key2": 21})
168    /// a_np = df.as_numpy_u64(['key1', 'key2'])
169    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint64))
170    /// ```
171    #[pyo3(signature = (keys=None, transposed=None))]
172    pub fn as_numpy_u64<'py>(
173        &self,
174        keys: Option<Vec<String>>,
175        transposed: Option<bool>,
176        py: Python<'py>,
177    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
178        let data = self
179            .select_data(keys, transposed)
180            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
181        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
182    }
183
184    /// Returns slice from dataframe as numpy.array of int32 of the given keys.
185    /// If `transposed` is true, the keys will be transposed.
186    /// If `keys` is None, all keys will be used.
187    /// ```text
188    /// import numpy as np
189    /// df = tdf.DataFrame.init()
190    /// df.push({"key1": 1, "key2": 2})
191    /// df.push({"key1": 11, "key2": 21})
192    /// a_np = df.as_numpy_i32(['key1', 'key2'])
193    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int32))
194    /// ```
195    #[pyo3(signature = (keys=None, transposed=None))]
196    pub fn as_numpy_i32<'py>(
197        &self,
198        keys: Option<Vec<String>>,
199        transposed: Option<bool>,
200        py: Python<'py>,
201    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
202        let data = self
203            .select_data(keys, transposed)
204            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
205        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
206    }
207
208    /// Returns slice from dataframe as numpy.array of int64 of the given keys.
209    /// If `transposed` is true, the keys will be transposed.
210    /// If `keys` is None, all keys will be used.
211    /// ```text
212    /// import numpy as np
213    /// df = tdf.DataFrame.init()
214    /// df.push({"key1": 1, "key2": 2})
215    /// df.push({"key1": 11, "key2": 21})
216    /// a_np = df.as_numpy_i64(['key1', 'key2'])
217    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int64))
218    /// ```
219    #[pyo3(signature = (keys=None, transposed=None))]
220    pub fn as_numpy_i64<'py>(
221        &self,
222        keys: Option<Vec<String>>,
223        transposed: Option<bool>,
224        py: Python<'py>,
225    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
226        let data = self
227            .select_data(keys, transposed)
228            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
229        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
230    }
231
232    /// Returns slice from dataframe as numpy.array of float32 of the given keys.
233    /// If `transposed` is true, the keys will be transposed.
234    /// If `keys` is None, all keys will be used.
235    /// ```text
236    /// import numpy as np
237    /// df = tdf.DataFrame.init()
238    /// df.push({"key1": 1, "key2": 2})
239    /// df.push({"key1": 11, "key2": 21})
240    /// a_np = df.as_numpy_f32(['key1', 'key2'])
241    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float32))
242    /// ```
243    #[pyo3(signature = (keys=None, transposed=None))]
244    pub fn as_numpy_f32<'py>(
245        &self,
246        keys: Option<Vec<String>>,
247        transposed: Option<bool>,
248        py: Python<'py>,
249    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
250        let data = self
251            .select_data(keys, transposed)
252            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
253        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
254    }
255
256    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
257    /// If `transposed` is true, the keys will be transposed.
258    /// If `keys` is None, all keys will be used.
259    /// ```text
260    /// import numpy as np
261    /// df = tdf.DataFrame.init()
262    /// df.push({"key1": 1, "key2": 2})
263    /// df.push({"key1": 11, "key2": 21})
264    /// a_np = df.as_numpy_f64(['key1', 'key2'])
265    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float64))
266    /// ```
267    #[pyo3(signature = (keys=None, transposed=None))]
268    pub fn as_numpy_f64<'py>(
269        &self,
270        keys: Option<Vec<String>>,
271        transposed: Option<bool>,
272        py: Python<'py>,
273    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
274        let data = self
275            .select_data(keys, transposed)
276            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
277        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
278    }
279
280    #[pyo3(name = "shrink")]
281    pub fn py_shrink(&mut self) {
282        self.dataframe.shrink();
283    }
284
285    #[pyo3(name = "add_metadata")]
286    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
287        self.metadata.insert(key, value);
288    }
289
290    #[pyo3(name = "get_metadata")]
291    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
292        self.metadata.get(key).cloned()
293    }
294
295    #[pyo3(name = "rename_key")]
296    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
297        // fixme this may have a problem when the type is different and checked
298        self.dataframe
299            .rename_key(key, new_name.into())
300            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
301    }
302
303    #[pyo3(name = "add_alias")]
304    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
305        self.dataframe
306            .add_alias(key, new_name)
307            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
308    }
309
310    /// Selects data from the DataFrame.
311    /// If `keys` is None, all keys will be used.
312    /// If `keys` is provided, only the specified keys will be selected.
313    /// Returns a list of lists, where each inner list represents a row of data.
314    /// ```text
315    /// import trs_dataframe as tdf
316    /// df = tdf.DataFrame.init()
317    /// df.push({"key1": 1, "key2": 2})
318    /// df.push({"key1": 11, "key2": 21})
319    /// # selected = df.select(["key1", "key2"])
320    /// # assert selected == [[1, 2], [11, 21]]
321    /// # selected = df.select()
322    #[pyo3(name = "select", signature = (keys=None, transposed=None))]
323    pub fn py_select<'py>(
324        &self,
325        py: Python<'py>,
326        keys: Option<Vec<String>>,
327        transposed: Option<bool>,
328    ) -> Result<Bound<'py, PyList>, PyErr> {
329        let keys = keys
330            .unwrap_or(self.keys())
331            .into_iter()
332            .map(Key::from)
333            .collect::<Vec<Key>>();
334
335        let selected = if transposed.unwrap_or_default() {
336            self.select_transposed(Some(keys.as_slice()))
337                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
338        } else {
339            self.select(Some(keys.as_slice()))
340                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
341        };
342
343        let list = PyList::empty(py);
344        for rows in selected.rows() {
345            let row = PyList::empty(py);
346            for value in rows.iter() {
347                row.append(value.clone())
348                    .expect("BUG: cannot append to list");
349            }
350            list.append(row).expect("BUG: cannot append to list");
351        }
352        Ok(list)
353    }
354
355    /// Selects a column from the DataFrame.
356    /// If the column does not exist, it will raise a TypeError.
357    /// Returns a list of values in the selected column.
358    /// ```text
359    /// import trs_dataframe as tdf
360    /// df = tdf.DataFrame.init()
361    /// df.push({"key1": 1, "key2": 2})
362    /// df.push({"key1": 11, "key2": 21})
363    /// # selected = df.select_column("key1")
364    /// # assert selected == [1, 11]
365    /// # selected = df.select_column("key2")
366    /// # assert selected == [2, 21]
367    /// # selected = df.select_column("non_existing_key")  # Raises TypeError
368    /// ```
369    #[pyo3(name = "select_column")]
370    pub fn py_select_column<'py>(
371        &self,
372        py: Python<'py>,
373        key: String,
374    ) -> Result<Bound<'py, PyList>, PyErr> {
375        let selected = self
376            .select_column(Key::from(key))
377            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
378
379        let list = PyList::empty(py);
380        for x in selected.to_vec().into_iter() {
381            list.append(x)?;
382        }
383
384        Ok(list)
385    }
386
387    /// Joins the current DataFrame with another DataFrame.
388    /// The join type is specified by the `join_type` parameter.
389    /// see [`JoinRelation`] for available join types.
390    /// ```text
391    /// import trs_dataframe as tdf
392    /// df1 = tdf.DataFrame.init()
393    /// df1.push({"key1": 1, "key2": 2})
394    /// df1.push({"key1": 11, "key2": 21})
395    /// df2 = tdf.DataFrame.init()
396    /// df2.push({"key1": 1, "key2": 3})
397    /// df2.push({"key1": 11, "key2": 23})
398    /// df1.join(df2, tei.JoinRelation.extend())
399    /// assert df1.select(["key1", "key2"]) == [[1, 2], [11, 21], [1, 3], [11, 23]]
400    /// ```
401    #[pyo3(name = "join")]
402    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
403        self.dataframe
404            .join(other.dataframe, &join_type)
405            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
406
407        Ok(())
408    }
409
410    /// Pushes a new row of data into the DataFrame.
411    /// The data should be provided as a dictionary where keys are column names and values are the corresponding data values.
412    /// ```text
413    /// import trs_dataframe as tdf
414    /// df = tdf.DataFrame.init()
415    /// df.push({"key1": 1, "key2": 2})
416    /// df.push({"key1": 11, "key2": 21})
417    /// ```
418    #[pyo3(name = "push")]
419    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
420        self.dataframe
421            .push(data)
422            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
423        Ok(())
424    }
425
426    /// Adds a new column to the DataFrame.
427    /// The column is specified by a key and a vector of data values.
428    /// If the length of the data vector does not match the number of rows in the DataFrame, it will raise a TypeError.
429    /// ```text
430    /// import trs_dataframe as tdf
431    /// df = tdf.DataFrame.init()
432    /// df.push({"key1": 1, "key2": 2})
433    /// df.push({"key1": 11, "key2": 21})
434    /// df.add_column("key3", [3, 4])
435    /// assert df.select(["key1", "key2", "key3"]) == [[1, 2, 3], [11, 21, 4]]
436    /// ```
437    #[pyo3(name = "add_column")]
438    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
439        self.dataframe
440            .add_single_column(key, Array1::from_vec(data))
441            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
442        Ok(())
443    }
444
445    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
446        self.constants.insert(key, feature);
447        Ok(())
448    }
449
450    /// Filters the DataFrame by a given expression.
451    /// The expression should be a string that can be parsed by the DataFrame's filter method
452    ///
453    /// ```text
454    /// import trs_dataframe as tdf
455    /// df = tdf.DataFrame.init()
456    /// df.push({"key1": 1, "key2": 2})
457    /// df.push({"key1": 11, "key2": 21})
458    /// df.filter_by_expression("key1 > 5")
459    /// assert df.select(["key1", "key2"]) == [[11, 21 ]]
460    /// ```
461    pub fn filter_by_expression(&mut self, expression: String) -> Result<Self, PyErr> {
462        let filter = crate::filter::FilterRules::try_from(expression.as_str())
463            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot parse expression: {e}")))?;
464        self.filter(&filter)
465            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot filter data: {e}")))
466    }
467
468    fn __repr__(&self) -> String {
469        self.to_string()
470    }
471
472    fn __str__(&self) -> String {
473        self.to_string()
474    }
475
476    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
477        trace!("{object:?}");
478        let df_or_dict = DfOrDict::new(object)?;
479        match df_or_dict {
480            DfOrDict::DataFrame(df) => {
481                self.dataframe += df.dataframe;
482            }
483            DfOrDict::Dict(dict) => {
484                self.dataframe += dict;
485            }
486        }
487        Ok(())
488    }
489
490    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
491        trace!("{object:?}");
492
493        let df_or_dict = DfOrDict::new(object)?;
494        match df_or_dict {
495            DfOrDict::DataFrame(df) => {
496                self.dataframe -= df.dataframe;
497            }
498            DfOrDict::Dict(dict) => {
499                self.dataframe -= dict;
500            }
501        }
502        Ok(())
503    }
504
505    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
506        trace!("{object:?}");
507        let df_or_dict = DfOrDict::new(object)?;
508        match df_or_dict {
509            DfOrDict::DataFrame(df) => {
510                self.dataframe *= df.dataframe;
511            }
512            DfOrDict::Dict(dict) => {
513                self.dataframe *= dict;
514            }
515        }
516        Ok(())
517    }
518
519    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
520        trace!("{object:?}");
521        let df_or_dict = DfOrDict::new(object)?;
522        match df_or_dict {
523            DfOrDict::DataFrame(df) => {
524                self.dataframe /= df.dataframe;
525            }
526            DfOrDict::Dict(dict) => {
527                self.dataframe /= dict;
528            }
529        }
530        Ok(())
531    }
532
533    pub fn __len__(&mut self) -> Result<usize, PyErr> {
534        Ok(self.dataframe.len())
535    }
536
537    pub fn serialize_to_json_string(&self) -> String {
538        serde_json::to_string(self).expect("Cannot serialize to strinng")
539    }
540
541    #[staticmethod]
542    pub fn deserialize_from_json_string(json_df: String) -> Self {
543        let mut df: DataFrame =
544            serde_json::from_str(json_df.as_str()).expect("Cannot deserialize from str");
545        let _ = df.dataframe.try_fix_dtype();
546
547        df
548    }
549
550    // derive Serialize and Deserialize
551    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
552        let s: DataFrame = rmp_serde::decode::from_slice(state.as_bytes()).map_err(|e| {
553            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
554                "Cannot deserialize object {e}"
555            ))
556        })?;
557        *self = s;
558        self.dataframe.try_fix_dtype().map_err(|e| {
559            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
560                "Cannot deserialize object {e}"
561            ))
562        })?;
563        Ok(())
564    }
565    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
566        let buf = rmp_serde::encode::to_vec(self).map_err(|e| {
567            // let buf = serde_json::to_string(self).map_err(|e| {
568            pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
569                "Cannot deserialize object {e}"
570            ))
571        })?;
572        Ok(PyBytes::new(py, &buf))
573    }
574
575    pub fn __del__(&mut self) {
576        self.dataframe = Default::default();
577    }
578
579    // pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(Py<PyAny>, Py<PyAny>)> {
580    //     let cls = py.get_type::<Self>();
581    //     Ok((
582    //         cls.into(),
583    //         pyo3::types::PyTuple::new(py, &[self.__getstate__(py)?])?
584    //             .into_any()
585    //             .unbind(),
586    //     ))
587    // }
588}
589
590#[cfg(test)]
591mod test {
592
593    use super::*;
594    use crate::DataType;
595    use data_value::{stdhashmap, DataValue};
596    use halfbrown::hashmap;
597    use pyo3::ffi::c_str;
598    use rstest::*;
599    use tracing_test::traced_test;
600
601    #[fixture]
602    fn df() -> DataFrame {
603        let mut df = DataFrame::init();
604        assert!(df
605            .push(hashmap! {
606                Key::new("key1", DataType::U32) => DataValue::U32(1),
607                Key::new("key2", DataType::U32) => DataValue::U32(2),
608            })
609            .is_ok());
610        assert!(df
611            .push(hashmap! {
612                Key::from("key1") => DataValue::U32(11),
613                Key::from("key2") => DataValue::U32(21),
614            })
615            .is_ok());
616        df
617    }
618
619    #[fixture]
620    fn hm() -> HashMap<String, DataValue> {
621        stdhashmap!(
622            "key1".to_string() => DataValue::U32(2),
623            "key2".to_string() => DataValue::U32(3),
624        )
625    }
626
627    #[rstest]
628    fn serde_py(df: DataFrame) {
629        let str_df = df.serialize_to_json_string();
630        assert!(!str_df.is_empty());
631
632        let loaded = DataFrame::deserialize_from_json_string(str_df);
633
634        assert_eq!(loaded, df);
635    }
636    #[cfg(feature = "python")]
637    #[rstest]
638    fn pickle_py(df: DataFrame) {
639        pyo3::Python::attach(|py| {
640            let bytes = df.__getstate__(py);
641            assert!(bytes.is_ok());
642
643            let mut deser = DataFrame::default();
644            assert!(deser.__setstate__(bytes.unwrap().into()).is_ok());
645            assert_eq!(deser, df);
646        });
647    }
648    #[rstest]
649    fn test_select_data(df: DataFrame) {
650        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
651        assert!(data.is_ok());
652        assert_eq!(
653            data.unwrap(),
654            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
655        );
656
657        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
658        assert!(data.is_ok());
659        assert_eq!(
660            data.unwrap(),
661            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
662        );
663    }
664
665    #[cfg(feature = "python")]
666    #[rstest]
667    fn test_from_create() {
668        pyo3::Python::attach(|_py| {
669            let mut hm: HashMap<String, Vec<DataValue>> = Default::default();
670            let value: Vec<DataValue> = vec![1i32.into(), 22i32.into()];
671            hm.insert("a".into(), value);
672
673            let df = DataFrame::from_dict(hm);
674            assert_eq!(
675                df.select(Some(&["a".into()])),
676                Ok(ndarray::array![
677                    [DataValue::from(1i32)],
678                    [DataValue::from(22i32)]
679                ]),
680            );
681        });
682        #[cfg(feature = "polars-df")]
683        {
684            let pdf = polars::df!(
685                "a" => [1u64, 2u64, 3u64],
686                "b" => [4f64, 5f64, 6f64],
687                "c" => [7i64, 8i64, 9i64]
688            )
689            .expect("BUG: should be ok");
690            let df = DataFrame::from_polars(pyo3_polars::PyDataFrame(pdf));
691            assert_eq!(
692                df.select(Some(&["a".into(), "b".into(), "c".into()])),
693                crate::df! {
694                    "a" => [1u64, 2u64, 3u64],
695                    "b" => [4f64, 5f64, 6f64],
696                    "c" => [7i64, 8i64, 9i64]
697                }
698                .select(Some(&["a".into(), "b".into(), "c".into()])),
699            );
700        }
701    }
702
703    #[rstest]
704    #[traced_test]
705    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
706        let mut df_expect = df.clone();
707        let df2 = df.clone();
708        let exec = Python::attach(|py| -> PyResult<()> {
709            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
710            df_expect.dataframe += df2.dataframe;
711            tracing::trace!("{} vs {}", df, df_expect);
712            assert_eq!(df.dataframe, df_expect.dataframe);
713
714            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
715            df_expect.dataframe += hm;
716            tracing::trace!("{} vs {}", df, df_expect);
717            assert_eq!(df.dataframe, df_expect.dataframe);
718
719            Ok(())
720        });
721
722        assert!(exec.is_ok(), "{:?}", exec);
723    }
724
725    #[rstest]
726    #[traced_test]
727    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
728        let mut df_expect = df.clone();
729        let df2 = df.clone();
730        let exec = Python::attach(|py| -> PyResult<()> {
731            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
732            df_expect.dataframe -= df2.dataframe;
733            tracing::trace!("{} vs {}", df, df_expect);
734            assert_eq!(df.dataframe, df_expect.dataframe);
735
736            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
737            df_expect.dataframe -= hm;
738            tracing::trace!("{} vs {}", df, df_expect);
739            assert_eq!(df.dataframe, df_expect.dataframe);
740
741            Ok(())
742        });
743
744        assert!(exec.is_ok(), "{:?}", exec);
745    }
746
747    #[rstest]
748    #[traced_test]
749    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
750        let mut df_expect = df.clone();
751        let df2 = df.clone();
752        let exec = Python::attach(|py| -> PyResult<()> {
753            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
754            df_expect.dataframe *= df2.dataframe;
755            tracing::trace!("{} vs {}", df, df_expect);
756            assert_eq!(df.dataframe, df_expect.dataframe);
757
758            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
759            df_expect.dataframe *= hm;
760            tracing::trace!("{} vs {}", df, df_expect);
761            assert_eq!(df.dataframe, df_expect.dataframe);
762            Ok(())
763        });
764
765        assert!(exec.is_ok(), "{:?}", exec);
766    }
767
768    #[rstest]
769    #[traced_test]
770    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
771        let mut df_expect = df.clone();
772        let df2 = df.clone();
773        let exec = Python::attach(|py| -> PyResult<()> {
774            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
775            df_expect.dataframe /= df2.dataframe;
776            tracing::trace!("{} vs {}", df, df_expect);
777            assert_eq!(df.dataframe, df_expect.dataframe);
778
779            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
780            df_expect.dataframe /= hm;
781            tracing::trace!("{} vs {}", df, df_expect);
782            assert_eq!(df.dataframe, df_expect.dataframe);
783            Ok(())
784        });
785
786        assert!(exec.is_ok(), "{:?}", exec);
787    }
788
789    #[rstest]
790    #[traced_test]
791    #[rstest]
792    fn test_numpy(mut df: DataFrame) {
793        let exec = Python::attach(|py| -> PyResult<()> {
794            let code = c_str!(
795                r#"
796def example(df):
797    import numpy as np
798    a_np = df.as_numpy_f32(['key1', 'key2'])
799    print(a_np)
800    b_np = df.as_numpy_u32(['key1', 'key'])
801    print(b_np)
802    b_np = df.as_numpy_i32(['key1', 'key'])
803    print(b_np)
804    b_np = df.as_numpy_i64(['key1', 'key'])
805    print(b_np)
806    b_np = df.as_numpy_u64(['key1', 'key'])
807    print(b_np)
808    b_np = df.as_numpy_f64(['key1', 'key'])
809    print(b_np)
810    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
811    print(b_np)
812    return df
813            "#
814            );
815            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
816                .getattr("example")?
817                .into();
818            let result = fun.call1(py, (df.clone(),));
819            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
820            // user may not have installed polars, we need to get an error in that
821            // case
822            if py.import("numpy").is_ok() {
823                assert!(result.is_ok(), "{:?}", result);
824            } else {
825                assert!(result.is_err(), "{:?}", result);
826            }
827            Ok(())
828        });
829        assert!(exec.is_ok(), "{:?}", exec);
830    }
831
832    #[rstest]
833    #[traced_test]
834    #[rstest]
835    fn test_fill_from_python(df: DataFrame) {
836        let exec = Python::attach(|_py| -> PyResult<()> {
837            let hm = stdhashmap!(
838                Key::from("key1") => DataValue::U32(1),
839                Key::from("key2") => DataValue::U32(2),
840            );
841            let mut df2 = DataFrame::init();
842            assert!(df2.py_push(hm).is_ok());
843            assert!(df2
844                .py_push(stdhashmap!(
845                    Key::from("key1") => DataValue::U32(11),
846                    Key::from("key2") => DataValue::U32(21),
847                ))
848                .is_ok());
849
850            assert_eq!(df, df2);
851
852            let mut df2 = DataFrame::init();
853            assert!(df2
854                .py_add_column(
855                    Key::from("key1"),
856                    vec![DataValue::U32(1), DataValue::U32(11)]
857                )
858                .is_ok());
859            assert!(df2
860                .py_add_column(
861                    Key::from("key2"),
862                    vec![DataValue::U32(2), DataValue::U32(21)]
863                )
864                .is_ok());
865
866            assert_eq!(df, df2);
867            Ok(())
868        });
869        assert!(exec.is_ok(), "{:?}", exec);
870    }
871
872    #[rstest]
873    fn basic_python_dataframe(mut df: DataFrame) {
874        let exec = Python::attach(|py| -> PyResult<()> {
875            let fun: Py<PyAny> = PyModule::from_code(
876                py,
877                c_str!(
878                    "
879def example(df):
880    print(df)
881    df.shrink()
882    assert len(df) == 2
883    df.add_alias('key1', 'key1-alias')
884    a = df.select(['key1', 'key2'])
885    print(a)
886    b = df.select(['key1-alias', 'key2'])
887    print(b)
888    df.rename_key('key1', 'key1new')
889    df.rename_key('key1new', 'key1')
890    assert a == [[1, 2], [11, 21]]
891    assert a == b
892    df.add_metadata('test', 1)
893    m = df.get_metadata('test')
894    assert m == 1
895    b = df.select_transposed(['key1', 'key2'])
896    print(b)
897    assert b == [[1, 11], [2, 21]]
898    c = df.select_column('key1')
899    print(c)
900    assert c == [1, 11]
901
902    a += b
903    print(a)
904    assert a == [[2, 13], [4, 23]]
905    a -= b
906    print(a)
907    assert e == a
908    f = e * b
909    print(f)
910    assert f == [[1, 22], [44, 441]]
911    g = f / b
912    print(g)
913    assert g == e
914
915                "
916                ),
917                c_str!(""),
918                c_str!(""),
919            )?
920            .getattr("example")?
921            .into();
922            let _ = fun.call1(py, (df.clone(),));
923            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
924            Ok(())
925        });
926        assert!(exec.is_ok(), "{:?}", exec);
927    }
928
929    #[rstest]
930    fn dummy_test_apply(mut df: DataFrame) {
931        let exec = Python::attach(|py| -> PyResult<()> {
932            let fun: Py<PyAny> = PyModule::from_code(
933                py,
934                c_str!(
935                    r#"
936def multiply_by_ten(x):
937    print(x)
938    x *= {"key1": 10}
939    print(x)
940    return x
941
942def example(df):
943    print(df)
944    df.apply(multiply_by_ten)
945                "#
946                ),
947                c_str!(""),
948                c_str!(""),
949            )?
950            .getattr("example")?
951            .into();
952            let _ = fun.call1(py, (df.clone(),));
953            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
954            Ok(())
955        });
956        assert!(exec.is_ok(), "{:?}", exec);
957    }
958}