1use std::collections::HashMap;
2
3use crate::{DataFrame, DataValue, JoinRelation, Key};
4use data_value::Extract as _;
5use ndarray::Array1;
6use numpy::PyArray2;
7use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
8use tracing::trace;
9
10impl DataFrame {
11 fn select_data(
12 &self,
13 keys: Option<Vec<String>>,
14 transposed: Option<bool>,
15 ) -> Result<ndarray::Array2<DataValue>, crate::error::Error> {
16 let keys = keys
17 .unwrap_or(self.keys())
18 .into_iter()
19 .map(Key::from)
20 .collect::<Vec<Key>>();
21 if transposed.unwrap_or(false) {
22 self.select(Some(keys.as_slice()))
23 } else {
24 self.select_transposed(Some(keys.as_slice()))
25 }
26 }
27}
28
29enum DfOrDict {
30 DataFrame(DataFrame),
31 Dict(HashMap<String, DataValue>),
32}
33
34impl DfOrDict {
35 pub fn new(object: Bound<'_, PyAny>) -> Result<DfOrDict, PyErr> {
36 if let Ok(df) = object.extract::<DataFrame>() {
37 Ok(DfOrDict::DataFrame(df))
38 } else {
39 let dict: HashMap<String, DataValue> = object.extract()?;
40 Ok(DfOrDict::Dict(dict))
41 }
42 }
43}
44
45#[pymethods]
46impl DataFrame {
47 #[new]
48 pub fn init() -> Self {
49 Self::default()
50 }
51
52 pub fn keys(&self) -> Vec<String> {
53 self.dataframe
54 .keys()
55 .iter()
56 .map(|x| x.name().to_string())
57 .collect()
58 }
59
60 #[cfg(feature = "polars-df")]
61 #[pyo3(name = "as_polars")]
62 pub fn py_as_polars(&self) -> PyResult<polars_python::dataframe::PyDataFrame> {
63 let df = self
64 .as_polars()
65 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot prepare polars DF: {e}")))?;
66 Ok(df.into())
67 }
68
69 pub fn apply(&mut self, function: Bound<'_, PyAny>) -> Result<(), PyErr> {
70 let df: DataFrame = pyo3::Python::with_gil(|py| {
71 let self_ = self
72 .clone()
73 .into_pyobject(py)
74 .expect("BUG: cannot convert to PyObject");
75 let result = function.call1((self_,)).expect("BUG: cannot call function");
76 result
77 .extract::<Bound<DataFrame>>()
78 .expect("BUG: cannot extract data frame")
79 .unbind()
80 .extract(py)
81 .expect("BUG: cannot extract data frame")
82 });
83 self.dataframe = df.dataframe;
84 Ok(())
85 }
86
87 #[pyo3(signature = (keys=None, transposed=None))]
88 pub fn as_numpy_u32<'py>(
89 &self,
90 keys: Option<Vec<String>>,
91 transposed: Option<bool>,
92 py: Python<'py>,
93 ) -> PyResult<Bound<'py, numpy::PyArray2<u32>>> {
94 let data = self
95 .select_data(keys, transposed)
96 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
97 Ok(PyArray2::from_array(py, &data.mapv(|x| u32::extract(&x))))
98 }
99
100 #[pyo3(signature = (keys=None, transposed=None))]
101 pub fn as_numpy_u64<'py>(
102 &self,
103 keys: Option<Vec<String>>,
104 transposed: Option<bool>,
105 py: Python<'py>,
106 ) -> PyResult<Bound<'py, numpy::PyArray2<u64>>> {
107 let data = self
108 .select_data(keys, transposed)
109 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
110 Ok(PyArray2::from_array(py, &data.mapv(|x| u64::extract(&x))))
111 }
112
113 #[pyo3(signature = (keys=None, transposed=None))]
114 pub fn as_numpy_i32<'py>(
115 &self,
116 keys: Option<Vec<String>>,
117 transposed: Option<bool>,
118 py: Python<'py>,
119 ) -> PyResult<Bound<'py, numpy::PyArray2<i32>>> {
120 let data = self
121 .select_data(keys, transposed)
122 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
123 Ok(PyArray2::from_array(py, &data.mapv(|x| i32::extract(&x))))
124 }
125
126 #[pyo3(signature = (keys=None, transposed=None))]
127 pub fn as_numpy_i64<'py>(
128 &self,
129 keys: Option<Vec<String>>,
130 transposed: Option<bool>,
131 py: Python<'py>,
132 ) -> PyResult<Bound<'py, numpy::PyArray2<i64>>> {
133 let data = self
134 .select_data(keys, transposed)
135 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
136 Ok(PyArray2::from_array(py, &data.mapv(|x| i64::extract(&x))))
137 }
138
139 #[pyo3(signature = (keys=None, transposed=None))]
140 pub fn as_numpy_f32<'py>(
141 &self,
142 keys: Option<Vec<String>>,
143 transposed: Option<bool>,
144 py: Python<'py>,
145 ) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
146 let data = self
147 .select_data(keys, transposed)
148 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
149 Ok(PyArray2::from_array(py, &data.mapv(|x| f32::extract(&x))))
150 }
151
152 #[pyo3(signature = (keys=None, transposed=None))]
153 pub fn as_numpy_f64<'py>(
154 &self,
155 keys: Option<Vec<String>>,
156 transposed: Option<bool>,
157 py: Python<'py>,
158 ) -> PyResult<Bound<'py, numpy::PyArray2<f64>>> {
159 let data = self
160 .select_data(keys, transposed)
161 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
162 Ok(PyArray2::from_array(py, &data.mapv(|x| f64::extract(&x))))
163 }
164
165 #[pyo3(name = "shrink")]
166 pub fn py_shrink(&mut self) {
167 self.dataframe.shrink();
168 }
169
170 #[pyo3(name = "add_metadata")]
171 pub fn py_add_metadata(&mut self, key: String, value: DataValue) {
172 self.metadata.insert(key, value);
173 }
174
175 #[pyo3(name = "get_metadata")]
176 pub fn py_get_metadata(&self, key: &str) -> Option<DataValue> {
177 self.metadata.get(key).cloned()
178 }
179
180 #[pyo3(name = "rename_key")]
181 pub fn py_rename_key(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
182 self.dataframe
184 .rename_key(key, new_name.into())
185 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
186 }
187
188 #[pyo3(name = "add_alias")]
189 pub fn py_add_alias(&mut self, key: &str, new_name: &str) -> Result<(), PyErr> {
190 self.dataframe
191 .add_alias(key, new_name)
192 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("{e}")))
193 }
194
195 #[pyo3(name = "select", signature = (keys=None))]
196 pub fn py_select<'py>(
197 &self,
198 py: Python<'py>,
199 keys: Option<Vec<String>>,
200 ) -> Result<Bound<'py, PyList>, PyErr> {
201 let keys = keys
202 .unwrap_or(self.keys())
203 .into_iter()
204 .map(Key::from)
205 .collect::<Vec<Key>>();
206 let selected = self
207 .select(Some(keys.as_slice()))
208 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
209
210 let list = PyList::empty(py);
211 for rows in selected.rows() {
212 let row = PyList::empty(py);
213 for value in rows.iter() {
214 row.append(value.clone())
215 .expect("BUG: cannot append to list");
216 }
217 list.append(row).expect("BUG: cannot append to list");
218 }
219 Ok(list)
220 }
221
222 #[pyo3(name = "select_transposed", signature = (keys=None))]
223 pub fn py_select_transposed<'py>(
224 &self,
225 py: Python<'py>,
226 keys: Option<Vec<String>>,
227 ) -> Result<Bound<'py, PyList>, PyErr> {
228 let keys = keys
229 .unwrap_or(self.keys())
230 .into_iter()
231 .map(Key::from)
232 .collect::<Vec<Key>>();
233 let selected = self
234 .select_transposed(Some(keys.as_slice()))
235 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot select data: {e}")))?;
236
237 let list = PyList::empty(py);
238 for rows in selected.rows() {
239 let row = PyList::empty(py);
240 for value in rows.iter() {
241 row.append(value.clone())?;
242 }
243 list.append(row)?;
244 }
245 Ok(list)
246 }
247
248 #[pyo3(name = "select_column")]
249 pub fn py_select_column<'py>(
250 &self,
251 py: Python<'py>,
252 key: String,
253 ) -> Result<Bound<'py, PyList>, PyErr> {
254 let selected = self
255 .select_column(Key::from(key))
256 .ok_or_else(|| PyErr::new::<PyTypeError, _>("Cannot select column"))?;
257
258 let list = PyList::empty(py);
259 for x in selected.to_vec().into_iter() {
260 list.append(x)?;
261 }
262
263 Ok(list)
264 }
265
266 #[pyo3(name = "join")]
267 pub fn py_join(&mut self, other: DataFrame, join_type: JoinRelation) -> Result<(), PyErr> {
268 self.dataframe
269 .join(other.dataframe, &join_type)
270 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
271
272 Ok(())
273 }
274
275 #[pyo3(name = "push")]
276 pub fn py_push(&mut self, data: HashMap<Key, DataValue>) -> Result<(), PyErr> {
277 self.dataframe
278 .push(data)
279 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
280 Ok(())
281 }
282
283 #[pyo3(name = "add_column")]
284 pub fn py_add_column(&mut self, key: Key, data: Vec<DataValue>) -> Result<(), PyErr> {
285 self.dataframe
286 .add_single_column(key, Array1::from_vec(data))
287 .map_err(|e| PyErr::new::<PyTypeError, _>(format!("Cannot join data: {e}")))?;
288 Ok(())
289 }
290
291 pub fn add_constant(&mut self, key: Key, feature: DataValue) -> Result<(), PyErr> {
292 self.constants.insert(key, feature);
293 Ok(())
294 }
295
296 fn __repr__(&self) -> String {
297 self.to_string()
298 }
299
300 fn __str__(&self) -> String {
301 self.to_string()
302 }
303
304 pub fn __iadd__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
305 trace!("{object:?}");
306 let df_or_dict = DfOrDict::new(object)?;
307 match df_or_dict {
308 DfOrDict::DataFrame(df) => {
309 self.dataframe += df.dataframe;
310 }
311 DfOrDict::Dict(dict) => {
312 self.dataframe += dict;
313 }
314 }
315 Ok(())
316 }
317
318 pub fn __isub__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
319 trace!("{object:?}");
320
321 let df_or_dict = DfOrDict::new(object)?;
322 match df_or_dict {
323 DfOrDict::DataFrame(df) => {
324 self.dataframe -= df.dataframe;
325 }
326 DfOrDict::Dict(dict) => {
327 self.dataframe -= dict;
328 }
329 }
330 Ok(())
331 }
332
333 pub fn __imul__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
334 trace!("{object:?}");
335 let df_or_dict = DfOrDict::new(object)?;
336 match df_or_dict {
337 DfOrDict::DataFrame(df) => {
338 self.dataframe *= df.dataframe;
339 }
340 DfOrDict::Dict(dict) => {
341 self.dataframe *= dict;
342 }
343 }
344 Ok(())
345 }
346
347 pub fn __itruediv__(&mut self, object: Bound<'_, PyAny>) -> Result<(), PyErr> {
348 trace!("{object:?}");
349 let df_or_dict = DfOrDict::new(object)?;
350 match df_or_dict {
351 DfOrDict::DataFrame(df) => {
352 self.dataframe /= df.dataframe;
353 }
354 DfOrDict::Dict(dict) => {
355 self.dataframe /= dict;
356 }
357 }
358 Ok(())
359 }
360
361 pub fn __len__(&mut self) -> Result<usize, PyErr> {
362 Ok(self.dataframe.len())
363 }
364}
365
366#[cfg(test)]
367mod test {
368
369 use super::*;
370 use data_value::{stdhashmap, DataValue};
371 use halfbrown::hashmap;
372 use pyo3::ffi::c_str;
373 use rstest::*;
374 use tracing_test::traced_test;
375
376 #[fixture]
377 fn df() -> DataFrame {
378 let mut df = DataFrame::init();
379 assert!(df
380 .push(hashmap! {
381 Key::from("key1") => DataValue::U32(1),
382 Key::from("key2") => DataValue::U32(2),
383 })
384 .is_ok());
385 assert!(df
386 .push(hashmap! {
387 Key::from("key1") => DataValue::U32(11),
388 Key::from("key2") => DataValue::U32(21),
389 })
390 .is_ok());
391 df
392 }
393
394 #[fixture]
395 fn hm() -> HashMap<String, DataValue> {
396 stdhashmap!(
397 "key1".to_string() => DataValue::U32(2),
398 "key2".to_string() => DataValue::U32(3),
399 )
400 }
401
402 #[rstest]
403 fn test_select_data(df: DataFrame) {
404 let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(false));
405 assert!(data.is_ok());
406 assert_eq!(
407 data.unwrap(),
408 ndarray::array![[1u32.into(), 11u32.into()], [2u32.into(), 21u32.into()]]
409 );
410
411 let data = df.select_data(Some(vec!["key1".into(), "key2".into()]), Some(true));
412 assert!(data.is_ok());
413 assert_eq!(
414 data.unwrap(),
415 ndarray::array![[1u32.into(), 2u32.into()], [11u32.into(), 21u32.into()]]
416 );
417 }
418
419 #[rstest]
420 #[traced_test]
421 fn basic_ops_add(mut df: DataFrame, hm: HashMap<String, DataValue>) {
422 let mut df_expect = df.clone();
423 let df2 = df.clone();
424 let exec = Python::with_gil(|py| -> PyResult<()> {
425 df.__iadd__(df.clone().into_pyobject(py)?.into_any())?;
426 df_expect.dataframe += df2.dataframe;
427 tracing::trace!("{} vs {}", df, df_expect);
428 assert_eq!(df.dataframe, df_expect.dataframe);
429
430 df.__iadd__(hm.clone().into_pyobject(py)?.into_any())?;
431 df_expect.dataframe += hm;
432 tracing::trace!("{} vs {}", df, df_expect);
433 assert_eq!(df.dataframe, df_expect.dataframe);
434
435 Ok(())
436 });
437
438 assert!(exec.is_ok(), "{:?}", exec);
439 }
440
441 #[rstest]
442 #[traced_test]
443 fn basic_ops_sub(mut df: DataFrame, hm: HashMap<String, DataValue>) {
444 let mut df_expect = df.clone();
445 let df2 = df.clone();
446 let exec = Python::with_gil(|py| -> PyResult<()> {
447 df.__isub__(df.clone().into_pyobject(py)?.into_any())?;
448 df_expect.dataframe -= df2.dataframe;
449 tracing::trace!("{} vs {}", df, df_expect);
450 assert_eq!(df.dataframe, df_expect.dataframe);
451
452 df.__isub__(hm.clone().into_pyobject(py)?.into_any())?;
453 df_expect.dataframe -= hm;
454 tracing::trace!("{} vs {}", df, df_expect);
455 assert_eq!(df.dataframe, df_expect.dataframe);
456
457 Ok(())
458 });
459
460 assert!(exec.is_ok(), "{:?}", exec);
461 }
462
463 #[rstest]
464 #[traced_test]
465 fn basic_ops_mul(mut df: DataFrame, hm: HashMap<String, DataValue>) {
466 let mut df_expect = df.clone();
467 let df2 = df.clone();
468 let exec = Python::with_gil(|py| -> PyResult<()> {
469 df.__imul__(df.clone().into_pyobject(py)?.into_any())?;
470 df_expect.dataframe *= df2.dataframe;
471 tracing::trace!("{} vs {}", df, df_expect);
472 assert_eq!(df.dataframe, df_expect.dataframe);
473
474 df.__imul__(hm.clone().into_pyobject(py)?.into_any())?;
475 df_expect.dataframe *= hm;
476 tracing::trace!("{} vs {}", df, df_expect);
477 assert_eq!(df.dataframe, df_expect.dataframe);
478 Ok(())
479 });
480
481 assert!(exec.is_ok(), "{:?}", exec);
482 }
483
484 #[rstest]
485 #[traced_test]
486 fn basic_ops_div(mut df: DataFrame, hm: HashMap<String, DataValue>) {
487 let mut df_expect = df.clone();
488 let df2 = df.clone();
489 let exec = Python::with_gil(|py| -> PyResult<()> {
490 df.__itruediv__(df.clone().into_pyobject(py)?.into_any())?;
491 df_expect.dataframe /= df2.dataframe;
492 tracing::trace!("{} vs {}", df, df_expect);
493 assert_eq!(df.dataframe, df_expect.dataframe);
494
495 df.__itruediv__(hm.clone().into_pyobject(py)?.into_any())?;
496 df_expect.dataframe /= hm;
497 tracing::trace!("{} vs {}", df, df_expect);
498 assert_eq!(df.dataframe, df_expect.dataframe);
499 Ok(())
500 });
501
502 assert!(exec.is_ok(), "{:?}", exec);
503 }
504
505 #[rstest]
506 #[traced_test]
507 #[rstest]
508 fn test_numpy(mut df: DataFrame) {
509 let exec = Python::with_gil(|py| -> PyResult<()> {
510 let code = c_str!(
511 r#"
512def example(df):
513 import numpy as np
514 a_np = df.as_numpy_f32(['key1', 'key2'])
515 print(a_np)
516 b_np = df.as_numpy_u32(['key1', 'key'])
517 print(b_np)
518 b_np = df.as_numpy_i32(['key1', 'key'])
519 print(b_np)
520 b_np = df.as_numpy_i64(['key1', 'key'])
521 print(b_np)
522 b_np = df.as_numpy_u64(['key1', 'key'])
523 print(b_np)
524 b_np = df.as_numpy_f64(['key1', 'key'])
525 print(b_np)
526 b_np = df.as_numpy_f64(['key1', 'key'], transposed=True)
527 print(b_np)
528 return df
529 "#
530 );
531 let fun: Py<PyAny> = PyModule::from_code(py, code, c_str!(""), c_str!(""))?
532 .getattr("example")?
533 .into();
534 let result = fun.call1(py, (df.clone(),));
535 assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
536 if py.import("numpy").is_ok() {
539 assert!(result.is_ok(), "{:?}", result);
540 } else {
541 assert!(result.is_err(), "{:?}", result);
542 }
543 Ok(())
544 });
545 assert!(exec.is_ok(), "{:?}", exec);
546 }
547
548 #[rstest]
549 #[traced_test]
550 #[rstest]
551 fn test_fill_from_python(df: DataFrame) {
552 let exec = Python::with_gil(|_py| -> PyResult<()> {
553 let hm = stdhashmap!(
554 Key::from("key1") => DataValue::U32(1),
555 Key::from("key2") => DataValue::U32(2),
556 );
557 let mut df2 = DataFrame::init();
558 assert!(df2.py_push(hm).is_ok());
559 assert!(df2
560 .py_push(stdhashmap!(
561 Key::from("key1") => DataValue::U32(11),
562 Key::from("key2") => DataValue::U32(21),
563 ))
564 .is_ok());
565
566 assert_eq!(df, df2);
567
568 let mut df2 = DataFrame::init();
569 assert!(df2
570 .py_add_column(
571 Key::from("key1"),
572 vec![DataValue::U32(1), DataValue::U32(11)]
573 )
574 .is_ok());
575 assert!(df2
576 .py_add_column(
577 Key::from("key2"),
578 vec![DataValue::U32(2), DataValue::U32(21)]
579 )
580 .is_ok());
581
582 assert_eq!(df, df2);
583 Ok(())
584 });
585 assert!(exec.is_ok(), "{:?}", exec);
586 }
587
588 #[rstest]
589 fn basic_python_dataframe(mut df: DataFrame) {
590 let exec = Python::with_gil(|py| -> PyResult<()> {
591 let fun: Py<PyAny> = PyModule::from_code(
592 py,
593 c_str!(
594 "
595def example(df):
596 print(df)
597 df.shrink()
598 assert len(df) == 2
599 df.add_alias('key1', 'key1-alias')
600 a = df.select(['key1', 'key2'])
601 print(a)
602 b = df.select(['key1-alias', 'key2'])
603 print(b)
604 df.rename_key('key1', 'key1new')
605 df.rename_key('key1new', 'key1')
606 assert a == [[1, 2], [11, 21]]
607 assert a == b
608 df.add_metadata('test', 1)
609 m = df.get_metadata('test')
610 assert m == 1
611 b = df.select_transposed(['key1', 'key2'])
612 print(b)
613 assert b == [[1, 11], [2, 21]]
614 c = df.select_column('key1')
615 print(c)
616 assert c == [1, 11]
617
618 a += b
619 print(a)
620 assert a == [[2, 13], [4, 23]]
621 a -= b
622 print(a)
623 assert e == a
624 f = e * b
625 print(f)
626 assert f == [[1, 22], [44, 441]]
627 g = f / b
628 print(g)
629 assert g == e
630
631 "
632 ),
633 c_str!(""),
634 c_str!(""),
635 )?
636 .getattr("example")?
637 .into();
638 let _ = fun.call1(py, (df.clone(),));
639 assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
640 Ok(())
641 });
642 assert!(exec.is_ok(), "{:?}", exec);
643 }
644
645 #[rstest]
646 fn dummy_test_apply(mut df: DataFrame) {
647 let exec = Python::with_gil(|py| -> PyResult<()> {
648 let fun: Py<PyAny> = PyModule::from_code(
649 py,
650 c_str!(
651 r#"
652def multiply_by_ten(x):
653 print(x)
654 x *= {"key1": 10}
655 print(x)
656 return x
657
658def example(df):
659 print(df)
660 df.apply(multiply_by_ten)
661 "#
662 ),
663 c_str!(""),
664 c_str!(""),
665 )?
666 .getattr("example")?
667 .into();
668 let _ = fun.call1(py, (df.clone(),));
669 assert!(df.py_join(df.clone(), JoinRelation::default()).is_ok());
670 Ok(())
671 });
672 assert!(exec.is_ok(), "{:?}", exec);
673 }
674}