ta_lib_in_rust/util/
dataframe_utils.rs1use polars::frame::column::Column;
2use polars::prelude::*;
3
4pub fn ensure_f64_column(df: &mut DataFrame, column_name: &str) -> PolarsResult<()> {
15 let s: Series = df.column(column_name)?.as_materialized_series().clone();
17 let mut col: Column = s.into_column();
18
19 let series_mut: &mut Series = col.into_materialized_series();
21 *series_mut = series_mut.cast(&DataType::Float64)?;
22
23 let series: Series = col.take_materialized_series();
25 df.replace(column_name, series)?;
26
27 Ok(())
28}
29
30pub fn check_window_size(df: &DataFrame, window: usize, indicator_name: &str) -> PolarsResult<()> {
42 if df.height() < window {
43 return Err(PolarsError::ComputeError(
44 format!(
45 "Not enough data points ({}) for {} window ({})",
46 df.height(),
47 indicator_name,
48 window
49 )
50 .into(),
51 ));
52 }
53 Ok(())
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use polars::prelude::*;
60
61 #[test]
62 fn test_ensure_f64_column() {
63 let int_series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
65 let mut df = DataFrame::new(vec![int_series.into()]).unwrap();
66
67 ensure_f64_column(&mut df, "values").unwrap();
69
70 assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
72
73 let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
75 for (i, val) in expected.iter().enumerate() {
76 assert_eq!(
77 df.column("values").unwrap().f64().unwrap().get(i).unwrap(),
78 *val
79 );
80 }
81 }
82
83 #[test]
84 fn test_ensure_f64_column_already_f64() {
85 let f64_series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
87 let mut df = DataFrame::new(vec![f64_series.into()]).unwrap();
88
89 ensure_f64_column(&mut df, "values").unwrap();
91
92 assert_eq!(df.column("values").unwrap().dtype(), &DataType::Float64);
94
95 let expected = &[1.0, 2.0, 3.0, 4.0, 5.0];
97 for (i, val) in expected.iter().enumerate() {
98 assert_eq!(
99 df.column("values").unwrap().f64().unwrap().get(i).unwrap(),
100 *val
101 );
102 }
103 }
104
105 #[test]
106 #[should_panic(expected = "not found")]
107 fn test_ensure_f64_column_nonexistent() {
108 let series = Series::new("values".into(), &[1, 2, 3, 4, 5]);
110 let mut df = DataFrame::new(vec![series.into()]).unwrap();
111
112 ensure_f64_column(&mut df, "nonexistent").unwrap();
114 }
115
116 #[test]
117 fn test_check_window_size_sufficient() {
118 let series = Series::new("values".into(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
120 let df = DataFrame::new(vec![series.into()]).unwrap();
121
122 let result = check_window_size(&df, 3, "Test");
124 assert!(result.is_ok());
125
126 let result = check_window_size(&df, 5, "Test");
128 assert!(result.is_ok());
129 }
130
131 #[test]
132 fn test_check_window_size_insufficient() {
133 let series = Series::new("values".into(), &[1.0, 2.0, 3.0]);
135 let df = DataFrame::new(vec![series.into()]).unwrap();
136
137 let result = check_window_size(&df, 4, "Test");
139 assert!(result.is_err());
140
141 let err = result.unwrap_err();
143 match err {
144 PolarsError::ComputeError(msg) => {
145 assert!(msg.contains("Not enough data points"));
146 assert!(msg.contains("Test"));
147 }
148 _ => panic!("Expected ComputeError, got {:?}", err),
149 }
150 }
151}