trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::PyArray2;
7use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
8use tracing::trace;
9
10impl DataFrame {
11    fn select_data(
12        &self,
13        keys: Option<Vec<String>>,
14        transposed: Option<bool>,
15    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
16        let keys = keys
17            .unwrap_or(self.keys())
18            .into_iter()
19            .map(Key::from)
20            .collect::<Vec<Key>>();
21        if transposed.unwrap_or(false) {
22            self.select(Some(keys.as_slice()))
23        } else {
24            self.select_transposed(Some(keys.as_slice()))
25        }
26    }
27}
28
29enum DfOrDict {
30    DataFrame(DataFrame),
31    Dict(HashMap<String, DataValue>),
32}
33
34impl DfOrDict {
35    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
36        if let Ok(df) = object.extract::<DataFrame>() {
37            Ok(DfOrDict::DataFrame(df))
38        } else {
39            let dict: HashMap<String, DataValue> = object.extract()?;
40            Ok(DfOrDict::Dict(dict))
41        }
42    }
43}
44
45#[pymethods]
46impl DataFrame {
47    #[new]
48    pub fn init() -> Self {
49        Self::default()
50    }
51
52    pub fn keys(&self) -> Vec<String> {
53        self.dataframe
54            .keys()
55            .iter()
56            .map(|x| x.name().to_string())
57            .collect()
58    }
59
60    #[cfg(feature = "polars-df")]
61    #[pyo3(name = "as_polars")]
62    pub fn py_as_polars(&self) -> PyResult<polars_python::dataframe::PyDataFrame> {
63        let df = self
64            .as_polars()
65            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
66        Ok(df.into())
67    }
68
69    pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
70        let df: DataFrame = pyo3::Python::with_gil(|py| {
71            let self_ = self
72                .clone()
73                .into_pyobject(py)
74                .expect("BUG: cannot convert to PyObject");
75            let result = function.call1((self_,)).expect("BUG: cannot call function");
76            result
77                .extract::<Bound<DataFrame>>()
78                .expect("BUG: cannot extract data frame")
79                .unbind()
80                .extract(py)
81                .expect("BUG: cannot extract data frame")
82        });
83        self.dataframe = df.dataframe;
84        Ok(())
85    }
86
87    #[pyo3(signature = (keys=None, transposed=None))]
88    pub fn as_numpy_u32<'py>(
89        &self,
90        keys: Option<Vec<String>>,
91        transposed: Option<bool>,
92        py: Python<'py>,
93    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
94        let data = self
95            .select_data(keys, transposed)
96            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
97        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
98    }
99
100    #[pyo3(signature = (keys=None, transposed=None))]
101    pub fn as_numpy_u64<'py>(
102        &self,
103        keys: Option<Vec<String>>,
104        transposed: Option<bool>,
105        py: Python<'py>,
106    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
107        let data = self
108            .select_data(keys, transposed)
109            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
110        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
111    }
112
113    #[pyo3(signature = (keys=None, transposed=None))]
114    pub fn as_numpy_i32<'py>(
115        &self,
116        keys: Option<Vec<String>>,
117        transposed: Option<bool>,
118        py: Python<'py>,
119    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
120        let data = self
121            .select_data(keys, transposed)
122            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
123        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
124    }
125
126    #[pyo3(signature = (keys=None, transposed=None))]
127    pub fn as_numpy_i64<'py>(
128        &self,
129        keys: Option<Vec<String>>,
130        transposed: Option<bool>,
131        py: Python<'py>,
132    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
133        let data = self
134            .select_data(keys, transposed)
135            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
136        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
137    }
138
139    #[pyo3(signature = (keys=None, transposed=None))]
140    pub fn as_numpy_f32<'py>(
141        &self,
142        keys: Option<Vec<String>>,
143        transposed: Option<bool>,
144        py: Python<'py>,
145    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
146        let data = self
147            .select_data(keys, transposed)
148            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
149        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
150    }
151
152    #[pyo3(signature = (keys=None, transposed=None))]
153    pub fn as_numpy_f64<'py>(
154        &self,
155        keys: Option<Vec<String>>,
156        transposed: Option<bool>,
157        py: Python<'py>,
158    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
159        let data = self
160            .select_data(keys, transposed)
161            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
162        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
163    }
164
165    #[pyo3(name = "shrink")]
166    pub fn py_shrink(&mut self) {
167        self.dataframe.shrink();
168    }
169
170    #[pyo3(name = "add_metadata")]
171    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
172        self.metadata.insert(key, value);
173    }
174
175    #[pyo3(name = "get_metadata")]
176    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
177        self.metadata.get(key).cloned()
178    }
179
180    #[pyo3(name = "rename_key")]
181    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
182        // fixme this may have a problem when the type is different and checked
183        self.dataframe
184            .rename_key(key, new_name.into())
185            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
186    }
187
188    #[pyo3(name = "add_alias")]
189    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
190        self.dataframe
191            .add_alias(key, new_name)
192            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
193    }
194
195    #[pyo3(name = "select", signature = (keys=None))]
196    pub fn py_select<'py>(
197        &self,
198        py: Python<'py>,
199        keys: Option<Vec<String>>,
200    ) -> Result<Bound<'py, PyList>, PyErr> {
201        let keys = keys
202            .unwrap_or(self.keys())
203            .into_iter()
204            .map(Key::from)
205            .collect::<Vec<Key>>();
206        let selected = self
207            .select(Some(keys.as_slice()))
208            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
209
210        let list = PyList::empty(py);
211        for rows in selected.rows() {
212            let row = PyList::empty(py);
213            for value in rows.iter() {
214                row.append(value.clone())
215                    .expect("BUG: cannot append to list");
216            }
217            list.append(row).expect("BUG: cannot append to list");
218        }
219        Ok(list)
220    }
221
222    #[pyo3(name = "select_transposed", signature = (keys=None))]
223    pub fn py_select_transposed<'py>(
224        &self,
225        py: Python<'py>,
226        keys: Option<Vec<String>>,
227    ) -> Result<Bound<'py, PyList>, PyErr> {
228        let keys = keys
229            .unwrap_or(self.keys())
230            .into_iter()
231            .map(Key::from)
232            .collect::<Vec<Key>>();
233        let selected = self
234            .select_transposed(Some(keys.as_slice()))
235            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
236
237        let list = PyList::empty(py);
238        for rows in selected.rows() {
239            let row = PyList::empty(py);
240            for value in rows.iter() {
241                row.append(value.clone())?;
242            }
243            list.append(row)?;
244        }
245        Ok(list)
246    }
247
248    #[pyo3(name = "select_column")]
249    pub fn py_select_column<'py>(
250        &self,
251        py: Python<'py>,
252        key: String,
253    ) -> Result<Bound<'py, PyList>, PyErr> {
254        let selected = self
255            .select_column(Key::from(key))
256            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
257
258        let list = PyList::empty(py);
259        for x in selected.to_vec().into_iter() {
260            list.append(x)?;
261        }
262
263        Ok(list)
264    }
265
266    #[pyo3(name = "join")]
267    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
268        self.dataframe
269            .join(other.dataframe, &join_type)
270            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
271
272        Ok(())
273    }
274
275    #[pyo3(name = "push")]
276    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
277        self.dataframe
278            .push(data)
279            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
280        Ok(())
281    }
282
283    #[pyo3(name = "add_column")]
284    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
285        self.dataframe
286            .add_single_column(key, Array1::from_vec(data))
287            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
288        Ok(())
289    }
290
291    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
292        self.constants.insert(key, feature);
293        Ok(())
294    }
295
296    fn __repr__(&self) -> String {
297        self.to_string()
298    }
299
300    fn __str__(&self) -> String {
301        self.to_string()
302    }
303
304    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
305        trace!("{object:?}");
306        let df_or_dict = DfOrDict::new(object)?;
307        match df_or_dict {
308            DfOrDict::DataFrame(df) => {
309                self.dataframe += df.dataframe;
310            }
311            DfOrDict::Dict(dict) => {
312                self.dataframe += dict;
313            }
314        }
315        Ok(())
316    }
317
318    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
319        trace!("{object:?}");
320
321        let df_or_dict = DfOrDict::new(object)?;
322        match df_or_dict {
323            DfOrDict::DataFrame(df) => {
324                self.dataframe -= df.dataframe;
325            }
326            DfOrDict::Dict(dict) => {
327                self.dataframe -= dict;
328            }
329        }
330        Ok(())
331    }
332
333    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
334        trace!("{object:?}");
335        let df_or_dict = DfOrDict::new(object)?;
336        match df_or_dict {
337            DfOrDict::DataFrame(df) => {
338                self.dataframe *= df.dataframe;
339            }
340            DfOrDict::Dict(dict) => {
341                self.dataframe *= dict;
342            }
343        }
344        Ok(())
345    }
346
347    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
348        trace!("{object:?}");
349        let df_or_dict = DfOrDict::new(object)?;
350        match df_or_dict {
351            DfOrDict::DataFrame(df) => {
352                self.dataframe /= df.dataframe;
353            }
354            DfOrDict::Dict(dict) => {
355                self.dataframe /= dict;
356            }
357        }
358        Ok(())
359    }
360
361    pub fn __len__(&mut self) -> Result<usize, PyErr> {
362        Ok(self.dataframe.len())
363    }
364}
365
366#[cfg(test)]
367mod test {
368
369    use super::*;
370    use data_value::{stdhashmap, DataValue};
371    use halfbrown::hashmap;
372    use pyo3::ffi::c_str;
373    use rstest::*;
374    use tracing_test::traced_test;
375
376    #[fixture]
377    fn df() -> DataFrame {
378        let mut df = DataFrame::init();
379        assert!(df
380            .push(hashmap! {
381                Key::from("key1") => DataValue::U32(1),
382                Key::from("key2") => DataValue::U32(2),
383            })
384            .is_ok());
385        assert!(df
386            .push(hashmap! {
387                Key::from("key1") => DataValue::U32(11),
388                Key::from("key2") => DataValue::U32(21),
389            })
390            .is_ok());
391        df
392    }
393
394    #[fixture]
395    fn hm() -> HashMap<String, DataValue> {
396        stdhashmap!(
397            "key1".to_string() => DataValue::U32(2),
398            "key2".to_string() => DataValue::U32(3),
399        )
400    }
401
402    #[rstest]
403    fn test_select_data(df: DataFrame) {
404        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
405        assert!(data.is_ok());
406        assert_eq!(
407            data.unwrap(),
408            ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
409        );
410
411        let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
412        assert!(data.is_ok());
413        assert_eq!(
414            data.unwrap(),
415            ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
416        );
417    }
418
419    #[rstest]
420    #[traced_test]
421    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
422        let mut df_expect = df.clone();
423        let df2 = df.clone();
424        let exec = Python::with_gil(|py| -> PyResult<()> {
425            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
426            df_expect.dataframe += df2.dataframe;
427            tracing::trace!("{} vs {}", df, df_expect);
428            assert_eq!(df.dataframe, df_expect.dataframe);
429
430            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
431            df_expect.dataframe += hm;
432            tracing::trace!("{} vs {}", df, df_expect);
433            assert_eq!(df.dataframe, df_expect.dataframe);
434
435            Ok(())
436        });
437
438        assert!(exec.is_ok(), "{:?}", exec);
439    }
440
441    #[rstest]
442    #[traced_test]
443    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
444        let mut df_expect = df.clone();
445        let df2 = df.clone();
446        let exec = Python::with_gil(|py| -> PyResult<()> {
447            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
448            df_expect.dataframe -= df2.dataframe;
449            tracing::trace!("{} vs {}", df, df_expect);
450            assert_eq!(df.dataframe, df_expect.dataframe);
451
452            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
453            df_expect.dataframe -= hm;
454            tracing::trace!("{} vs {}", df, df_expect);
455            assert_eq!(df.dataframe, df_expect.dataframe);
456
457            Ok(())
458        });
459
460        assert!(exec.is_ok(), "{:?}", exec);
461    }
462
463    #[rstest]
464    #[traced_test]
465    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
466        let mut df_expect = df.clone();
467        let df2 = df.clone();
468        let exec = Python::with_gil(|py| -> PyResult<()> {
469            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
470            df_expect.dataframe *= df2.dataframe;
471            tracing::trace!("{} vs {}", df, df_expect);
472            assert_eq!(df.dataframe, df_expect.dataframe);
473
474            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
475            df_expect.dataframe *= hm;
476            tracing::trace!("{} vs {}", df, df_expect);
477            assert_eq!(df.dataframe, df_expect.dataframe);
478            Ok(())
479        });
480
481        assert!(exec.is_ok(), "{:?}", exec);
482    }
483
484    #[rstest]
485    #[traced_test]
486    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
487        let mut df_expect = df.clone();
488        let df2 = df.clone();
489        let exec = Python::with_gil(|py| -> PyResult<()> {
490            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
491            df_expect.dataframe /= df2.dataframe;
492            tracing::trace!("{} vs {}", df, df_expect);
493            assert_eq!(df.dataframe, df_expect.dataframe);
494
495            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
496            df_expect.dataframe /= hm;
497            tracing::trace!("{} vs {}", df, df_expect);
498            assert_eq!(df.dataframe, df_expect.dataframe);
499            Ok(())
500        });
501
502        assert!(exec.is_ok(), "{:?}", exec);
503    }
504
505    #[rstest]
506    #[traced_test]
507    #[rstest]
508    fn test_numpy(mut df: DataFrame) {
509        let exec = Python::with_gil(|py| -> PyResult<()> {
510            let code = c_str!(
511                r#"
512def example(df):
513    import numpy as np
514    a_np = df.as_numpy_f32(['key1', 'key2'])
515    print(a_np)
516    b_np = df.as_numpy_u32(['key1', 'key'])
517    print(b_np)
518    b_np = df.as_numpy_i32(['key1', 'key'])
519    print(b_np)
520    b_np = df.as_numpy_i64(['key1', 'key'])
521    print(b_np)
522    b_np = df.as_numpy_u64(['key1', 'key'])
523    print(b_np)
524    b_np = df.as_numpy_f64(['key1', 'key'])
525    print(b_np)
526    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
527    print(b_np)
528    return df
529            "#
530            );
531            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
532                .getattr("example")?
533                .into();
534            let result = fun.call1(py, (df.clone(),));
535            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
536            // user may not have installed polars, we need to get an error in that
537            // case
538            if py.import("numpy").is_ok() {
539                assert!(result.is_ok(), "{:?}", result);
540            } else {
541                assert!(result.is_err(), "{:?}", result);
542            }
543            Ok(())
544        });
545        assert!(exec.is_ok(), "{:?}", exec);
546    }
547
548    #[rstest]
549    #[traced_test]
550    #[rstest]
551    fn test_fill_from_python(df: DataFrame) {
552        let exec = Python::with_gil(|_py| -> PyResult<()> {
553            let hm = stdhashmap!(
554                Key::from("key1") => DataValue::U32(1),
555                Key::from("key2") => DataValue::U32(2),
556            );
557            let mut df2 = DataFrame::init();
558            assert!(df2.py_push(hm).is_ok());
559            assert!(df2
560                .py_push(stdhashmap!(
561                    Key::from("key1") => DataValue::U32(11),
562                    Key::from("key2") => DataValue::U32(21),
563                ))
564                .is_ok());
565
566            assert_eq!(df, df2);
567
568            let mut df2 = DataFrame::init();
569            assert!(df2
570                .py_add_column(
571                    Key::from("key1"),
572                    vec![DataValue::U32(1), DataValue::U32(11)]
573                )
574                .is_ok());
575            assert!(df2
576                .py_add_column(
577                    Key::from("key2"),
578                    vec![DataValue::U32(2), DataValue::U32(21)]
579                )
580                .is_ok());
581
582            assert_eq!(df, df2);
583            Ok(())
584        });
585        assert!(exec.is_ok(), "{:?}", exec);
586    }
587
588    #[rstest]
589    fn basic_python_dataframe(mut df: DataFrame) {
590        let exec = Python::with_gil(|py| -> PyResult<()> {
591            let fun: Py<PyAny> = PyModule::from_code(
592                py,
593                c_str!(
594                    "
595def example(df):
596    print(df)
597    df.shrink()
598    assert len(df) == 2
599    df.add_alias('key1', 'key1-alias')
600    a = df.select(['key1', 'key2'])
601    print(a)
602    b = df.select(['key1-alias', 'key2'])
603    print(b)
604    df.rename_key('key1', 'key1new')
605    df.rename_key('key1new', 'key1')
606    assert a == [[1, 2], [11, 21]]
607    assert a == b
608    df.add_metadata('test', 1)
609    m = df.get_metadata('test')
610    assert m == 1
611    b = df.select_transposed(['key1', 'key2'])
612    print(b)
613    assert b == [[1, 11], [2, 21]]
614    c = df.select_column('key1')
615    print(c)
616    assert c == [1, 11]
617
618    a += b
619    print(a)
620    assert a == [[2, 13], [4, 23]]
621    a -= b
622    print(a)
623    assert e == a
624    f = e * b
625    print(f)
626    assert f == [[1, 22], [44, 441]]
627    g = f / b
628    print(g)
629    assert g == e
630
631                "
632                ),
633                c_str!(""),
634                c_str!(""),
635            )?
636            .getattr("example")?
637            .into();
638            let _ = fun.call1(py, (df.clone(),));
639            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
640            Ok(())
641        });
642        assert!(exec.is_ok(), "{:?}", exec);
643    }
644
645    #[rstest]
646    fn dummy_test_apply(mut df: DataFrame) {
647        let exec = Python::with_gil(|py| -> PyResult<()> {
648            let fun: Py<PyAny> = PyModule::from_code(
649                py,
650                c_str!(
651                    r#"
652def multiply_by_ten(x):
653    print(x)
654    x *= {"key1": 10}
655    print(x)
656    return x
657
658def example(df):
659    print(df)
660    df.apply(multiply_by_ten)
661                "#
662                ),
663                c_str!(""),
664                c_str!(""),
665            )?
666            .getattr("example")?
667            .into();
668            let _ = fun.call1(py, (df.clone(),));
669            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
670            Ok(())
671        });
672        assert!(exec.is_ok(), "{:?}", exec);
673    }
674}