trs_dataframe/
lib.rs

1pub mod candidate;
2pub mod dataframe;
3pub mod error;
4pub use candidate::CandidateData;
5pub use data_value;
6pub use data_value::DataValue;
7pub use dataframe::DataFrame;
8pub mod filter;
9pub mod utils;
10pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
11pub use dataframe::{
12    column_store::{ColumnFrame, KeyIndex},
13    key::Key,
14};
15pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
16pub use ndarray;
17
18#[cfg(feature = "jmalloc")]
19#[global_allocator]
20static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
21
22#[cfg(feature = "polars-df")]
23pub use polars;
24/// Data type for the values in the dataframe
25#[derive(
26    Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
27)]
28#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
29#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
30pub enum DataType {
31    Bool,
32    U32,
33    I32,
34    U8,
35    U64,
36    I64,
37    F32,
38    F64,
39    I128,
40    U128,
41    String,
42    Bytes,
43    #[default]
44    Unknown,
45    Vec,
46    Map,
47}
48
49/// Autodetector for the data type from [`DataValue`]
50pub fn detect_dtype(value: &DataValue) -> DataType {
51    use DataValue::*;
52    match value {
53        Bool(_) => DataType::Bool,
54        I32(_) => DataType::I32,
55        U32(_) => DataType::U32,
56        I64(_) => DataType::I64,
57        U64(_) => DataType::U64,
58        F32(_) => DataType::F32,
59        F64(_) => DataType::F64,
60        I128(_) => DataType::I128,
61        U128(_) => DataType::U128,
62        String(_) => DataType::String,
63        Bytes(_) => DataType::Bytes,
64        Vec(_) => DataType::Vec,
65        Map(_) => DataType::Map,
66        _ => DataType::Unknown,
67    }
68}
69pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
70    let mut dtype = DataType::Unknown;
71    let mut find_count = 3;
72    for val in value {
73        let new_dtype = detect_dtype(val);
74        if new_dtype != dtype {
75            dtype = new_dtype;
76        } else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
77            find_count -= 1;
78        }
79        if find_count == 0 {
80            break;
81        }
82    }
83
84    dtype
85}
86
87#[cfg(feature = "python")]
88use pyo3::prelude::*;
89
90#[cfg(feature = "python")]
91///
92/// ```
93/// use pyo3::prelude::*;
94///
95///  fn main() {
96///     let result = pyo3::Python::with_gil(|py| -> PyResult<()> {
97///         let module = PyModule::new(py, "trs_dataframe")?;
98///         let _m = trs_dataframe::trs_dataframe(py, module)?;
99///         Ok(())
100///         });
101///     assert!(result.is_ok(), "{:?}", result);
102///  }
103///
104/// ```
105#[pymodule]
106pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
107    m.add_class::<DataFrame>()?;
108    m.add_class::<JoinRelation>()?;
109    m.add_class::<Key>()?;
110    Ok(())
111}
112#[cfg(test)]
113mod test {
114    use crate::dataframe::column_store::convert_data_value;
115
116    use super::*;
117    use rstest::*;
118
119    #[rstest]
120    #[case(DataType::Bool, DataValue::Bool(true))]
121    #[case(DataType::I32, DataValue::I32(1))]
122    #[case(DataType::U32, DataValue::U32(1))]
123    #[case(DataType::I64, DataValue::I64(1))]
124    #[case(DataType::U64, DataValue::U64(1))]
125    #[case(DataType::F32, DataValue::F32(1.0))]
126    #[case(DataType::F64, DataValue::F64(1.0))]
127    #[case(DataType::U128, DataValue::U128(1))]
128    #[case(DataType::I128, DataValue::I128(1))]
129    #[case(DataType::String, DataValue::String("1".into()))]
130    #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
131    #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
132    #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
133    #[case(DataType::Unknown, DataValue::Null)]
134    fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
135        assert_eq!(detect_dtype(&value), dtype);
136        let serde_dtype: DataType =
137            serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
138                .expect("BUG: cannot deserialize");
139        assert_eq!(serde_dtype, dtype);
140        let dt = convert_data_value(value.clone(), dtype);
141        assert_eq!(dt, value);
142    }
143}