Skip to main content

wp_data_fmt/
sql.rs

1#[allow(deprecated)]
2use crate::formatter::DataFormat;
3use crate::formatter::{RecordFormatter, ValueFormatter};
4use wp_model_core::model::fmt_def::TextFmt;
5use wp_model_core::model::{DataRecord, DataType, FieldStorage, Value, types::value::ObjectValue};
6
7pub struct SqlInsert {
8    pub table_name: String,
9    pub quote_identifiers: bool,
10    pub obj_formatter: crate::SqlFormat,
11}
12
13impl Default for SqlInsert {
14    fn default() -> Self {
15        Self {
16            table_name: String::new(),
17            quote_identifiers: true,
18            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
19        }
20    }
21}
22
23impl SqlInsert {
24    pub fn new_with_json<T: Into<String>>(table: T) -> Self {
25        Self {
26            table_name: table.into(),
27            quote_identifiers: true,
28            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
29        }
30    }
31    fn quote_identifier(&self, name: &str) -> String {
32        if self.quote_identifiers {
33            let escaped = name.replace('"', "\"\"");
34            format!("\"{}\"", escaped)
35        } else {
36            name.to_string()
37        }
38    }
39    fn escape_string(&self, value: &str) -> String {
40        value.replace('\'', "''")
41    }
42}
43
44#[allow(deprecated)]
45impl DataFormat for SqlInsert {
46    type Output = String;
47    fn format_null(&self) -> String {
48        "NULL".to_string()
49    }
50    fn format_bool(&self, value: &bool) -> String {
51        if *value { "TRUE" } else { "FALSE" }.to_string()
52    }
53    fn format_string(&self, value: &str) -> String {
54        format!("'{}'", self.escape_string(value))
55    }
56    fn format_i64(&self, value: &i64) -> String {
57        value.to_string()
58    }
59    fn format_f64(&self, value: &f64) -> String {
60        if value.is_nan() {
61            "NULL".into()
62        } else if value.is_infinite() {
63            if value.is_sign_positive() {
64                "'Infinity'".into()
65            } else {
66                "'-Infinity'".into()
67            }
68        } else {
69            value.to_string()
70        }
71    }
72    fn format_ip(&self, value: &std::net::IpAddr) -> String {
73        self.format_string(&value.to_string())
74    }
75    fn format_datetime(&self, value: &chrono::NaiveDateTime) -> String {
76        self.format_string(&value.to_string())
77    }
78    fn format_object(&self, value: &ObjectValue) -> String {
79        let inner = match &self.obj_formatter {
80            crate::SqlFormat::Json(f) => f.format_object(value),
81            crate::SqlFormat::Kv(f) => f.format_object(value),
82            crate::SqlFormat::Raw(f) => f.format_object(value),
83            crate::SqlFormat::ProtoText(f) => f.format_object(value),
84        };
85        format!("'{}'", self.escape_string(&inner))
86    }
87    fn format_array(&self, value: &[FieldStorage]) -> String {
88        let inner = match &self.obj_formatter {
89            crate::SqlFormat::Json(f) => f.format_array(value),
90            crate::SqlFormat::Kv(f) => f.format_array(value),
91            crate::SqlFormat::Raw(f) => f.format_array(value),
92            crate::SqlFormat::ProtoText(f) => f.format_array(value),
93        };
94        format!("'{}'", self.escape_string(&inner))
95    }
96    fn format_record(&self, record: &DataRecord) -> String {
97        let columns: Vec<String> = record
98            .items
99            .iter()
100            .filter(|f| *f.get_meta() != DataType::Ignore)
101            .map(|f| self.quote_identifier(f.get_name()))
102            .collect();
103        let values: Vec<String> = record
104            .items
105            .iter()
106            .filter(|f| *f.get_meta() != DataType::Ignore)
107            .map(|f| self.format_field(f))
108            .collect();
109        format!(
110            "INSERT INTO {} ({}) VALUES ({});",
111            self.quote_identifier(&self.table_name),
112            columns.join(", "),
113            values.join(", ")
114        )
115    }
116    fn format_field(&self, field: &FieldStorage) -> String {
117        if *field.get_meta() == DataType::Ignore {
118            String::new()
119        } else {
120            self.fmt_value(field.get_value())
121        }
122    }
123}
124
125impl SqlInsert {
126    #[allow(deprecated)]
127    pub fn format_batch(&self, records: &[DataRecord]) -> String {
128        if records.is_empty() {
129            return String::new();
130        }
131        let mut output = String::new();
132        let columns: Vec<String> = records[0]
133            .items
134            .iter()
135            .filter(|f| *f.get_meta() != DataType::Ignore)
136            .map(|f| self.quote_identifier(f.get_name()))
137            .collect();
138        use std::fmt::Write;
139        writeln!(
140            output,
141            "INSERT INTO {} ({}) VALUES",
142            self.quote_identifier(&self.table_name),
143            columns.join(", ")
144        )
145        .unwrap();
146        for (i, record) in records.iter().enumerate() {
147            if i > 0 {
148                output.push_str(",\n");
149            }
150            let values: Vec<String> = record
151                .items
152                .iter()
153                .filter(|f| *f.get_meta() != DataType::Ignore)
154                .map(|f| self.format_field(f))
155                .collect();
156            write!(output, "  ({})", values.join(", ")).unwrap();
157        }
158        output.push(';');
159        output
160    }
161    pub fn generate_create_table(&self, records: &[DataRecord]) -> String {
162        if records.is_empty() {
163            return String::new();
164        }
165        let mut columns = Vec::new();
166        for field in &records[0].items {
167            if *field.get_meta() == DataType::Ignore {
168                continue;
169            }
170            let sql_type = &match field.get_value() {
171                Value::Bool(_) => "BOOLEAN",
172                Value::Chars(_) => "TEXT",
173                Value::Digit(_) => "BIGINT",
174                Value::Float(_) => "DOUBLE PRECISION",
175                Value::Time(_) => "TIMESTAMP",
176                Value::IpAddr(_) => "INET",
177                Value::Obj(_) | Value::Array(_) => "JSONB",
178                _ => "TEXT",
179            };
180            columns.push(format!(
181                "  {} {}",
182                self.quote_identifier(field.get_name()),
183                sql_type
184            ));
185        }
186        format!(
187            "CREATE TABLE IF NOT EXISTS {} (\n{}\n);",
188            self.quote_identifier(&self.table_name),
189            columns.join(",\n")
190        )
191    }
192    #[allow(deprecated)]
193    pub fn format_upsert(&self, record: &DataRecord, conflict_columns: &[&str]) -> String {
194        let insert = self.format_record(record);
195        let mut update_parts = Vec::new();
196        for field in record
197            .items
198            .iter()
199            .filter(|f| *f.get_meta() != DataType::Ignore)
200        {
201            let name = field.get_name();
202            if !conflict_columns.contains(&name) {
203                let col = self.quote_identifier(name);
204                update_parts.push(format!("{} = EXCLUDED.{}", &col, &col));
205            }
206        }
207        if update_parts.is_empty() {
208            insert
209        } else {
210            let quoted_conflicts: Vec<String> = conflict_columns
211                .iter()
212                .map(|c| self.quote_identifier(c))
213                .collect();
214            format!(
215                "{} ON CONFLICT ({}) DO UPDATE SET {};",
216                insert.trim_end_matches(';'),
217                quoted_conflicts.join(", "),
218                update_parts.join(", ")
219            )
220        }
221    }
222}
223
224#[cfg(test)]
225#[allow(deprecated)]
226mod tests {
227    use super::*;
228    use crate::formatter::DataFormat;
229    use wp_model_core::model::{DataField, DataRecord};
230    #[test]
231    fn test_sql_basic() {
232        let f = SqlInsert {
233            table_name: "t".into(),
234            quote_identifiers: true,
235            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
236        };
237        let r = DataRecord {
238            id: Default::default(),
239            items: vec![
240                FieldStorage::from_owned(DataField::from_chars("name", "Alice")),
241                FieldStorage::from_owned(DataField::from_digit("age", 30)),
242            ],
243        };
244        let s = f.format_record(&r);
245        assert!(s.contains("INSERT INTO \"t\" (\"name\", \"age\") VALUES"));
246    }
247
248    #[test]
249    fn test_sql_default() {
250        let sql = SqlInsert::default();
251        assert_eq!(sql.table_name, "");
252        assert!(sql.quote_identifiers);
253    }
254
255    #[test]
256    fn test_sql_new_with_json() {
257        let sql = SqlInsert::new_with_json("users");
258        assert_eq!(sql.table_name, "users");
259        assert!(sql.quote_identifiers);
260    }
261
262    #[test]
263    fn test_format_null() {
264        let sql = SqlInsert::default();
265        assert_eq!(sql.format_value(&Value::Null), "NULL");
266    }
267
268    #[test]
269    fn test_format_bool() {
270        let sql = SqlInsert::default();
271        assert_eq!(sql.format_value(&Value::Bool(true)), "TRUE");
272        assert_eq!(sql.format_value(&Value::Bool(false)), "FALSE");
273    }
274
275    #[test]
276    fn test_format_string() {
277        let sql = SqlInsert::default();
278        assert_eq!(sql.format_value(&Value::Chars("hello".into())), "'hello'");
279        assert_eq!(sql.format_value(&Value::Chars("".into())), "''");
280    }
281
282    #[test]
283    fn test_format_string_escape() {
284        let sql = SqlInsert::default();
285        // Single quotes should be escaped by doubling
286        assert_eq!(sql.format_value(&Value::Chars("it's".into())), "'it''s'");
287        assert_eq!(
288            sql.format_value(&Value::Chars("say 'hi'".into())),
289            "'say ''hi'''"
290        );
291    }
292
293    #[test]
294    fn test_format_i64() {
295        let sql = SqlInsert::default();
296        assert_eq!(sql.format_value(&Value::Digit(0)), "0");
297        assert_eq!(sql.format_value(&Value::Digit(42)), "42");
298        assert_eq!(sql.format_value(&Value::Digit(-100)), "-100");
299    }
300
301    #[test]
302    fn test_format_f64_normal() {
303        let sql = SqlInsert::default();
304        assert_eq!(sql.format_value(&Value::Float(3.24)), "3.24");
305        assert_eq!(sql.format_value(&Value::Float(0.0)), "0");
306    }
307
308    #[test]
309    fn test_format_f64_special() {
310        let sql = SqlInsert::default();
311        assert_eq!(sql.format_value(&Value::Float(f64::NAN)), "NULL");
312        assert_eq!(sql.format_value(&Value::Float(f64::INFINITY)), "'Infinity'");
313        assert_eq!(
314            sql.format_value(&Value::Float(f64::NEG_INFINITY)),
315            "'-Infinity'"
316        );
317    }
318
319    #[test]
320    fn test_format_ip() {
321        use std::net::IpAddr;
322        use std::str::FromStr;
323        let sql = SqlInsert::default();
324        let ip = IpAddr::from_str("192.168.1.1").unwrap();
325        assert_eq!(sql.format_value(&Value::IpAddr(ip)), "'192.168.1.1'");
326    }
327
328    #[test]
329    fn test_format_datetime() {
330        let sql = SqlInsert::default();
331        let dt = chrono::NaiveDateTime::parse_from_str("2024-01-15 10:30:45", "%Y-%m-%d %H:%M:%S")
332            .unwrap();
333        let result = sql.format_value(&Value::Time(dt));
334        assert!(result.starts_with('\''));
335        assert!(result.ends_with('\''));
336        assert!(result.contains("2024"));
337    }
338
339    #[test]
340    fn test_quote_identifier() {
341        let sql = SqlInsert::new_with_json("t");
342        assert_eq!(sql.quote_identifier("name"), "\"name\"");
343        assert_eq!(sql.quote_identifier("user_id"), "\"user_id\"");
344    }
345
346    #[test]
347    fn test_quote_identifier_escape() {
348        let sql = SqlInsert::new_with_json("t");
349        // Double quotes in identifier should be escaped by doubling
350        assert_eq!(sql.quote_identifier("col\"name"), "\"col\"\"name\"");
351    }
352
353    #[test]
354    fn test_quote_identifier_disabled() {
355        let sql = SqlInsert {
356            table_name: "t".into(),
357            quote_identifiers: false,
358            obj_formatter: crate::SqlFormat::from(&TextFmt::Json),
359        };
360        assert_eq!(sql.quote_identifier("name"), "name");
361    }
362
363    #[test]
364    fn test_format_record() {
365        let sql = SqlInsert::new_with_json("users");
366        let record = DataRecord {
367            id: Default::default(),
368            items: vec![
369                FieldStorage::from_owned(DataField::from_chars("name", "Alice")),
370                FieldStorage::from_owned(DataField::from_digit("age", 30)),
371                FieldStorage::from_owned(DataField::from_bool("active", true)),
372            ],
373        };
374        let result = sql.fmt_record(&record);
375        assert!(result.starts_with("INSERT INTO \"users\""));
376        assert!(result.contains("(\"name\", \"age\", \"active\")"));
377        assert!(result.contains("VALUES ('Alice', 30, TRUE)"));
378        assert!(result.ends_with(';'));
379    }
380
381    #[test]
382    fn test_format_batch_empty() {
383        let sql = SqlInsert::new_with_json("users");
384        let records: Vec<DataRecord> = vec![];
385        assert_eq!(sql.format_batch(&records), "");
386    }
387
388    #[test]
389    fn test_format_batch() {
390        let sql = SqlInsert::new_with_json("users");
391        let records = vec![
392            DataRecord {
393                id: Default::default(),
394                items: vec![
395                    FieldStorage::from_owned(DataField::from_chars("name", "Alice")),
396                    FieldStorage::from_owned(DataField::from_digit("age", 30)),
397                ],
398            },
399            DataRecord {
400                id: Default::default(),
401                items: vec![
402                    FieldStorage::from_owned(DataField::from_chars("name", "Bob")),
403                    FieldStorage::from_owned(DataField::from_digit("age", 25)),
404                ],
405            },
406        ];
407        let result = sql.format_batch(&records);
408        assert!(result.contains("INSERT INTO \"users\""));
409        assert!(result.contains("('Alice', 30)"));
410        assert!(result.contains("('Bob', 25)"));
411        assert!(result.ends_with(';'));
412    }
413
414    #[test]
415    fn test_generate_create_table_empty() {
416        let sql = SqlInsert::new_with_json("users");
417        let records: Vec<DataRecord> = vec![];
418        assert_eq!(sql.generate_create_table(&records), "");
419    }
420
421    #[test]
422    fn test_generate_create_table() {
423        let sql = SqlInsert::new_with_json("users");
424        let records = vec![DataRecord {
425            id: Default::default(),
426            items: vec![
427                FieldStorage::from_owned(DataField::from_chars("name", "Alice")),
428                FieldStorage::from_owned(DataField::from_digit("age", 30)),
429                FieldStorage::from_owned(DataField::from_bool("active", true)),
430                FieldStorage::from_owned(DataField::from_float("score", 95.5)),
431            ],
432        }];
433        let result = sql.generate_create_table(&records);
434        assert!(result.contains("CREATE TABLE IF NOT EXISTS \"users\""));
435        assert!(result.contains("\"name\" TEXT"));
436        assert!(result.contains("\"age\" BIGINT"));
437        assert!(result.contains("\"active\" BOOLEAN"));
438        assert!(result.contains("\"score\" DOUBLE PRECISION"));
439    }
440
441    #[test]
442    fn test_format_upsert() {
443        let sql = SqlInsert::new_with_json("users");
444        let record = DataRecord {
445            id: Default::default(),
446            items: vec![
447                FieldStorage::from_owned(DataField::from_chars("id", "u1")),
448                FieldStorage::from_owned(DataField::from_chars("name", "Alice")),
449                FieldStorage::from_owned(DataField::from_digit("age", 30)),
450            ],
451        };
452        let result = sql.format_upsert(&record, &["id"]);
453        assert!(result.contains("INSERT INTO \"users\""));
454        assert!(result.contains("ON CONFLICT (\"id\")"));
455        assert!(result.contains("DO UPDATE SET"));
456        assert!(result.contains("\"name\" = EXCLUDED.\"name\""));
457        assert!(result.contains("\"age\" = EXCLUDED.\"age\""));
458    }
459
460    #[test]
461    fn test_format_upsert_no_update_columns() {
462        let sql = SqlInsert::new_with_json("users");
463        let record = DataRecord {
464            id: Default::default(),
465            items: vec![FieldStorage::from_owned(DataField::from_chars("id", "u1"))],
466        };
467        // When all columns are conflict columns, no update is needed
468        let result = sql.format_upsert(&record, &["id"]);
469        // Should just be a regular insert with semicolon
470        assert!(result.contains("INSERT INTO"));
471        assert!(!result.contains("ON CONFLICT"));
472    }
473
474    fn make_record_with_obj() -> DataRecord {
475        let mut obj = ObjectValue::new();
476        obj.insert(
477            "ssl_cipher".to_string(),
478            FieldStorage::from_owned(DataField::from_chars("ssl_cipher", "ECDHE")),
479        );
480        DataRecord {
481            id: Default::default(),
482            items: vec![
483                FieldStorage::from_owned(DataField::from_digit("status", 200)),
484                FieldStorage::from_owned(DataField::from_obj("extends", obj)),
485                FieldStorage::from_owned(DataField::from_digit("length", 50)),
486            ],
487        }
488    }
489
490    #[test]
491    fn test_format_record_with_obj_no_newlines() {
492        let sql = SqlInsert::new_with_json("t");
493        let record = make_record_with_obj();
494        let result = sql.format_record(&record);
495        assert!(
496            !result.contains('\n'),
497            "record output should not contain newlines: {}",
498            result
499        );
500        assert!(result.contains("ECDHE"));
501    }
502
503    #[test]
504    fn test_fmt_record_with_obj_no_newlines() {
505        let sql = SqlInsert::new_with_json("t");
506        let record = make_record_with_obj();
507        let result = sql.fmt_record(&record);
508        assert!(
509            !result.contains('\n'),
510            "record output should not contain newlines: {}",
511            result
512        );
513    }
514
515    #[test]
516    fn test_old_new_api_consistency_nested() {
517        let sql = SqlInsert::new_with_json("t");
518        let record = make_record_with_obj();
519        assert_eq!(sql.format_record(&record), sql.fmt_record(&record));
520    }
521}
522
523// ============================================================================
524// 新 trait 实现:ValueFormatter + RecordFormatter
525// ============================================================================
526
527#[allow(clippy::items_after_test_module)]
528impl ValueFormatter for SqlInsert {
529    type Output = String;
530
531    fn format_value(&self, value: &Value) -> String {
532        match value {
533            Value::Null => "NULL".to_string(),
534            Value::Bool(v) => if *v { "TRUE" } else { "FALSE" }.to_string(),
535            Value::Chars(v) => format!("'{}'", self.escape_string(v)),
536            Value::Digit(v) => v.to_string(),
537            Value::Float(v) => {
538                if v.is_nan() {
539                    "NULL".into()
540                } else if v.is_infinite() {
541                    if v.is_sign_positive() {
542                        "'Infinity'".into()
543                    } else {
544                        "'-Infinity'".into()
545                    }
546                } else {
547                    v.to_string()
548                }
549            }
550            Value::IpAddr(v) => format!("'{}'", self.escape_string(&v.to_string())),
551            Value::Time(v) => format!("'{}'", self.escape_string(&v.to_string())),
552            Value::Obj(_obj) => {
553                let inner = match &self.obj_formatter {
554                    crate::SqlFormat::Json(f) => f.format_value(value),
555                    crate::SqlFormat::Kv(f) => f.format_value(value),
556                    crate::SqlFormat::Raw(f) => f.format_value(value),
557                    crate::SqlFormat::ProtoText(f) => f.format_value(value),
558                };
559                format!("'{}'", self.escape_string(&inner))
560            }
561            Value::Array(arr) => {
562                let inner = match &self.obj_formatter {
563                    crate::SqlFormat::Json(f) => {
564                        let items: Vec<String> = arr
565                            .iter()
566                            .map(|field| f.format_value(field.get_value()))
567                            .collect();
568                        format!("[{}]", items.join(","))
569                    }
570                    crate::SqlFormat::Kv(f) => {
571                        let mut output = String::new();
572                        output.push('[');
573                        for (i, field) in arr.iter().enumerate() {
574                            if i > 0 {
575                                output.push_str(", ");
576                            }
577                            output.push_str(&f.format_value(field.get_value()));
578                        }
579                        output.push(']');
580                        output
581                    }
582                    crate::SqlFormat::Raw(f) => {
583                        if arr.is_empty() {
584                            "[]".to_string()
585                        } else {
586                            let content: Vec<String> = arr
587                                .iter()
588                                .map(|field| f.format_value(field.get_value()))
589                                .collect();
590                            format!("[{}]", content.join(", "))
591                        }
592                    }
593                    crate::SqlFormat::ProtoText(f) => {
594                        let items: Vec<String> = arr
595                            .iter()
596                            .map(|field| f.format_value(field.get_value()))
597                            .collect();
598                        format!("[{}]", items.join(", "))
599                    }
600                };
601                format!("'{}'", self.escape_string(&inner))
602            }
603            _ => format!("'{}'", self.escape_string(&value.to_string())),
604        }
605    }
606}
607
608impl RecordFormatter for SqlInsert {
609    fn fmt_field(&self, field: &FieldStorage) -> String {
610        if *field.get_meta() == DataType::Ignore {
611            String::new()
612        } else {
613            self.format_value(field.get_value())
614        }
615    }
616
617    fn fmt_record(&self, record: &DataRecord) -> String {
618        let columns: Vec<String> = record
619            .items
620            .iter()
621            .filter(|f| *f.get_meta() != DataType::Ignore)
622            .map(|f| self.quote_identifier(f.get_name()))
623            .collect();
624        let values: Vec<String> = record
625            .items
626            .iter()
627            .filter(|f| *f.get_meta() != DataType::Ignore)
628            .map(|f| self.fmt_field(f))
629            .collect();
630        format!(
631            "INSERT INTO {} ({}) VALUES ({});",
632            self.quote_identifier(&self.table_name),
633            columns.join(", "),
634            values.join(", ")
635        )
636    }
637}