sql_splitter/analyzer/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use ahash::AHashMap;
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7#[derive(Debug, Clone)]
8pub struct TableStats {
9    pub table_name: String,
10    pub insert_count: u64,
11    pub create_count: u64,
12    pub total_bytes: u64,
13    pub statement_count: u64,
14}
15
16impl TableStats {
17    fn new(table_name: String) -> Self {
18        Self {
19            table_name,
20            insert_count: 0,
21            create_count: 0,
22            total_bytes: 0,
23            statement_count: 0,
24        }
25    }
26}
27
28pub struct Analyzer {
29    input_file: PathBuf,
30    dialect: SqlDialect,
31    stats: AHashMap<String, TableStats>,
32}
33
34impl Analyzer {
35    pub fn new(input_file: PathBuf) -> Self {
36        Self {
37            input_file,
38            dialect: SqlDialect::default(),
39            stats: AHashMap::new(),
40        }
41    }
42
43    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
44        self.dialect = dialect;
45        self
46    }
47
48    pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
49        let file = File::open(&self.input_file)?;
50        let file_size = file.metadata()?.len();
51        let buffer_size = determine_buffer_size(file_size);
52        let dialect = self.dialect;
53
54        let mut parser = Parser::with_dialect(file, buffer_size, dialect);
55
56        while let Some(stmt) = parser.read_statement()? {
57            let (stmt_type, table_name) =
58                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
59
60            if stmt_type == StatementType::Unknown || table_name.is_empty() {
61                continue;
62            }
63
64            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
65        }
66
67        Ok(self.get_sorted_stats())
68    }
69
70    pub fn analyze_with_progress<F: Fn(u64)>(
71        mut self,
72        progress_fn: F,
73    ) -> anyhow::Result<Vec<TableStats>> {
74        let file = File::open(&self.input_file)?;
75        let file_size = file.metadata()?.len();
76        let buffer_size = determine_buffer_size(file_size);
77        let dialect = self.dialect;
78
79        let reader = ProgressReader::new(file, progress_fn);
80        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
81
82        while let Some(stmt) = parser.read_statement()? {
83            let (stmt_type, table_name) =
84                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
85
86            if stmt_type == StatementType::Unknown || table_name.is_empty() {
87                continue;
88            }
89
90            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
91        }
92
93        Ok(self.get_sorted_stats())
94    }
95
96    fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
97        let stats = self
98            .stats
99            .entry(table_name.to_string())
100            .or_insert_with(|| TableStats::new(table_name.to_string()));
101
102        stats.statement_count += 1;
103        stats.total_bytes += bytes;
104
105        match stmt_type {
106            StatementType::CreateTable => stats.create_count += 1,
107            StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
108            _ => {}
109        }
110    }
111
112    fn get_sorted_stats(&self) -> Vec<TableStats> {
113        let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
114        result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
115        result
116    }
117}
118
119struct ProgressReader<R: Read, F: Fn(u64)> {
120    reader: R,
121    callback: F,
122    bytes_read: u64,
123}
124
125impl<R: Read, F: Fn(u64)> ProgressReader<R, F> {
126    fn new(reader: R, callback: F) -> Self {
127        Self {
128            reader,
129            callback,
130            bytes_read: 0,
131        }
132    }
133}
134
135impl<R: Read, F: Fn(u64)> Read for ProgressReader<R, F> {
136    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
137        let n = self.reader.read(buf)?;
138        self.bytes_read += n as u64;
139        (self.callback)(self.bytes_read);
140        Ok(n)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use tempfile::TempDir;
148
149    #[test]
150    fn test_analyzer_basic() {
151        let temp_dir = TempDir::new().unwrap();
152        let input_file = temp_dir.path().join("input.sql");
153
154        std::fs::write(
155            &input_file,
156            b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);\nCREATE TABLE posts (id INT);\nINSERT INTO posts VALUES (1);",
157        )
158        .unwrap();
159
160        let analyzer = Analyzer::new(input_file);
161        let stats = analyzer.analyze().unwrap();
162
163        assert_eq!(stats.len(), 2);
164
165        let users_stats = stats.iter().find(|s| s.table_name == "users").unwrap();
166        assert_eq!(users_stats.insert_count, 2);
167        assert_eq!(users_stats.create_count, 1);
168        assert_eq!(users_stats.statement_count, 3);
169
170        let posts_stats = stats.iter().find(|s| s.table_name == "posts").unwrap();
171        assert_eq!(posts_stats.insert_count, 1);
172        assert_eq!(posts_stats.create_count, 1);
173        assert_eq!(posts_stats.statement_count, 2);
174    }
175
176    #[test]
177    fn test_analyzer_sorted_by_insert_count() {
178        let temp_dir = TempDir::new().unwrap();
179        let input_file = temp_dir.path().join("input.sql");
180
181        std::fs::write(
182            &input_file,
183            b"CREATE TABLE a (id INT);\nINSERT INTO a VALUES (1);\nCREATE TABLE b (id INT);\nINSERT INTO b VALUES (1);\nINSERT INTO b VALUES (2);\nINSERT INTO b VALUES (3);",
184        )
185        .unwrap();
186
187        let analyzer = Analyzer::new(input_file);
188        let stats = analyzer.analyze().unwrap();
189
190        assert_eq!(stats[0].table_name, "b");
191        assert_eq!(stats[0].insert_count, 3);
192        assert_eq!(stats[1].table_name, "a");
193        assert_eq!(stats[1].insert_count, 1);
194    }
195}