trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::PyArray2;
7use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
8use tracing::trace;
9
10impl DataFrame {
11    fn select_data(
12        &self,
13        keys: Option<Vec<String>>,
14        transposed: Option<bool>,
15    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
16        let keys = keys
17            .unwrap_or(self.keys())
18            .into_iter()
19            .map(Key::from)
20            .collect::<Vec<Key>>();
21        if transposed.unwrap_or(false) {
22            self.select(Some(keys.as_slice()))
23        } else {
24            self.select_transposed(Some(keys.as_slice()))
25        }
26    }
27}
28
29enum DfOrDict {
30    DataFrame(DataFrame),
31    Dict(HashMap<String, DataValue>),
32}
33
34impl DfOrDict {
35    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
36        if let Ok(df) = object.extract::<DataFrame>() {
37            Ok(DfOrDict::DataFrame(df))
38        } else {
39            let dict: HashMap<String, DataValue> = object.extract()?;
40            Ok(DfOrDict::Dict(dict))
41        }
42    }
43}
44
45#[pymethods]
46impl DataFrame {
47    /// Create a new empty DataFrame.
48    #[new]
49    pub fn init() -> Self {
50        Self::default()
51    }
52
53    /// Create a DataFrame from a polars dataframe in python.
54    /// ```text
55    /// import polars as pl
56    /// df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
57    /// tei_df = tdf.DataFrame.from_polars(df)
58    /// ```
59    #[cfg(feature = "polars-df")]
60    #[staticmethod]
61    pub fn from_polars(df: pyo3_polars::PyDataFrame) -> Self {
62        df.0.into()
63    }
64
65    /// Create a DataFrame from a dictionary.
66    /// ```text
67    /// df = tdf.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
68    /// ```
69    #[staticmethod]
70    pub fn from_dict(df: HashMap<String, Vec<DataValue>>) -> Self {
71        let mut result_df: Vec<(Key, Vec<DataValue>)> = Vec::new();
72        for (key, value) in df.into_iter() {
73            result_df.push((key.as_str().into(), value));
74        }
75        result_df.into()
76    }
77
78    /// Returns the keys of the DataFrame.
79    pub fn keys(&self) -> Vec<String> {
80        self.dataframe
81            .keys()
82            .iter()
83            .map(|x| x.name().to_string())
84            .collect()
85    }
86
87    /// Convert the DataFrame to polars DataFrame.
88    /// This requires the `polars-df` feature to be enabled.
89    /// ```text
90    /// import polars as pl
91    /// original_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
92    /// df = tdf.DataFrame.from_polars(original_df)
93    /// polars_df = df.as_polars()
94    /// assert polars_df.frame_equal(original_df)
95    /// ```  
96    #[cfg(feature = "polars-df")]
97    #[pyo3(name = "as_polars")]
98    pub fn py_as_polars(&self) -> PyResult<pyo3_polars::PyDataFrame> {
99        let df = self
100            .as_polars()
101            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
102        Ok(pyo3_polars::PyDataFrame(df))
103    }
104
105    /// Apply a function to the DataFrame.
106    /// The function should accept a DataFrame and return a DataFrame.
107    /// ```text
108    /// def my_function(df):
109    ///    # Perform some operations on the DataFrame
110    ///   return df
111    /// /// df = tdf.DataFrame.init()
112    /// df.apply(my_function)
113    /// ```
114    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
115        let df: DataFrame = pyo3::Python::with_gil(|py| {
116            let self_ = self
117                .clone()
118                .into_pyobject(py)
119                .expect("BUG: cannot convert to PyObject");
120            let result = function.call1((self_,)).expect("BUG: cannot call function");
121            result
122                .extract::<Bound<DataFrame>>()
123                .expect("BUG: cannot extract data frame")
124                .unbind()
125                .extract(py)
126                .expect("BUG: cannot extract data frame")
127        });
128        self.dataframe = df.dataframe;
129        Ok(())
130    }
131
132    /// Returns slice from dataframe as numpy.array of uint32 of the given keys.
133    /// If `transposed` is true, the keys will be transposed.
134    /// If `keys` is None, all keys will be used.
135    /// ```text
136    /// import numpy as np
137    /// df = tdf.DataFrame.init()
138    /// df.push({"key1": 1, "key2": 2})
139    /// df.push({"key1": 11, "key2": 21})
140    /// a_np = df.as_numpy_u32(['key1', 'key2'])
141    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint32))
142    /// ```
143    #[pyo3(signature = (keys=None, transposed=None))]
144    pub fn as_numpy_u32<'py>(
145        &self,
146        keys: Option<Vec<String>>,
147        transposed: Option<bool>,
148        py: Python<'py>,
149    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
150        let data = self
151            .select_data(keys, transposed)
152            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
153        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
154    }
155
156    /// Returns slice from dataframe as numpy.array of uint64 of the given keys.
157    /// If `transposed` is true, the keys will be transposed.
158    /// If `keys` is None, all keys will be used.
159    /// ```text
160    /// import numpy as np
161    /// df = tdf.DataFrame.init()
162    /// df.push({"key1": 1, "key2": 2})
163    /// df.push({"key1": 11, "key2": 21})
164    /// a_np = df.as_numpy_u64(['key1', 'key2'])
165    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.uint64))
166    /// ```
167    #[pyo3(signature = (keys=None, transposed=None))]
168    pub fn as_numpy_u64<'py>(
169        &self,
170        keys: Option<Vec<String>>,
171        transposed: Option<bool>,
172        py: Python<'py>,
173    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
174        let data = self
175            .select_data(keys, transposed)
176            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
177        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
178    }
179
180    /// Returns slice from dataframe as numpy.array of int32 of the given keys.
181    /// If `transposed` is true, the keys will be transposed.
182    /// If `keys` is None, all keys will be used.
183    /// ```text
184    /// import numpy as np
185    /// df = tdf.DataFrame.init()
186    /// df.push({"key1": 1, "key2": 2})
187    /// df.push({"key1": 11, "key2": 21})
188    /// a_np = df.as_numpy_i32(['key1', 'key2'])
189    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int32))
190    /// ```
191    #[pyo3(signature = (keys=None, transposed=None))]
192    pub fn as_numpy_i32<'py>(
193        &self,
194        keys: Option<Vec<String>>,
195        transposed: Option<bool>,
196        py: Python<'py>,
197    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
198        let data = self
199            .select_data(keys, transposed)
200            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
201        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
202    }
203
204    /// Returns slice from dataframe as numpy.array of int64 of the given keys.
205    /// If `transposed` is true, the keys will be transposed.
206    /// If `keys` is None, all keys will be used.
207    /// ```text
208    /// import numpy as np
209    /// df = tdf.DataFrame.init()
210    /// df.push({"key1": 1, "key2": 2})
211    /// df.push({"key1": 11, "key2": 21})
212    /// a_np = df.as_numpy_i64(['key1', 'key2'])
213    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.int64))
214    /// ```
215    #[pyo3(signature = (keys=None, transposed=None))]
216    pub fn as_numpy_i64<'py>(
217        &self,
218        keys: Option<Vec<String>>,
219        transposed: Option<bool>,
220        py: Python<'py>,
221    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
222        let data = self
223            .select_data(keys, transposed)
224            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
225        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
226    }
227
228    /// Returns slice from dataframe as numpy.array of float32 of the given keys.
229    /// If `transposed` is true, the keys will be transposed.
230    /// If `keys` is None, all keys will be used.
231    /// ```text
232    /// import numpy as np
233    /// df = tdf.DataFrame.init()
234    /// df.push({"key1": 1, "key2": 2})
235    /// df.push({"key1": 11, "key2": 21})
236    /// a_np = df.as_numpy_f32(['key1', 'key2'])
237    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float32))
238    /// ```
239    #[pyo3(signature = (keys=None, transposed=None))]
240    pub fn as_numpy_f32<'py>(
241        &self,
242        keys: Option<Vec<String>>,
243        transposed: Option<bool>,
244        py: Python<'py>,
245    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
246        let data = self
247            .select_data(keys, transposed)
248            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
249        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
250    }
251
252    /// Returns slice from dataframe as numpy.array of float64 of the given keys.
253    /// If `transposed` is true, the keys will be transposed.
254    /// If `keys` is None, all keys will be used.
255    /// ```text
256    /// import numpy as np
257    /// df = tdf.DataFrame.init()
258    /// df.push({"key1": 1, "key2": 2})
259    /// df.push({"key1": 11, "key2": 21})
260    /// a_np = df.as_numpy_f64(['key1', 'key2'])
261    /// assert np.array_equal(a_np, np.array([[1, 11], [2, 21]], dtype=np.float64))
262    /// ```
263    #[pyo3(signature = (keys=None, transposed=None))]
264    pub fn as_numpy_f64<'py>(
265        &self,
266        keys: Option<Vec<String>>,
267        transposed: Option<bool>,
268        py: Python<'py>,
269    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
270        let data = self
271            .select_data(keys, transposed)
272            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
273        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
274    }
275
276    #[pyo3(name = "shrink")]
277    pub fn py_shrink(&mut self) {
278        self.dataframe.shrink();
279    }
280
281    #[pyo3(name = "add_metadata")]
282    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
283        self.metadata.insert(key, value);
284    }
285
286    #[pyo3(name = "get_metadata")]
287    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
288        self.metadata.get(key).cloned()
289    }
290
291    #[pyo3(name = "rename_key")]
292    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
293        // fixme this may have a problem when the type is different and checked
294        self.dataframe
295            .rename_key(key, new_name.into())
296            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
297    }
298
299    #[pyo3(name = "add_alias")]
300    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
301        self.dataframe
302            .add_alias(key, new_name)
303            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
304    }
305
306    /// Selects data from the DataFrame.
307    /// If `keys` is None, all keys will be used.
308    /// If `keys` is provided, only the specified keys will be selected.
309    /// Returns a list of lists, where each inner list represents a row of data.
310    /// ```text
311    /// import trs_dataframe as tdf
312    /// df = tdf.DataFrame.init()
313    /// df.push({"key1": 1, "key2": 2})
314    /// df.push({"key1": 11, "key2": 21})
315    /// # selected = df.select(["key1", "key2"])
316    /// # assert selected == [[1, 2], [11, 21]]
317    /// # selected = df.select()
318    #[pyo3(name = "select", signature = (keys=None, transposed=None))]
319    pub fn py_select<'py>(
320        &self,
321        py: Python<'py>,
322        keys: Option<Vec<String>>,
323        transposed: Option<bool>,
324    ) -> Result<Bound<'py, PyList>, PyErr> {
325        let keys = keys
326            .unwrap_or(self.keys())
327            .into_iter()
328            .map(Key::from)
329            .collect::<Vec<Key>>();
330
331        let selected = if transposed.unwrap_or_default() {
332            self.select_transposed(Some(keys.as_slice()))
333                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
334        } else {
335            self.select(Some(keys.as_slice()))
336                .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?
337        };
338
339        let list = PyList::empty(py);
340        for rows in selected.rows() {
341            let row = PyList::empty(py);
342            for value in rows.iter() {
343                row.append(value.clone())
344                    .expect("BUG: cannot append to list");
345            }
346            list.append(row).expect("BUG: cannot append to list");
347        }
348        Ok(list)
349    }
350
351    /// Selects a column from the DataFrame.
352    /// If the column does not exist, it will raise a TypeError.
353    /// Returns a list of values in the selected column.
354    /// ```text
355    /// import trs_dataframe as tdf
356    /// df = tdf.DataFrame.init()
357    /// df.push({"key1": 1, "key2": 2})
358    /// df.push({"key1": 11, "key2": 21})
359    /// # selected = df.select_column("key1")
360    /// # assert selected == [1, 11]
361    /// # selected = df.select_column("key2")
362    /// # assert selected == [2, 21]
363    /// # selected = df.select_column("non_existing_key")  # Raises TypeError
364    /// ```
365    #[pyo3(name = "select_column")]
366    pub fn py_select_column<'py>(
367        &self,
368        py: Python<'py>,
369        key: String,
370    ) -> Result<Bound<'py, PyList>, PyErr> {
371        let selected = self
372            .select_column(Key::from(key))
373            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
374
375        let list = PyList::empty(py);
376        for x in selected.to_vec().into_iter() {
377            list.append(x)?;
378        }
379
380        Ok(list)
381    }
382
383    /// Joins the current DataFrame with another DataFrame.
384    /// The join type is specified by the `join_type` parameter.
385    /// see [`JoinRelation`] for available join types.
386    /// ```text
387    /// import trs_dataframe as tdf
388    /// df1 = tdf.DataFrame.init()
389    /// df1.push({"key1": 1, "key2": 2})
390    /// df1.push({"key1": 11, "key2": 21})
391    /// df2 = tdf.DataFrame.init()
392    /// df2.push({"key1": 1, "key2": 3})
393    /// df2.push({"key1": 11, "key2": 23})
394    /// df1.join(df2, tei.JoinRelation.extend())
395    /// assert df1.select(["key1", "key2"]) == [[1, 2], [11, 21], [1, 3], [11, 23]]
396    /// ```
397    #[pyo3(name = "join")]
398    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
399        self.dataframe
400            .join(other.dataframe, &join_type)
401            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
402
403        Ok(())
404    }
405
406    /// Pushes a new row of data into the DataFrame.
407    /// The data should be provided as a dictionary where keys are column names and values are the corresponding data values.
408    /// ```text
409    /// import trs_dataframe as tdf
410    /// df = tdf.DataFrame.init()
411    /// df.push({"key1": 1, "key2": 2})
412    /// df.push({"key1": 11, "key2": 21})
413    /// ```
414    #[pyo3(name = "push")]
415    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
416        self.dataframe
417            .push(data)
418            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
419        Ok(())
420    }
421
422    /// Adds a new column to the DataFrame.
423    /// The column is specified by a key and a vector of data values.
424    /// If the length of the data vector does not match the number of rows in the DataFrame, it will raise a TypeError.
425    /// ```text
426    /// import trs_dataframe as tdf
427    /// df = tdf.DataFrame.init()
428    /// df.push({"key1": 1, "key2": 2})
429    /// df.push({"key1": 11, "key2": 21})
430    /// df.add_column("key3", [3, 4])
431    /// assert df.select(["key1", "key2", "key3"]) == [[1, 2, 3], [11, 21, 4]]
432    /// ```
433    #[pyo3(name = "add_column")]
434    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
435        self.dataframe
436            .add_single_column(key, Array1::from_vec(data))
437            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
438        Ok(())
439    }
440
441    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
442        self.constants.insert(key, feature);
443        Ok(())
444    }
445
446    /// Filters the DataFrame by a given expression.
447    /// The expression should be a string that can be parsed by the DataFrame's filter method
448    ///
449    /// ```text
450    /// import trs_dataframe as tdf
451    /// df = tdf.DataFrame.init()
452    /// df.push({"key1": 1, "key2": 2})
453    /// df.push({"key1": 11, "key2": 21})
454    /// df.filter_by_expression("key1 > 5")
455    /// assert df.select(["key1", "key2"]) == [[11, 21 ]]
456    /// ```
457    pub fn filter_by_expression(&mut self, expression: String) -> Result<Self, PyErr> {
458        let filter = crate::filter::FilterRules::try_from(expression.as_str())
459            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot parse expression: {e}")))?;
460        self.filter(&filter)
461            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot filter data: {e}")))
462    }
463
464    fn __repr__(&self) -> String {
465        self.to_string()
466    }
467
468    fn __str__(&self) -> String {
469        self.to_string()
470    }
471
472    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
473        trace!("{object:?}");
474        let df_or_dict = DfOrDict::new(object)?;
475        match df_or_dict {
476            DfOrDict::DataFrame(df) => {
477                self.dataframe += df.dataframe;
478            }
479            DfOrDict::Dict(dict) => {
480                self.dataframe += dict;
481            }
482        }
483        Ok(())
484    }
485
486    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
487        trace!("{object:?}");
488
489        let df_or_dict = DfOrDict::new(object)?;
490        match df_or_dict {
491            DfOrDict::DataFrame(df) => {
492                self.dataframe -= df.dataframe;
493            }
494            DfOrDict::Dict(dict) => {
495                self.dataframe -= dict;
496            }
497        }
498        Ok(())
499    }
500
501    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
502        trace!("{object:?}");
503        let df_or_dict = DfOrDict::new(object)?;
504        match df_or_dict {
505            DfOrDict::DataFrame(df) => {
506                self.dataframe *= df.dataframe;
507            }
508            DfOrDict::Dict(dict) => {
509                self.dataframe *= dict;
510            }
511        }
512        Ok(())
513    }
514
515    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
516        trace!("{object:?}");
517        let df_or_dict = DfOrDict::new(object)?;
518        match df_or_dict {
519            DfOrDict::DataFrame(df) => {
520                self.dataframe /= df.dataframe;
521            }
522            DfOrDict::Dict(dict) => {
523                self.dataframe /= dict;
524            }
525        }
526        Ok(())
527    }
528
529    pub fn __len__(&mut self) -> Result<usize, PyErr> {
530        Ok(self.dataframe.len())
531    }
532}
533
534#[cfg(test)]
535mod test {
536
537    use super::*;
538    use data_value::{stdhashmap, DataValue};
539    use halfbrown::hashmap;
540    use pyo3::ffi::c_str;
541    use rstest::*;
542    use tracing_test::traced_test;
543
544    #[fixture]
545    fn df() -> DataFrame {
546        let mut df = DataFrame::init();
547        assert!(df
548            .push(hashmap! {
549                Key::from("key1") => DataValue::U32(1),
550                Key::from("key2") => DataValue::U32(2),
551            })
552            .is_ok());
553        assert!(df
554            .push(hashmap! {
555                Key::from("key1") => DataValue::U32(11),
556                Key::from("key2") => DataValue::U32(21),
557            })
558            .is_ok());
559        df
560    }
561
562    #[fixture]
563    fn hm() -> HashMap<String, DataValue> {
564        stdhashmap!(
565            "key1".to_string() => DataValue::U32(2),
566            "key2".to_string() => DataValue::U32(3),
567        )
568    }
569
570    #[rstest]
571    fn test_select_data(df: DataFrame) {
572        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
573        assert!(data.is_ok());
574        assert_eq!(
575            data.unwrap(),
576            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
577        );
578
579        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
580        assert!(data.is_ok());
581        assert_eq!(
582            data.unwrap(),
583            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
584        );
585    }
586
587    #[cfg(feature = "python")]
588    #[rstest]
589    fn test_from_create() {
590        pyo3::Python::with_gil(|_py| {
591            let mut hm: HashMap<String, Vec<DataValue>> = Default::default();
592            let value: Vec<DataValue> = vec![1i32.into(), 22i32.into()];
593            hm.insert("a".into(), value);
594
595            let df = DataFrame::from_dict(hm);
596            assert_eq!(
597                df.select(Some(&["a".into()])),
598                Ok(ndarray::array![
599                    [DataValue::from(1i32)],
600                    [DataValue::from(22i32)]
601                ]),
602            );
603        });
604        #[cfg(feature = "polars-df")]
605        {
606            let pdf = polars::df!(
607                "a" => [1u64, 2u64, 3u64],
608                "b" => [4f64, 5f64, 6f64],
609                "c" => [7i64, 8i64, 9i64]
610            )
611            .expect("BUG: should be ok");
612            let df = DataFrame::from_polars(pyo3_polars::PyDataFrame(pdf));
613            assert_eq!(
614                df.select(Some(&["a".into(), "b".into(), "c".into()])),
615                crate::df! {
616                    "a" => [1u64, 2u64, 3u64],
617                    "b" => [4f64, 5f64, 6f64],
618                    "c" => [7i64, 8i64, 9i64]
619                }
620                .select(Some(&["a".into(), "b".into(), "c".into()])),
621            );
622        }
623    }
624
625    #[rstest]
626    #[traced_test]
627    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
628        let mut df_expect = df.clone();
629        let df2 = df.clone();
630        let exec = Python::with_gil(|py| -> PyResult<()> {
631            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
632            df_expect.dataframe += df2.dataframe;
633            tracing::trace!("{} vs {}", df, df_expect);
634            assert_eq!(df.dataframe, df_expect.dataframe);
635
636            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
637            df_expect.dataframe += hm;
638            tracing::trace!("{} vs {}", df, df_expect);
639            assert_eq!(df.dataframe, df_expect.dataframe);
640
641            Ok(())
642        });
643
644        assert!(exec.is_ok(), "{:?}", exec);
645    }
646
647    #[rstest]
648    #[traced_test]
649    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
650        let mut df_expect = df.clone();
651        let df2 = df.clone();
652        let exec = Python::with_gil(|py| -> PyResult<()> {
653            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
654            df_expect.dataframe -= df2.dataframe;
655            tracing::trace!("{} vs {}", df, df_expect);
656            assert_eq!(df.dataframe, df_expect.dataframe);
657
658            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
659            df_expect.dataframe -= hm;
660            tracing::trace!("{} vs {}", df, df_expect);
661            assert_eq!(df.dataframe, df_expect.dataframe);
662
663            Ok(())
664        });
665
666        assert!(exec.is_ok(), "{:?}", exec);
667    }
668
669    #[rstest]
670    #[traced_test]
671    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
672        let mut df_expect = df.clone();
673        let df2 = df.clone();
674        let exec = Python::with_gil(|py| -> PyResult<()> {
675            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
676            df_expect.dataframe *= df2.dataframe;
677            tracing::trace!("{} vs {}", df, df_expect);
678            assert_eq!(df.dataframe, df_expect.dataframe);
679
680            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
681            df_expect.dataframe *= hm;
682            tracing::trace!("{} vs {}", df, df_expect);
683            assert_eq!(df.dataframe, df_expect.dataframe);
684            Ok(())
685        });
686
687        assert!(exec.is_ok(), "{:?}", exec);
688    }
689
690    #[rstest]
691    #[traced_test]
692    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
693        let mut df_expect = df.clone();
694        let df2 = df.clone();
695        let exec = Python::with_gil(|py| -> PyResult<()> {
696            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
697            df_expect.dataframe /= df2.dataframe;
698            tracing::trace!("{} vs {}", df, df_expect);
699            assert_eq!(df.dataframe, df_expect.dataframe);
700
701            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
702            df_expect.dataframe /= hm;
703            tracing::trace!("{} vs {}", df, df_expect);
704            assert_eq!(df.dataframe, df_expect.dataframe);
705            Ok(())
706        });
707
708        assert!(exec.is_ok(), "{:?}", exec);
709    }
710
711    #[rstest]
712    #[traced_test]
713    #[rstest]
714    fn test_numpy(mut df: DataFrame) {
715        let exec = Python::with_gil(|py| -> PyResult<()> {
716            let code = c_str!(
717                r#"
718def example(df):
719    import numpy as np
720    a_np = df.as_numpy_f32(['key1', 'key2'])
721    print(a_np)
722    b_np = df.as_numpy_u32(['key1', 'key'])
723    print(b_np)
724    b_np = df.as_numpy_i32(['key1', 'key'])
725    print(b_np)
726    b_np = df.as_numpy_i64(['key1', 'key'])
727    print(b_np)
728    b_np = df.as_numpy_u64(['key1', 'key'])
729    print(b_np)
730    b_np = df.as_numpy_f64(['key1', 'key'])
731    print(b_np)
732    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
733    print(b_np)
734    return df
735            "#
736            );
737            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
738                .getattr("example")?
739                .into();
740            let result = fun.call1(py, (df.clone(),));
741            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
742            // user may not have installed polars, we need to get an error in that
743            // case
744            if py.import("numpy").is_ok() {
745                assert!(result.is_ok(), "{:?}", result);
746            } else {
747                assert!(result.is_err(), "{:?}", result);
748            }
749            Ok(())
750        });
751        assert!(exec.is_ok(), "{:?}", exec);
752    }
753
754    #[rstest]
755    #[traced_test]
756    #[rstest]
757    fn test_fill_from_python(df: DataFrame) {
758        let exec = Python::with_gil(|_py| -> PyResult<()> {
759            let hm = stdhashmap!(
760                Key::from("key1") => DataValue::U32(1),
761                Key::from("key2") => DataValue::U32(2),
762            );
763            let mut df2 = DataFrame::init();
764            assert!(df2.py_push(hm).is_ok());
765            assert!(df2
766                .py_push(stdhashmap!(
767                    Key::from("key1") => DataValue::U32(11),
768                    Key::from("key2") => DataValue::U32(21),
769                ))
770                .is_ok());
771
772            assert_eq!(df, df2);
773
774            let mut df2 = DataFrame::init();
775            assert!(df2
776                .py_add_column(
777                    Key::from("key1"),
778                    vec![DataValue::U32(1), DataValue::U32(11)]
779                )
780                .is_ok());
781            assert!(df2
782                .py_add_column(
783                    Key::from("key2"),
784                    vec![DataValue::U32(2), DataValue::U32(21)]
785                )
786                .is_ok());
787
788            assert_eq!(df, df2);
789            Ok(())
790        });
791        assert!(exec.is_ok(), "{:?}", exec);
792    }
793
794    #[rstest]
795    fn basic_python_dataframe(mut df: DataFrame) {
796        let exec = Python::with_gil(|py| -> PyResult<()> {
797            let fun: Py<PyAny> = PyModule::from_code(
798                py,
799                c_str!(
800                    "
801def example(df):
802    print(df)
803    df.shrink()
804    assert len(df) == 2
805    df.add_alias('key1', 'key1-alias')
806    a = df.select(['key1', 'key2'])
807    print(a)
808    b = df.select(['key1-alias', 'key2'])
809    print(b)
810    df.rename_key('key1', 'key1new')
811    df.rename_key('key1new', 'key1')
812    assert a == [[1, 2], [11, 21]]
813    assert a == b
814    df.add_metadata('test', 1)
815    m = df.get_metadata('test')
816    assert m == 1
817    b = df.select_transposed(['key1', 'key2'])
818    print(b)
819    assert b == [[1, 11], [2, 21]]
820    c = df.select_column('key1')
821    print(c)
822    assert c == [1, 11]
823
824    a += b
825    print(a)
826    assert a == [[2, 13], [4, 23]]
827    a -= b
828    print(a)
829    assert e == a
830    f = e * b
831    print(f)
832    assert f == [[1, 22], [44, 441]]
833    g = f / b
834    print(g)
835    assert g == e
836
837                "
838                ),
839                c_str!(""),
840                c_str!(""),
841            )?
842            .getattr("example")?
843            .into();
844            let _ = fun.call1(py, (df.clone(),));
845            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
846            Ok(())
847        });
848        assert!(exec.is_ok(), "{:?}", exec);
849    }
850
851    #[rstest]
852    fn dummy_test_apply(mut df: DataFrame) {
853        let exec = Python::with_gil(|py| -> PyResult<()> {
854            let fun: Py<PyAny> = PyModule::from_code(
855                py,
856                c_str!(
857                    r#"
858def multiply_by_ten(x):
859    print(x)
860    x *= {"key1": 10}
861    print(x)
862    return x
863
864def example(df):
865    print(df)
866    df.apply(multiply_by_ten)
867                "#
868                ),
869                c_str!(""),
870                c_str!(""),
871            )?
872            .getattr("example")?
873            .into();
874            let _ = fun.call1(py, (df.clone(),));
875            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
876            Ok(())
877        });
878        assert!(exec.is_ok(), "{:?}", exec);
879    }
880}