sql_splitter/parser/
postgres_copy.rs

1//! PostgreSQL COPY statement parser.
2//!
3//! Parses COPY ... FROM stdin data blocks to extract individual rows
4//! and optionally extract PK/FK column values for dependency tracking.
5
6use crate::schema::{ColumnId, ColumnType, TableSchema};
7use smallvec::SmallVec;
8
9// Re-use types from mysql_insert for consistency
10use super::mysql_insert::{FkRef, PkValue};
11
12/// Tuple of PK values for composite primary keys
13pub type PkTuple = SmallVec<[PkValue; 2]>;
14
15/// A parsed row from a COPY data block
16#[derive(Debug, Clone)]
17pub struct ParsedCopyRow {
18    /// Raw bytes of the row (tab-separated values, no newline)
19    pub raw: Vec<u8>,
20    /// Extracted primary key values (if table has PK and values are non-NULL)
21    pub pk: Option<PkTuple>,
22    /// Extracted foreign key values with their references
23    pub fk_values: Vec<(FkRef, PkTuple)>,
24    /// All column values (for data diff comparison)
25    pub all_values: Vec<PkValue>,
26}
27
28/// Parser for PostgreSQL COPY data blocks
29pub struct CopyParser<'a> {
30    data: &'a [u8],
31    table_schema: Option<&'a TableSchema>,
32    /// Column order from COPY header
33    column_order: Vec<Option<ColumnId>>,
34}
35
36impl<'a> CopyParser<'a> {
37    /// Create a new parser for COPY data
38    pub fn new(data: &'a [u8]) -> Self {
39        Self {
40            data,
41            table_schema: None,
42            column_order: Vec::new(),
43        }
44    }
45
46    /// Set the table schema for PK/FK extraction
47    pub fn with_schema(mut self, schema: &'a TableSchema) -> Self {
48        self.table_schema = Some(schema);
49        self
50    }
51
52    /// Set column order from COPY header
53    pub fn with_column_order(mut self, columns: Vec<String>) -> Self {
54        if let Some(schema) = self.table_schema {
55            self.column_order = columns
56                .iter()
57                .map(|name| schema.get_column_id(name))
58                .collect();
59        }
60        self
61    }
62
63    /// Parse all rows from the COPY data block
64    pub fn parse_rows(&mut self) -> anyhow::Result<Vec<ParsedCopyRow>> {
65        // If no explicit column order, use natural schema order
66        if self.column_order.is_empty() {
67            if let Some(schema) = self.table_schema {
68                self.column_order = schema.columns.iter().map(|c| Some(c.ordinal)).collect();
69            }
70        }
71
72        let mut rows = Vec::new();
73        let mut pos = 0;
74
75        while pos < self.data.len() {
76            // Find end of line
77            let line_end = self.data[pos..]
78                .iter()
79                .position(|&b| b == b'\n')
80                .map(|p| pos + p)
81                .unwrap_or(self.data.len());
82
83            let line = &self.data[pos..line_end];
84
85            // Check for terminator
86            if line == b"\\." || line.is_empty() {
87                pos = line_end + 1;
88                continue;
89            }
90
91            // Parse the row
92            if let Some(row) = self.parse_row(line)? {
93                rows.push(row);
94            }
95
96            pos = line_end + 1;
97        }
98
99        Ok(rows)
100    }
101
102    /// Parse a single tab-separated row
103    fn parse_row(&self, line: &[u8]) -> anyhow::Result<Option<ParsedCopyRow>> {
104        let raw = line.to_vec();
105
106        // Split by tabs
107        let values: Vec<CopyValue> = self.split_and_parse_values(line);
108
109        // Extract PK, FK, and all values if we have schema
110        let (pk, fk_values, all_values) = if let Some(schema) = self.table_schema {
111            self.extract_pk_fk(&values, schema)
112        } else {
113            (None, Vec::new(), Vec::new())
114        };
115
116        Ok(Some(ParsedCopyRow {
117            raw,
118            pk,
119            fk_values,
120            all_values,
121        }))
122    }
123
124    /// Split line by tabs and parse each value
125    fn split_and_parse_values(&self, line: &[u8]) -> Vec<CopyValue> {
126        let mut values = Vec::new();
127        let mut start = 0;
128
129        for (i, &b) in line.iter().enumerate() {
130            if b == b'\t' {
131                values.push(self.parse_copy_value(&line[start..i]));
132                start = i + 1;
133            }
134        }
135        // Last value
136        if start <= line.len() {
137            values.push(self.parse_copy_value(&line[start..]));
138        }
139
140        values
141    }
142
143    /// Parse a single COPY value
144    fn parse_copy_value(&self, value: &[u8]) -> CopyValue {
145        // Check for NULL marker
146        if value == b"\\N" {
147            return CopyValue::Null;
148        }
149
150        // Decode escape sequences
151        let decoded = self.decode_copy_escapes(value);
152
153        // Try to parse as integer
154        if let Ok(s) = std::str::from_utf8(&decoded) {
155            if let Ok(n) = s.parse::<i64>() {
156                return CopyValue::Integer(n);
157            }
158            if let Ok(n) = s.parse::<i128>() {
159                return CopyValue::BigInteger(n);
160            }
161        }
162
163        CopyValue::Text(decoded)
164    }
165
166    /// Decode PostgreSQL COPY escape sequences
167    pub fn decode_copy_escapes(&self, value: &[u8]) -> Vec<u8> {
168        let mut result = Vec::with_capacity(value.len());
169        let mut i = 0;
170
171        while i < value.len() {
172            if value[i] == b'\\' && i + 1 < value.len() {
173                let next = value[i + 1];
174                let decoded = match next {
175                    b'n' => b'\n',
176                    b'r' => b'\r',
177                    b't' => b'\t',
178                    b'\\' => b'\\',
179                    b'N' => {
180                        // This shouldn't happen here since we check for \N above
181                        result.push(b'\\');
182                        result.push(b'N');
183                        i += 2;
184                        continue;
185                    }
186                    _ => {
187                        // Unknown escape, keep as-is
188                        result.push(b'\\');
189                        result.push(next);
190                        i += 2;
191                        continue;
192                    }
193                };
194                result.push(decoded);
195                i += 2;
196            } else {
197                result.push(value[i]);
198                i += 1;
199            }
200        }
201
202        result
203    }
204
205    /// Extract PK, FK, and all values from parsed values
206    fn extract_pk_fk(
207        &self,
208        values: &[CopyValue],
209        schema: &TableSchema,
210    ) -> (Option<PkTuple>, Vec<(FkRef, PkTuple)>, Vec<PkValue>) {
211        let mut pk_values = PkTuple::new();
212        let mut fk_values = Vec::new();
213
214        // Build all_values: convert each value to PkValue
215        let all_values: Vec<PkValue> = values
216            .iter()
217            .enumerate()
218            .map(|(idx, v)| {
219                let col = self
220                    .column_order
221                    .get(idx)
222                    .and_then(|c| *c)
223                    .and_then(|id| schema.column(id));
224                self.value_to_pk(v, col)
225            })
226            .collect();
227
228        // Build PK from columns marked as primary key
229        for (idx, col_id_opt) in self.column_order.iter().enumerate() {
230            if let Some(col_id) = col_id_opt {
231                if schema.is_pk_column(*col_id) {
232                    if let Some(value) = values.get(idx) {
233                        let pk_val = self.value_to_pk(value, schema.column(*col_id));
234                        pk_values.push(pk_val);
235                    }
236                }
237            }
238        }
239
240        // Build FK tuples
241        for (fk_idx, fk) in schema.foreign_keys.iter().enumerate() {
242            if fk.referenced_table_id.is_none() {
243                continue;
244            }
245
246            let mut fk_tuple = PkTuple::new();
247            let mut all_non_null = true;
248
249            for &col_id in &fk.columns {
250                if let Some(idx) = self.column_order.iter().position(|&c| c == Some(col_id)) {
251                    if let Some(value) = values.get(idx) {
252                        let pk_val = self.value_to_pk(value, schema.column(col_id));
253                        if pk_val.is_null() {
254                            all_non_null = false;
255                            break;
256                        }
257                        fk_tuple.push(pk_val);
258                    }
259                }
260            }
261
262            if all_non_null && !fk_tuple.is_empty() {
263                fk_values.push((
264                    FkRef {
265                        table_id: schema.id.0,
266                        fk_index: fk_idx as u16,
267                    },
268                    fk_tuple,
269                ));
270            }
271        }
272
273        let pk = if pk_values.is_empty() || pk_values.iter().any(|v| v.is_null()) {
274            None
275        } else {
276            Some(pk_values)
277        };
278
279        (pk, fk_values, all_values)
280    }
281
282    /// Convert a parsed value to a PkValue
283    fn value_to_pk(&self, value: &CopyValue, col: Option<&crate::schema::Column>) -> PkValue {
284        match value {
285            CopyValue::Null => PkValue::Null,
286            CopyValue::Integer(n) => PkValue::Int(*n),
287            CopyValue::BigInteger(n) => PkValue::BigInt(*n),
288            CopyValue::Text(bytes) => {
289                let s = String::from_utf8_lossy(bytes);
290
291                // Check if this might be an integer stored as text
292                if let Some(col) = col {
293                    match col.col_type {
294                        ColumnType::Int => {
295                            if let Ok(n) = s.parse::<i64>() {
296                                return PkValue::Int(n);
297                            }
298                        }
299                        ColumnType::BigInt => {
300                            if let Ok(n) = s.parse::<i128>() {
301                                return PkValue::BigInt(n);
302                            }
303                        }
304                        _ => {}
305                    }
306                }
307
308                PkValue::Text(s.into_owned().into_boxed_str())
309            }
310        }
311    }
312}
313
314/// Internal representation of a parsed COPY value
315#[derive(Debug, Clone)]
316enum CopyValue {
317    Null,
318    Integer(i64),
319    BigInteger(i128),
320    Text(Vec<u8>),
321}
322
323/// Parse column list from COPY header
324pub fn parse_copy_columns(header: &str) -> Vec<String> {
325    // COPY table_name (col1, col2, ...) FROM stdin;
326    if let Some(start) = header.find('(') {
327        if let Some(end) = header.find(')') {
328            let cols = &header[start + 1..end];
329            return cols
330                .split(',')
331                .map(|c| c.trim().trim_matches('"').to_string())
332                .collect();
333        }
334    }
335    Vec::new()
336}
337
338/// Parse all rows from a PostgreSQL COPY data block
339pub fn parse_postgres_copy_rows(
340    data: &[u8],
341    schema: &TableSchema,
342    column_order: Vec<String>,
343) -> anyhow::Result<Vec<ParsedCopyRow>> {
344    let mut parser = CopyParser::new(data)
345        .with_schema(schema)
346        .with_column_order(column_order);
347    parser.parse_rows()
348}