sql_splitter/analyzer/
mod.rs

1use crate::parser::{determine_buffer_size, Parser, SqlDialect, StatementType};
2use crate::progress::ProgressReader;
3use crate::splitter::Compression;
4use ahash::AHashMap;
5use serde::Serialize;
6use std::fs::File;
7use std::io::Read;
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Serialize)]
11pub struct TableStats {
12    pub table_name: String,
13    pub insert_count: u64,
14    pub create_count: u64,
15    pub total_bytes: u64,
16    pub statement_count: u64,
17}
18
19impl TableStats {
20    fn new(table_name: String) -> Self {
21        Self {
22            table_name,
23            insert_count: 0,
24            create_count: 0,
25            total_bytes: 0,
26            statement_count: 0,
27        }
28    }
29}
30
31pub struct Analyzer {
32    input_file: PathBuf,
33    dialect: SqlDialect,
34    stats: AHashMap<String, TableStats>,
35}
36
37impl Analyzer {
38    pub fn new(input_file: PathBuf) -> Self {
39        Self {
40            input_file,
41            dialect: SqlDialect::default(),
42            stats: AHashMap::new(),
43        }
44    }
45
46    pub fn with_dialect(mut self, dialect: SqlDialect) -> Self {
47        self.dialect = dialect;
48        self
49    }
50
51    pub fn analyze(mut self) -> anyhow::Result<Vec<TableStats>> {
52        let file = File::open(&self.input_file)?;
53        let file_size = file.metadata()?.len();
54        let buffer_size = determine_buffer_size(file_size);
55        let dialect = self.dialect;
56
57        // Detect and apply decompression
58        let compression = Compression::from_path(&self.input_file);
59        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(file));
60
61        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
62
63        while let Some(stmt) = parser.read_statement()? {
64            let (stmt_type, table_name) =
65                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
66
67            if stmt_type == StatementType::Unknown || table_name.is_empty() {
68                continue;
69            }
70
71            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
72        }
73
74        Ok(self.get_sorted_stats())
75    }
76
77    pub fn analyze_with_progress<F: Fn(u64) + 'static>(
78        mut self,
79        progress_fn: F,
80    ) -> anyhow::Result<Vec<TableStats>> {
81        let file = File::open(&self.input_file)?;
82        let file_size = file.metadata()?.len();
83        let buffer_size = determine_buffer_size(file_size);
84        let dialect = self.dialect;
85
86        // Detect and apply decompression
87        let compression = Compression::from_path(&self.input_file);
88        let progress_reader = ProgressReader::new(file, progress_fn);
89        let reader: Box<dyn Read> = compression.wrap_reader(Box::new(progress_reader));
90
91        let mut parser = Parser::with_dialect(reader, buffer_size, dialect);
92
93        while let Some(stmt) = parser.read_statement()? {
94            let (stmt_type, table_name) =
95                Parser::<&[u8]>::parse_statement_with_dialect(&stmt, dialect);
96
97            if stmt_type == StatementType::Unknown || table_name.is_empty() {
98                continue;
99            }
100
101            self.update_stats(&table_name, stmt_type, stmt.len() as u64);
102        }
103
104        Ok(self.get_sorted_stats())
105    }
106
107    fn update_stats(&mut self, table_name: &str, stmt_type: StatementType, bytes: u64) {
108        let stats = self
109            .stats
110            .entry(table_name.to_string())
111            .or_insert_with(|| TableStats::new(table_name.to_string()));
112
113        stats.statement_count += 1;
114        stats.total_bytes += bytes;
115
116        match stmt_type {
117            StatementType::CreateTable => stats.create_count += 1,
118            StatementType::Insert | StatementType::Copy => stats.insert_count += 1,
119            _ => {}
120        }
121    }
122
123    fn get_sorted_stats(&self) -> Vec<TableStats> {
124        let mut result: Vec<TableStats> = self.stats.values().cloned().collect();
125        result.sort_by(|a, b| b.insert_count.cmp(&a.insert_count));
126        result
127    }
128}