sql_splitter/writer/
mod.rs

1use ahash::AHashMap;
2use std::fs::{self, File};
3use std::io::{BufWriter, Write};
4use std::path::{Path, PathBuf};
5
6pub const WRITER_BUFFER_SIZE: usize = 256 * 1024;
7pub const STMT_BUFFER_COUNT: usize = 100;
8
9pub struct TableWriter {
10    writer: BufWriter<File>,
11    write_count: usize,
12    max_stmt_buffer: usize,
13}
14
15impl TableWriter {
16    pub fn new(filename: &Path) -> std::io::Result<Self> {
17        let file = File::create(filename)?;
18        let writer = BufWriter::with_capacity(WRITER_BUFFER_SIZE, file);
19
20        Ok(Self {
21            writer,
22            write_count: 0,
23            max_stmt_buffer: STMT_BUFFER_COUNT,
24        })
25    }
26
27    pub fn write_statement(&mut self, stmt: &[u8]) -> std::io::Result<()> {
28        self.writer.write_all(stmt)?;
29        self.writer.write_all(b"\n")?;
30
31        self.write_count += 1;
32        if self.write_count >= self.max_stmt_buffer {
33            self.write_count = 0;
34            self.writer.flush()?;
35        }
36
37        Ok(())
38    }
39
40    pub fn flush(&mut self) -> std::io::Result<()> {
41        self.write_count = 0;
42        self.writer.flush()
43    }
44}
45
46pub struct WriterPool {
47    output_dir: PathBuf,
48    writers: AHashMap<String, TableWriter>,
49}
50
51impl WriterPool {
52    pub fn new(output_dir: PathBuf) -> Self {
53        Self {
54            output_dir,
55            writers: AHashMap::new(),
56        }
57    }
58
59    pub fn ensure_output_dir(&self) -> std::io::Result<()> {
60        fs::create_dir_all(&self.output_dir)
61    }
62
63    pub fn get_writer(&mut self, table_name: &str) -> std::io::Result<&mut TableWriter> {
64        if !self.writers.contains_key(table_name) {
65            let filename = self.output_dir.join(format!("{}.sql", table_name));
66            let writer = TableWriter::new(&filename)?;
67            self.writers.insert(table_name.to_string(), writer);
68        }
69
70        Ok(self.writers.get_mut(table_name).unwrap())
71    }
72
73    pub fn write_statement(&mut self, table_name: &str, stmt: &[u8]) -> std::io::Result<()> {
74        let writer = self.get_writer(table_name)?;
75        writer.write_statement(stmt)
76    }
77
78    pub fn close_all(&mut self) -> std::io::Result<()> {
79        for (_, writer) in self.writers.iter_mut() {
80            writer.flush()?;
81        }
82        Ok(())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use tempfile::TempDir;
90
91    #[test]
92    fn test_table_writer() {
93        let temp_dir = TempDir::new().unwrap();
94        let file_path = temp_dir.path().join("test.sql");
95
96        let mut writer = TableWriter::new(&file_path).unwrap();
97        writer
98            .write_statement(b"CREATE TABLE t1 (id INT);")
99            .unwrap();
100        writer
101            .write_statement(b"INSERT INTO t1 VALUES (1);")
102            .unwrap();
103        writer.flush().unwrap();
104
105        let content = std::fs::read_to_string(&file_path).unwrap();
106        assert!(content.contains("CREATE TABLE t1"));
107        assert!(content.contains("INSERT INTO t1"));
108    }
109
110    #[test]
111    fn test_writer_pool() {
112        let temp_dir = TempDir::new().unwrap();
113        let mut pool = WriterPool::new(temp_dir.path().to_path_buf());
114        pool.ensure_output_dir().unwrap();
115
116        pool.write_statement("users", b"CREATE TABLE users (id INT);")
117            .unwrap();
118        pool.write_statement("posts", b"CREATE TABLE posts (id INT);")
119            .unwrap();
120        pool.write_statement("users", b"INSERT INTO users VALUES (1);")
121            .unwrap();
122
123        pool.close_all().unwrap();
124
125        // Verify both table files were created
126        let users_content = std::fs::read_to_string(temp_dir.path().join("users.sql")).unwrap();
127        assert!(users_content.contains("CREATE TABLE users"));
128        assert!(users_content.contains("INSERT INTO users"));
129
130        let posts_content = std::fs::read_to_string(temp_dir.path().join("posts.sql")).unwrap();
131        assert!(posts_content.contains("CREATE TABLE posts"));
132    }
133}