tank_core/writer/
sql_writer.rs

1use crate::{
2    Action, BinaryOp, BinaryOpType, ColumnDef, ColumnRef, DataSet, EitherIterator, Entity,
3    Expression, Fragment, Interval, Join, JoinType, Operand, Order, Ordered, PrimaryKeyType,
4    TableRef, UnaryOp, UnaryOpType, Value, possibly_parenthesized, separated_by, writer::Context,
5};
6use std::{collections::HashMap, fmt::Write};
7use time::{Date, Time};
8
9macro_rules! write_integer {
10    ($buff:ident, $value:expr) => {{
11        let mut buffer = itoa::Buffer::new();
12        $buff.push_str(buffer.format($value));
13    }};
14}
15macro_rules! write_float {
16    ($this:ident, $context:ident,$buff:ident, $value:expr) => {{
17        if $value.is_infinite() {
18            $this.write_value_infinity($context, $buff, $value.is_sign_negative());
19        } else {
20            let mut buffer = ryu::Buffer::new();
21            $buff.push_str(buffer.format($value));
22        }
23    }};
24}
25
26pub trait SqlWriter {
27    fn as_dyn(&self) -> &dyn SqlWriter;
28
29    fn alias_declaration(&self, context: Context) -> bool {
30        match context.fragment {
31            Fragment::SqlSelectFrom | Fragment::SqlJoin => true,
32            _ => false,
33        }
34    }
35
36    fn write_escaped(
37        &self,
38        _context: Context,
39        buff: &mut String,
40        value: &str,
41        search: char,
42        replace: &str,
43    ) {
44        let mut position = 0;
45        for (i, c) in value.char_indices() {
46            if c == search {
47                buff.push_str(&value[position..i]);
48                buff.push_str(replace);
49                position = i + 1;
50            }
51        }
52        buff.push_str(&value[position..]);
53    }
54
55    fn write_identifier_quoted(&self, context: Context, buff: &mut String, value: &str) {
56        buff.push('"');
57        self.write_escaped(context, buff, value, '"', r#""""#);
58        buff.push('"');
59    }
60
61    fn write_table_ref(&self, context: Context, buff: &mut String, value: &TableRef) {
62        if self.alias_declaration(context) || value.alias.is_empty() {
63            if !value.schema.is_empty() {
64                self.write_identifier_quoted(context, buff, &value.schema);
65                buff.push('.');
66            }
67            self.write_identifier_quoted(context, buff, &value.name);
68        }
69        if !value.alias.is_empty() {
70            let _ = write!(buff, " {}", value.alias);
71        }
72    }
73
74    fn write_column_ref(&self, context: Context, buff: &mut String, value: &ColumnRef) {
75        if context.qualify_columns && !value.table.is_empty() {
76            if !value.schema.is_empty() {
77                self.write_identifier_quoted(context, buff, &value.schema);
78                buff.push('.');
79            }
80            self.write_identifier_quoted(context, buff, &value.table);
81            buff.push('.');
82        }
83        self.write_identifier_quoted(context, buff, &value.name);
84    }
85
86    fn write_column_type(&self, context: Context, buff: &mut String, value: &Value) {
87        match value {
88            Value::Boolean(..) => buff.push_str("BOOLEAN"),
89            Value::Int8(..) => buff.push_str("TINYINT"),
90            Value::Int16(..) => buff.push_str("SMALLINT"),
91            Value::Int32(..) => buff.push_str("INTEGER"),
92            Value::Int64(..) => buff.push_str("BIGINT"),
93            Value::Int128(..) => buff.push_str("HUGEINT"),
94            Value::UInt8(..) => buff.push_str("UTINYINT"),
95            Value::UInt16(..) => buff.push_str("USMALLINT"),
96            Value::UInt32(..) => buff.push_str("UINTEGER"),
97            Value::UInt64(..) => buff.push_str("UBIGINT"),
98            Value::UInt128(..) => buff.push_str("UHUGEINT"),
99            Value::Float32(..) => buff.push_str("FLOAT"),
100            Value::Float64(..) => buff.push_str("DOUBLE"),
101            Value::Decimal(.., precision, scale) => {
102                buff.push_str("DECIMAL");
103                if (precision, scale) != (&0, &0) {
104                    let _ = write!(buff, "({},{})", precision, scale);
105                }
106            }
107            Value::Char(..) => buff.push_str("CHAR(1)"),
108            Value::Varchar(..) => buff.push_str("VARCHAR"),
109            Value::Blob(..) => buff.push_str("BLOB"),
110            Value::Date(..) => buff.push_str("DATE"),
111            Value::Time(..) => buff.push_str("TIME"),
112            Value::Timestamp(..) => buff.push_str("TIMESTAMP"),
113            Value::TimestampWithTimezone(..) => buff.push_str("TIMESTAMP WITH TIME ZONE"),
114            Value::Interval(..) => buff.push_str("INTERVAL"),
115            Value::Uuid(..) => buff.push_str("UUID"),
116            Value::Array(.., inner, size) => {
117                self.write_column_type(context, buff, inner);
118                let _ = write!(buff, "[{}]", size);
119            }
120            Value::List(.., inner) => {
121                self.write_column_type(context, buff, inner);
122                buff.push_str("[]");
123            }
124            Value::Map(.., key, value) => {
125                buff.push_str("MAP(");
126                self.write_column_type(context, buff, key);
127                buff.push(',');
128                self.write_column_type(context, buff, value);
129                buff.push(')');
130            }
131            _ => log::error!(
132                "Unexpected tank::Value, variant {:?} is not supported",
133                value
134            ),
135        };
136    }
137
138    fn write_value(&self, context: Context, buff: &mut String, value: &Value) {
139        match value {
140            Value::Null
141            | Value::Boolean(None, ..)
142            | Value::Int8(None, ..)
143            | Value::Int16(None, ..)
144            | Value::Int32(None, ..)
145            | Value::Int64(None, ..)
146            | Value::Int128(None, ..)
147            | Value::UInt8(None, ..)
148            | Value::UInt16(None, ..)
149            | Value::UInt32(None, ..)
150            | Value::UInt64(None, ..)
151            | Value::UInt128(None, ..)
152            | Value::Float32(None, ..)
153            | Value::Float64(None, ..)
154            | Value::Decimal(None, ..)
155            | Value::Char(None, ..)
156            | Value::Varchar(None, ..)
157            | Value::Blob(None, ..)
158            | Value::Date(None, ..)
159            | Value::Time(None, ..)
160            | Value::Timestamp(None, ..)
161            | Value::TimestampWithTimezone(None, ..)
162            | Value::Interval(None, ..)
163            | Value::Uuid(None, ..)
164            | Value::Array(None, ..)
165            | Value::List(None, ..)
166            | Value::Map(None, ..)
167            | Value::Struct(None, ..) => self.write_value_none(context, buff),
168            Value::Boolean(Some(v), ..) => self.write_value_bool(context, buff, *v),
169            Value::Int8(Some(v), ..) => write_integer!(buff, *v),
170            Value::Int16(Some(v), ..) => write_integer!(buff, *v),
171            Value::Int32(Some(v), ..) => write_integer!(buff, *v),
172            Value::Int64(Some(v), ..) => write_integer!(buff, *v),
173            Value::Int128(Some(v), ..) => write_integer!(buff, *v),
174            Value::UInt8(Some(v), ..) => write_integer!(buff, *v),
175            Value::UInt16(Some(v), ..) => write_integer!(buff, *v),
176            Value::UInt32(Some(v), ..) => write_integer!(buff, *v),
177            Value::UInt64(Some(v), ..) => write_integer!(buff, *v),
178            Value::UInt128(Some(v), ..) => write_integer!(buff, *v),
179            Value::Float32(Some(v), ..) => write_float!(self, context, buff, *v),
180            Value::Float64(Some(v), ..) => write_float!(self, context, buff, *v),
181            Value::Decimal(Some(v), ..) => drop(write!(buff, "{}", v)),
182            Value::Char(Some(v), ..) => {
183                buff.push('\'');
184                buff.push(*v);
185                buff.push('\'');
186            }
187            Value::Varchar(Some(v), ..) => self.write_value_string(context, buff, v),
188            Value::Blob(Some(v), ..) => self.write_value_blob(context, buff, v.as_ref()),
189            Value::Date(Some(v), ..) => {
190                buff.push('\'');
191                self.write_value_date(context, buff, v);
192                buff.push('\'');
193            }
194            Value::Time(Some(v), ..) => {
195                buff.push('\'');
196                self.write_value_time(context, buff, v);
197                buff.push('\'');
198            }
199            Value::Timestamp(Some(v), ..) => {
200                buff.push('\'');
201                self.write_value_date(context, buff, &v.date());
202                buff.push('T');
203                self.write_value_time(context, buff, &v.time());
204                buff.push('\'');
205            }
206            Value::TimestampWithTimezone(Some(v), ..) => {
207                buff.push('\'');
208                self.write_value_date(context, buff, &v.date());
209                buff.push('T');
210                self.write_value_time(context, buff, &v.time());
211                let _ = write!(
212                    buff,
213                    "{:+02}:{:02}",
214                    v.offset().whole_hours(),
215                    v.offset().whole_minutes()
216                );
217                buff.push('\'');
218            }
219            Value::Interval(Some(v), ..) => self.write_value_interval(context, buff, v),
220            Value::Uuid(Some(v), ..) => drop(write!(buff, "'{}'", v)),
221            Value::List(Some(..), ..) | Value::Array(Some(..), ..) => {
222                let v = match value {
223                    Value::List(Some(v), ..) => v.iter(),
224                    Value::Array(Some(v), ..) => v.iter(),
225                    _ => unreachable!(),
226                };
227                buff.push('[');
228                separated_by(
229                    buff,
230                    v,
231                    |buff, v| {
232                        self.write_value(context, buff, v);
233                    },
234                    ",",
235                );
236                buff.push(']');
237            }
238            Value::Map(Some(v), ..) => self.write_value_map(context, buff, v),
239            Value::Struct(Some(_v), ..) => {
240                todo!()
241            }
242        };
243    }
244
245    fn write_value_none(&self, _context: Context, buff: &mut String) {
246        buff.push_str("NULL");
247    }
248
249    fn write_value_bool(&self, _context: Context, buff: &mut String, value: bool) {
250        buff.push_str(["false", "true"][value as usize]);
251    }
252
253    fn write_value_infinity(&self, context: Context, buff: &mut String, negative: bool) {
254        let mut buffer = ryu::Buffer::new();
255        self.write_expression_binary_op(
256            context,
257            buff,
258            &BinaryOp {
259                op: BinaryOpType::Cast,
260                lhs: &Operand::LitStr(buffer.format(if negative {
261                    f64::NEG_INFINITY
262                } else {
263                    f64::INFINITY
264                })),
265                rhs: &Operand::Type(Value::Float64(None)),
266            },
267        );
268    }
269
270    fn write_value_string(&self, _context: Context, buff: &mut String, value: &str) {
271        buff.push('\'');
272        let mut position = 0;
273        for (i, c) in value.char_indices() {
274            if c == '\'' {
275                buff.push_str(&value[position..i]);
276                buff.push_str("''");
277                position = i + 1;
278            } else if c == '\n' {
279                buff.push_str(&value[position..i]);
280                buff.push_str("\\n");
281                position = i + 1;
282            }
283        }
284        buff.push_str(&value[position..]);
285        buff.push('\'');
286    }
287
288    fn write_value_blob(&self, _context: Context, buff: &mut String, value: &[u8]) {
289        buff.push('\'');
290        for b in value {
291            let _ = write!(buff, "\\x{:X}", b);
292        }
293        buff.push('\'');
294    }
295
296    fn write_value_date(&self, _context: Context, buff: &mut String, value: &Date) {
297        let _ = write!(
298            buff,
299            "{:04}-{:02}-{:02}",
300            value.year(),
301            value.month() as u8,
302            value.day()
303        );
304    }
305
306    fn write_value_time(&self, _context: Context, buff: &mut String, value: &Time) {
307        let mut subsecond = value.nanosecond();
308        let mut width = 9;
309        while width > 1 && subsecond % 10 == 0 {
310            subsecond /= 10;
311            width -= 1;
312        }
313        let _ = write!(
314            buff,
315            "{:02}:{:02}:{:02}.{:0width$}",
316            value.hour(),
317            value.minute(),
318            value.second(),
319            subsecond
320        );
321    }
322
323    fn value_interval_units(&self) -> &[(&str, i128)] {
324        static UNITS: &[(&str, i128)] = &[
325            ("DAY", Interval::NANOS_IN_DAY),
326            ("HOUR", Interval::NANOS_IN_SEC * 3600),
327            ("MINUTE", Interval::NANOS_IN_SEC * 60),
328            ("SECOND", Interval::NANOS_IN_SEC),
329            ("MICROSECOND", 1_000),
330            ("NANOSECOND", 1),
331        ];
332        UNITS
333    }
334
335    fn write_value_interval(&self, _context: Context, buff: &mut String, value: &Interval) {
336        buff.push_str("INTERVAL ");
337        if value.is_zero() {
338            buff.push_str("0 SECONDS");
339            return;
340        }
341        macro_rules! write_unit {
342            ($buff:ident, $val:expr, $unit:expr) => {
343                let _ = write!(
344                    $buff,
345                    "{} {}{}",
346                    $val,
347                    $unit,
348                    if $val != 1 { "S" } else { "" }
349                );
350            };
351        }
352        let months = value.months;
353        let nanos = value.nanos + value.days as i128 * Interval::NANOS_IN_DAY;
354        let multiple_units = nanos != 0 && value.months != 0;
355        if multiple_units {
356            buff.push('\'');
357        }
358        if months != 0 {
359            if months % 12 == 0 {
360                write_unit!(buff, months / 12, "YEAR");
361            } else {
362                write_unit!(buff, months, "MONTH");
363            }
364        }
365        for &(name, factor) in self.value_interval_units() {
366            if nanos % factor == 0 {
367                let value = nanos / factor;
368                if value != 0 {
369                    if months != 0 {
370                        buff.push(' ');
371                    }
372                    write_unit!(buff, value, name);
373                    break;
374                }
375            }
376        }
377        if multiple_units {
378            buff.push('\'');
379        }
380    }
381
382    fn write_value_map(&self, context: Context, buff: &mut String, value: &HashMap<Value, Value>) {
383        buff.push('{');
384        separated_by(
385            buff,
386            value,
387            |buff, (k, v)| {
388                self.write_value(context, buff, k);
389                buff.push(':');
390                self.write_value(context, buff, v);
391            },
392            ",",
393        );
394        buff.push('}');
395    }
396
397    fn expression_unary_op_precedence<'a>(&self, value: &UnaryOpType) -> i32 {
398        match value {
399            UnaryOpType::Negative => 1250,
400            UnaryOpType::Not => 250,
401        }
402    }
403
404    fn expression_binary_op_precedence<'a>(&self, value: &BinaryOpType) -> i32 {
405        match value {
406            BinaryOpType::Or => 100,
407            BinaryOpType::And => 200,
408            BinaryOpType::Equal => 300,
409            BinaryOpType::NotEqual => 300,
410            BinaryOpType::Less => 300,
411            BinaryOpType::Greater => 300,
412            BinaryOpType::LessEqual => 300,
413            BinaryOpType::GreaterEqual => 300,
414            BinaryOpType::Is => 400,
415            BinaryOpType::IsNot => 400,
416            BinaryOpType::Like => 400,
417            BinaryOpType::NotLike => 400,
418            BinaryOpType::Regexp => 400,
419            BinaryOpType::NotRegexp => 400,
420            BinaryOpType::Glob => 400,
421            BinaryOpType::NotGlob => 400,
422            BinaryOpType::BitwiseOr => 500,
423            BinaryOpType::BitwiseAnd => 600,
424            BinaryOpType::ShiftLeft => 700,
425            BinaryOpType::ShiftRight => 700,
426            BinaryOpType::Subtraction => 800,
427            BinaryOpType::Addition => 800,
428            BinaryOpType::Multiplication => 900,
429            BinaryOpType::Division => 900,
430            BinaryOpType::Remainder => 900,
431            BinaryOpType::Indexing => 1000,
432            BinaryOpType::Cast => 1100,
433            BinaryOpType::Alias => 1200,
434        }
435    }
436
437    fn write_expression_operand(&self, context: Context, buff: &mut String, value: &Operand) {
438        match value {
439            Operand::LitBool(v) => self.write_value_bool(context, buff, *v),
440            Operand::LitFloat(v) => write_float!(self, context, buff, *v),
441            Operand::LitIdent(v) => drop(buff.push_str(v)),
442            Operand::LitField(v) => separated_by(buff, *v, |buff, v| buff.push_str(v), "."),
443            Operand::LitInt(v) => write_integer!(buff, *v),
444            Operand::LitStr(v) => self.write_value_string(context, buff, v),
445            Operand::LitArray(v) => {
446                buff.push('[');
447                separated_by(
448                    buff,
449                    *v,
450                    |buff, v| {
451                        v.write_query(self.as_dyn(), context, buff);
452                    },
453                    ", ",
454                );
455                buff.push(']');
456            }
457            Operand::Null => drop(buff.push_str("NULL")),
458            Operand::Type(v) => self.write_column_type(context, buff, v),
459            Operand::Variable(v) => self.write_value(context, buff, v),
460            Operand::Call(f, args) => {
461                buff.push_str(f);
462                buff.push('(');
463                separated_by(
464                    buff,
465                    *args,
466                    |buff, v| {
467                        v.write_query(self.as_dyn(), context, buff);
468                    },
469                    ",",
470                );
471                buff.push(')');
472            }
473            Operand::Asterisk => drop(buff.push('*')),
474            Operand::QuestionMark => drop(buff.push('?')),
475        };
476    }
477
478    fn write_expression_unary_op(
479        self: &Self,
480        context: Context,
481        buff: &mut String,
482        value: &UnaryOp<&dyn Expression>,
483    ) {
484        match value.op {
485            UnaryOpType::Negative => buff.push('-'),
486            UnaryOpType::Not => buff.push_str("NOT "),
487        };
488        possibly_parenthesized!(
489            buff,
490            value.arg.precedence(self.as_dyn()) <= self.expression_unary_op_precedence(&value.op),
491            value.arg.write_query(self.as_dyn(), context, buff)
492        );
493    }
494
495    fn write_expression_binary_op(
496        &self,
497        context: Context,
498        buff: &mut String,
499        value: &BinaryOp<&dyn Expression, &dyn Expression>,
500    ) {
501        let (prefix, infix, suffix, lhs_parenthesized, rhs_parenthesized) = match value.op {
502            BinaryOpType::Indexing => ("", "[", "]", false, true),
503            BinaryOpType::Cast => ("CAST(", " AS ", ")", true, true),
504            BinaryOpType::Multiplication => ("", " * ", "", false, false),
505            BinaryOpType::Division => ("", " / ", "", false, false),
506            BinaryOpType::Remainder => ("", " % ", "", false, false),
507            BinaryOpType::Addition => ("", " + ", "", false, false),
508            BinaryOpType::Subtraction => ("", " - ", "", false, false),
509            BinaryOpType::ShiftLeft => ("", " << ", "", false, false),
510            BinaryOpType::ShiftRight => ("", " >> ", "", false, false),
511            BinaryOpType::BitwiseAnd => ("", " & ", "", false, false),
512            BinaryOpType::BitwiseOr => ("", " | ", "", false, false),
513            BinaryOpType::Is => ("", " Is ", "", false, false),
514            BinaryOpType::IsNot => ("", " IS NOT ", "", false, false),
515            BinaryOpType::Like => ("", " LIKE ", "", false, false),
516            BinaryOpType::NotLike => ("", " NOT LIKE ", "", false, false),
517            BinaryOpType::Regexp => ("", " REGEXP ", "", false, false),
518            BinaryOpType::NotRegexp => ("", " NOT REGEXP ", "", false, false),
519            BinaryOpType::Glob => ("", " GLOB ", "", false, false),
520            BinaryOpType::NotGlob => ("", " NOT GLOB ", "", false, false),
521            BinaryOpType::Equal => ("", " = ", "", false, false),
522            BinaryOpType::NotEqual => ("", " != ", "", false, false),
523            BinaryOpType::Less => ("", " < ", "", false, false),
524            BinaryOpType::LessEqual => ("", " <= ", "", false, false),
525            BinaryOpType::Greater => ("", " > ", "", false, false),
526            BinaryOpType::GreaterEqual => ("", " >= ", "", false, false),
527            BinaryOpType::And => ("", " AND ", "", false, false),
528            BinaryOpType::Or => ("", " OR ", "", false, false),
529            BinaryOpType::Alias => ("", " AS ", "", false, false),
530        };
531        let precedence = self.expression_binary_op_precedence(&value.op);
532        buff.push_str(prefix);
533        possibly_parenthesized!(
534            buff,
535            !lhs_parenthesized && value.lhs.precedence(self.as_dyn()) < precedence,
536            value.lhs.write_query(self.as_dyn(), context, buff)
537        );
538        buff.push_str(infix);
539        possibly_parenthesized!(
540            buff,
541            !rhs_parenthesized && value.rhs.precedence(self.as_dyn()) <= precedence,
542            value.rhs.write_query(self.as_dyn(), context, buff)
543        );
544        buff.push_str(suffix);
545    }
546
547    fn write_expression_ordered(
548        &self,
549        context: Context,
550        buff: &mut String,
551        value: &Ordered<&dyn Expression>,
552    ) {
553        value.expression.write_query(self.as_dyn(), context, buff);
554        if context.fragment == Fragment::SqlSelectOrderBy {
555            let _ = write!(
556                buff,
557                " {}",
558                match value.order {
559                    Order::ASC => "ASC",
560                    Order::DESC => "DESC",
561                }
562            );
563        }
564    }
565
566    fn write_join_type(&self, _context: Context, buff: &mut String, join_type: &JoinType) {
567        buff.push_str(match &join_type {
568            JoinType::Default => "JOIN",
569            JoinType::Inner => "INNER JOIN",
570            JoinType::Outer => "OUTER JOIN",
571            JoinType::Left => "LEFT JOIN",
572            JoinType::Right => "RIGHT JOIN",
573            JoinType::Cross => "CROSS",
574            JoinType::Natural => "NATURAL JOIN",
575        });
576    }
577
578    fn write_join(
579        &self,
580        _context: Context,
581        buff: &mut String,
582        join: &Join<&dyn DataSet, &dyn DataSet, &dyn Expression>,
583    ) {
584        let context = Context {
585            fragment: Fragment::SqlJoin,
586            qualify_columns: true,
587        };
588        join.lhs.write_query(self.as_dyn(), context, buff);
589        buff.push(' ');
590        self.write_join_type(context, buff, &join.join);
591        buff.push(' ');
592        join.rhs.write_query(self.as_dyn(), context, buff);
593        if let Some(on) = &join.on {
594            buff.push_str(" ON ");
595            let context = Context {
596                qualify_columns: true,
597                ..context
598            };
599            on.write_query(self.as_dyn(), context, buff);
600        }
601    }
602
603    fn write_transaction_begin(&self, buff: &mut String) {
604        buff.push_str("BEGIN;");
605    }
606
607    fn write_transaction_commit(&self, buff: &mut String) {
608        buff.push_str("COMMIT;");
609    }
610
611    fn write_transaction_rollback(&self, buff: &mut String) {
612        buff.push_str("ROLLBACK;");
613    }
614
615    fn write_create_schema<E>(&self, buff: &mut String, if_not_exists: bool)
616    where
617        Self: Sized,
618        E: Entity,
619    {
620        buff.push_str("CREATE SCHEMA ");
621        let context = Context {
622            fragment: Fragment::SqlCreateSchema,
623            qualify_columns: E::qualified_columns(),
624        };
625        if if_not_exists {
626            buff.push_str("IF NOT EXISTS ");
627        }
628        self.write_identifier_quoted(context, buff, E::table_ref().schema);
629        buff.push(';');
630    }
631
632    fn write_drop_schema<E>(&self, buff: &mut String, if_exists: bool)
633    where
634        Self: Sized,
635        E: Entity,
636    {
637        buff.push_str("DROP SCHEMA ");
638        let context = Context {
639            fragment: Fragment::SqlDropSchema,
640            qualify_columns: E::qualified_columns(),
641        };
642        if if_exists {
643            buff.push_str("IF EXISTS ");
644        }
645        self.write_identifier_quoted(context, buff, E::table_ref().schema);
646        buff.push(';');
647    }
648
649    fn write_create_table<E>(&self, buff: &mut String, if_not_exists: bool)
650    where
651        Self: Sized,
652        E: Entity,
653    {
654        buff.push_str("CREATE TABLE ");
655        if if_not_exists {
656            buff.push_str("IF NOT EXISTS ");
657        }
658        let context = Context {
659            fragment: Fragment::SqlCreateTable,
660            qualify_columns: E::qualified_columns(),
661        };
662        self.write_table_ref(context, buff, E::table_ref());
663        buff.push_str(" (\n");
664        separated_by(
665            buff,
666            E::columns(),
667            |buff, v| {
668                self.write_create_table_column_fragment(context, buff, v);
669            },
670            ",\n",
671        );
672        let primary_key = E::primary_key_def();
673        if primary_key.len() > 1 {
674            buff.push_str(",\nPRIMARY KEY (");
675            separated_by(
676                buff,
677                primary_key,
678                |buff, v| {
679                    self.write_identifier_quoted(
680                        context.with_context(Fragment::SqlCreateTablePrimaryKey),
681                        buff,
682                        v.name(),
683                    );
684                },
685                ", ",
686            );
687            buff.push(')');
688        }
689        for unique in E::unique_defs() {
690            if unique.len() > 1 {
691                buff.push_str(",\nUNIQUE (");
692                separated_by(
693                    buff,
694                    unique,
695                    |buff, v| {
696                        self.write_identifier_quoted(
697                            context.with_context(Fragment::SqlCreateTableUnique),
698                            buff,
699                            v.name(),
700                        );
701                    },
702                    ", ",
703                );
704                buff.push(')');
705            }
706        }
707        buff.push_str(");");
708        self.write_column_comments::<E>(context, buff);
709    }
710
711    fn write_column_comments<E>(&self, _context: Context, buff: &mut String)
712    where
713        Self: Sized,
714        E: Entity,
715    {
716        let context = Context {
717            fragment: Fragment::SqlCommentOnColumn,
718            qualify_columns: true,
719        };
720        for c in E::columns().iter().filter(|c| !c.comment.is_empty()) {
721            buff.push_str("\nCOMMENT ON COLUMN ");
722            self.write_column_ref(context, buff, c.into());
723            buff.push_str(" IS ");
724            self.write_value_string(context, buff, c.comment);
725            buff.push(';');
726        }
727    }
728
729    fn write_create_table_column_fragment(
730        &self,
731        context: Context,
732        buff: &mut String,
733        column: &ColumnDef,
734    ) where
735        Self: Sized,
736    {
737        self.write_identifier_quoted(context, buff, &column.name());
738        buff.push(' ');
739        if !column.column_type.is_empty() {
740            buff.push_str(&column.column_type);
741        } else {
742            SqlWriter::write_column_type(self, context, buff, &column.value);
743        }
744        if !column.nullable && column.primary_key == PrimaryKeyType::None {
745            buff.push_str(" NOT NULL");
746        }
747        if let Some(default) = &column.default {
748            buff.push_str(" DEFAULT ");
749            default.write_query(self.as_dyn(), context, buff);
750        }
751        if column.primary_key == PrimaryKeyType::PrimaryKey {
752            // Composite primary key will be printed elsewhere
753            buff.push_str(" PRIMARY KEY");
754        }
755        if column.unique && column.primary_key != PrimaryKeyType::PrimaryKey {
756            buff.push_str(" UNIQUE");
757        }
758        if let Some(references) = column.references {
759            buff.push_str(" REFERENCES ");
760            self.write_table_ref(context, buff, &references.table_ref());
761            buff.push('(');
762            self.write_column_ref(context, buff, &references);
763            buff.push(')');
764            if let Some(on_delete) = &column.on_delete {
765                buff.push_str(" ON DELETE ");
766                self.write_create_table_references_action(context, buff, on_delete);
767            }
768            if let Some(on_update) = &column.on_update {
769                buff.push_str(" ON UPDATE ");
770                self.write_create_table_references_action(context, buff, on_update);
771            }
772        }
773    }
774
775    fn write_create_table_references_action(
776        &self,
777        _context: Context,
778        buff: &mut String,
779        action: &Action,
780    ) {
781        buff.push_str(match action {
782            Action::NoAction => "NO ACTION",
783            Action::Restrict => "RESTRICT",
784            Action::Cascade => "CASCADE",
785            Action::SetNull => "SET NULL",
786            Action::SetDefault => "SET DEFAULT",
787        });
788    }
789
790    fn write_drop_table<E>(&self, buff: &mut String, if_exists: bool)
791    where
792        Self: Sized,
793        E: Entity,
794    {
795        buff.push_str("DROP TABLE ");
796        let context = Context {
797            fragment: Fragment::SqlDropTable,
798            qualify_columns: E::qualified_columns(),
799        };
800        if if_exists {
801            buff.push_str("IF EXISTS ");
802        }
803        self.write_table_ref(context, buff, E::table_ref());
804        buff.push(';');
805    }
806
807    fn write_select<Item, Cols, Data, Cond>(
808        &self,
809        buff: &mut String,
810        columns: Cols,
811        from: &Data,
812        condition: &Cond,
813        limit: Option<u32>,
814    ) where
815        Self: Sized,
816        Item: Expression,
817        Cols: IntoIterator<Item = Item> + Clone,
818        Data: DataSet,
819        Cond: Expression,
820    {
821        buff.push_str("SELECT ");
822        let mut has_order_by = false;
823        let context = Context {
824            fragment: Fragment::SqlSelect,
825            qualify_columns: Data::qualified_columns(),
826        };
827        separated_by(
828            buff,
829            columns.clone(),
830            |buff, col| {
831                col.write_query(self, context, buff);
832                has_order_by = has_order_by || col.is_ordered();
833            },
834            ", ",
835        );
836        buff.push_str("\nFROM ");
837        from.write_query(self, context.with_context(Fragment::SqlSelectFrom), buff);
838        buff.push_str("\nWHERE ");
839        condition.write_query(self, context.with_context(Fragment::SqlSelectWhere), buff);
840        if has_order_by {
841            buff.push_str("\nORDER BY ");
842            for col in columns.into_iter().filter(Expression::is_ordered) {
843                col.write_query(self, context.with_context(Fragment::SqlSelectOrderBy), buff);
844            }
845        }
846        if let Some(limit) = limit {
847            let _ = write!(buff, "\nLIMIT {}", limit);
848        }
849        buff.push(';');
850    }
851
852    fn write_insert<'b, E, It>(&self, buff: &mut String, entities: It, update: bool)
853    where
854        Self: Sized,
855        E: Entity + 'b,
856        It: IntoIterator<Item = &'b E>,
857    {
858        let mut rows = entities.into_iter().map(Entity::row_filtered).peekable();
859        let Some(mut row) = rows.next() else {
860            return;
861        };
862        buff.push_str("INSERT INTO ");
863        let mut context = Context {
864            fragment: Fragment::SqlInsertInto,
865            qualify_columns: E::qualified_columns(),
866        };
867        self.write_table_ref(context, buff, E::table_ref());
868        buff.push_str(" (");
869        let columns = E::columns().iter();
870        let single = rows.peek().is_none();
871        if single {
872            // Inserting a single row uses row_labeled to filter buff Passive::NotSet columns
873            separated_by(
874                buff,
875                row.iter(),
876                |buff, v| {
877                    self.write_identifier_quoted(context, buff, v.0);
878                },
879                ", ",
880            );
881        } else {
882            // Inserting more rows will list all columns, Passive::NotSet columns will result in DEFAULT value
883            separated_by(
884                buff,
885                columns.clone(),
886                |buff, v| {
887                    self.write_identifier_quoted(context, buff, v.name());
888                },
889                ", ",
890            );
891        };
892        buff.push_str(") VALUES\n");
893        context.fragment = Fragment::SqlInsertIntoValues;
894        let mut first_row = None;
895        let mut separate = false;
896        loop {
897            if separate {
898                buff.push_str(",\n");
899            }
900            buff.push('(');
901            let mut fields = row.iter();
902            let mut field = fields.next();
903            separated_by(
904                buff,
905                E::columns(),
906                |buff, col| {
907                    if Some(col.name()) == field.map(|v| v.0) {
908                        self.write_value(
909                            context,
910                            buff,
911                            field
912                                .map(|v| &v.1)
913                                .expect(&format!("Column {} does not have a value", col.name())),
914                        );
915                        field = fields.next();
916                    } else if !single {
917                        buff.push_str("DEFAULT");
918                    }
919                },
920                ", ",
921            );
922            buff.push(')');
923            separate = true;
924            if first_row.is_none() {
925                first_row = row.into();
926            }
927            if let Some(next) = rows.next() {
928                row = next;
929            } else {
930                break;
931            };
932        }
933        let first_row = first_row
934            .expect("Should have at least one row")
935            .into_iter()
936            .map(|(v, _)| v);
937        if update {
938            self.write_insert_update_fragment::<E, _>(
939                context,
940                buff,
941                if single {
942                    EitherIterator::Left(
943                        // If there is only one row to insert then list only the columns that appear
944                        columns.filter(|c| first_row.clone().find(|n| *n == c.name()).is_some()),
945                    )
946                } else {
947                    EitherIterator::Right(columns)
948                },
949            );
950        }
951        buff.push(';');
952    }
953
954    fn write_insert_update_fragment<'a, E, It>(
955        &self,
956        mut context: Context,
957        buff: &mut String,
958        columns: It,
959    ) where
960        Self: Sized,
961        E: Entity,
962        It: Iterator<Item = &'a ColumnDef>,
963    {
964        let pk = E::primary_key_def();
965        if pk.len() == 0 {
966            return;
967        }
968        buff.push_str("\nON CONFLICT");
969        context.fragment = Fragment::SqlInsertIntoOnConflict;
970        if pk.len() > 0 {
971            buff.push_str(" (");
972            separated_by(
973                buff,
974                pk,
975                |buff, v| {
976                    self.write_identifier_quoted(context, buff, v.name());
977                },
978                ", ",
979            );
980            buff.push(')');
981        }
982        buff.push_str(" DO UPDATE SET\n");
983        separated_by(
984            buff,
985            columns.filter(|c| c.primary_key == PrimaryKeyType::None),
986            |buff, v| {
987                self.write_identifier_quoted(context, buff, v.name());
988                buff.push_str(" = EXCLUDED.");
989                self.write_identifier_quoted(context, buff, v.name());
990            },
991            ",\n",
992        );
993    }
994
995    fn write_delete<E, Expr>(&self, buff: &mut String, condition: &Expr)
996    where
997        Self: Sized,
998        E: Entity,
999        Expr: Expression,
1000    {
1001        buff.push_str("DELETE FROM ");
1002        let context = Context {
1003            fragment: Fragment::SqlDeleteFrom,
1004            qualify_columns: E::qualified_columns(),
1005        };
1006        self.write_table_ref(context, buff, E::table_ref());
1007        buff.push_str("\nWHERE ");
1008        condition.write_query(
1009            self,
1010            context.with_context(Fragment::SqlDeleteFromWhere),
1011            buff,
1012        );
1013        buff.push(';');
1014    }
1015}
1016
1017pub struct GenericSqlWriter;
1018impl GenericSqlWriter {
1019    pub fn new() -> Self {
1020        Self {}
1021    }
1022}
1023impl SqlWriter for GenericSqlWriter {
1024    fn as_dyn(&self) -> &dyn SqlWriter {
1025        self
1026    }
1027}