ta_lib_in_rust/util/
dataframe_utils.rs

1use polars::prelude::*;
2use polars::frame::column::Column;
3
4/// Ensure a column in a DataFrame is of Float64 type
5/// 
6/// # Arguments
7/// 
8/// * `df` - DataFrame to modify
9/// * `column_name` - Name of the column to convert
10/// 
11/// # Returns
12/// 
13/// Returns a PolarsResult indicating success or failure
14pub fn ensure_f64_column(df: &mut DataFrame, column_name: &str) -> PolarsResult<()> {
15    // 1) Wrap the existing Series in a Column for in-place mutation
16    let s: Series = df.column(column_name)?.as_materialized_series().clone();
17    let mut col: Column = s.into_column();
18
19    // 2) Materialize and get a &mut Series to cast in place
20    let series_mut: &mut Series = col.into_materialized_series();
21    *series_mut = series_mut.cast(&DataType::Float64)?;
22
23    // 3) Convert the Column back into a Series and replace it in the DataFrame
24    let series: Series = col.take_materialized_series();
25    df.replace(column_name, series)?;
26    
27    Ok(())
28}
29
30/// Check if a DataFrame has enough rows for a given window size
31///
32/// # Arguments
33///
34/// * `df` - The DataFrame to check
35/// * `window` - The window size required
36/// * `indicator_name` - Name of the indicator (for error message)
37///
38/// # Returns
39///
40/// Returns a PolarsResult<()> or an error if there are not enough rows
41pub fn check_window_size(df: &DataFrame, window: usize, indicator_name: &str) -> PolarsResult<()> {
42    if df.height() < window {
43        return Err(PolarsError::ComputeError(
44            format!("Not enough data points ({}) for {} window ({})", 
45                    df.height(), indicator_name, window).into()
46        ));
47    }
48    Ok(())
49} 
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use polars::prelude::*;
55    
56    #[test]
57    fn test_ensure_f64_column() {
58        // Test with integer column
59        let int_series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
60        let mut df = DataFrame::new(vec![int_series.into()]).unwrap();
61        
62        // Ensure the column is converted to f64
63        ensure_f64_column(&mut df, "values").unwrap();
64        
65        // Verify the column type is now Float64
66        assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
67        
68        // Verify values were properly converted
69        let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
70        for (i, val) in expected.iter().enumerate() {
71            assert_eq!(df.column("values").unwrap().f64().unwrap().get(i).unwrap(), *val);
72        }
73    }
74    
75    #[test]
76    fn test_ensure_f64_column_already_f64() {
77        // Test with already f64 column
78        let f64_series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
79        let mut df = DataFrame::new(vec![f64_series.into()]).unwrap();
80        
81        // Ensure the column is f64 (should be a no-op)
82        ensure_f64_column(&mut df, "values").unwrap();
83        
84        // Verify the column type is still Float64
85        assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
86        
87        // Verify values are unchanged
88        let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
89        for (i, val) in expected.iter().enumerate() {
90            assert_eq!(df.column("values").unwrap().f64().unwrap().get(i).unwrap(), *val);
91        }
92    }
93    
94    #[test]
95    #[should_panic(expected = "not found")]
96    fn test_ensure_f64_column_nonexistent() {
97        // Test with nonexistent column
98        let series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
99        let mut df = DataFrame::new(vec![series.into()]).unwrap();
100        
101        // Try to ensure a nonexistent column is f64
102        ensure_f64_column(&mut df, "nonexistent").unwrap();
103    }
104    
105    #[test]
106    fn test_check_window_size_sufficient() {
107        // Test with sufficient data points
108        let series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
109        let df = DataFrame::new(vec![series.into()]).unwrap();
110        
111        // Window size is less than number of rows
112        let result = check_window_size(&df, 3, "Test");
113        assert!(result.is_ok());
114        
115        // Window size equals number of rows
116        let result = check_window_size(&df, 5, "Test");
117        assert!(result.is_ok());
118    }
119    
120    #[test]
121    fn test_check_window_size_insufficient() {
122        // Test with insufficient data points
123        let series = Series::new("values".into(), &[1.0, 2.0, 3.0]);
124        let df = DataFrame::new(vec![series.into()]).unwrap();
125        
126        // Window size is greater than number of rows
127        let result = check_window_size(&df, 4, "Test");
128        assert!(result.is_err());
129        
130        // Check error message
131        let err = result.unwrap_err();
132        match err {
133            PolarsError::ComputeError(msg) => {
134                assert!(msg.contains("Not enough data points"));
135                assert!(msg.contains("Test"));
136            },
137            _ => panic!("Expected ComputeError, got {:?}", err),
138        }
139    }
140}