Skip to main content

sql_splitter/parser/
postgres_copy.rs

1//! PostgreSQL COPY statement parser.
2//!
3//! Parses COPY ... FROM stdin data blocks to extract individual rows
4//! and optionally extract PK/FK column values for dependency tracking.
5
6use crate::schema::{ColumnId, ColumnType, TableSchema};
7use smallvec::SmallVec;
8
9// Re-use types from mysql_insert for consistency
10use super::mysql_insert::{FkRef, PkValue};
11
12/// Tuple of PK values for composite primary keys
13pub type PkTuple = SmallVec<[PkValue; 2]>;
14
15/// A parsed row from a COPY data block
16#[derive(Debug, Clone)]
17pub struct ParsedCopyRow {
18    /// Raw bytes of the row (tab-separated values, no newline)
19    pub raw: Vec<u8>,
20    /// Extracted primary key values (if table has PK and values are non-NULL)
21    pub pk: Option<PkTuple>,
22    /// Extracted foreign key values with their references
23    pub fk_values: Vec<(FkRef, PkTuple)>,
24    /// All column values (for data diff comparison)
25    pub all_values: Vec<PkValue>,
26    /// Mapping from schema column index to value index (for finding specific columns)
27    pub column_map: Vec<Option<usize>>,
28}
29
30impl ParsedCopyRow {
31    /// Get the value for a specific schema column index
32    pub fn get_column_value(&self, schema_col_index: usize) -> Option<&PkValue> {
33        self.column_map
34            .get(schema_col_index)
35            .and_then(|v| *v)
36            .and_then(|val_idx| self.all_values.get(val_idx))
37    }
38}
39
40/// Parser for PostgreSQL COPY data blocks
41pub struct CopyParser<'a> {
42    data: &'a [u8],
43    table_schema: Option<&'a TableSchema>,
44    /// Column order from COPY header
45    column_order: Vec<Option<ColumnId>>,
46}
47
48impl<'a> CopyParser<'a> {
49    /// Create a new parser for COPY data
50    pub fn new(data: &'a [u8]) -> Self {
51        Self {
52            data,
53            table_schema: None,
54            column_order: Vec::new(),
55        }
56    }
57
58    /// Set the table schema for PK/FK extraction
59    pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
60        self.table_schema = Some(schema);
61        self
62    }
63
64    /// Set column order from COPY header
65    pub fn with_column_order(mut self, columns: Vec<String>) -> Self {
66        if let Some(schema) = self.table_schema {
67            self.column_order = columns
68                .iter()
69                .map(|name| schema.get_column_id(name))
70                .collect();
71        }
72        self
73    }
74
75    /// Parse all rows from the COPY data block
76    pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedCopyRow>> {
77        // If no explicit column order, use natural schema order
78        if self.column_order.is_empty() {
79            if let Some(schema) = self.table_schema {
80                self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
81            }
82        }
83
84        let mut rows = Vec::new();
85        let mut pos = 0;
86
87        while pos < self.data.len() {
88            // Find end of line
89            let line_end = self.data[pos..]
90                .iter()
91                .position(|&b| b == b'\n')
92                .map(|p| pos + p)
93                .unwrap_or(self.data.len());
94
95            let line = &self.data[pos..line_end];
96
97            // Check for terminator
98            if line == b"\\." || line.is_empty() {
99                pos = line_end + 1;
100                continue;
101            }
102
103            // Parse the row
104            if let Some(row) = self.parse_row(line)? {
105                rows.push(row);
106            }
107
108            pos = line_end + 1;
109        }
110
111        Ok(rows)
112    }
113
114    /// Parse a single tab-separated row
115    fn parse_row(&self, line: &[u8]) -> anyhow::Result<Option<ParsedCopyRow>> {
116        let raw = line.to_vec();
117
118        // Split by tabs
119        let values: Vec<CopyValue> = self.split_and_parse_values(line);
120
121        // Extract PK, FK, all values, and column map if we have schema
122        let (pk, fk_values, all_values, column_map) = if let Some(schema) = self.table_schema {
123            let (pk, fk_values, all_values) = self.extract_pk_fk(&values, schema);
124            let column_map = self.build_column_map(schema);
125            (pk, fk_values, all_values, column_map)
126        } else {
127            (None, Vec::new(), Vec::new(), Vec::new())
128        };
129
130        Ok(Some(ParsedCopyRow {
131            raw,
132            pk,
133            fk_values,
134            all_values,
135            column_map,
136        }))
137    }
138
139    /// Build a mapping from schema column index to value index
140    fn build_column_map(&self, schema: &TableSchema) -> Vec<Option<usize>> {
141        let mut map = vec![None; schema.columns.len()];
142
143        for (val_idx, col_id_opt) in self.column_order.iter().enumerate() {
144            if let Some(col_id) = col_id_opt {
145                let ordinal = col_id.0 as usize;
146                if ordinal < map.len() {
147                    map[ordinal] = Some(val_idx);
148                }
149            }
150        }
151
152        map
153    }
154
155    /// Split line by tabs and parse each value
156    fn split_and_parse_values(&self, line: &[u8]) -> Vec<CopyValue> {
157        let mut values = Vec::new();
158        let mut start = 0;
159
160        for (i, &b) in line.iter().enumerate() {
161            if b == b'\t' {
162                values.push(self.parse_copy_value(&line[start..i]));
163                start = i + 1;
164            }
165        }
166        // Last value
167        if start <= line.len() {
168            values.push(self.parse_copy_value(&line[start..]));
169        }
170
171        values
172    }
173
174    /// Parse a single COPY value
175    fn parse_copy_value(&self, value: &[u8]) -> CopyValue {
176        // Check for NULL marker
177        if value == b"\\N" {
178            return CopyValue::Null;
179        }
180
181        // Decode escape sequences
182        let decoded = self.decode_copy_escapes(value);
183
184        // Try to parse as integer
185        if let Ok(s) = std::str::from_utf8(&decoded) {
186            if let Ok(n) = s.parse::<i64>() {
187                return CopyValue::Integer(n);
188            }
189            if let Ok(n) = s.parse::<i128>() {
190                return CopyValue::BigInteger(n);
191            }
192        }
193
194        CopyValue::Text(decoded)
195    }
196
197    /// Decode PostgreSQL COPY escape sequences
198    pub fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
199        let mut result = Vec::with_capacity(value.len());
200        let mut i = 0;
201
202        while i < value.len() {
203            if value[i] == b'\\' && i + 1 < value.len() {
204                let next = value[i + 1];
205                let decoded = match next {
206                    b'n' => b'\n',
207                    b'r' => b'\r',
208                    b't' => b'\t',
209                    b'\\' => b'\\',
210                    b'N' => {
211                        // This shouldn't happen here since we check for \N above
212                        result.push(b'\\');
213                        result.push(b'N');
214                        i += 2;
215                        continue;
216                    }
217                    _ => {
218                        // Unknown escape, keep as-is
219                        result.push(b'\\');
220                        result.push(next);
221                        i += 2;
222                        continue;
223                    }
224                };
225                result.push(decoded);
226                i += 2;
227            } else {
228                result.push(value[i]);
229                i += 1;
230            }
231        }
232
233        result
234    }
235
236    /// Extract PK, FK, and all values from parsed values
237    fn extract_pk_fk(
238        &self,
239        values: &[CopyValue],
240        schema: &TableSchema,
241    ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
242        let mut pk_values = PkTuple::new();
243        let mut fk_values = Vec::new();
244
245        // Build all_values: convert each value to PkValue
246        let all_values: Vec<PkValue> = values
247            .iter()
248            .enumerate()
249            .map(|(idx, v)| {
250                let col = self
251                    .column_order
252                    .get(idx)
253                    .and_then(|c| *c)
254                    .and_then(|id| schema.column(id));
255                self.value_to_pk(v, col)
256            })
257            .collect();
258
259        // Build PK from columns marked as primary key
260        for (idx, col_id_opt) in self.column_order.iter().enumerate() {
261            if let Some(col_id) = col_id_opt {
262                if schema.is_pk_column(*col_id) {
263                    if let Some(value) = values.get(idx) {
264                        let pk_val = self.value_to_pk(value, schema.column(*col_id));
265                        pk_values.push(pk_val);
266                    }
267                }
268            }
269        }
270
271        // Build FK tuples
272        for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
273            if fk.referenced_table_id.is_none() {
274                continue;
275            }
276
277            let mut fk_tuple = PkTuple::new();
278            let mut all_non_null = true;
279
280            for &col_id in &fk.columns {
281                if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
282                    if let Some(value) = values.get(idx) {
283                        let pk_val = self.value_to_pk(value, schema.column(col_id));
284                        if pk_val.is_null() {
285                            all_non_null = false;
286                            break;
287                        }
288                        fk_tuple.push(pk_val);
289                    }
290                }
291            }
292
293            if all_non_null && !fk_tuple.is_empty() {
294                fk_values.push((
295                    FkRef {
296                        table_id: schema.id.0,
297                        fk_index: fk_idx as u16,
298                    },
299                    fk_tuple,
300                ));
301            }
302        }
303
304        let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
305            None
306        } else {
307            Some(pk_values)
308        };
309
310        (pk, fk_values, all_values)
311    }
312
313    /// Convert a parsed value to a PkValue
314    fn value_to_pk(&self, value: &CopyValue, col: Option<&crate::schema::Column>) -> PkValue {
315        match value {
316            CopyValue::Null => PkValue::Null,
317            CopyValue::Integer(n) => PkValue::Int(*n),
318            CopyValue::BigInteger(n) => PkValue::BigInt(*n),
319            CopyValue::Text(bytes) => {
320                let s = String::from_utf8_lossy(bytes);
321
322                // Check if this might be an integer stored as text
323                if let Some(col) = col {
324                    match col.col_type {
325                        ColumnType::Int => {
326                            if let Ok(n) = s.parse::<i64>() {
327                                return PkValue::Int(n);
328                            }
329                        }
330                        ColumnType::BigInt => {
331                            if let Ok(n) = s.parse::<i128>() {
332                                return PkValue::BigInt(n);
333                            }
334                        }
335                        _ => {}
336                    }
337                }
338
339                PkValue::Text(s.into_owned().into_boxed_str())
340            }
341        }
342    }
343}
344
345/// Internal representation of a parsed COPY value
346#[derive(Debug, Clone)]
347enum CopyValue {
348    Null,
349    Integer(i64),
350    BigInteger(i128),
351    Text(Vec<u8>),
352}
353
354/// Parse column list from COPY header
355pub fn parse_copy_columns(header: &str) -> Vec<String> {
356    // COPY table_name (col1, col2, ...) FROM stdin;
357    if let Some(start) = header.find('(') {
358        if let Some(end) = header.find(')') {
359            let cols = &header[start + 1..end];
360            return cols
361                .split(',')
362                .map(|c| c.trim().trim_matches('"').to_string())
363                .collect();
364        }
365    }
366    Vec::new()
367}
368
369/// Parse all rows from a PostgreSQL COPY data block
370pub fn parse_postgres_copy_rows(
371    data: &[u8],
372    schema: &TableSchema,
373    column_order: Vec<String>,
374) -> anyhow::Result<Vec<ParsedCopyRow>> {
375    let mut parser = CopyParser::new(data)
376        .with_schema(schema)
377        .with_column_order(column_order);
378    parser.parse_rows()
379}