tigers/structs/
dataframe.rs

1use std::collections::HashMap;
2
3use crate::{left_pad, max_length, structs::column::Column, structs::row::Row};
4
5#[derive(Debug)]
6pub struct DataFrame {
7    headers: Vec<String>,
8    rows: Vec<Row>,
9    columns: HashMap<String, Column>,
10}
11
12impl DataFrame {
13    fn new(headers: Vec<String>, rows: Vec<Row>) -> DataFrame {
14        let columns = DataFrame::get_columns(&headers, &rows);
15        DataFrame {
16            headers,
17            rows,
18            columns,
19        }
20    }
21
22    fn get_columns(headers: &[String], rows: &[Row]) -> HashMap<String, Column> {
23        let mut columns = HashMap::new();
24        for (i, header) in headers.iter().enumerate() {
25            let values: Vec<String> = rows.iter().map(|row| row[i].clone()).collect();
26            columns.insert(header.clone(), Column::new(header.clone(), values));
27        }
28        columns
29    }
30
31    pub fn from_csv(path: &str) -> Result<DataFrame, std::io::Error> {
32        let mut reader = csv::Reader::from_path(path)?;
33        let mut rows: Vec<Row> = Vec::new();
34        let headers: Vec<String> = reader.headers()?.iter().map(|h| h.to_string()).collect();
35        if headers.is_empty() {
36            return Err(std::io::Error::new(
37                std::io::ErrorKind::InvalidData,
38                "No headers found in CSV file",
39            ));
40        }
41        for result in reader.records() {
42            let values: Vec<String> = result?.iter().map(|v| v.to_string()).collect();
43            if values.len() != headers.len() {
44                return Err(std::io::Error::new(
45                    std::io::ErrorKind::InvalidData,
46                    format!(
47                        "Number of values in row {} does not match number of headers",
48                        rows.len() + 1
49                    ),
50                ));
51            }
52            rows.push(Row::new(headers.clone(), values));
53        }
54        if rows.is_empty() {
55            return Err(std::io::Error::new(
56                std::io::ErrorKind::InvalidData,
57                "No rows found in CSV file",
58            ));
59        }
60        let df: DataFrame = DataFrame::new(headers, rows);
61        Ok(df)
62    }
63
64    pub fn to_csv(&self, path: &str) -> Result<(), std::io::Error> {
65        let mut writer = csv::Writer::from_path(path)?;
66        writer.write_record(&self.headers)?;
67        for row in &self.rows {
68            writer.write_record(row.get_values())?;
69        }
70        Ok(())
71    }
72
73    pub fn head(&self, n: usize) -> DataFrame {
74        let headers = self.headers.clone();
75        let rows = self.rows.iter().take(n).cloned().collect();
76        DataFrame::new(headers, rows)
77    }
78
79    pub fn rename(&self, map: &HashMap<String, String>) -> DataFrame {
80        let headers = self
81            .headers
82            .iter()
83            .map(|h| map.get(h).unwrap_or(h).clone())
84            .collect::<Vec<String>>();
85        let rows = self.rows.iter().map(|r| r.rename(&headers)).collect();
86        DataFrame::new(headers, rows)
87    }
88}
89
90impl std::ops::Index<usize> for DataFrame {
91    type Output = Row;
92
93    fn index(&self, index: usize) -> &Self::Output {
94        &self.rows[index]
95    }
96}
97
98impl std::ops::Index<&str> for DataFrame {
99    type Output = Column;
100
101    fn index(&self, index: &str) -> &Self::Output {
102        &self.columns[index]
103    }
104}
105
106impl std::fmt::Display for DataFrame {
107    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
108        let w = 4;
109
110        // Compute the width of each column
111        let widths: Vec<usize> = self
112            .headers
113            .iter()
114            .map(|header| max_length!(&self.columns[header]).max(header.len()))
115            .collect();
116
117        // Right-align the headers
118        for (i, header) in self.headers.iter().enumerate() {
119            write!(f, "{}", left_pad!(header, widths[i]))?;
120            if i < self.headers.len() - 1 {
121                write!(f, "{}", " ".repeat(w))?;
122            }
123        }
124        writeln!(f)?;
125
126        // Right-align the values
127        for (i, row) in self.rows.iter().enumerate() {
128            for (i, value) in row.get_values().iter().enumerate() {
129                write!(f, "{}", left_pad!(value, widths[i]))?;
130                if i < row.get_values().len() - 1 {
131                    write!(f, "{}", " ".repeat(w))?;
132                }
133            }
134            if i < self.rows.len() - 1 {
135                writeln!(f)?;
136            }
137        }
138
139        Ok(())
140    }
141}