Skip to main content

trs_dataframe/
lib.rs

1pub mod candidate;
2pub mod dataframe;
3pub mod error;
4pub use candidate::CandidateData;
5pub use data_value;
6pub use data_value::DataValue;
7pub use dataframe::DataFrame;
8pub mod filter;
9pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
10pub use dataframe::{
11    column_store::{ColumnFrame, KeyIndex},
12    key::Key,
13};
14pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
15pub use ndarray;
16
17#[cfg(feature = "jmalloc")]
18#[global_allocator]
19static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
20
21#[cfg(feature = "polars-df")]
22pub use polars;
23/// Data type for the values in the dataframe
24#[derive(
25    Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
26)]
27#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub enum DataType {
30    Bool,
31    U32,
32    I32,
33    U8,
34    U64,
35    I64,
36    F32,
37    F64,
38    I128,
39    U128,
40    String,
41    Bytes,
42    #[default]
43    Unknown,
44    Vec,
45    Map,
46}
47
48/// Autodetector for the data type from [`DataValue`]
49pub fn detect_dtype(value: &DataValue) -> DataType {
50    use DataValue::*;
51    match value {
52        Bool(_) => DataType::Bool,
53        I32(_) => DataType::I32,
54        U32(_) => DataType::U32,
55        I64(_) => DataType::I64,
56        U64(_) => DataType::U64,
57        F32(_) => DataType::F32,
58        F64(_) => DataType::F64,
59        I128(_) => DataType::I128,
60        U128(_) => DataType::U128,
61        String(_) => DataType::String,
62        Bytes(_) => DataType::Bytes,
63        Vec(_) => DataType::Vec,
64        Map(_) => DataType::Map,
65        _ => DataType::Unknown,
66    }
67}
68pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
69    let mut dtype = DataType::Unknown;
70    let mut find_count = 3;
71    for val in value {
72        let new_dtype = detect_dtype(val);
73        if new_dtype != dtype {
74            dtype = new_dtype;
75        } else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
76            find_count -= 1;
77        }
78        if find_count == 0 {
79            break;
80        }
81    }
82
83    dtype
84}
85
86#[cfg(feature = "python")]
87use pyo3::prelude::*;
88
89#[cfg(feature = "python")]
90///
91/// ```
92/// use pyo3::prelude::*;
93///
94///  fn main() {
95///     let result = pyo3::Python::with_gil(|py| -> PyResult<()> {
96///         let module = PyModule::new(py, "trs_dataframe")?;
97///         let _m = trs_dataframe::trs_dataframe(py, module)?;
98///         Ok(())
99///         });
100///     assert!(result.is_ok(), "{:?}", result);
101///  }
102///
103/// ```
104#[pymodule]
105pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
106    m.add_class::<DataFrame>()?;
107    m.add_class::<JoinRelation>()?;
108    m.add_class::<Key>()?;
109    Ok(())
110}
111#[cfg(test)]
112mod test {
113    use crate::dataframe::column_store::convert_data_value;
114
115    use super::*;
116    use rstest::*;
117
118    #[rstest]
119    #[case(DataType::Bool, DataValue::Bool(true))]
120    #[case(DataType::I32, DataValue::I32(1))]
121    #[case(DataType::U32, DataValue::U32(1))]
122    #[case(DataType::I64, DataValue::I64(1))]
123    #[case(DataType::U64, DataValue::U64(1))]
124    #[case(DataType::F32, DataValue::F32(1.0))]
125    #[case(DataType::F64, DataValue::F64(1.0))]
126    #[case(DataType::U128, DataValue::U128(1))]
127    #[case(DataType::I128, DataValue::I128(1))]
128    #[case(DataType::String, DataValue::String("1".into()))]
129    #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
130    #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
131    #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
132    #[case(DataType::Unknown, DataValue::Null)]
133    fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
134        assert_eq!(detect_dtype(&value), dtype);
135        let serde_dtype: DataType =
136            serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
137                .expect("BUG: cannot deserialize");
138        assert_eq!(serde_dtype, dtype);
139        let dt = convert_data_value(value.clone(), dtype);
140        assert_eq!(dt, value);
141    }
142}