Skip to main content

trs_dataframe/
lib.rs

1/// Row-oriented candidate storage and conversion utilities.
2pub mod candidate;
3/// Column-oriented dataframe storage, joins, keys, and indexing.
4pub mod dataframe;
5/// Error types used throughout the crate.
6pub mod error;
7pub use candidate::CandidateData;
8pub use data_value;
9pub use data_value::DataValue;
10#[cfg(feature = "python")]
11pub use dataframe::python::DataFrameOrDict;
12pub use dataframe::DataFrame;
13/// Expression-based row filtering for dataframes.
14pub mod filter;
15pub use dataframe::join::{JoinBy, JoinById, JoinRelation};
16pub use dataframe::{
17    column_store::{
18        typed_array::{TypedData, TypedDataArray},
19        ColumnFrame, KeyIndex, MaybeView,
20    },
21    index::hash_datavalue,
22    key::Key,
23};
24/// Convenience alias for a string-keyed map of `DataValue` vectors.
25pub type MLChefMap = halfbrown::HashMap<smartstring::alias::String, Vec<DataValue>>;
26pub use ndarray;
27
28#[cfg(feature = "jmalloc")]
29#[global_allocator]
30static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
31
32#[cfg(feature = "polars-df")]
33pub use polars;
34/// Discriminant for the primitive type stored in a column or [`DataValue`].
35#[derive(
36    Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, Default,
37)]
38#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int))]
39#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
40pub enum DataType {
41    Bool,
42    U32,
43    I32,
44    U8,
45    U64,
46    I64,
47    F32,
48    F64,
49    I128,
50    U128,
51    String,
52    Bytes,
53    #[default]
54    Unknown,
55    Vec,
56    Map,
57}
58
59#[inline]
60/// Autodetector for the data type from [`DataValue`]
61pub fn detect_dtype(value: &DataValue) -> DataType {
62    use DataValue::*;
63    match value {
64        Bool(_) => DataType::Bool,
65        I32(_) => DataType::I32,
66        U32(_) => DataType::U32,
67        I64(_) => DataType::I64,
68        U64(_) => DataType::U64,
69        F32(_) => DataType::F32,
70        F64(_) => DataType::F64,
71        I128(_) => DataType::I128,
72        U128(_) => DataType::U128,
73        String(_) => DataType::String,
74        Bytes(_) => DataType::Bytes,
75        Vec(_) => DataType::Vec,
76        Map(_) => DataType::Map,
77        _ => DataType::Unknown,
78    }
79}
80/// Scans a slice of [`DataValue`]s and returns the dominant [`DataType`].
81///
82/// Inspects up to three consecutive values of the same type before
83/// short-circuiting. Heterogeneous slices return the type of the last
84/// observed change.
85pub fn detect_dtype_arr(value: &[DataValue]) -> DataType {
86    let mut dtype = DataType::Unknown;
87    let mut find_count = 3;
88    for val in value {
89        let new_dtype = detect_dtype(val);
90        if new_dtype != dtype {
91            dtype = new_dtype;
92        } else if new_dtype == dtype && !matches!(dtype, DataType::Unknown) {
93            find_count -= 1;
94        }
95        if find_count == 0 {
96            break;
97        }
98    }
99
100    dtype
101}
102
103#[cfg(feature = "python")]
104use pyo3::prelude::*;
105
106#[cfg(feature = "python")]
107///
108/// ```
109/// use pyo3::prelude::*;
110///
111///  fn main() {
112///     let result = pyo3::Python::with_gil(|py| -> PyResult<()> {
113///         let module = PyModule::new(py, "trs_dataframe")?;
114///         let _m = trs_dataframe::trs_dataframe(py, module)?;
115///         Ok(())
116///         });
117///     assert!(result.is_ok(), "{:?}", result);
118///  }
119///
120/// ```
121#[pymodule]
122pub fn trs_dataframe(_py: pyo3::Python<'_>, m: pyo3::Bound<'_, PyModule>) -> pyo3::PyResult<()> {
123    m.add_class::<DataFrame>()?;
124    m.add_class::<JoinRelation>()?;
125    m.add_class::<Key>()?;
126    Ok(())
127}
128#[cfg(test)]
129mod test {
130    use crate::dataframe::column_store::convert_data_value;
131
132    use super::*;
133    use rstest::*;
134
135    #[rstest]
136    #[case(DataType::Bool, DataValue::Bool(true))]
137    #[case(DataType::I32, DataValue::I32(1))]
138    #[case(DataType::U32, DataValue::U32(1))]
139    #[case(DataType::I64, DataValue::I64(1))]
140    #[case(DataType::U64, DataValue::U64(1))]
141    #[case(DataType::F32, DataValue::F32(1.0))]
142    #[case(DataType::F64, DataValue::F64(1.0))]
143    #[case(DataType::U128, DataValue::U128(1))]
144    #[case(DataType::I128, DataValue::I128(1))]
145    #[case(DataType::String, DataValue::String("1".into()))]
146    #[case(DataType::Bytes, DataValue::Bytes(b"1".to_vec()))]
147    #[case(DataType::Vec, DataValue::Vec(vec![DataValue::I32(1)]))]
148    #[case(DataType::Map, DataValue::Map(std::collections::HashMap::new()))]
149    #[case(DataType::Unknown, DataValue::Null)]
150    fn detection_test(#[case] dtype: DataType, #[case] value: DataValue) {
151        assert_eq!(detect_dtype(&value), dtype);
152        let serde_dtype: DataType =
153            serde_json::from_str(&serde_json::to_string(&dtype).expect("BUG: cannot serialize"))
154                .expect("BUG: cannot deserialize");
155        assert_eq!(serde_dtype, dtype);
156        let dt = convert_data_value(value.clone(), dtype);
157        assert_eq!(dt, value);
158    }
159
160    #[test]
161    fn detect_dtype_arr_unknown_for_empty() {
162        assert_eq!(detect_dtype_arr(&[]), DataType::Unknown);
163    }
164
165    #[test]
166    fn detect_dtype_arr_settles_on_repeated_dtype() {
167        // Three matching readings short-circuit the loop and lock the dtype.
168        let arr = vec![
169            DataValue::I32(1),
170            DataValue::I32(2),
171            DataValue::I32(3),
172            DataValue::I32(4),
173            DataValue::F64(5.0), // never observed because the loop exits first.
174        ];
175        assert_eq!(detect_dtype_arr(&arr), DataType::I32);
176    }
177
178    #[test]
179    fn detect_dtype_arr_overrides_on_change() {
180        // The dtype keeps updating until 3 consecutive matches are seen.
181        let arr = vec![DataValue::I32(1), DataValue::Null, DataValue::F64(1.0)];
182        assert_eq!(detect_dtype_arr(&arr), DataType::F64);
183    }
184
185    #[test]
186    fn detect_dtype_arr_ignores_repeated_unknown() {
187        // Repeated Unknown does not decrement the find counter — verifies the
188        // `!matches!(dtype, DataType::Unknown)` guard.
189        let arr = vec![DataValue::Null, DataValue::Null, DataValue::Null];
190        assert_eq!(detect_dtype_arr(&arr), DataType::Unknown);
191    }
192}