ta_lib_in_rust/util/
dataframe_utils.rs1use polars::prelude::*;
2use polars::frame::column::Column;
3
4pub fn ensure_f64_column(df: &mut DataFrame, column_name: &str) -> PolarsResult<()> {
15 let s: Series = df.column(column_name)?.as_materialized_series().clone();
17 let mut col: Column = s.into_column();
18
19 let series_mut: &mut Series = col.into_materialized_series();
21 *series_mut = series_mut.cast(&DataType::Float64)?;
22
23 let series: Series = col.take_materialized_series();
25 df.replace(column_name, series)?;
26
27 Ok(())
28}
29
30pub fn check_window_size(df: &DataFrame, window: usize, indicator_name: &str) -> PolarsResult<()> {
42 if df.height() < window {
43 return Err(PolarsError::ComputeError(
44 format!("Not enough data points ({}) for {} window ({})",
45 df.height(), indicator_name, window).into()
46 ));
47 }
48 Ok(())
49}
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54 use polars::prelude::*;
55
56 #[test]
57 fn test_ensure_f64_column() {
58 let int_series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
60 let mut df = DataFrame::new(vec![int_series.into()]).unwrap();
61
62 ensure_f64_column(&mut df, "values").unwrap();
64
65 assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
67
68 let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
70 for (i, val) in expected.iter().enumerate() {
71 assert_eq!(df.column("values").unwrap().f64().unwrap().get(i).unwrap(), *val);
72 }
73 }
74
75 #[test]
76 fn test_ensure_f64_column_already_f64() {
77 let f64_series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
79 let mut df = DataFrame::new(vec![f64_series.into()]).unwrap();
80
81 ensure_f64_column(&mut df, "values").unwrap();
83
84 assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
86
87 let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
89 for (i, val) in expected.iter().enumerate() {
90 assert_eq!(df.column("values").unwrap().f64().unwrap().get(i).unwrap(), *val);
91 }
92 }
93
94 #[test]
95 #[should_panic(expected = "not found")]
96 fn test_ensure_f64_column_nonexistent() {
97 let series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
99 let mut df = DataFrame::new(vec![series.into()]).unwrap();
100
101 ensure_f64_column(&mut df, "nonexistent").unwrap();
103 }
104
105 #[test]
106 fn test_check_window_size_sufficient() {
107 let series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
109 let df = DataFrame::new(vec![series.into()]).unwrap();
110
111 let result = check_window_size(&df, 3, "Test");
113 assert!(result.is_ok());
114
115 let result = check_window_size(&df, 5, "Test");
117 assert!(result.is_ok());
118 }
119
120 #[test]
121 fn test_check_window_size_insufficient() {
122 let series = Series::new("values".into(), &[1.0, 2.0, 3.0]);
124 let df = DataFrame::new(vec![series.into()]).unwrap();
125
126 let result = check_window_size(&df, 4, "Test");
128 assert!(result.is_err());
129
130 let err = result.unwrap_err();
132 match err {
133 PolarsError::ComputeError(msg) => {
134 assert!(msg.contains("Not enough data points"));
135 assert!(msg.contains("Test"));
136 },
137 _ => panic!("Expected ComputeError, got {:?}", err),
138 }
139 }
140}