1pub mod candidate;
2pub mod dataframe;
3pub mod error;
4pub use candidate::CandidateData;
5pub use data_value;
6pub use data_value::DataValue;
7pub use dataframe::DataFrame;
8pub mod filter;
9pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
10pub use dataframe::{
11 column_store::{ColumnFrame, KeyIndex},
12 key::Key,
13};
14pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
15pub use ndarray;
16
17#[cfg(feature = "jmalloc")]
18#[global_allocator]
19static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
20
21#[cfg(feature = "polars-df")]
22pub use polars;
23#[derive(
25 Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
26)]
27#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub enum DataType {
30 Bool,
31 U32,
32 I32,
33 U8,
34 U64,
35 I64,
36 F32,
37 F64,
38 I128,
39 U128,
40 String,
41 Bytes,
42 #[default]
43 Unknown,
44 Vec,
45 Map,
46}
47
48pub fn detect_dtype(value: &DataValue) -> DataType {
50 use DataValue::*;
51 match value {
52 Bool(_) => DataType::Bool,
53 I32(_) => DataType::I32,
54 U32(_) => DataType::U32,
55 I64(_) => DataType::I64,
56 U64(_) => DataType::U64,
57 F32(_) => DataType::F32,
58 F64(_) => DataType::F64,
59 I128(_) => DataType::I128,
60 U128(_) => DataType::U128,
61 String(_) => DataType::String,
62 Bytes(_) => DataType::Bytes,
63 Vec(_) => DataType::Vec,
64 Map(_) => DataType::Map,
65 _ => DataType::Unknown,
66 }
67}
68pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
69 let mut dtype = DataType::Unknown;
70 let mut find_count = 3;
71 for val in value {
72 let new_dtype = detect_dtype(val);
73 if new_dtype != dtype {
74 dtype = new_dtype;
75 } else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
76 find_count -= 1;
77 }
78 if find_count == 0 {
79 break;
80 }
81 }
82
83 dtype
84}
85
86#[cfg(feature = "python")]
87use pyo3::prelude::*;
88
89#[cfg(feature = "python")]
90#[pymodule]
105pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
106 m.add_class::<DataFrame>()?;
107 m.add_class::<JoinRelation>()?;
108 m.add_class::<Key>()?;
109 Ok(())
110}
111#[cfg(test)]
112mod test {
113 use crate::dataframe::column_store::convert_data_value;
114
115 use super::*;
116 use rstest::*;
117
118 #[rstest]
119 #[case(DataType::Bool, DataValue::Bool(true))]
120 #[case(DataType::I32, DataValue::I32(1))]
121 #[case(DataType::U32, DataValue::U32(1))]
122 #[case(DataType::I64, DataValue::I64(1))]
123 #[case(DataType::U64, DataValue::U64(1))]
124 #[case(DataType::F32, DataValue::F32(1.0))]
125 #[case(DataType::F64, DataValue::F64(1.0))]
126 #[case(DataType::U128, DataValue::U128(1))]
127 #[case(DataType::I128, DataValue::I128(1))]
128 #[case(DataType::String, DataValue::String("1".into()))]
129 #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
130 #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
131 #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
132 #[case(DataType::Unknown, DataValue::Null)]
133 fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
134 assert_eq!(detect_dtype(&value), dtype);
135 let serde_dtype: DataType =
136 serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
137 .expect("BUG: cannot deserialize");
138 assert_eq!(serde_dtype, dtype);
139 let dt = convert_data_value(value.clone(), dtype);
140 assert_eq!(dt, value);
141 }
142}