ta_lib_in_rust/util/
dataframe_utils.rs

1use polars::frame::column::Column;
2use polars::prelude::*;
3
4/// Ensure a column in a DataFrame is of Float64 type
5///
6/// # Arguments
7///
8/// * `df` - DataFrame to modify
9/// * `column_name` - Name of the column to convert
10///
11/// # Returns
12///
13/// Returns a PolarsResult indicating success or failure
14pub fn ensure_f64_column(df: &mut DataFrame, column_name: &str) -> PolarsResult<()> {
15    // 1) Wrap the existing Series in a Column for in-place mutation
16    let s: Series = df.column(column_name)?.as_materialized_series().clone();
17    let mut col: Column = s.into_column();
18
19    // 2) Materialize and get a &mut Series to cast in place
20    let series_mut: &mut Series = col.into_materialized_series();
21    *series_mut = series_mut.cast(&DataType::Float64)?;
22
23    // 3) Convert the Column back into a Series and replace it in the DataFrame
24    let series: Series = col.take_materialized_series();
25    df.replace(column_name, series)?;
26
27    Ok(())
28}
29
30/// Check if a DataFrame has enough rows for a given window size
31///
32/// # Arguments
33///
34/// * `df` - The DataFrame to check
35/// * `window` - The window size required
36/// * `indicator_name` - Name of the indicator (for error message)
37///
38/// # Returns
39///
40/// Returns a PolarsResult<()> or an error if there are not enough rows
41pub fn check_window_size(df: &DataFrame, window: usize, indicator_name: &str) -> PolarsResult<()> {
42    if df.height() < window {
43        return Err(PolarsError::ComputeError(
44            format!(
45                "Not enough data points ({}) for {} window ({})",
46                df.height(),
47                indicator_name,
48                window
49            )
50            .into(),
51        ));
52    }
53    Ok(())
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use polars::prelude::*;
60
61    #[test]
62    fn test_ensure_f64_column() {
63        // Test with integer column
64        let int_series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
65        let mut df = DataFrame::new(vec![int_series.into()]).unwrap();
66
67        // Ensure the column is converted to f64
68        ensure_f64_column(&mut df, "values").unwrap();
69
70        // Verify the column type is now Float64
71        assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
72
73        // Verify values were properly converted
74        let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
75        for (i, val) in expected.iter().enumerate() {
76            assert_eq!(
77                df.column("values").unwrap().f64().unwrap().get(i).unwrap(),
78                *val
79            );
80        }
81    }
82
83    #[test]
84    fn test_ensure_f64_column_already_f64() {
85        // Test with already f64 column
86        let f64_series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
87        let mut df = DataFrame::new(vec![f64_series.into()]).unwrap();
88
89        // Ensure the column is f64 (should be a no-op)
90        ensure_f64_column(&mut df, "values").unwrap();
91
92        // Verify the column type is still Float64
93        assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
94
95        // Verify values are unchanged
96        let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
97        for (i, val) in expected.iter().enumerate() {
98            assert_eq!(
99                df.column("values").unwrap().f64().unwrap().get(i).unwrap(),
100                *val
101            );
102        }
103    }
104
105    #[test]
106    #[should_panic(expected = "not found")]
107    fn test_ensure_f64_column_nonexistent() {
108        // Test with nonexistent column
109        let series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
110        let mut df = DataFrame::new(vec![series.into()]).unwrap();
111
112        // Try to ensure a nonexistent column is f64
113        ensure_f64_column(&mut df, "nonexistent").unwrap();
114    }
115
116    #[test]
117    fn test_check_window_size_sufficient() {
118        // Test with sufficient data points
119        let series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
120        let df = DataFrame::new(vec![series.into()]).unwrap();
121
122        // Window size is less than number of rows
123        let result = check_window_size(&df, 3, "Test");
124        assert!(result.is_ok());
125
126        // Window size equals number of rows
127        let result = check_window_size(&df, 5, "Test");
128        assert!(result.is_ok());
129    }
130
131    #[test]
132    fn test_check_window_size_insufficient() {
133        // Test with insufficient data points
134        let series = Series::new("values".into(), &[1.0, 2.0, 3.0]);
135        let df = DataFrame::new(vec![series.into()]).unwrap();
136
137        // Window size is greater than number of rows
138        let result = check_window_size(&df, 4, "Test");
139        assert!(result.is_err());
140
141        // Check error message
142        let err = result.unwrap_err();
143        match err {
144            PolarsError::ComputeError(msg) => {
145                assert!(msg.contains("Not enough data points"));
146                assert!(msg.contains("Test"));
147            }
148            _ => panic!("Expected ComputeError, got {:?}", err),
149        }
150    }
151}