sql_splitter/analyzer/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::splitter::Compression;
4use ahash::AHashMap;
5use std::fs::File;
6use std::io::Read;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone)]
10pub struct TableStats {
11    pub table_name: String,
12    pub insert_count: u64,
13    pub create_count: u64,
14    pub total_bytes: u64,
15    pub statement_count: u64,
16}
17
18impl TableStats {
19    fn new(table_name: String) -> Self {
20        Self {
21            table_name,
22            insert_count: 0,
23            create_count: 0,
24            total_bytes: 0,
25            statement_count: 0,
26        }
27    }
28}
29
30pub struct Analyzer {
31    input_file: PathBuf,
32    dialect: SqlDialect,
33    stats: AHashMap<String, TableStats>,
34}
35
36impl Analyzer {
37    pub fn new(input_file: PathBuf) -> Self {
38        Self {
39            input_file,
40            dialect: SqlDialect::default(),
41            stats: AHashMap::new(),
42        }
43    }
44
45    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
46        self.dialect = dialect;
47        self
48    }
49
50    pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
51        let file = File::open(&self.input_file)?;
52        let file_size = file.metadata()?.len();
53        let buffer_size = determine_buffer_size(file_size);
54        let dialect = self.dialect;
55
56        // Detect and apply decompression
57        let compression = Compression::from_path(&self.input_file);
58        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(file));
59
60        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
61
62        while let Some(stmt) = parser.read_statement()? {
63            let (stmt_type, table_name) =
64                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
65
66            if stmt_type == StatementType::Unknown || table_name.is_empty() {
67                continue;
68            }
69
70            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
71        }
72
73        Ok(self.get_sorted_stats())
74    }
75
76    pub fn analyze_with_progress<F: Fn(u64) + 'static>(
77        mut self,
78        progress_fn: F,
79    ) -> anyhow::Result<Vec<TableStats>> {
80        let file = File::open(&self.input_file)?;
81        let file_size = file.metadata()?.len();
82        let buffer_size = determine_buffer_size(file_size);
83        let dialect = self.dialect;
84
85        // Detect and apply decompression
86        let compression = Compression::from_path(&self.input_file);
87        let progress_reader = ProgressReader::new(file, progress_fn);
88        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(progress_reader));
89
90        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
91
92        while let Some(stmt) = parser.read_statement()? {
93            let (stmt_type, table_name) =
94                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
95
96            if stmt_type == StatementType::Unknown || table_name.is_empty() {
97                continue;
98            }
99
100            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
101        }
102
103        Ok(self.get_sorted_stats())
104    }
105
106    fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
107        let stats = self
108            .stats
109            .entry(table_name.to_string())
110            .or_insert_with(|| TableStats::new(table_name.to_string()));
111
112        stats.statement_count += 1;
113        stats.total_bytes += bytes;
114
115        match stmt_type {
116            StatementType::CreateTable => stats.create_count += 1,
117            StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
118            _ => {}
119        }
120    }
121
122    fn get_sorted_stats(&self) -> Vec<TableStats> {
123        let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
124        result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
125        result
126    }
127}