tank_core/writer/
sql_writer.rs

1use crate::{
2    BinaryOp, BinaryOpType, ColumnDef, ColumnRef, DataSet, EitherIterator, Entity, Expression,
3    Fragment, Interval, Join, JoinType, Operand, Order, Ordered, PrimaryKeyType, TableRef, UnaryOp,
4    UnaryOpType, Value, possibly_parenthesized, separated_by, writer::Context,
5};
6use std::{collections::HashMap, fmt::Write};
7use time::{Date, Time};
8
9macro_rules! write_integer {
10    ($buff:ident, $value:expr) => {{
11        let mut buffer = itoa::Buffer::new();
12        $buff.push_str(buffer.format($value));
13    }};
14}
15macro_rules! write_float {
16    ($this:ident, $context:ident,$buff:ident, $value:expr) => {{
17        if $value.is_infinite() {
18            $this.write_value_infinity($context, $buff, $value.is_sign_negative());
19        } else {
20            let mut buffer = ryu::Buffer::new();
21            $buff.push_str(buffer.format($value));
22        }
23    }};
24}
25
26pub trait SqlWriter {
27    fn as_dyn(&self) -> &dyn SqlWriter;
28
29    fn alias_declaration(&self, context: Context) -> bool {
30        match context.fragment {
31            Fragment::SqlSelectFrom | Fragment::SqlJoin => true,
32            _ => false,
33        }
34    }
35
36    fn write_escaped(
37        &self,
38        _context: Context,
39        buff: &mut String,
40        value: &str,
41        search: char,
42        replace: &str,
43    ) {
44        let mut position = 0;
45        for (i, c) in value.char_indices() {
46            if c == search {
47                buff.push_str(&value[position..i]);
48                buff.push_str(replace);
49                position = i + 1;
50            }
51        }
52        buff.push_str(&value[position..]);
53    }
54
55    fn write_identifier_quoted(&self, context: Context, buff: &mut String, value: &str) {
56        buff.push('"');
57        self.write_escaped(context, buff, value, '"', r#""""#);
58        buff.push('"');
59    }
60
61    fn write_table_ref(&self, context: Context, buff: &mut String, value: &TableRef) {
62        if self.alias_declaration(context) || value.alias.is_empty() {
63            if !value.schema.is_empty() {
64                self.write_identifier_quoted(context, buff, &value.schema);
65                buff.push('.');
66            }
67            self.write_identifier_quoted(context, buff, &value.name);
68        }
69        if !value.alias.is_empty() {
70            let _ = write!(buff, " {}", value.alias);
71        }
72    }
73
74    fn write_column_ref(&self, context: Context, buff: &mut String, value: &ColumnRef) {
75        if context.qualify_columns && !value.table.is_empty() {
76            if !value.schema.is_empty() {
77                self.write_identifier_quoted(context, buff, &value.schema);
78                buff.push('.');
79            }
80            self.write_identifier_quoted(context, buff, &value.table);
81            buff.push('.');
82        }
83        self.write_identifier_quoted(context, buff, &value.name);
84    }
85
86    fn write_column_type(&self, context: Context, buff: &mut String, value: &Value) {
87        match value {
88            Value::Boolean(..) => buff.push_str("BOOLEAN"),
89            Value::Int8(..) => buff.push_str("TINYINT"),
90            Value::Int16(..) => buff.push_str("SMALLINT"),
91            Value::Int32(..) => buff.push_str("INTEGER"),
92            Value::Int64(..) => buff.push_str("BIGINT"),
93            Value::Int128(..) => buff.push_str("HUGEINT"),
94            Value::UInt8(..) => buff.push_str("UTINYINT"),
95            Value::UInt16(..) => buff.push_str("USMALLINT"),
96            Value::UInt32(..) => buff.push_str("UINTEGER"),
97            Value::UInt64(..) => buff.push_str("UBIGINT"),
98            Value::UInt128(..) => buff.push_str("UHUGEINT"),
99            Value::Float32(..) => buff.push_str("FLOAT"),
100            Value::Float64(..) => buff.push_str("DOUBLE"),
101            Value::Decimal(.., precision, scale) => {
102                buff.push_str("DECIMAL");
103                if (precision, scale) != (&0, &0) {
104                    let _ = write!(buff, "({},{})", precision, scale);
105                }
106            }
107            Value::Char(..) => buff.push_str("CHAR(1)"),
108            Value::Varchar(..) => buff.push_str("VARCHAR"),
109            Value::Blob(..) => buff.push_str("BLOB"),
110            Value::Date(..) => buff.push_str("DATE"),
111            Value::Time(..) => buff.push_str("TIME"),
112            Value::Timestamp(..) => buff.push_str("TIMESTAMP"),
113            Value::TimestampWithTimezone(..) => buff.push_str("TIMESTAMP WITH TIME ZONE"),
114            Value::Interval(..) => buff.push_str("INTERVAL"),
115            Value::Uuid(..) => buff.push_str("UUID"),
116            Value::Array(.., inner, size) => {
117                self.write_column_type(context, buff, inner);
118                let _ = write!(buff, "[{}]", size);
119            }
120            Value::List(.., inner) => {
121                self.write_column_type(context, buff, inner);
122                buff.push_str("[]");
123            }
124            Value::Map(.., key, value) => {
125                buff.push_str("MAP(");
126                self.write_column_type(context, buff, key);
127                buff.push(',');
128                self.write_column_type(context, buff, value);
129                buff.push(')');
130            }
131            _ => panic!(
132                "Unexpected tank::Value, cannot get the sql type from {:?} variant",
133                value
134            ),
135        };
136    }
137
138    fn write_value(&self, context: Context, buff: &mut String, value: &Value) {
139        match value {
140            Value::Null
141            | Value::Boolean(None, ..)
142            | Value::Int8(None, ..)
143            | Value::Int16(None, ..)
144            | Value::Int32(None, ..)
145            | Value::Int64(None, ..)
146            | Value::Int128(None, ..)
147            | Value::UInt8(None, ..)
148            | Value::UInt16(None, ..)
149            | Value::UInt32(None, ..)
150            | Value::UInt64(None, ..)
151            | Value::UInt128(None, ..)
152            | Value::Float32(None, ..)
153            | Value::Float64(None, ..)
154            | Value::Decimal(None, ..)
155            | Value::Char(None, ..)
156            | Value::Varchar(None, ..)
157            | Value::Blob(None, ..)
158            | Value::Date(None, ..)
159            | Value::Time(None, ..)
160            | Value::Timestamp(None, ..)
161            | Value::TimestampWithTimezone(None, ..)
162            | Value::Interval(None, ..)
163            | Value::Uuid(None, ..)
164            | Value::Array(None, ..)
165            | Value::List(None, ..)
166            | Value::Map(None, ..)
167            | Value::Struct(None, ..) => self.write_value_none(context, buff),
168            Value::Boolean(Some(v), ..) => self.write_value_bool(context, buff, *v),
169            Value::Int8(Some(v), ..) => write_integer!(buff, *v),
170            Value::Int16(Some(v), ..) => write_integer!(buff, *v),
171            Value::Int32(Some(v), ..) => write_integer!(buff, *v),
172            Value::Int64(Some(v), ..) => write_integer!(buff, *v),
173            Value::Int128(Some(v), ..) => write_integer!(buff, *v),
174            Value::UInt8(Some(v), ..) => write_integer!(buff, *v),
175            Value::UInt16(Some(v), ..) => write_integer!(buff, *v),
176            Value::UInt32(Some(v), ..) => write_integer!(buff, *v),
177            Value::UInt64(Some(v), ..) => write_integer!(buff, *v),
178            Value::UInt128(Some(v), ..) => write_integer!(buff, *v),
179            Value::Float32(Some(v), ..) => write_float!(self, context, buff, *v),
180            Value::Float64(Some(v), ..) => write_float!(self, context, buff, *v),
181            Value::Decimal(Some(v), ..) => drop(write!(buff, "{}", v)),
182            Value::Char(Some(v), ..) => {
183                buff.push('\'');
184                buff.push(*v);
185                buff.push('\'');
186            }
187            Value::Varchar(Some(v), ..) => self.write_value_string(context, buff, v),
188            Value::Blob(Some(v), ..) => self.write_value_blob(context, buff, v.as_ref()),
189            Value::Date(Some(v), ..) => {
190                buff.push('\'');
191                self.write_value_date(context, buff, v);
192                buff.push('\'');
193            }
194            Value::Time(Some(v), ..) => {
195                buff.push('\'');
196                self.write_value_time(context, buff, v);
197                buff.push('\'');
198            }
199            Value::Timestamp(Some(v), ..) => {
200                buff.push('\'');
201                self.write_value_date(context, buff, &v.date());
202                buff.push('T');
203                self.write_value_time(context, buff, &v.time());
204                buff.push('\'');
205            }
206            Value::TimestampWithTimezone(Some(v), ..) => {
207                buff.push('\'');
208                self.write_value_date(context, buff, &v.date());
209                buff.push('T');
210                self.write_value_time(context, buff, &v.time());
211                let _ = write!(
212                    buff,
213                    "{:+02}:{:02}",
214                    v.offset().whole_hours(),
215                    v.offset().whole_minutes()
216                );
217                buff.push('\'');
218            }
219            Value::Interval(Some(v), ..) => self.write_value_interval(context, buff, v),
220            Value::Uuid(Some(v), ..) => drop(write!(buff, "'{}'", v)),
221            Value::List(Some(..), ..) | Value::Array(Some(..), ..) => {
222                let v = match value {
223                    Value::List(Some(v), ..) => v.iter(),
224                    Value::Array(Some(v), ..) => v.iter(),
225                    _ => unreachable!(),
226                };
227                buff.push('[');
228                separated_by(
229                    buff,
230                    v,
231                    |buff, v| {
232                        self.write_value(context, buff, v);
233                    },
234                    ",",
235                );
236                buff.push(']');
237            }
238            Value::Map(Some(v), ..) => self.write_value_map(context, buff, v),
239            Value::Struct(Some(_v), ..) => {
240                todo!()
241            }
242        };
243    }
244
245    fn write_value_none(&self, _context: Context, buff: &mut String) {
246        buff.push_str("NULL");
247    }
248
249    fn write_value_bool(&self, _context: Context, buff: &mut String, value: bool) {
250        buff.push_str(["false", "true"][value as usize]);
251    }
252
253    fn write_value_infinity(&self, context: Context, buff: &mut String, negative: bool) {
254        let mut buffer = ryu::Buffer::new();
255        self.write_expression_binary_op(
256            context,
257            buff,
258            &BinaryOp {
259                op: BinaryOpType::Cast,
260                lhs: &Operand::LitStr(buffer.format(if negative {
261                    f64::NEG_INFINITY
262                } else {
263                    f64::INFINITY
264                })),
265                rhs: &Operand::Type(Value::Float64(None)),
266            },
267        );
268    }
269
270    fn write_value_string(&self, _context: Context, buff: &mut String, value: &str) {
271        buff.push('\'');
272        let mut position = 0;
273        for (i, c) in value.char_indices() {
274            if c == '\'' {
275                buff.push_str(&value[position..i]);
276                buff.push_str("''");
277                position = i + 1;
278            } else if c == '\n' {
279                buff.push_str(&value[position..i]);
280                buff.push_str("\\n");
281                position = i + 1;
282            }
283        }
284        buff.push_str(&value[position..]);
285        buff.push('\'');
286    }
287
288    fn write_value_blob(&self, _context: Context, buff: &mut String, value: &[u8]) {
289        buff.push('\'');
290        for b in value {
291            let _ = write!(buff, "\\x{:X}", b);
292        }
293        buff.push('\'');
294    }
295
296    fn write_value_date(&self, _context: Context, buff: &mut String, value: &Date) {
297        let _ = write!(
298            buff,
299            "{:04}-{:02}-{:02}",
300            value.year(),
301            value.month() as u8,
302            value.day()
303        );
304    }
305
306    fn write_value_time(&self, _context: Context, buff: &mut String, value: &Time) {
307        let mut subsecond = value.nanosecond();
308        let mut width = 9;
309        while width > 1 && subsecond % 10 == 0 {
310            subsecond /= 10;
311            width -= 1;
312        }
313        let _ = write!(
314            buff,
315            "{:02}:{:02}:{:02}.{:0width$}",
316            value.hour(),
317            value.minute(),
318            value.second(),
319            subsecond
320        );
321    }
322
323    fn value_interval_units(&self) -> &[(&str, i128)] {
324        static UNITS: &[(&str, i128)] = &[
325            ("DAY", Interval::NANOS_IN_DAY),
326            ("HOUR", Interval::NANOS_IN_SEC * 3600),
327            ("MINUTE", Interval::NANOS_IN_SEC * 60),
328            ("SECOND", Interval::NANOS_IN_SEC),
329            ("MICROSECOND", 1_000),
330            ("NANOSECOND", 1),
331        ];
332        UNITS
333    }
334
335    fn write_value_interval(&self, _context: Context, buff: &mut String, value: &Interval) {
336        buff.push_str("INTERVAL ");
337        if value.is_zero() {
338            buff.push_str("0 SECONDS");
339            return;
340        }
341        macro_rules! write_unit {
342            ($buff:ident, $val:expr, $unit:expr) => {
343                let _ = write!(
344                    $buff,
345                    "{} {}{}",
346                    $val,
347                    $unit,
348                    if $val != 1 { "S" } else { "" }
349                );
350            };
351        }
352        let months = value.months;
353        let nanos = value.nanos + value.days as i128 * Interval::NANOS_IN_DAY;
354        let multiple_units = nanos != 0 && value.months != 0;
355        if multiple_units {
356            buff.push('\'');
357        }
358        if months != 0 {
359            if months % 12 == 0 {
360                write_unit!(buff, months / 12, "YEAR");
361            } else {
362                write_unit!(buff, months, "MONTH");
363            }
364        }
365        for &(name, factor) in self.value_interval_units() {
366            if nanos % factor == 0 {
367                let value = nanos / factor;
368                if value != 0 {
369                    if months != 0 {
370                        buff.push(' ');
371                    }
372                    write_unit!(buff, value, name);
373                    break;
374                }
375            }
376        }
377        if multiple_units {
378            buff.push('\'');
379        }
380    }
381
382    fn write_value_map(&self, context: Context, buff: &mut String, value: &HashMap<Value, Value>) {
383        buff.push('{');
384        separated_by(
385            buff,
386            value,
387            |buff, (k, v)| {
388                self.write_value(context, buff, k);
389                buff.push(':');
390                self.write_value(context, buff, v);
391            },
392            ",",
393        );
394        buff.push('}');
395    }
396
397    fn expression_unary_op_precedence<'a>(&self, value: &UnaryOpType) -> i32 {
398        match value {
399            UnaryOpType::Negative => 1250,
400            UnaryOpType::Not => 250,
401        }
402    }
403
404    fn expression_binary_op_precedence<'a>(&self, value: &BinaryOpType) -> i32 {
405        match value {
406            BinaryOpType::Or => 100,
407            BinaryOpType::And => 200,
408            BinaryOpType::Equal => 300,
409            BinaryOpType::NotEqual => 300,
410            BinaryOpType::Less => 300,
411            BinaryOpType::Greater => 300,
412            BinaryOpType::LessEqual => 300,
413            BinaryOpType::GreaterEqual => 300,
414            BinaryOpType::Is => 400,
415            BinaryOpType::IsNot => 400,
416            BinaryOpType::Like => 400,
417            BinaryOpType::NotLike => 400,
418            BinaryOpType::Regexp => 400,
419            BinaryOpType::NotRegexp => 400,
420            BinaryOpType::Glob => 400,
421            BinaryOpType::NotGlob => 400,
422            BinaryOpType::BitwiseOr => 500,
423            BinaryOpType::BitwiseAnd => 600,
424            BinaryOpType::ShiftLeft => 700,
425            BinaryOpType::ShiftRight => 700,
426            BinaryOpType::Subtraction => 800,
427            BinaryOpType::Addition => 800,
428            BinaryOpType::Multiplication => 900,
429            BinaryOpType::Division => 900,
430            BinaryOpType::Remainder => 900,
431            BinaryOpType::Indexing => 1000,
432            BinaryOpType::Cast => 1100,
433            BinaryOpType::Alias => 1200,
434        }
435    }
436
437    fn write_expression_operand(&self, context: Context, buff: &mut String, value: &Operand) {
438        match value {
439            Operand::LitBool(v) => self.write_value_bool(context, buff, *v),
440            Operand::LitFloat(v) => write_float!(self, context, buff, *v),
441            Operand::LitIdent(v) => drop(buff.push_str(v)),
442            Operand::LitField(v) => separated_by(buff, *v, |buff, v| buff.push_str(v), "."),
443            Operand::LitInt(v) => write_integer!(buff, *v),
444            Operand::LitStr(v) => self.write_value_string(context, buff, v),
445            Operand::LitArray(v) => {
446                buff.push('[');
447                separated_by(
448                    buff,
449                    *v,
450                    |buff, v| {
451                        v.write_query(self.as_dyn(), context, buff);
452                    },
453                    ", ",
454                );
455                buff.push(']');
456            }
457            Operand::Null => drop(buff.push_str("NULL")),
458            Operand::Type(v) => self.write_column_type(context, buff, v),
459            Operand::Variable(v) => self.write_value(context, buff, v),
460            Operand::Call(f, args) => {
461                buff.push_str(f);
462                buff.push('(');
463                separated_by(
464                    buff,
465                    *args,
466                    |buff, v| {
467                        v.write_query(self.as_dyn(), context, buff);
468                    },
469                    ",",
470                );
471                buff.push(')');
472            }
473            Operand::Asterisk => drop(buff.push('*')),
474            Operand::QuestionMark => drop(buff.push('?')),
475        };
476    }
477
478    fn write_expression_unary_op(
479        self: &Self,
480        context: Context,
481        buff: &mut String,
482        value: &UnaryOp<&dyn Expression>,
483    ) {
484        match value.op {
485            UnaryOpType::Negative => buff.push('-'),
486            UnaryOpType::Not => buff.push_str("NOT "),
487        };
488        possibly_parenthesized!(
489            buff,
490            value.arg.precedence(self.as_dyn()) <= self.expression_unary_op_precedence(&value.op),
491            value.arg.write_query(self.as_dyn(), context, buff)
492        );
493    }
494
495    fn write_expression_binary_op(
496        &self,
497        context: Context,
498        buff: &mut String,
499        value: &BinaryOp<&dyn Expression, &dyn Expression>,
500    ) {
501        let (prefix, infix, suffix, lhs_parenthesized, rhs_parenthesized) = match value.op {
502            BinaryOpType::Indexing => ("", "[", "]", false, true),
503            BinaryOpType::Cast => ("CAST(", " AS ", ")", true, true),
504            BinaryOpType::Multiplication => ("", " * ", "", false, false),
505            BinaryOpType::Division => ("", " / ", "", false, false),
506            BinaryOpType::Remainder => ("", " % ", "", false, false),
507            BinaryOpType::Addition => ("", " + ", "", false, false),
508            BinaryOpType::Subtraction => ("", " - ", "", false, false),
509            BinaryOpType::ShiftLeft => ("", " << ", "", false, false),
510            BinaryOpType::ShiftRight => ("", " >> ", "", false, false),
511            BinaryOpType::BitwiseAnd => ("", " & ", "", false, false),
512            BinaryOpType::BitwiseOr => ("", " | ", "", false, false),
513            BinaryOpType::Is => ("", " Is ", "", false, false),
514            BinaryOpType::IsNot => ("", " IS NOT ", "", false, false),
515            BinaryOpType::Like => ("", " LIKE ", "", false, false),
516            BinaryOpType::NotLike => ("", " NOT LIKE ", "", false, false),
517            BinaryOpType::Regexp => ("", " REGEXP ", "", false, false),
518            BinaryOpType::NotRegexp => ("", " NOT REGEXP ", "", false, false),
519            BinaryOpType::Glob => ("", " GLOB ", "", false, false),
520            BinaryOpType::NotGlob => ("", " NOT GLOB ", "", false, false),
521            BinaryOpType::Equal => ("", " = ", "", false, false),
522            BinaryOpType::NotEqual => ("", " != ", "", false, false),
523            BinaryOpType::Less => ("", " < ", "", false, false),
524            BinaryOpType::LessEqual => ("", " <= ", "", false, false),
525            BinaryOpType::Greater => ("", " > ", "", false, false),
526            BinaryOpType::GreaterEqual => ("", " >= ", "", false, false),
527            BinaryOpType::And => ("", " AND ", "", false, false),
528            BinaryOpType::Or => ("", " OR ", "", false, false),
529            BinaryOpType::Alias => ("", " AS ", "", false, false),
530        };
531        let precedence = self.expression_binary_op_precedence(&value.op);
532        buff.push_str(prefix);
533        possibly_parenthesized!(
534            buff,
535            !lhs_parenthesized && value.lhs.precedence(self.as_dyn()) < precedence,
536            value.lhs.write_query(self.as_dyn(), context, buff)
537        );
538        buff.push_str(infix);
539        possibly_parenthesized!(
540            buff,
541            !rhs_parenthesized && value.rhs.precedence(self.as_dyn()) <= precedence,
542            value.rhs.write_query(self.as_dyn(), context, buff)
543        );
544        buff.push_str(suffix);
545    }
546
547    fn write_expression_ordered(
548        &self,
549        context: Context,
550        buff: &mut String,
551        value: &Ordered<&dyn Expression>,
552    ) {
553        value.expression.write_query(self.as_dyn(), context, buff);
554        if context.fragment == Fragment::SqlSelectOrderBy {
555            let _ = write!(
556                buff,
557                " {}",
558                match value.order {
559                    Order::ASC => "ASC",
560                    Order::DESC => "DESC",
561                }
562            );
563        }
564    }
565
566    fn write_join_type(&self, _context: Context, buff: &mut String, join_type: &JoinType) {
567        buff.push_str(match &join_type {
568            JoinType::Default => "JOIN",
569            JoinType::Inner => "INNER JOIN",
570            JoinType::Outer => "OUTER JOIN",
571            JoinType::Left => "LEFT JOIN",
572            JoinType::Right => "RIGHT JOIN",
573            JoinType::Cross => "CROSS",
574            JoinType::Natural => "NATURAL JOIN",
575        });
576    }
577
578    fn write_join(
579        &self,
580        _context: Context,
581        buff: &mut String,
582        join: &Join<&dyn DataSet, &dyn DataSet, &dyn Expression>,
583    ) {
584        let context = Context {
585            fragment: Fragment::SqlJoin,
586            qualify_columns: true,
587        };
588        join.lhs.write_query(self.as_dyn(), context, buff);
589        buff.push(' ');
590        self.write_join_type(context, buff, &join.join);
591        buff.push(' ');
592        join.rhs.write_query(self.as_dyn(), context, buff);
593        if let Some(on) = &join.on {
594            buff.push_str(" ON ");
595            let context = Context {
596                qualify_columns: true,
597                ..context
598            };
599            on.write_query(self.as_dyn(), context, buff);
600        }
601    }
602
603    fn write_transaction_begin(&self, buff: &mut String) {
604        buff.push_str("BEGIN;");
605    }
606
607    fn write_transaction_commit(&self, buff: &mut String) {
608        buff.push_str("COMMIT;");
609    }
610
611    fn write_transaction_rollback(&self, buff: &mut String) {
612        buff.push_str("ROLLBACK;");
613    }
614
615    fn write_create_schema<E>(&self, buff: &mut String, if_not_exists: bool)
616    where
617        Self: Sized,
618        E: Entity,
619    {
620        buff.push_str("CREATE SCHEMA ");
621        let context = Context {
622            fragment: Fragment::SqlCreateSchema,
623            qualify_columns: E::qualified_columns(),
624        };
625        if if_not_exists {
626            buff.push_str("IF NOT EXISTS ");
627        }
628        self.write_identifier_quoted(context, buff, E::table_ref().schema);
629        buff.push(';');
630    }
631
632    fn write_drop_schema<E>(&self, buff: &mut String, if_exists: bool)
633    where
634        Self: Sized,
635        E: Entity,
636    {
637        buff.push_str("DROP SCHEMA ");
638        let context = Context {
639            fragment: Fragment::SqlDropSchema,
640            qualify_columns: E::qualified_columns(),
641        };
642        if if_exists {
643            buff.push_str("IF EXISTS ");
644        }
645        self.write_identifier_quoted(context, buff, E::table_ref().schema);
646        buff.push(';');
647    }
648
649    fn write_create_table<E>(&self, buff: &mut String, if_not_exists: bool)
650    where
651        Self: Sized,
652        E: Entity,
653    {
654        buff.push_str("CREATE TABLE ");
655        if if_not_exists {
656            buff.push_str("IF NOT EXISTS ");
657        }
658        let context = Context {
659            fragment: Fragment::SqlCreateTable,
660            qualify_columns: E::qualified_columns(),
661        };
662        self.write_table_ref(context, buff, E::table_ref());
663        buff.push_str(" (\n");
664        separated_by(
665            buff,
666            E::columns(),
667            |buff, v| {
668                self.write_create_table_column_fragment(context, buff, v);
669            },
670            ",\n",
671        );
672        let primary_key = E::primary_key_def();
673        if primary_key.len() > 1 {
674            buff.push_str(",\nPRIMARY KEY (");
675            separated_by(
676                buff,
677                primary_key,
678                |buff, v| {
679                    self.write_identifier_quoted(
680                        context.with_context(Fragment::SqlCreateTablePrimaryKey),
681                        buff,
682                        v.name(),
683                    );
684                },
685                ", ",
686            );
687            buff.push(')');
688        }
689        for unique in E::unique_defs() {
690            if unique.len() > 1 {
691                buff.push_str(",\nUNIQUE (");
692                separated_by(
693                    buff,
694                    unique,
695                    |buff, v| {
696                        self.write_identifier_quoted(
697                            context.with_context(Fragment::SqlCreateTableUnique),
698                            buff,
699                            v.name(),
700                        );
701                    },
702                    ", ",
703                );
704                buff.push(')');
705            }
706        }
707        buff.push_str("\n)");
708        buff.push(';');
709        self.write_column_comments::<E>(context, buff);
710    }
711
712    fn write_column_comments<E>(&self, _context: Context, buff: &mut String)
713    where
714        Self: Sized,
715        E: Entity,
716    {
717        let context = Context {
718            fragment: Fragment::SqlCommentOnColumn,
719            qualify_columns: true,
720        };
721        for c in E::columns().iter().filter(|c| !c.comment.is_empty()) {
722            buff.push_str("\nCOMMENT ON COLUMN ");
723            self.write_column_ref(context, buff, c.into());
724            buff.push_str(" IS ");
725            self.write_value_string(context, buff, c.comment);
726            buff.push(';');
727        }
728    }
729
730    fn write_create_table_column_fragment(
731        &self,
732        context: Context,
733        buff: &mut String,
734        column: &ColumnDef,
735    ) where
736        Self: Sized,
737    {
738        self.write_identifier_quoted(context, buff, &column.name());
739        buff.push(' ');
740        if !column.column_type.is_empty() {
741            buff.push_str(&column.column_type);
742        } else {
743            SqlWriter::write_column_type(self, context, buff, &column.value);
744        }
745        if !column.nullable && column.primary_key == PrimaryKeyType::None {
746            buff.push_str(" NOT NULL");
747        }
748        if let Some(default) = &column.default {
749            buff.push_str(" DEFAULT ");
750            default.write_query(self.as_dyn(), context, buff);
751        }
752        if column.primary_key == PrimaryKeyType::PrimaryKey {
753            // Composite primary key will be printed elsewhere
754            buff.push_str(" PRIMARY KEY");
755        }
756        if column.unique && column.primary_key != PrimaryKeyType::PrimaryKey {
757            buff.push_str(" UNIQUE");
758        }
759        if let Some(references) = column.references {
760            buff.push_str(" REFERENCES ");
761            self.write_table_ref(context, buff, &references.table_ref());
762            buff.push('(');
763            self.write_column_ref(context, buff, &references);
764            buff.push(')');
765        }
766    }
767
768    fn write_drop_table<E>(&self, buff: &mut String, if_exists: bool)
769    where
770        Self: Sized,
771        E: Entity,
772    {
773        buff.push_str("DROP TABLE ");
774        let context = Context {
775            fragment: Fragment::SqlDropTable,
776            qualify_columns: E::qualified_columns(),
777        };
778        if if_exists {
779            buff.push_str("IF EXISTS ");
780        }
781        self.write_table_ref(context, buff, E::table_ref());
782        buff.push(';');
783    }
784
785    fn write_select<Item, Cols, Data, Cond>(
786        &self,
787        buff: &mut String,
788        columns: Cols,
789        from: &Data,
790        condition: &Cond,
791        limit: Option<u32>,
792    ) where
793        Self: Sized,
794        Item: Expression,
795        Cols: IntoIterator<Item = Item> + Clone,
796        Data: DataSet,
797        Cond: Expression,
798    {
799        buff.push_str("SELECT ");
800        let mut has_order_by = false;
801        let context = Context {
802            fragment: Fragment::SqlSelect,
803            qualify_columns: Data::qualified_columns(),
804        };
805        separated_by(
806            buff,
807            columns.clone(),
808            |buff, col| {
809                col.write_query(self, context, buff);
810                has_order_by = has_order_by || col.is_ordered();
811            },
812            ", ",
813        );
814        buff.push_str("\nFROM ");
815        from.write_query(self, context.with_context(Fragment::SqlSelectFrom), buff);
816        buff.push_str("\nWHERE ");
817        condition.write_query(self, context.with_context(Fragment::SqlSelectWhere), buff);
818        if has_order_by {
819            buff.push_str("\nORDER BY ");
820            for col in columns.into_iter().filter(Expression::is_ordered) {
821                col.write_query(self, context.with_context(Fragment::SqlSelectOrderBy), buff);
822            }
823        }
824        if let Some(limit) = limit {
825            let _ = write!(buff, "\nLIMIT {}", limit);
826        }
827        buff.push(';');
828    }
829
830    fn write_insert<'b, E, It>(&self, buff: &mut String, entities: It, update: bool)
831    where
832        Self: Sized,
833        E: Entity + 'b,
834        It: IntoIterator<Item = &'b E>,
835    {
836        let mut rows = entities.into_iter().map(Entity::row_filtered).peekable();
837        let Some(mut row) = rows.next() else {
838            return;
839        };
840        buff.push_str("INSERT INTO ");
841        let mut context = Context {
842            fragment: Fragment::SqlInsertInto,
843            qualify_columns: E::qualified_columns(),
844        };
845        self.write_table_ref(context, buff, E::table_ref());
846        buff.push_str(" (");
847        let columns = E::columns().iter();
848        let single = rows.peek().is_none();
849        if single {
850            // Inserting a single row uses row_labeled to filter buff Passive::NotSet columns
851            separated_by(
852                buff,
853                row.iter(),
854                |buff, v| {
855                    self.write_identifier_quoted(context, buff, v.0);
856                },
857                ", ",
858            );
859        } else {
860            // Inserting more rows will list all columns, Passive::NotSet columns will result in DEFAULT value
861            separated_by(
862                buff,
863                columns.clone(),
864                |buff, v| {
865                    self.write_identifier_quoted(context, buff, v.name());
866                },
867                ", ",
868            );
869        };
870        buff.push_str(") VALUES\n");
871        context.fragment = Fragment::SqlInsertIntoValues;
872        let mut first_row = None;
873        let mut separate = false;
874        loop {
875            if separate {
876                buff.push_str(",\n");
877            }
878            buff.push('(');
879            let mut fields = row.iter();
880            let mut field = fields.next();
881            separated_by(
882                buff,
883                E::columns(),
884                |buff, col| {
885                    if Some(col.name()) == field.map(|v| v.0) {
886                        self.write_value(
887                            context,
888                            buff,
889                            field
890                                .map(|v| &v.1)
891                                .expect(&format!("Column {} does not have a value", col.name())),
892                        );
893                        field = fields.next();
894                    } else if !single {
895                        buff.push_str("DEFAULT");
896                    }
897                },
898                ", ",
899            );
900            buff.push(')');
901            separate = true;
902            if first_row.is_none() {
903                first_row = row.into();
904            }
905            if let Some(next) = rows.next() {
906                row = next;
907            } else {
908                break;
909            };
910        }
911        let first_row = first_row
912            .expect("Should have at least one row")
913            .into_iter()
914            .map(|(v, _)| v);
915        if update {
916            self.write_insert_update_fragment::<E, _>(
917                context,
918                buff,
919                if single {
920                    EitherIterator::Left(
921                        // If there is only one row to insert then list only the columns that appear
922                        columns.filter(|c| first_row.clone().find(|n| *n == c.name()).is_some()),
923                    )
924                } else {
925                    EitherIterator::Right(columns)
926                },
927            );
928        }
929        buff.push(';');
930    }
931
932    fn write_insert_update_fragment<'a, E, It>(
933        &self,
934        mut context: Context,
935        buff: &mut String,
936        columns: It,
937    ) where
938        Self: Sized,
939        E: Entity,
940        It: Iterator<Item = &'a ColumnDef>,
941    {
942        let pk = E::primary_key_def();
943        if pk.len() == 0 {
944            return;
945        }
946        buff.push_str("\nON CONFLICT");
947        context.fragment = Fragment::SqlInsertIntoOnConflict;
948        if pk.len() > 0 {
949            buff.push_str(" (");
950            separated_by(
951                buff,
952                pk,
953                |buff, v| {
954                    self.write_identifier_quoted(context, buff, v.name());
955                },
956                ", ",
957            );
958            buff.push(')');
959        }
960        buff.push_str(" DO UPDATE SET\n");
961        separated_by(
962            buff,
963            columns.filter(|c| c.primary_key == PrimaryKeyType::None),
964            |buff, v| {
965                self.write_identifier_quoted(context, buff, v.name());
966                buff.push_str(" = EXCLUDED.");
967                self.write_identifier_quoted(context, buff, v.name());
968            },
969            ",\n",
970        );
971    }
972
973    fn write_delete<E, Expr>(&self, buff: &mut String, condition: &Expr)
974    where
975        Self: Sized,
976        E: Entity,
977        Expr: Expression,
978    {
979        buff.push_str("DELETE FROM ");
980        let context = Context {
981            fragment: Fragment::SqlDeleteFrom,
982            qualify_columns: E::qualified_columns(),
983        };
984        self.write_table_ref(context, buff, E::table_ref());
985        buff.push_str("\nWHERE ");
986        condition.write_query(
987            self,
988            context.with_context(Fragment::SqlDeleteFromWhere),
989            buff,
990        );
991        buff.push(';');
992    }
993}
994
995pub struct GenericSqlWriter;
996impl GenericSqlWriter {
997    pub fn new() -> Self {
998        Self {}
999    }
1000}
1001impl SqlWriter for GenericSqlWriter {
1002    fn as_dyn(&self) -> &dyn SqlWriter {
1003        self
1004    }
1005}