sql_splitter/analyzer/
mod.rs1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use ahash::AHashMap;
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7#[derive(Debug, Clone)]
8pub struct TableStats {
9 pub table_name: String,
10 pub insert_count: u64,
11 pub create_count: u64,
12 pub total_bytes: u64,
13 pub statement_count: u64,
14}
15
16impl TableStats {
17 fn new(table_name: String) -> Self {
18 Self {
19 table_name,
20 insert_count: 0,
21 create_count: 0,
22 total_bytes: 0,
23 statement_count: 0,
24 }
25 }
26}
27
28pub struct Analyzer {
29 input_file: PathBuf,
30 dialect: SqlDialect,
31 stats: AHashMap<String, TableStats>,
32}
33
34impl Analyzer {
35 pub fn new(input_file: PathBuf) -> Self {
36 Self {
37 input_file,
38 dialect: SqlDialect::default(),
39 stats: AHashMap::new(),
40 }
41 }
42
43 pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
44 self.dialect = dialect;
45 self
46 }
47
48 pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
49 let file = File::open(&self.input_file)?;
50 let file_size = file.metadata()?.len();
51 let buffer_size = determine_buffer_size(file_size);
52 let dialect = self.dialect;
53
54 let mut parser = Parser::with_dialect(file, buffer_size, dialect);
55
56 while let Some(stmt) = parser.read_statement()? {
57 let (stmt_type, table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
58
59 if stmt_type == StatementType::Unknown || table_name.is_empty() {
60 continue;
61 }
62
63 self.update_stats(&table_name, stmt_type, stmt.len() as u64);
64 }
65
66 Ok(self.get_sorted_stats())
67 }
68
69 pub fn analyze_with_progress<F: Fn(u64)>(
70 mut self,
71 progress_fn: F,
72 ) -> anyhow::Result<Vec<TableStats>> {
73 let file = File::open(&self.input_file)?;
74 let file_size = file.metadata()?.len();
75 let buffer_size = determine_buffer_size(file_size);
76 let dialect = self.dialect;
77
78 let reader = ProgressReader::new(file, progress_fn);
79 let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
80
81 while let Some(stmt) = parser.read_statement()? {
82 let (stmt_type, table_name) = Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
83
84 if stmt_type == StatementType::Unknown || table_name.is_empty() {
85 continue;
86 }
87
88 self.update_stats(&table_name, stmt_type, stmt.len() as u64);
89 }
90
91 Ok(self.get_sorted_stats())
92 }
93
94 fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
95 let stats = self
96 .stats
97 .entry(table_name.to_string())
98 .or_insert_with(|| TableStats::new(table_name.to_string()));
99
100 stats.statement_count += 1;
101 stats.total_bytes += bytes;
102
103 match stmt_type {
104 StatementType::CreateTable => stats.create_count += 1,
105 StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
106 _ => {}
107 }
108 }
109
110 fn get_sorted_stats(&self) -> Vec<TableStats> {
111 let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
112 result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
113 result
114 }
115}
116
117struct ProgressReader<R: Read, F: Fn(u64)> {
118 reader: R,
119 callback: F,
120 bytes_read: u64,
121}
122
123impl<R: Read, F: Fn(u64)> ProgressReader<R, F> {
124 fn new(reader: R, callback: F) -> Self {
125 Self {
126 reader,
127 callback,
128 bytes_read: 0,
129 }
130 }
131}
132
133impl<R: Read, F: Fn(u64)> Read for ProgressReader<R, F> {
134 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
135 let n = self.reader.read(buf)?;
136 self.bytes_read += n as u64;
137 (self.callback)(self.bytes_read);
138 Ok(n)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use tempfile::TempDir;
146
147 #[test]
148 fn test_analyzer_basic() {
149 let temp_dir = TempDir::new().unwrap();
150 let input_file = temp_dir.path().join("input.sql");
151
152 std::fs::write(
153 &input_file,
154 b"CREATE TABLE users (id INT);\nINSERT INTO users VALUES (1);\nINSERT INTO users VALUES (2);\nCREATE TABLE posts (id INT);\nINSERT INTO posts VALUES (1);",
155 )
156 .unwrap();
157
158 let analyzer = Analyzer::new(input_file);
159 let stats = analyzer.analyze().unwrap();
160
161 assert_eq!(stats.len(), 2);
162
163 let users_stats = stats.iter().find(|s| s.table_name == "users").unwrap();
164 assert_eq!(users_stats.insert_count, 2);
165 assert_eq!(users_stats.create_count, 1);
166 assert_eq!(users_stats.statement_count, 3);
167
168 let posts_stats = stats.iter().find(|s| s.table_name == "posts").unwrap();
169 assert_eq!(posts_stats.insert_count, 1);
170 assert_eq!(posts_stats.create_count, 1);
171 assert_eq!(posts_stats.statement_count, 2);
172 }
173
174 #[test]
175 fn test_analyzer_sorted_by_insert_count() {
176 let temp_dir = TempDir::new().unwrap();
177 let input_file = temp_dir.path().join("input.sql");
178
179 std::fs::write(
180 &input_file,
181 b"CREATE TABLE a (id INT);\nINSERT INTO a VALUES (1);\nCREATE TABLE b (id INT);\nINSERT INTO b VALUES (1);\nINSERT INTO b VALUES (2);\nINSERT INTO b VALUES (3);",
182 )
183 .unwrap();
184
185 let analyzer = Analyzer::new(input_file);
186 let stats = analyzer.analyze().unwrap();
187
188 assert_eq!(stats[0].table_name, "b");
189 assert_eq!(stats[0].insert_count, 3);
190 assert_eq!(stats[1].table_name, "a");
191 assert_eq!(stats[1].insert_count, 1);
192 }
193}