trs_dataframe/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
pub mod candidate;
pub mod dataframe;
pub mod error;
pub use candidate::CandidateData;
pub use data_value;
pub use data_value::DataValue;
pub use dataframe::DataFrame;
pub mod utils;
pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
pub use dataframe::{
    colums_store::{ColumnFrame, KeyIndex},
    key::Key,
};
pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
pub use ndarray;
/// Data type for the values in the dataframe
#[derive(
    Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
)]
#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
pub enum DataType {
    Bool,
    U32,
    I32,
    U8,
    U64,
    I64,
    F32,
    F64,
    String,
    Bytes,
    #[default]
    Unknown,
    Vec,
    Map,
}

/// Autodetector for the data type from [`DataValue`]
pub fn detect_dtype(value: &DataValue) -> DataType {
    use DataValue::*;
    match value {
        Bool(_) => DataType::Bool,
        I32(_) => DataType::I32,
        U32(_) => DataType::U32,
        I64(_) => DataType::I64,
        U64(_) => DataType::U64,
        F32(_) => DataType::F32,
        F64(_) => DataType::F64,
        String(_) => DataType::String,
        Bytes(_) => DataType::Bytes,
        Vec(_) => DataType::Vec,
        Map(_) => DataType::Map,
        _ => DataType::Unknown,
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use rstest::*;

    #[rstest]
    #[case(DataType::Bool, DataValue::Bool(true))]
    #[case(DataType::I32, DataValue::I32(1))]
    #[case(DataType::U32, DataValue::U32(1))]
    #[case(DataType::I64, DataValue::I64(1))]
    #[case(DataType::U64, DataValue::U64(1))]
    #[case(DataType::F32, DataValue::F32(1.0))]
    #[case(DataType::F64, DataValue::F64(1.0))]
    #[case(DataType::String, DataValue::String("1".into()))]
    #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
    #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
    #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
    #[case(DataType::Unknown, DataValue::Null)]
    fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
        assert_eq!(detect_dtype(&value), dtype);
        let serde_dtype: DataType =
            serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
                .expect("BUG: cannot deserialize");
        assert_eq!(serde_dtype, dtype);
    }
}