sql_splitter/writer/
mod.rs1use ahash::AHashMap;
2use std::fs::{self, File};
3use std::io::{BufWriter, Write};
4use std::path::{Path, PathBuf};
5
6pub const WRITER_BUFFER_SIZE: usize = 256 * 1024;
7pub const STMT_BUFFER_COUNT: usize = 100;
8
9pub struct TableWriter {
10 writer: BufWriter<File>,
11 write_count: usize,
12 max_stmt_buffer: usize,
13}
14
15impl TableWriter {
16 pub fn new(filename: &Path) -> std::io::Result<Self> {
17 let file = File::create(filename)?;
18 let writer = BufWriter::with_capacity(WRITER_BUFFER_SIZE, file);
19
20 Ok(Self {
21 writer,
22 write_count: 0,
23 max_stmt_buffer: STMT_BUFFER_COUNT,
24 })
25 }
26
27 pub fn write_statement(&mut self, stmt: &[u8]) -> std::io::Result<()> {
28 self.writer.write_all(stmt)?;
29 self.writer.write_all(b"\n")?;
30
31 self.write_count += 1;
32 if self.write_count >= self.max_stmt_buffer {
33 self.write_count = 0;
34 self.writer.flush()?;
35 }
36
37 Ok(())
38 }
39
40 pub fn flush(&mut self) -> std::io::Result<()> {
41 self.write_count = 0;
42 self.writer.flush()
43 }
44}
45
46pub struct WriterPool {
47 output_dir: PathBuf,
48 writers: AHashMap<String, TableWriter>,
49}
50
51impl WriterPool {
52 pub fn new(output_dir: PathBuf) -> Self {
53 Self {
54 output_dir,
55 writers: AHashMap::new(),
56 }
57 }
58
59 pub fn ensure_output_dir(&self) -> std::io::Result<()> {
60 fs::create_dir_all(&self.output_dir)
61 }
62
63 pub fn get_writer(&mut self, table_name: &str) -> std::io::Result<&mut TableWriter> {
64 if !self.writers.contains_key(table_name) {
65 let filename = self.output_dir.join(format!("{}.sql", table_name));
66 let writer = TableWriter::new(&filename)?;
67 self.writers.insert(table_name.to_string(), writer);
68 }
69
70 Ok(self.writers.get_mut(table_name).unwrap())
71 }
72
73 pub fn write_statement(&mut self, table_name: &str, stmt: &[u8]) -> std::io::Result<()> {
74 let writer = self.get_writer(table_name)?;
75 writer.write_statement(stmt)
76 }
77
78 pub fn close_all(&mut self) -> std::io::Result<()> {
79 for (_, writer) in self.writers.iter_mut() {
80 writer.flush()?;
81 }
82 Ok(())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use tempfile::TempDir;
90
91 #[test]
92 fn test_table_writer() {
93 let temp_dir = TempDir::new().unwrap();
94 let file_path = temp_dir.path().join("test.sql");
95
96 let mut writer = TableWriter::new(&file_path).unwrap();
97 writer
98 .write_statement(b"CREATE TABLE t1 (id INT);")
99 .unwrap();
100 writer
101 .write_statement(b"INSERT INTO t1 VALUES (1);")
102 .unwrap();
103 writer.flush().unwrap();
104
105 let content = std::fs::read_to_string(&file_path).unwrap();
106 assert!(content.contains("CREATE TABLE t1"));
107 assert!(content.contains("INSERT INTO t1"));
108 }
109
110 #[test]
111 fn test_writer_pool() {
112 let temp_dir = TempDir::new().unwrap();
113 let mut pool = WriterPool::new(temp_dir.path().to_path_buf());
114 pool.ensure_output_dir().unwrap();
115
116 pool.write_statement("users", b"CREATE TABLE users (id INT);")
117 .unwrap();
118 pool.write_statement("posts", b"CREATE TABLE posts (id INT);")
119 .unwrap();
120 pool.write_statement("users", b"INSERT INTO users VALUES (1);")
121 .unwrap();
122
123 pool.close_all().unwrap();
124
125 let users_content = std::fs::read_to_string(temp_dir.path().join("users.sql")).unwrap();
127 assert!(users_content.contains("CREATE TABLE users"));
128 assert!(users_content.contains("INSERT INTO users"));
129
130 let posts_content = std::fs::read_to_string(temp_dir.path().join("posts.sql")).unwrap();
131 assert!(posts_content.contains("CREATE TABLE posts"));
132 }
133}