trs_dataframe/dataframe/
python.rs

1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::PyArray2;
7use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
8use tracing::trace;
9
10impl DataFrame {
11    fn select_data(
12        &self,
13        keys: Option<Vec<String>>,
14        transposed: Option<bool>,
15    ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
16        let keys = keys
17            .unwrap_or(self.keys())
18            .into_iter()
19            .map(Key::from)
20            .collect::<Vec<Key>>();
21        if transposed.unwrap_or(false) {
22            self.select(Some(keys.as_slice()))
23        } else {
24            self.select_transposed(Some(keys.as_slice()))
25        }
26    }
27}
28
29enum DfOrDict {
30    DataFrame(DataFrame),
31    Dict(HashMap<String, DataValue>),
32}
33
34impl DfOrDict {
35    pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
36        if let Ok(df) = object.extract::<DataFrame>() {
37            Ok(DfOrDict::DataFrame(df))
38        } else {
39            let dict: HashMap<String, DataValue> = object.extract()?;
40            Ok(DfOrDict::Dict(dict))
41        }
42    }
43}
44
45#[pymethods]
46impl DataFrame {
47    #[new]
48    pub fn init() -> Self {
49        Self::default()
50    }
51
52    pub fn keys(&self) -> Vec<String> {
53        self.dataframe
54            .keys()
55            .iter()
56            .map(|x| x.name().to_string())
57            .collect()
58    }
59
60    #[pyo3(signature = (keys=None, transposed=None))]
61    pub fn as_numpy_u32<'py>(
62        &self,
63        keys: Option<Vec<String>>,
64        transposed: Option<bool>,
65        py: Python<'py>,
66    ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
67        let data = self
68            .select_data(keys, transposed)
69            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
70        Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
71    }
72
73    #[pyo3(signature = (keys=None, transposed=None))]
74    pub fn as_numpy_u64<'py>(
75        &self,
76        keys: Option<Vec<String>>,
77        transposed: Option<bool>,
78        py: Python<'py>,
79    ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
80        let data = self
81            .select_data(keys, transposed)
82            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
83        Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
84    }
85
86    #[pyo3(signature = (keys=None, transposed=None))]
87    pub fn as_numpy_i32<'py>(
88        &self,
89        keys: Option<Vec<String>>,
90        transposed: Option<bool>,
91        py: Python<'py>,
92    ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
93        let data = self
94            .select_data(keys, transposed)
95            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
96        Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
97    }
98
99    #[pyo3(signature = (keys=None, transposed=None))]
100    pub fn as_numpy_i64<'py>(
101        &self,
102        keys: Option<Vec<String>>,
103        transposed: Option<bool>,
104        py: Python<'py>,
105    ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
106        let data = self
107            .select_data(keys, transposed)
108            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
109        Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
110    }
111
112    #[pyo3(signature = (keys=None, transposed=None))]
113    pub fn as_numpy_f32<'py>(
114        &self,
115        keys: Option<Vec<String>>,
116        transposed: Option<bool>,
117        py: Python<'py>,
118    ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
119        let data = self
120            .select_data(keys, transposed)
121            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
122        Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
123    }
124
125    #[pyo3(signature = (keys=None, transposed=None))]
126    pub fn as_numpy_f64<'py>(
127        &self,
128        keys: Option<Vec<String>>,
129        transposed: Option<bool>,
130        py: Python<'py>,
131    ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
132        let data = self
133            .select_data(keys, transposed)
134            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
135        Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
136    }
137
138    #[pyo3(name = "shrink")]
139    pub fn py_shrink(&mut self) {
140        self.dataframe.shrink();
141    }
142
143    #[pyo3(name = "add_metadata")]
144    pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
145        self.metadata.insert(key, value);
146    }
147
148    #[pyo3(name = "get_metadata")]
149    pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
150        self.metadata.get(key).cloned()
151    }
152
153    #[pyo3(name = "rename_key")]
154    pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
155        // fixme this may have a problem when the type is different and checked
156        self.dataframe
157            .rename_key(key, new_name.into())
158            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
159    }
160
161    #[pyo3(name = "add_alias")]
162    pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
163        self.dataframe
164            .add_alias(key, new_name)
165            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
166    }
167
168    #[pyo3(name = "select", signature = (keys=None))]
169    pub fn py_select<'py>(
170        &self,
171        py: Python<'py>,
172        keys: Option<Vec<String>>,
173    ) -> Result<Bound<'py, PyList>, PyErr> {
174        let keys = keys
175            .unwrap_or(self.keys())
176            .into_iter()
177            .map(Key::from)
178            .collect::<Vec<Key>>();
179        let selected = self
180            .select(Some(keys.as_slice()))
181            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
182
183        let list = PyList::empty(py);
184        for rows in selected.rows() {
185            let row = PyList::empty(py);
186            for value in rows.iter() {
187                row.append(value.clone())
188                    .expect("BUG: cannot append to list");
189            }
190            list.append(row).expect("BUG: cannot append to list");
191        }
192        Ok(list)
193    }
194
195    #[pyo3(name = "select_transposed", signature = (keys=None))]
196    pub fn py_select_transposed<'py>(
197        &self,
198        py: Python<'py>,
199        keys: Option<Vec<String>>,
200    ) -> Result<Bound<'py, PyList>, PyErr> {
201        let keys = keys
202            .unwrap_or(self.keys())
203            .into_iter()
204            .map(Key::from)
205            .collect::<Vec<Key>>();
206        let selected = self
207            .select_transposed(Some(keys.as_slice()))
208            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
209
210        let list = PyList::empty(py);
211        for rows in selected.rows() {
212            let row = PyList::empty(py);
213            for value in rows.iter() {
214                row.append(value.clone())?;
215            }
216            list.append(row)?;
217        }
218        Ok(list)
219    }
220
221    #[pyo3(name = "select_column")]
222    pub fn py_select_column<'py>(
223        &self,
224        py: Python<'py>,
225        key: String,
226    ) -> Result<Bound<'py, PyList>, PyErr> {
227        let selected = self
228            .select_column(Key::from(key))
229            .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
230
231        let list = PyList::empty(py);
232        for x in selected.to_vec().into_iter() {
233            list.append(x)?;
234        }
235
236        Ok(list)
237    }
238
239    #[pyo3(name = "join")]
240    pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
241        self.dataframe
242            .join(other.dataframe, &join_type)
243            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
244
245        Ok(())
246    }
247
248    #[pyo3(name = "push")]
249    pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
250        self.dataframe
251            .push(data)
252            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
253        Ok(())
254    }
255
256    #[pyo3(name = "add_column")]
257    pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
258        self.dataframe
259            .add_single_column(key, Array1::from_vec(data))
260            .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
261        Ok(())
262    }
263
264    pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
265        self.constants.insert(key, feature);
266        Ok(())
267    }
268
269    fn __repr__(&self) -> String {
270        self.to_string()
271    }
272
273    fn __str__(&self) -> String {
274        self.to_string()
275    }
276
277    pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
278        trace!("{object:?}");
279        let df_or_dict = DfOrDict::new(object)?;
280        match df_or_dict {
281            DfOrDict::DataFrame(df) => {
282                self.dataframe += df.dataframe;
283            }
284            DfOrDict::Dict(dict) => {
285                self.dataframe += dict;
286            }
287        }
288        Ok(())
289    }
290
291    pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
292        trace!("{object:?}");
293
294        let df_or_dict = DfOrDict::new(object)?;
295        match df_or_dict {
296            DfOrDict::DataFrame(df) => {
297                self.dataframe -= df.dataframe;
298            }
299            DfOrDict::Dict(dict) => {
300                self.dataframe -= dict;
301            }
302        }
303        Ok(())
304    }
305
306    pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
307        trace!("{object:?}");
308        let df_or_dict = DfOrDict::new(object)?;
309        match df_or_dict {
310            DfOrDict::DataFrame(df) => {
311                self.dataframe *= df.dataframe;
312            }
313            DfOrDict::Dict(dict) => {
314                self.dataframe *= dict;
315            }
316        }
317        Ok(())
318    }
319
320    pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
321        trace!("{object:?}");
322        let df_or_dict = DfOrDict::new(object)?;
323        match df_or_dict {
324            DfOrDict::DataFrame(df) => {
325                self.dataframe /= df.dataframe;
326            }
327            DfOrDict::Dict(dict) => {
328                self.dataframe /= dict;
329            }
330        }
331        Ok(())
332    }
333
334    pub fn __len__(&mut self) -> Result<usize, PyErr> {
335        Ok(self.dataframe.len())
336    }
337}
338
339#[cfg(test)]
340mod test {
341
342    use super::*;
343    use data_value::{stdhashmap, DataValue};
344    use halfbrown::hashmap;
345    use pyo3::ffi::c_str;
346    use rstest::*;
347    use tracing_test::traced_test;
348
349    #[fixture]
350    fn df() -> DataFrame {
351        let mut df = DataFrame::init();
352        assert!(df
353            .push(hashmap! {
354                Key::from("key1") => DataValue::U32(1),
355                Key::from("key2") => DataValue::U32(2),
356            })
357            .is_ok());
358        assert!(df
359            .push(hashmap! {
360                Key::from("key1") => DataValue::U32(11),
361                Key::from("key2") => DataValue::U32(21),
362            })
363            .is_ok());
364        df
365    }
366
367    #[fixture]
368    fn hm() -> HashMap<String, DataValue> {
369        stdhashmap!(
370            "key1".to_string() => DataValue::U32(2),
371            "key2".to_string() => DataValue::U32(3),
372        )
373    }
374
375    #[rstest]
376    #[traced_test]
377    fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
378        let mut df_expect = df.clone();
379        let df2 = df.clone();
380        let exec = Python::with_gil(|py| -> PyResult<()> {
381            df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
382            df_expect.dataframe += df2.dataframe;
383            tracing::trace!("{} vs {}", df, df_expect);
384            assert_eq!(df.dataframe, df_expect.dataframe);
385
386            df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
387            df_expect.dataframe += hm;
388            tracing::trace!("{} vs {}", df, df_expect);
389            assert_eq!(df.dataframe, df_expect.dataframe);
390
391            Ok(())
392        });
393
394        assert!(exec.is_ok(), "{:?}", exec);
395    }
396
397    #[rstest]
398    #[traced_test]
399    fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
400        let mut df_expect = df.clone();
401        let df2 = df.clone();
402        let exec = Python::with_gil(|py| -> PyResult<()> {
403            df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
404            df_expect.dataframe -= df2.dataframe;
405            tracing::trace!("{} vs {}", df, df_expect);
406            assert_eq!(df.dataframe, df_expect.dataframe);
407
408            df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
409            df_expect.dataframe -= hm;
410            tracing::trace!("{} vs {}", df, df_expect);
411            assert_eq!(df.dataframe, df_expect.dataframe);
412
413            Ok(())
414        });
415
416        assert!(exec.is_ok(), "{:?}", exec);
417    }
418
419    #[rstest]
420    #[traced_test]
421    fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
422        let mut df_expect = df.clone();
423        let df2 = df.clone();
424        let exec = Python::with_gil(|py| -> PyResult<()> {
425            df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
426            df_expect.dataframe *= df2.dataframe;
427            tracing::trace!("{} vs {}", df, df_expect);
428            assert_eq!(df.dataframe, df_expect.dataframe);
429
430            df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
431            df_expect.dataframe *= hm;
432            tracing::trace!("{} vs {}", df, df_expect);
433            assert_eq!(df.dataframe, df_expect.dataframe);
434            Ok(())
435        });
436
437        assert!(exec.is_ok(), "{:?}", exec);
438    }
439
440    #[rstest]
441    #[traced_test]
442    fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
443        let mut df_expect = df.clone();
444        let df2 = df.clone();
445        let exec = Python::with_gil(|py| -> PyResult<()> {
446            df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
447            df_expect.dataframe /= df2.dataframe;
448            tracing::trace!("{} vs {}", df, df_expect);
449            assert_eq!(df.dataframe, df_expect.dataframe);
450
451            df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
452            df_expect.dataframe /= hm;
453            tracing::trace!("{} vs {}", df, df_expect);
454            assert_eq!(df.dataframe, df_expect.dataframe);
455            Ok(())
456        });
457
458        assert!(exec.is_ok(), "{:?}", exec);
459    }
460
461    #[rstest]
462    #[traced_test]
463    #[rstest]
464    fn test_numpy(mut df: DataFrame) {
465        let exec = Python::with_gil(|py| -> PyResult<()> {
466            let code = c_str!(
467                r#"
468def example(df):
469    import numpy as np
470    a_np = df.as_numpy_f32(['key1', 'key2'])
471    print(a_np)
472    b_np = df.as_numpy_u32(['key1', 'key'])
473    print(b_np)
474    b_np = df.as_numpy_i32(['key1', 'key'])
475    print(b_np)
476    b_np = df.as_numpy_i64(['key1', 'key'])
477    print(b_np)
478    b_np = df.as_numpy_u64(['key1', 'key'])
479    print(b_np)
480    b_np = df.as_numpy_f64(['key1', 'key'])
481    print(b_np)
482    b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
483    print(b_np)
484    return df
485            "#
486            );
487            let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
488                .getattr("example")?
489                .into();
490            let result = fun.call1(py, (df.clone(),));
491            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
492            // user may not have installed polars, we need to get an error in that
493            // case
494            if py.import("numpy").is_ok() {
495                assert!(result.is_ok(), "{:?}", result);
496            } else {
497                assert!(result.is_err(), "{:?}", result);
498            }
499            Ok(())
500        });
501        assert!(exec.is_ok(), "{:?}", exec);
502    }
503
504    #[rstest]
505    #[traced_test]
506    #[rstest]
507    fn test_fill_from_python(df: DataFrame) {
508        let exec = Python::with_gil(|_py| -> PyResult<()> {
509            let hm = stdhashmap!(
510                Key::from("key1") => DataValue::U32(1),
511                Key::from("key2") => DataValue::U32(2),
512            );
513            let mut df2 = DataFrame::init();
514            assert!(df2.py_push(hm).is_ok());
515            assert!(df2
516                .py_push(stdhashmap!(
517                    Key::from("key1") => DataValue::U32(11),
518                    Key::from("key2") => DataValue::U32(21),
519                ))
520                .is_ok());
521
522            assert_eq!(df, df2);
523
524            let mut df2 = DataFrame::init();
525            assert!(df2
526                .py_add_column(
527                    Key::from("key1"),
528                    vec![DataValue::U32(1), DataValue::U32(11)]
529                )
530                .is_ok());
531            assert!(df2
532                .py_add_column(
533                    Key::from("key2"),
534                    vec![DataValue::U32(2), DataValue::U32(21)]
535                )
536                .is_ok());
537
538            assert_eq!(df, df2);
539            Ok(())
540        });
541        assert!(exec.is_ok(), "{:?}", exec);
542    }
543
544    #[rstest]
545    fn basic_python_dataframe(mut df: DataFrame) {
546        let exec = Python::with_gil(|py| -> PyResult<()> {
547            let fun: Py<PyAny> = PyModule::from_code(
548                py,
549                c_str!(
550                    "
551def example(df):
552    print(df)
553    df.shrink()
554    assert len(df) == 2
555    df.add_alias('key1', 'key1-alias')
556    a = df.select(['key1', 'key2'])
557    print(a)
558    b = df.select(['key1-alias', 'key2'])
559    print(b)
560    df.rename_key('key1', 'key1new')
561    df.rename_key('key1new', 'key1')
562    assert a == [[1, 2], [11, 21]]
563    assert a == b
564    df.add_metadata('test', 1)
565    m = df.get_metadata('test')
566    assert m == 1
567    b = df.select_transposed(['key1', 'key2'])
568    print(b)
569    assert b == [[1, 11], [2, 21]]
570    c = df.select_column('key1')
571    print(c)
572    assert c == [1, 11]
573
574    a += b
575    print(a)
576    assert a == [[2, 13], [4, 23]]
577    a -= b
578    print(a)
579    assert e == a
580    f = e * b
581    print(f)
582    assert f == [[1, 22], [44, 441]]
583    g = f / b
584    print(g)
585    assert g == e
586
587                "
588                ),
589                c_str!(""),
590                c_str!(""),
591            )?
592            .getattr("example")?
593            .into();
594            let _ = fun.call1(py, (df.clone(),));
595            assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
596            Ok(())
597        });
598        assert!(exec.is_ok(), "{:?}", exec);
599    }
600}