tank_core/writer/
sql_writer.rs

1use crate::{
2    Action, BinaryOp, BinaryOpType, ColumnDef, ColumnRef, DataSet, EitherIterator, Entity,
3    Expression, Fragment, Interval, Join, JoinType, Operand, Order, Ordered, PrimaryKeyType,
4    TableRef, UnaryOp, UnaryOpType, Value, possibly_parenthesized, separated_by, writer::Context,
5};
6use futures::future::Either;
7use std::{collections::HashMap, fmt::Write};
8use time::{Date, OffsetDateTime, PrimitiveDateTime, Time};
9
10macro_rules! write_integer {
11    ($buff:ident, $value:expr) => {{
12        let mut buffer = itoa::Buffer::new();
13        $buff.push_str(buffer.format($value));
14    }};
15}
16macro_rules! write_float {
17    ($this:ident, $context:ident,$buff:ident, $value:expr) => {{
18        if $value.is_infinite() {
19            $this.write_value_infinity($context, $buff, $value.is_sign_negative());
20        } else {
21            let mut buffer = ryu::Buffer::new();
22            $buff.push_str(buffer.format($value));
23        }
24    }};
25}
26
27pub trait SqlWriter {
28    fn as_dyn(&self) -> &dyn SqlWriter;
29
30    fn alias_declaration(&self, context: &mut Context) -> bool {
31        match context.fragment {
32            Fragment::SqlSelectFrom | Fragment::SqlJoin => true,
33            _ => false,
34        }
35    }
36
37    fn write_escaped(
38        &self,
39        _context: &mut Context,
40        buff: &mut String,
41        value: &str,
42        search: char,
43        replace: &str,
44    ) {
45        let mut position = 0;
46        for (i, c) in value.char_indices() {
47            if c == search {
48                buff.push_str(&value[position..i]);
49                buff.push_str(replace);
50                position = i + 1;
51            }
52        }
53        buff.push_str(&value[position..]);
54    }
55
56    fn write_identifier_quoted(&self, context: &mut Context, buff: &mut String, value: &str) {
57        buff.push('"');
58        self.write_escaped(context, buff, value, '"', r#""""#);
59        buff.push('"');
60    }
61
62    fn write_table_ref(&self, context: &mut Context, buff: &mut String, value: &TableRef) {
63        if self.alias_declaration(context) || value.alias.is_empty() {
64            if !value.schema.is_empty() {
65                self.write_identifier_quoted(context, buff, &value.schema);
66                buff.push('.');
67            }
68            self.write_identifier_quoted(context, buff, &value.name);
69        }
70        if !value.alias.is_empty() {
71            let _ = write!(buff, " {}", value.alias);
72        }
73    }
74
75    fn write_column_ref(&self, context: &mut Context, buff: &mut String, value: &ColumnRef) {
76        if context.qualify_columns && !value.table.is_empty() {
77            if !value.schema.is_empty() {
78                self.write_identifier_quoted(context, buff, &value.schema);
79                buff.push('.');
80            }
81            self.write_identifier_quoted(context, buff, &value.table);
82            buff.push('.');
83        }
84        self.write_identifier_quoted(context, buff, &value.name);
85    }
86
87    fn write_column_type(&self, context: &mut Context, buff: &mut String, value: &Value) {
88        match value {
89            Value::Boolean(..) => buff.push_str("BOOLEAN"),
90            Value::Int8(..) => buff.push_str("TINYINT"),
91            Value::Int16(..) => buff.push_str("SMALLINT"),
92            Value::Int32(..) => buff.push_str("INTEGER"),
93            Value::Int64(..) => buff.push_str("BIGINT"),
94            Value::Int128(..) => buff.push_str("HUGEINT"),
95            Value::UInt8(..) => buff.push_str("UTINYINT"),
96            Value::UInt16(..) => buff.push_str("USMALLINT"),
97            Value::UInt32(..) => buff.push_str("UINTEGER"),
98            Value::UInt64(..) => buff.push_str("UBIGINT"),
99            Value::UInt128(..) => buff.push_str("UHUGEINT"),
100            Value::Float32(..) => buff.push_str("FLOAT"),
101            Value::Float64(..) => buff.push_str("DOUBLE"),
102            Value::Decimal(.., precision, scale) => {
103                buff.push_str("DECIMAL");
104                if (precision, scale) != (&0, &0) {
105                    let _ = write!(buff, "({},{})", precision, scale);
106                }
107            }
108            Value::Char(..) => buff.push_str("CHAR(1)"),
109            Value::Varchar(..) => buff.push_str("VARCHAR"),
110            Value::Blob(..) => buff.push_str("BLOB"),
111            Value::Date(..) => buff.push_str("DATE"),
112            Value::Time(..) => buff.push_str("TIME"),
113            Value::Timestamp(..) => buff.push_str("TIMESTAMP"),
114            Value::TimestampWithTimezone(..) => buff.push_str("TIMESTAMPTZ"),
115            Value::Interval(..) => buff.push_str("INTERVAL"),
116            Value::Uuid(..) => buff.push_str("UUID"),
117            Value::Array(.., inner, size) => {
118                self.write_column_type(context, buff, inner);
119                let _ = write!(buff, "[{}]", size);
120            }
121            Value::List(.., inner) => {
122                self.write_column_type(context, buff, inner);
123                buff.push_str("[]");
124            }
125            Value::Map(.., key, value) => {
126                buff.push_str("MAP(");
127                self.write_column_type(context, buff, key);
128                buff.push(',');
129                self.write_column_type(context, buff, value);
130                buff.push(')');
131            }
132            _ => log::error!(
133                "Unexpected tank::Value, variant {:?} is not supported",
134                value
135            ),
136        };
137    }
138
139    fn write_value(&self, context: &mut Context, buff: &mut String, value: &Value) {
140        match value {
141            v if v.is_null() => self.write_value_none(context, buff),
142            Value::Boolean(Some(v), ..) => self.write_value_bool(context, buff, *v),
143            Value::Int8(Some(v), ..) => write_integer!(buff, *v),
144            Value::Int16(Some(v), ..) => write_integer!(buff, *v),
145            Value::Int32(Some(v), ..) => write_integer!(buff, *v),
146            Value::Int64(Some(v), ..) => write_integer!(buff, *v),
147            Value::Int128(Some(v), ..) => write_integer!(buff, *v),
148            Value::UInt8(Some(v), ..) => write_integer!(buff, *v),
149            Value::UInt16(Some(v), ..) => write_integer!(buff, *v),
150            Value::UInt32(Some(v), ..) => write_integer!(buff, *v),
151            Value::UInt64(Some(v), ..) => write_integer!(buff, *v),
152            Value::UInt128(Some(v), ..) => write_integer!(buff, *v),
153            Value::Float32(Some(v), ..) => write_float!(self, context, buff, *v),
154            Value::Float64(Some(v), ..) => write_float!(self, context, buff, *v),
155            Value::Decimal(Some(v), ..) => drop(write!(buff, "{}", v)),
156            Value::Char(Some(v), ..) => {
157                buff.push('\'');
158                buff.push(*v);
159                buff.push('\'');
160            }
161            Value::Varchar(Some(v), ..) => self.write_value_string(context, buff, v),
162            Value::Blob(Some(v), ..) => self.write_value_blob(context, buff, v.as_ref()),
163            Value::Date(Some(v), ..) => self.write_value_date(context, buff, v, false),
164            Value::Time(Some(v), ..) => self.write_value_time(context, buff, v, false),
165            Value::Timestamp(Some(v), ..) => self.write_value_timestamp(context, buff, v),
166            Value::TimestampWithTimezone(Some(v), ..) => {
167                self.write_value_timestamptz(context, buff, v)
168            }
169            Value::Interval(Some(v), ..) => self.write_value_interval(context, buff, v),
170            Value::Uuid(Some(v), ..) => drop(write!(buff, "'{}'", v)),
171            Value::Array(Some(..), ..) | Value::List(Some(..), ..) => match value {
172                Value::Array(Some(v), ..) => {
173                    self.write_value_list(context, buff, Either::Left(v), value)
174                }
175                Value::List(Some(v), ..) => {
176                    self.write_value_list(context, buff, Either::Right(v), value)
177                }
178                _ => unreachable!(),
179            },
180            Value::Map(Some(v), ..) => self.write_value_map(context, buff, v),
181            Value::Struct(Some(v), ..) => self.write_value_struct(context, buff, v),
182            _ => {
183                log::error!("Cannot write {:?}", value);
184            }
185        };
186    }
187
188    fn write_value_none(&self, _context: &mut Context, buff: &mut String) {
189        buff.push_str("NULL");
190    }
191
192    fn write_value_bool(&self, _context: &mut Context, buff: &mut String, value: bool) {
193        buff.push_str(["false", "true"][value as usize]);
194    }
195
196    fn write_value_infinity(&self, context: &mut Context, buff: &mut String, negative: bool) {
197        let mut buffer = ryu::Buffer::new();
198        self.write_expression_binary_op(
199            context,
200            buff,
201            &BinaryOp {
202                op: BinaryOpType::Cast,
203                lhs: &Operand::LitStr(buffer.format(if negative {
204                    f64::NEG_INFINITY
205                } else {
206                    f64::INFINITY
207                })),
208                rhs: &Operand::Type(Value::Float64(None)),
209            },
210        );
211    }
212
213    fn write_value_string(&self, _context: &mut Context, buff: &mut String, value: &str) {
214        buff.push('\'');
215        let mut position = 0;
216        for (i, c) in value.char_indices() {
217            if c == '\'' {
218                buff.push_str(&value[position..i]);
219                buff.push_str("''");
220                position = i + 1;
221            } else if c == '\n' {
222                buff.push_str(&value[position..i]);
223                buff.push_str("\\n");
224                position = i + 1;
225            }
226        }
227        buff.push_str(&value[position..]);
228        buff.push('\'');
229    }
230
231    fn write_value_blob(&self, _context: &mut Context, buff: &mut String, value: &[u8]) {
232        buff.push('\'');
233        for b in value {
234            let _ = write!(buff, "\\x{:X}", b);
235        }
236        buff.push('\'');
237    }
238
239    fn write_value_date(
240        &self,
241        _context: &mut Context,
242        buff: &mut String,
243        value: &Date,
244        timestamp: bool,
245    ) {
246        let b = if timestamp { "" } else { "'" };
247        let _ = write!(
248            buff,
249            "{b}{:04}-{:02}-{:02}{b}",
250            value.year(),
251            value.month() as u8,
252            value.day()
253        );
254    }
255
256    fn write_value_time(
257        &self,
258        _context: &mut Context,
259        buff: &mut String,
260        value: &Time,
261        timestamp: bool,
262    ) {
263        let mut subsecond = value.nanosecond();
264        let mut width = 9;
265        while width > 1 && subsecond % 10 == 0 {
266            subsecond /= 10;
267            width -= 1;
268        }
269        let b = if timestamp { "" } else { "'" };
270        let _ = write!(
271            buff,
272            "{b}{:02}:{:02}:{:02}.{:0width$}{b}",
273            value.hour(),
274            value.minute(),
275            value.second(),
276            subsecond
277        );
278    }
279
280    fn write_value_timestamp(
281        &self,
282        context: &mut Context,
283        buff: &mut String,
284        value: &PrimitiveDateTime,
285    ) {
286        buff.push('\'');
287        self.write_value_date(context, buff, &value.date(), true);
288        buff.push('T');
289        self.write_value_time(context, buff, &value.time(), true);
290        buff.push('\'');
291    }
292
293    fn write_value_timestamptz(
294        &self,
295        context: &mut Context,
296        buff: &mut String,
297        value: &OffsetDateTime,
298    ) {
299        buff.push('\'');
300        self.write_value_date(context, buff, &value.date(), true);
301        buff.push('T');
302        self.write_value_time(context, buff, &value.time(), true);
303        let _ = write!(
304            buff,
305            "{:+02}:{:02}",
306            value.offset().whole_hours(),
307            value.offset().whole_minutes()
308        );
309        if value.date().year() <= 0 {
310            buff.push_str(" BC");
311        }
312        buff.push_str("'::TIMESTAMPTZ");
313    }
314
315    fn value_interval_units(&self) -> &[(&str, i128)] {
316        static UNITS: &[(&str, i128)] = &[
317            ("DAY", Interval::NANOS_IN_DAY),
318            ("HOUR", Interval::NANOS_IN_SEC * 3600),
319            ("MINUTE", Interval::NANOS_IN_SEC * 60),
320            ("SECOND", Interval::NANOS_IN_SEC),
321            ("MICROSECOND", 1_000),
322            ("NANOSECOND", 1),
323        ];
324        UNITS
325    }
326
327    fn write_value_interval(&self, _context: &mut Context, buff: &mut String, value: &Interval) {
328        buff.push_str("INTERVAL ");
329        if value.is_zero() {
330            buff.push_str("0 SECONDS");
331            return;
332        }
333        macro_rules! write_unit {
334            ($buff:ident, $val:expr, $unit:expr) => {
335                let _ = write!(
336                    $buff,
337                    "{} {}{}",
338                    $val,
339                    $unit,
340                    if $val != 1 { "S" } else { "" }
341                );
342            };
343        }
344        let months = value.months;
345        let nanos = value.nanos + value.days as i128 * Interval::NANOS_IN_DAY;
346        let multiple_units = nanos != 0 && value.months != 0;
347        if multiple_units {
348            buff.push('\'');
349        }
350        if months != 0 {
351            if months % 12 == 0 {
352                write_unit!(buff, months / 12, "YEAR");
353            } else {
354                write_unit!(buff, months, "MONTH");
355            }
356        }
357        for &(name, factor) in self.value_interval_units() {
358            if nanos % factor == 0 {
359                let value = nanos / factor;
360                if value != 0 {
361                    if months != 0 {
362                        buff.push(' ');
363                    }
364                    write_unit!(buff, value, name);
365                    break;
366                }
367            }
368        }
369        if multiple_units {
370            buff.push('\'');
371        }
372    }
373
374    fn write_value_list<'a>(
375        &self,
376        context: &mut Context,
377        buff: &mut String,
378        value: Either<&Box<[Value]>, &Vec<Value>>,
379        _ty: &Value,
380    ) {
381        buff.push('[');
382        separated_by(
383            buff,
384            match value {
385                Either::Left(v) => v.iter(),
386                Either::Right(v) => v.iter(),
387            },
388            |buff, v| {
389                self.write_value(context, buff, v);
390            },
391            ",",
392        );
393        buff.push(']');
394    }
395
396    fn write_value_map(
397        &self,
398        context: &mut Context,
399        buff: &mut String,
400        value: &HashMap<Value, Value>,
401    ) {
402        buff.push('{');
403        separated_by(
404            buff,
405            value,
406            |buff, (k, v)| {
407                self.write_value(context, buff, k);
408                buff.push(':');
409                self.write_value(context, buff, v);
410            },
411            ",",
412        );
413        buff.push('}');
414    }
415
416    fn write_value_struct(
417        &self,
418        context: &mut Context,
419        buff: &mut String,
420        value: &Vec<(String, Value)>,
421    ) {
422        buff.push('{');
423        separated_by(
424            buff,
425            value,
426            |buff, (k, v)| {
427                self.write_value_string(context, buff, k);
428                buff.push(':');
429                self.write_value(context, buff, v);
430            },
431            ",",
432        );
433        buff.push('}');
434    }
435
436    fn expression_unary_op_precedence<'a>(&self, value: &UnaryOpType) -> i32 {
437        match value {
438            UnaryOpType::Negative => 1250,
439            UnaryOpType::Not => 250,
440        }
441    }
442
443    fn expression_binary_op_precedence<'a>(&self, value: &BinaryOpType) -> i32 {
444        match value {
445            BinaryOpType::Or => 100,
446            BinaryOpType::And => 200,
447            BinaryOpType::Equal => 300,
448            BinaryOpType::NotEqual => 300,
449            BinaryOpType::Less => 300,
450            BinaryOpType::Greater => 300,
451            BinaryOpType::LessEqual => 300,
452            BinaryOpType::GreaterEqual => 300,
453            BinaryOpType::Is => 400,
454            BinaryOpType::IsNot => 400,
455            BinaryOpType::Like => 400,
456            BinaryOpType::NotLike => 400,
457            BinaryOpType::Regexp => 400,
458            BinaryOpType::NotRegexp => 400,
459            BinaryOpType::Glob => 400,
460            BinaryOpType::NotGlob => 400,
461            BinaryOpType::BitwiseOr => 500,
462            BinaryOpType::BitwiseAnd => 600,
463            BinaryOpType::ShiftLeft => 700,
464            BinaryOpType::ShiftRight => 700,
465            BinaryOpType::Subtraction => 800,
466            BinaryOpType::Addition => 800,
467            BinaryOpType::Multiplication => 900,
468            BinaryOpType::Division => 900,
469            BinaryOpType::Remainder => 900,
470            BinaryOpType::Indexing => 1000,
471            BinaryOpType::Cast => 1100,
472            BinaryOpType::Alias => 1200,
473        }
474    }
475
476    fn write_expression_operand(&self, context: &mut Context, buff: &mut String, value: &Operand) {
477        match value {
478            Operand::LitBool(v) => self.write_value_bool(context, buff, *v),
479            Operand::LitFloat(v) => write_float!(self, context, buff, *v),
480            Operand::LitIdent(v) => drop(buff.push_str(v)),
481            Operand::LitField(v) => separated_by(buff, *v, |buff, v| buff.push_str(v), "."),
482            Operand::LitInt(v) => write_integer!(buff, *v),
483            Operand::LitStr(v) => self.write_value_string(context, buff, v),
484            Operand::LitArray(v) => {
485                buff.push('[');
486                separated_by(
487                    buff,
488                    *v,
489                    |buff, v| {
490                        v.write_query(self.as_dyn(), context, buff);
491                    },
492                    ", ",
493                );
494                buff.push(']');
495            }
496            Operand::Null => drop(buff.push_str("NULL")),
497            Operand::Type(v) => self.write_column_type(context, buff, v),
498            Operand::Variable(v) => self.write_value(context, buff, v),
499            Operand::Call(f, args) => {
500                buff.push_str(f);
501                buff.push('(');
502                separated_by(
503                    buff,
504                    *args,
505                    |buff, v| {
506                        v.write_query(self.as_dyn(), context, buff);
507                    },
508                    ",",
509                );
510                buff.push(')');
511            }
512            Operand::Asterisk => drop(buff.push('*')),
513            Operand::QuestionMark => self.write_expression_operand_question_mark(context, buff),
514        };
515    }
516
517    fn write_expression_operand_question_mark(&self, _context: &mut Context, buff: &mut String) {
518        buff.push('?');
519    }
520
521    fn write_expression_unary_op(
522        &self,
523        context: &mut Context,
524        buff: &mut String,
525        value: &UnaryOp<&dyn Expression>,
526    ) {
527        match value.op {
528            UnaryOpType::Negative => buff.push('-'),
529            UnaryOpType::Not => buff.push_str("NOT "),
530        };
531        possibly_parenthesized!(
532            buff,
533            value.arg.precedence(self.as_dyn()) <= self.expression_unary_op_precedence(&value.op),
534            value.arg.write_query(self.as_dyn(), context, buff)
535        );
536    }
537
538    fn write_expression_binary_op(
539        &self,
540        context: &mut Context,
541        buff: &mut String,
542        value: &BinaryOp<&dyn Expression, &dyn Expression>,
543    ) {
544        let (prefix, infix, suffix, lhs_parenthesized, rhs_parenthesized) = match value.op {
545            BinaryOpType::Indexing => ("", "[", "]", false, true),
546            BinaryOpType::Cast => ("CAST(", " AS ", ")", true, true),
547            BinaryOpType::Multiplication => ("", " * ", "", false, false),
548            BinaryOpType::Division => ("", " / ", "", false, false),
549            BinaryOpType::Remainder => ("", " % ", "", false, false),
550            BinaryOpType::Addition => ("", " + ", "", false, false),
551            BinaryOpType::Subtraction => ("", " - ", "", false, false),
552            BinaryOpType::ShiftLeft => ("", " << ", "", false, false),
553            BinaryOpType::ShiftRight => ("", " >> ", "", false, false),
554            BinaryOpType::BitwiseAnd => ("", " & ", "", false, false),
555            BinaryOpType::BitwiseOr => ("", " | ", "", false, false),
556            BinaryOpType::Is => ("", " Is ", "", false, false),
557            BinaryOpType::IsNot => ("", " IS NOT ", "", false, false),
558            BinaryOpType::Like => ("", " LIKE ", "", false, false),
559            BinaryOpType::NotLike => ("", " NOT LIKE ", "", false, false),
560            BinaryOpType::Regexp => ("", " REGEXP ", "", false, false),
561            BinaryOpType::NotRegexp => ("", " NOT REGEXP ", "", false, false),
562            BinaryOpType::Glob => ("", " GLOB ", "", false, false),
563            BinaryOpType::NotGlob => ("", " NOT GLOB ", "", false, false),
564            BinaryOpType::Equal => ("", " = ", "", false, false),
565            BinaryOpType::NotEqual => ("", " != ", "", false, false),
566            BinaryOpType::Less => ("", " < ", "", false, false),
567            BinaryOpType::LessEqual => ("", " <= ", "", false, false),
568            BinaryOpType::Greater => ("", " > ", "", false, false),
569            BinaryOpType::GreaterEqual => ("", " >= ", "", false, false),
570            BinaryOpType::And => ("", " AND ", "", false, false),
571            BinaryOpType::Or => ("", " OR ", "", false, false),
572            BinaryOpType::Alias => ("", " AS ", "", false, false),
573        };
574        let mut context = context.switch_fragment(if value.op == BinaryOpType::Cast {
575            Fragment::Casting
576        } else {
577            context.fragment
578        });
579        let precedence = self.expression_binary_op_precedence(&value.op);
580        buff.push_str(prefix);
581        possibly_parenthesized!(
582            buff,
583            !lhs_parenthesized && value.lhs.precedence(self.as_dyn()) < precedence,
584            value
585                .lhs
586                .write_query(self.as_dyn(), &mut context.current, buff)
587        );
588        buff.push_str(infix);
589        possibly_parenthesized!(
590            buff,
591            !rhs_parenthesized && value.rhs.precedence(self.as_dyn()) <= precedence,
592            value
593                .rhs
594                .write_query(self.as_dyn(), &mut context.current, buff)
595        );
596        buff.push_str(suffix);
597    }
598
599    fn write_expression_ordered(
600        &self,
601        context: &mut Context,
602        buff: &mut String,
603        value: &Ordered<&dyn Expression>,
604    ) {
605        value.expression.write_query(self.as_dyn(), context, buff);
606        if context.fragment == Fragment::SqlSelectOrderBy {
607            let _ = write!(
608                buff,
609                " {}",
610                match value.order {
611                    Order::ASC => "ASC",
612                    Order::DESC => "DESC",
613                }
614            );
615        }
616    }
617
618    fn write_join_type(&self, _context: &mut Context, buff: &mut String, join_type: &JoinType) {
619        buff.push_str(match &join_type {
620            JoinType::Default => "JOIN",
621            JoinType::Inner => "INNER JOIN",
622            JoinType::Outer => "OUTER JOIN",
623            JoinType::Left => "LEFT JOIN",
624            JoinType::Right => "RIGHT JOIN",
625            JoinType::Cross => "CROSS",
626            JoinType::Natural => "NATURAL JOIN",
627        });
628    }
629
630    fn write_join(
631        &self,
632        context: &mut Context,
633        buff: &mut String,
634        join: &Join<&dyn DataSet, &dyn DataSet, &dyn Expression>,
635    ) {
636        let mut context = context.switch_fragment(Fragment::SqlJoin);
637        context.current.qualify_columns = true;
638        join.lhs
639            .write_query(self.as_dyn(), &mut context.current, buff);
640        buff.push(' ');
641        self.write_join_type(&mut context.current, buff, &join.join);
642        buff.push(' ');
643        join.rhs
644            .write_query(self.as_dyn(), &mut context.current, buff);
645        if let Some(on) = &join.on {
646            buff.push_str(" ON ");
647            on.write_query(self.as_dyn(), &mut context.current, buff);
648        }
649    }
650
651    fn write_transaction_begin(&self, buff: &mut String) {
652        buff.push_str("BEGIN;");
653    }
654
655    fn write_transaction_commit(&self, buff: &mut String) {
656        buff.push_str("COMMIT;");
657    }
658
659    fn write_transaction_rollback(&self, buff: &mut String) {
660        buff.push_str("ROLLBACK;");
661    }
662
663    fn write_create_schema<E>(&self, buff: &mut String, if_not_exists: bool)
664    where
665        Self: Sized,
666        E: Entity,
667    {
668        buff.push_str("CREATE SCHEMA ");
669        let mut context = Context::new(Fragment::SqlCreateSchema, E::qualified_columns());
670        if if_not_exists {
671            buff.push_str("IF NOT EXISTS ");
672        }
673        self.write_identifier_quoted(&mut context, buff, E::table_ref().schema);
674        buff.push(';');
675    }
676
677    fn write_drop_schema<E>(&self, buff: &mut String, if_exists: bool)
678    where
679        Self: Sized,
680        E: Entity,
681    {
682        buff.push_str("DROP SCHEMA ");
683        let mut context = Context::new(Fragment::SqlDropSchema, E::qualified_columns());
684        if if_exists {
685            buff.push_str("IF EXISTS ");
686        }
687        self.write_identifier_quoted(&mut context, buff, E::table_ref().schema);
688        buff.push(';');
689    }
690
691    fn write_create_table<E>(&self, buff: &mut String, if_not_exists: bool)
692    where
693        Self: Sized,
694        E: Entity,
695    {
696        let mut context = Context::new(Fragment::SqlCreateTable, E::qualified_columns());
697        buff.push_str("CREATE TABLE ");
698        if if_not_exists {
699            buff.push_str("IF NOT EXISTS ");
700        }
701        self.write_table_ref(&mut context, buff, E::table_ref());
702        buff.push_str(" (\n");
703        separated_by(
704            buff,
705            E::columns(),
706            |buff, v| {
707                self.write_create_table_column_fragment(&mut context, buff, v);
708            },
709            ",\n",
710        );
711        let primary_key = E::primary_key_def();
712        if primary_key.len() > 1 {
713            buff.push_str(",\nPRIMARY KEY (");
714            separated_by(
715                buff,
716                primary_key,
717                |buff, v| {
718                    self.write_identifier_quoted(
719                        &mut context
720                            .switch_fragment(Fragment::SqlCreateTablePrimaryKey)
721                            .current,
722                        buff,
723                        v.name(),
724                    );
725                },
726                ", ",
727            );
728            buff.push(')');
729        }
730        for unique in E::unique_defs() {
731            if unique.len() > 1 {
732                buff.push_str(",\nUNIQUE (");
733                separated_by(
734                    buff,
735                    unique,
736                    |buff, v| {
737                        self.write_identifier_quoted(
738                            &mut context
739                                .switch_fragment(Fragment::SqlCreateTableUnique)
740                                .current,
741                            buff,
742                            v.name(),
743                        );
744                    },
745                    ", ",
746                );
747                buff.push(')');
748            }
749        }
750        buff.push_str(");");
751        self.write_column_comments::<E>(&mut context, buff);
752    }
753
754    fn write_column_comments<E>(&self, context: &mut Context, buff: &mut String)
755    where
756        Self: Sized,
757        E: Entity,
758    {
759        let mut context = context.switch_fragment(Fragment::SqlCommentOnColumn);
760        context.current.qualify_columns = true;
761        for c in E::columns().iter().filter(|c| !c.comment.is_empty()) {
762            buff.push_str("\nCOMMENT ON COLUMN ");
763            self.write_column_ref(&mut context.current, buff, c.into());
764            buff.push_str(" IS ");
765            self.write_value_string(&mut context.current, buff, c.comment);
766            buff.push(';');
767        }
768    }
769
770    fn write_create_table_column_fragment(
771        &self,
772        context: &mut Context,
773        buff: &mut String,
774        column: &ColumnDef,
775    ) where
776        Self: Sized,
777    {
778        self.write_identifier_quoted(context, buff, &column.name());
779        buff.push(' ');
780        if !column.column_type.is_empty() {
781            buff.push_str(&column.column_type);
782        } else {
783            SqlWriter::write_column_type(self, context, buff, &column.value);
784        }
785        if !column.nullable && column.primary_key == PrimaryKeyType::None {
786            buff.push_str(" NOT NULL");
787        }
788        if let Some(default) = &column.default {
789            buff.push_str(" DEFAULT ");
790            default.write_query(self.as_dyn(), context, buff);
791        }
792        if column.primary_key == PrimaryKeyType::PrimaryKey {
793            // Composite primary key will be printed elsewhere
794            buff.push_str(" PRIMARY KEY");
795        }
796        if column.unique && column.primary_key != PrimaryKeyType::PrimaryKey {
797            buff.push_str(" UNIQUE");
798        }
799        if let Some(references) = column.references {
800            buff.push_str(" REFERENCES ");
801            self.write_table_ref(context, buff, &references.table_ref());
802            buff.push('(');
803            self.write_column_ref(context, buff, &references);
804            buff.push(')');
805            if let Some(on_delete) = &column.on_delete {
806                buff.push_str(" ON DELETE ");
807                self.write_create_table_references_action(context, buff, on_delete);
808            }
809            if let Some(on_update) = &column.on_update {
810                buff.push_str(" ON UPDATE ");
811                self.write_create_table_references_action(context, buff, on_update);
812            }
813        }
814    }
815
816    fn write_create_table_references_action(
817        &self,
818        _context: &mut Context,
819        buff: &mut String,
820        action: &Action,
821    ) {
822        buff.push_str(match action {
823            Action::NoAction => "NO ACTION",
824            Action::Restrict => "RESTRICT",
825            Action::Cascade => "CASCADE",
826            Action::SetNull => "SET NULL",
827            Action::SetDefault => "SET DEFAULT",
828        });
829    }
830
831    fn write_drop_table<E>(&self, buff: &mut String, if_exists: bool)
832    where
833        Self: Sized,
834        E: Entity,
835    {
836        buff.push_str("DROP TABLE ");
837        let mut context = Context::new(Fragment::SqlDropTable, E::qualified_columns());
838        if if_exists {
839            buff.push_str("IF EXISTS ");
840        }
841        self.write_table_ref(&mut context, buff, E::table_ref());
842        buff.push(';');
843    }
844
845    fn write_select<Item, Cols, Data, Cond>(
846        &self,
847        buff: &mut String,
848        columns: Cols,
849        from: &Data,
850        condition: &Cond,
851        limit: Option<u32>,
852    ) where
853        Self: Sized,
854        Item: Expression,
855        Cols: IntoIterator<Item = Item> + Clone,
856        Data: DataSet,
857        Cond: Expression,
858    {
859        buff.push_str("SELECT ");
860        let mut has_order_by = false;
861        let mut context = Context::new(Fragment::SqlSelect, Data::qualified_columns());
862        separated_by(
863            buff,
864            columns.clone(),
865            |buff, col| {
866                col.write_query(self, &mut context, buff);
867                has_order_by = has_order_by || col.is_ordered();
868            },
869            ", ",
870        );
871        buff.push_str("\nFROM ");
872        from.write_query(
873            self,
874            &mut context.switch_fragment(Fragment::SqlSelectFrom).current,
875            buff,
876        );
877        buff.push_str("\nWHERE ");
878        condition.write_query(
879            self,
880            &mut context.switch_fragment(Fragment::SqlSelectWhere).current,
881            buff,
882        );
883        if has_order_by {
884            buff.push_str("\nORDER BY ");
885            for col in columns.into_iter().filter(Expression::is_ordered) {
886                col.write_query(
887                    self,
888                    &mut context.switch_fragment(Fragment::SqlSelectOrderBy).current,
889                    buff,
890                );
891            }
892        }
893        if let Some(limit) = limit {
894            let _ = write!(buff, "\nLIMIT {}", limit);
895        }
896        buff.push(';');
897    }
898
899    fn write_insert<'b, E, It>(&self, buff: &mut String, entities: It, update: bool)
900    where
901        Self: Sized,
902        E: Entity + 'b,
903        It: IntoIterator<Item = &'b E>,
904    {
905        let mut rows = entities.into_iter().map(Entity::row_filtered).peekable();
906        let Some(mut row) = rows.next() else {
907            return;
908        };
909        buff.push_str("INSERT INTO ");
910        let mut context = Context::new(Fragment::SqlInsertInto, E::qualified_columns());
911        self.write_table_ref(&mut context, buff, E::table_ref());
912        buff.push_str(" (");
913        let columns = E::columns().iter();
914        let single = rows.peek().is_none();
915        if single {
916            // Inserting a single row uses row_labeled to filter buff Passive::NotSet columns
917            separated_by(
918                buff,
919                row.iter(),
920                |buff, v| {
921                    self.write_identifier_quoted(&mut context, buff, v.0);
922                },
923                ", ",
924            );
925        } else {
926            // Inserting more rows will list all columns, Passive::NotSet columns will result in DEFAULT value
927            separated_by(
928                buff,
929                columns.clone(),
930                |buff, v| {
931                    self.write_identifier_quoted(&mut context, buff, v.name());
932                },
933                ", ",
934            );
935        };
936        buff.push_str(") VALUES\n");
937        let mut context = context.switch_fragment(Fragment::SqlInsertIntoValues);
938        let mut first_row = None;
939        let mut separate = false;
940        loop {
941            if separate {
942                buff.push_str(",\n");
943            }
944            buff.push('(');
945            let mut fields = row.iter();
946            let mut field = fields.next();
947            separated_by(
948                buff,
949                E::columns(),
950                |buff, col| {
951                    if Some(col.name()) == field.map(|v| v.0) {
952                        self.write_value(
953                            &mut context.current,
954                            buff,
955                            field
956                                .map(|v| &v.1)
957                                .expect(&format!("Column {} does not have a value", col.name())),
958                        );
959                        field = fields.next();
960                    } else if !single {
961                        buff.push_str("DEFAULT");
962                    }
963                },
964                ", ",
965            );
966            buff.push(')');
967            separate = true;
968            if first_row.is_none() {
969                first_row = row.into();
970            }
971            if let Some(next) = rows.next() {
972                row = next;
973            } else {
974                break;
975            };
976        }
977        let first_row = first_row
978            .expect("Should have at least one row")
979            .into_iter()
980            .map(|(v, _)| v);
981        if update {
982            self.write_insert_update_fragment::<E, _>(
983                &mut context.current,
984                buff,
985                if single {
986                    EitherIterator::Left(
987                        // If there is only one row to insert then list only the columns that appear
988                        columns.filter(|c| first_row.clone().find(|n| *n == c.name()).is_some()),
989                    )
990                } else {
991                    EitherIterator::Right(columns)
992                },
993            );
994        }
995        buff.push(';');
996    }
997
998    fn write_insert_update_fragment<'a, E, It>(
999        &self,
1000        context: &mut Context,
1001        buff: &mut String,
1002        columns: It,
1003    ) where
1004        Self: Sized,
1005        E: Entity,
1006        It: Iterator<Item = &'a ColumnDef>,
1007    {
1008        let pk = E::primary_key_def();
1009        if pk.len() == 0 {
1010            return;
1011        }
1012        buff.push_str("\nON CONFLICT");
1013        context.fragment = Fragment::SqlInsertIntoOnConflict;
1014        if pk.len() > 0 {
1015            buff.push_str(" (");
1016            separated_by(
1017                buff,
1018                pk,
1019                |buff, v| {
1020                    self.write_identifier_quoted(context, buff, v.name());
1021                },
1022                ", ",
1023            );
1024            buff.push(')');
1025        }
1026        buff.push_str(" DO UPDATE SET\n");
1027        separated_by(
1028            buff,
1029            columns.filter(|c| c.primary_key == PrimaryKeyType::None),
1030            |buff, v| {
1031                self.write_identifier_quoted(context, buff, v.name());
1032                buff.push_str(" = EXCLUDED.");
1033                self.write_identifier_quoted(context, buff, v.name());
1034            },
1035            ",\n",
1036        );
1037    }
1038
1039    fn write_delete<E, Expr>(&self, buff: &mut String, condition: &Expr)
1040    where
1041        Self: Sized,
1042        E: Entity,
1043        Expr: Expression,
1044    {
1045        buff.push_str("DELETE FROM ");
1046        let mut context = Context::new(Fragment::SqlDeleteFrom, E::qualified_columns());
1047        self.write_table_ref(&mut context, buff, E::table_ref());
1048        buff.push_str("\nWHERE ");
1049        condition.write_query(
1050            self,
1051            &mut context
1052                .switch_fragment(Fragment::SqlDeleteFromWhere)
1053                .current,
1054            buff,
1055        );
1056        buff.push(';');
1057    }
1058}
1059
1060pub struct GenericSqlWriter;
1061impl GenericSqlWriter {
1062    pub fn new() -> Self {
1063        Self {}
1064    }
1065}
1066impl SqlWriter for GenericSqlWriter {
1067    fn as_dyn(&self) -> &dyn SqlWriter {
1068        self
1069    }
1070}