1use polars::prelude::*;
2use std::fs::File;
3use std::io::BufRead;
4use std::path::Path;
5
6#[derive(Debug, Clone)]
8pub struct FinancialColumns {
9 pub date: Option<String>,
10 pub open: Option<String>,
11 pub high: Option<String>,
12 pub low: Option<String>,
13 pub close: Option<String>,
14 pub volume: Option<String>,
15}
16
17pub fn read_csv<P: AsRef<Path>>(
38 file_path: P,
39 has_header: bool,
40 delimiter: char,
41) -> PolarsResult<DataFrame> {
42 let file = File::open(file_path)?;
43
44 let csv_options = CsvReadOptions::default()
46 .with_has_header(has_header)
47 .map_parse_options(|opts| opts.with_separator(delimiter as u8));
48
49 CsvReader::new(file).with_options(csv_options).finish()
50}
51
52pub fn read_csv_default<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
62 read_csv(file_path, true, ',')
63}
64
65pub fn read_parquet<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
84 let file = File::open(file_path)?;
85 ParquetReader::new(file).finish()
86}
87
88pub fn read_financial_data<P: AsRef<Path>>(
151 file_path: P,
152) -> PolarsResult<(DataFrame, FinancialColumns)> {
153 let path = file_path.as_ref();
154
155 let file_type = path
157 .extension()
158 .and_then(|ext| ext.to_str())
159 .map(|ext| ext.to_lowercase())
160 .ok_or_else(|| {
161 PolarsError::ComputeError("Could not determine file type from extension".into())
162 })?;
163
164 let df = match file_type.as_str() {
166 "csv" => {
167 let file = File::open(path)?;
169 let mut reader = std::io::BufReader::new(file);
170 let mut first_line = String::new();
171 reader.read_line(&mut first_line)?;
172
173 let has_header = ["date", "time", "open", "high", "low", "close", "volume"]
175 .iter()
176 .any(|&name| first_line.to_lowercase().contains(name));
177
178 let delimiters = [',', ';', '\t', '|'];
180 let mut last_error = None;
181
182 for &delimiter in &delimiters {
183 match read_csv(path, has_header, delimiter) {
184 Ok(df) => return process_dataframe(df, has_header),
185 Err(e) => last_error = Some(e),
186 }
187 }
188
189 Err(last_error.unwrap_or_else(|| {
191 PolarsError::ComputeError("Failed to read CSV with any common delimiter".into())
192 }))?
193 }
194 "parquet" => read_parquet(path)?,
195 _ => {
196 return Err(PolarsError::ComputeError(
197 format!("Unsupported file type: {}", file_type).into(),
198 ))
199 }
200 };
201
202 let columns = map_columns_with_headers(&df)?;
204 Ok((df, columns))
205}
206
207fn process_dataframe(
209 mut df: DataFrame,
210 has_header: bool,
211) -> PolarsResult<(DataFrame, FinancialColumns)> {
212 let columns = if has_header {
213 map_columns_with_headers(&df)?
214 } else {
215 rename_columns_without_headers(&mut df)?
216 };
217 Ok((df, columns))
218}
219
220fn map_columns_with_headers(df: &DataFrame) -> PolarsResult<FinancialColumns> {
222 let column_names: Vec<String> = df
223 .get_column_names()
224 .into_iter()
225 .map(|s| s.to_string())
226 .collect();
227
228 let mut financial_columns = FinancialColumns {
230 date: None,
231 open: None,
232 high: None,
233 low: None,
234 close: None,
235 volume: None,
236 };
237
238 let date_variations = ["date", "time", "datetime", "timestamp", "dt"];
240 let open_variations = ["open", "o", "opening"];
241 let high_variations = ["high", "h", "highest"];
242 let low_variations = ["low", "l", "lowest"];
243 let close_variations = ["close", "c", "closing"];
244 let volume_variations = ["volume", "vol", "v"];
245
246 for col in column_names {
247 let lower_col = col.to_lowercase();
248
249 if financial_columns.date.is_none()
250 && date_variations.iter().any(|&v| lower_col.contains(v))
251 {
252 financial_columns.date = Some(col.clone());
253 } else if financial_columns.open.is_none()
254 && open_variations.iter().any(|&v| lower_col.contains(v))
255 {
256 financial_columns.open = Some(col.clone());
257 } else if financial_columns.high.is_none()
258 && high_variations.iter().any(|&v| lower_col.contains(v))
259 {
260 financial_columns.high = Some(col.clone());
261 } else if financial_columns.low.is_none()
262 && low_variations.iter().any(|&v| lower_col.contains(v))
263 {
264 financial_columns.low = Some(col.clone());
265 } else if financial_columns.close.is_none()
266 && close_variations.iter().any(|&v| lower_col.contains(v))
267 {
268 financial_columns.close = Some(col.clone());
269 } else if financial_columns.volume.is_none()
270 && volume_variations.iter().any(|&v| lower_col.contains(v))
271 {
272 financial_columns.volume = Some(col.clone());
273 }
274 }
275
276 Ok(financial_columns)
277}
278
279fn rename_columns_without_headers(df: &mut DataFrame) -> PolarsResult<FinancialColumns> {
281 let n_cols = df.width();
282 let mut col_names = vec![String::new(); n_cols];
283 let mut identified_cols = vec![false; n_cols];
284
285 let mut financial_columns = FinancialColumns {
287 date: None,
288 open: None,
289 high: None,
290 low: None,
291 close: None,
292 volume: None,
293 };
294
295 for i in 0..n_cols {
297 if let Some(series) = df.select_at_idx(i) {
298 if !identified_cols[i]
299 && (series.dtype() == &DataType::String
300 || series.dtype() == &DataType::Date
301 || matches!(series.dtype(), DataType::Datetime(_, _)))
302 {
303 col_names[i] = "date".to_string();
304 financial_columns.date = Some("date".to_string());
305 identified_cols[i] = true;
306 break; }
308 }
309 }
310
311 for (i, &identified) in identified_cols.iter().enumerate().take(n_cols) {
314 if identified {
315 continue;
316 }
317
318 if let Some(series) = df.select_at_idx(i) {
319 if series.dtype().is_primitive_numeric() {
320 if let Ok(f64_series) = series.cast(&DataType::Float64) {
321 if let Ok(nums) = f64_series.f64() {
322 let is_volume =
324 if let (Some(mean), Some(std_dev)) = (nums.mean(), nums.std(0)) {
325 let other_cols_mean = get_numeric_columns_mean(df, i)?;
330 mean > other_cols_mean * 100.0 && std_dev > mean * 0.1
331 } else {
332 false
333 };
334
335 if is_volume {
336 col_names[i] = "volume".to_string();
337 financial_columns.volume = Some("volume".to_string());
338 identified_cols[i] = true;
339 break; }
341 }
342 }
343 }
344 }
345 }
346
347 let mut price_stats: Vec<(usize, f64, f64, f64)> = Vec::new(); for (i, &identified) in identified_cols.iter().enumerate().take(n_cols) {
351 if identified {
352 continue;
353 }
354
355 if let Some(series) = df.select_at_idx(i) {
356 if series.dtype().is_primitive_numeric() {
357 if let Ok(f64_series) = series.cast(&DataType::Float64) {
358 if let Ok(nums) = f64_series.f64() {
359 if let (Some(min), Some(max), Some(std)) =
360 (nums.min(), nums.max(), nums.std(0))
361 {
362 price_stats.push((i, min, max, std));
363 }
364 }
365 }
366 }
367 }
368 }
369
370 price_stats.sort_by(|a, b| {
372 let a_range = a.2 - a.1;
373 let b_range = b.2 - b.1;
374 b_range
375 .partial_cmp(&a_range)
376 .unwrap_or(std::cmp::Ordering::Equal)
377 .then_with(|| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal))
378 });
379
380 for (idx, stat) in price_stats.iter().enumerate() {
382 let i = stat.0;
383 if !identified_cols[i] {
384 let col_name = match idx {
385 0 => {
386 financial_columns.high = Some("high".to_string());
387 "high"
388 }
389 1 => {
390 financial_columns.low = Some("low".to_string());
391 "low"
392 }
393 2 => {
394 financial_columns.close = Some("close".to_string());
395 "close"
396 }
397 3 => {
398 financial_columns.open = Some("open".to_string());
399 "open"
400 }
401 _ => "unknown",
402 };
403 col_names[i] = col_name.to_string();
404 identified_cols[i] = true;
405 }
406 }
407
408 for (i, name) in col_names.iter_mut().enumerate().take(n_cols) {
410 if name.is_empty() {
411 *name = format!("unknown_{}", i);
412 }
413 }
414
415 df.set_column_names(&col_names)?;
417
418 Ok(financial_columns)
419}
420
421fn get_numeric_columns_mean(df: &DataFrame, exclude_idx: usize) -> PolarsResult<f64> {
423 let mut sum = 0.0;
424 let mut count = 0;
425
426 for i in 0..df.width() {
427 if i == exclude_idx {
428 continue;
429 }
430
431 if let Some(series) = df.select_at_idx(i) {
432 if series.dtype().is_primitive_numeric() {
433 if let Ok(f64_series) = series.cast(&DataType::Float64) {
434 if let Ok(nums) = f64_series.f64() {
435 if let Some(mean) = nums.mean() {
436 sum += mean;
437 count += 1;
438 }
439 }
440 }
441 }
442 }
443 }
444
445 if count > 0 {
446 Ok(sum / count as f64)
447 } else {
448 Ok(0.0)
449 }
450}