zino_orm/
sqlite.rs

1use super::{query::QueryExt, DatabaseDriver, DatabaseRow, DecodeRow, EncodeColumn, Schema};
2use std::borrow::Cow;
3use zino_core::{
4    datetime::{Date, DateTime, Time},
5    error::Error,
6    extension::{JsonObjectExt, JsonValueExt},
7    model::{Column, Query, QueryOrder},
8    AvroValue, JsonValue, Map, Record, SharedString, Uuid,
9};
10
11#[cfg(feature = "orm-sqlx")]
12use sqlx::{Column as _, Row, TypeInfo, ValueRef};
13
14impl EncodeColumn<DatabaseDriver> for Column<'_> {
15    fn column_type(&self) -> &str {
16        if let Some(column_type) = self.extra().get_str("column_type") {
17            return column_type;
18        }
19        match self.type_name() {
20            "bool" => "BOOLEAN",
21            "u64" | "i64" | "usize" | "isize" | "Option<u64>" | "Option<i64>" | "u32" | "i32"
22            | "u16" | "i16" | "u8" | "i8" | "Option<u32>" | "Option<i32>" => "INTEGER",
23            "f64" | "f32" => "REAL",
24            "Date" | "NaiveDate" => "DATE",
25            "Time" | "NaiveTime" => "TIME",
26            "DateTime" | "NaiveDateTime" => "DATETIME",
27            "Vec<u8>" => "BLOB",
28            _ => "TEXT",
29        }
30    }
31
32    fn encode_value<'a>(&self, value: Option<&'a JsonValue>) -> Cow<'a, str> {
33        if let Some(value) = value {
34            match value {
35                JsonValue::Null => "NULL".into(),
36                JsonValue::Bool(b) => {
37                    let value = if *b { "TRUE" } else { "FALSE" };
38                    value.into()
39                }
40                JsonValue::Number(n) => n.to_string().into(),
41                JsonValue::String(s) => {
42                    if s.is_empty() {
43                        if let Some(value) = self.default_value() {
44                            self.format_value(value).into_owned().into()
45                        } else {
46                            "''".into()
47                        }
48                    } else if s == "null" {
49                        "NULL".into()
50                    } else if s == "not_null" {
51                        "NOT NULL".into()
52                    } else {
53                        self.format_value(s)
54                    }
55                }
56                JsonValue::Array(vec) => {
57                    let values = vec
58                        .iter()
59                        .map(|v| match v {
60                            JsonValue::String(v) => Query::escape_string(v),
61                            _ => self.encode_value(Some(v)).into_owned(),
62                        })
63                        .collect::<Vec<_>>();
64                    format!(r#"json_array({})"#, values.join(",")).into()
65                }
66                JsonValue::Object(_) => Query::escape_string(value).into(),
67            }
68        } else if self.default_value().is_some() {
69            "DEFAULT".into()
70        } else {
71            "NULL".into()
72        }
73    }
74
75    fn format_value<'a>(&self, value: &'a str) -> Cow<'a, str> {
76        match self.type_name() {
77            "bool" => {
78                let value = if value == "true" { "TRUE" } else { "FALSE" };
79                value.into()
80            }
81            "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
82            | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
83                if value.parse::<i64>().is_ok() {
84                    value.into()
85                } else {
86                    "NULL".into()
87                }
88            }
89            "f64" | "f32" => {
90                if value.parse::<f64>().is_ok() {
91                    value.into()
92                } else {
93                    "NULL".into()
94                }
95            }
96            "DateTime" | "NaiveDateTime" => match value {
97                "epoch" => "datetime(0, 'unixepoch')".into(),
98                "now" => "datetime('now', 'localtime')".into(),
99                "today" => "datetime('now', 'start of day')".into(),
100                "tomorrow" => "datetime('now', 'start of day', '+1 day')".into(),
101                "yesterday" => "datetime('now', 'start of day', '-1 day')".into(),
102                _ => Query::escape_string(value).into(),
103            },
104            "Date" | "NaiveDate" => match value {
105                "epoch" => "'1970-01-01'".into(),
106                "today" => "date('now', 'localtime')".into(),
107                "tomorrow" => "date('now', '+1 day')".into(),
108                "yesterday" => "date('now', '-1 day')".into(),
109                _ => Query::escape_string(value).into(),
110            },
111            "Time" | "NaiveTime" => match value {
112                "now" => "time('now', 'localtime')".into(),
113                "midnight" => "'00:00:00'".into(),
114                _ => Query::escape_string(value).into(),
115            },
116            "Vec<u8>" => format!("'{value}'").into(),
117            "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
118                if value.contains(',') {
119                    let values = value
120                        .split(',')
121                        .map(Query::escape_string)
122                        .collect::<Vec<_>>();
123                    format!(r#"json_array({})"#, values.join(",")).into()
124                } else {
125                    let value = Query::escape_string(value);
126                    format!(r#"json_array({value})"#).into()
127                }
128            }
129            _ => Query::escape_string(value).into(),
130        }
131    }
132
133    fn format_filter(&self, field: &str, value: &JsonValue) -> String {
134        let type_name = self.type_name();
135        let field = Query::format_field(field);
136        if let Some(filter) = value.as_object() {
137            let mut conditions = Vec::with_capacity(filter.len());
138            if type_name == "Map" {
139                for (key, value) in filter {
140                    let key = Query::escape_string(key);
141                    let value = self.encode_value(Some(value));
142                    let condition =
143                        format!(r#"json_tree.key = {key} AND json_tree.value = {value}"#);
144                    conditions.push(condition);
145                }
146                return Query::join_conditions(conditions, " OR ");
147            } else {
148                for (name, value) in filter {
149                    let name = name.as_str();
150                    let operator = match name {
151                        "$eq" => "=",
152                        "$ne" => "<>",
153                        "$lt" => "<",
154                        "$le" => "<=",
155                        "$gt" => ">",
156                        "$ge" => ">=",
157                        "$in" => "IN",
158                        "$nin" => "NOT IN",
159                        "$betw" => "BETWEEN",
160                        "$like" => "LIKE",
161                        "$ilike" => "ILIKE",
162                        "$rlike" => "REGEXP",
163                        "$is" => "IS",
164                        "$size" => "json_array_length",
165                        _ => {
166                            if cfg!(debug_assertions) && name.starts_with('$') {
167                                tracing::warn!("unsupported operator `{name}` for SQLite");
168                            }
169                            name
170                        }
171                    };
172                    if let Some(subquery) = value.as_object().and_then(|m| m.get_str("$subquery")) {
173                        let condition = format!(r#"{field} {operator} {subquery}"#);
174                        conditions.push(condition);
175                    } else if operator == "IN" || operator == "NOT IN" {
176                        if let Some(values) = value.as_array() {
177                            if values.is_empty() {
178                                let condition = if operator == "IN" { "FALSE" } else { "TRUE" };
179                                conditions.push(condition.to_owned());
180                            } else {
181                                let value = values
182                                    .iter()
183                                    .map(|v| self.encode_value(Some(v)))
184                                    .collect::<Vec<_>>()
185                                    .join(", ");
186                                let condition = format!(r#"{field} {operator} ({value})"#);
187                                conditions.push(condition);
188                            }
189                        }
190                    } else if operator == "BETWEEN" {
191                        if let Some(values) = value.as_array() {
192                            if let [min_value, max_value] = values.as_slice() {
193                                let min_value = self.encode_value(Some(min_value));
194                                let max_value = self.encode_value(Some(max_value));
195                                let condition =
196                                    format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
197                                conditions.push(condition);
198                            }
199                        } else if let Some(values) = value.parse_str_array() {
200                            if let [min_value, max_value] = values.as_slice() {
201                                let min_value = self.format_value(min_value);
202                                let max_value = self.format_value(max_value);
203                                let condition =
204                                    format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
205                                conditions.push(condition);
206                            }
207                        }
208                    } else if operator == "ILIKE" {
209                        let value = self.encode_value(Some(value));
210                        let condition = format!(r#"LOWER({field}) LIKE LOWER({value})"#);
211                        conditions.push(condition);
212                    } else if operator == "json_array_length" {
213                        if let Some(Ok(length)) = value.parse_usize() {
214                            let condition = format!(r#"json_array_length({field}) = {length}"#);
215                            conditions.push(condition);
216                        }
217                    } else {
218                        let value = self.encode_value(Some(value));
219                        let condition = format!(r#"{field} {operator} {value}"#);
220                        conditions.push(condition);
221                    }
222                }
223                if conditions.is_empty() {
224                    return String::new();
225                } else {
226                    return conditions.join(" AND ");
227                }
228            }
229        } else if value.is_null() {
230            return format!(r#"{field} IS NULL"#);
231        } else if self.has_attribute("exact_filter") {
232            let value = self.encode_value(Some(value));
233            return format!(r#"{field} = {value}"#);
234        } else if let Some(value) = value.as_str() {
235            if value == "null" {
236                return format!(r#"{field} IS NULL"#);
237            } else if value == "not_null" {
238                return format!(r#"{field} IS NOT NULL"#);
239            } else if let Some((min_value, max_value)) =
240                value.split_once(',').filter(|_| self.is_datetime_type())
241            {
242                let min_value = self.format_value(min_value);
243                let max_value = self.format_value(max_value);
244                return format!(r#"{field} >= {min_value} AND {field} < {max_value}"#);
245            }
246        }
247
248        match type_name {
249            "bool" => {
250                let value = self.encode_value(Some(value));
251                if value == "TRUE" {
252                    format!(r#"{field} IS TRUE"#)
253                } else {
254                    format!(r#"{field} IS NOT TRUE"#)
255                }
256            }
257            "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
258            | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
259                if let Some(value) = value.as_str() {
260                    if value == "nonzero" {
261                        format!(r#"{field} <> 0"#)
262                    } else if value.contains(',') {
263                        let value = value.split(',').collect::<Vec<_>>().join(",");
264                        format!(r#"{field} IN ({value})"#)
265                    } else {
266                        let value = self.format_value(value);
267                        format!(r#"{field} = {value}"#)
268                    }
269                } else {
270                    let value = self.encode_value(Some(value));
271                    format!(r#"{field} = {value}"#)
272                }
273            }
274            "String" | "Option<String>" => {
275                if let Some(value) = value.as_str() {
276                    if value == "empty" {
277                        // either NULL or empty
278                        format!(r#"({field} = '') IS NOT FALSE"#)
279                    } else if value == "nonempty" {
280                        format!(r#"({field} = '') IS FALSE"#)
281                    } else if self.fuzzy_search() {
282                        if value.contains(',') {
283                            let exprs = value
284                                .split(',')
285                                .map(|s| {
286                                    let value = Query::escape_string(format!("%{s}%"));
287                                    format!(r#"{field} LIKE {value}"#)
288                                })
289                                .collect::<Vec<_>>();
290                            format!("({})", exprs.join(" OR "))
291                        } else {
292                            let value = Query::escape_string(format!("%{value}%"));
293                            format!(r#"{field} LIKE {value}"#)
294                        }
295                    } else if value.contains(',') {
296                        let value = value
297                            .split(',')
298                            .map(Query::escape_string)
299                            .collect::<Vec<_>>()
300                            .join(", ");
301                        format!(r#"{field} IN ({value})"#)
302                    } else {
303                        let value = Query::escape_string(value);
304                        format!(r#"{field} = {value}"#)
305                    }
306                } else {
307                    let value = self.encode_value(Some(value));
308                    format!(r#"{field} = {value}"#)
309                }
310            }
311            "DateTime" | "NaiveDateTime" => {
312                if let Some(value) = value.as_str() {
313                    let length = value.len();
314                    let value = self.format_value(value);
315                    match length {
316                        4 => format!(r#"strftime('%Y', {field}) = {value}"#),
317                        7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
318                        10 => format!(r#"strftime('%Y-%m-%d', {field}) = {value}"#),
319                        _ => format!(r#"{field} = {value}"#),
320                    }
321                } else {
322                    let value = self.encode_value(Some(value));
323                    format!(r#"{field} = {value}"#)
324                }
325            }
326            "Date" | "NaiveDate" => {
327                if let Some(value) = value.as_str() {
328                    let length = value.len();
329                    let value = self.format_value(value);
330                    match length {
331                        4 => format!(r#"strftime('%Y', {field}) = {value}"#),
332                        7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
333                        _ => format!(r#"{field} = {value}"#),
334                    }
335                } else {
336                    let value = self.encode_value(Some(value));
337                    format!(r#"{field} = {value}"#)
338                }
339            }
340            "Time" | "NaiveTime" => {
341                if let Some(value) = value.as_str() {
342                    let length = value.len();
343                    let value = self.format_value(value);
344                    match length {
345                        2 => format!(r#"strftime('%H', {field}) = {value}"#),
346                        5 => format!(r#"strftime('%H:%M', {field}) = {value}"#),
347                        8 => format!(r#"strftime('%H:%M:%S', {field}) = {value}"#),
348                        _ => format!(r#"{field} = {value}"#),
349                    }
350                } else {
351                    let value = self.encode_value(Some(value));
352                    format!(r#"{field} = {value}"#)
353                }
354            }
355            "Uuid" | "Option<Uuid>" => {
356                if let Some(value) = value.as_str() {
357                    if value.contains(',') {
358                        let value = value
359                            .split(',')
360                            .map(Query::escape_string)
361                            .collect::<Vec<_>>()
362                            .join(", ");
363                        format!(r#"{field} IN ({value})"#)
364                    } else {
365                        let value = Query::escape_string(value);
366                        format!(r#"{field} = {value}"#)
367                    }
368                } else {
369                    let value = self.encode_value(Some(value));
370                    format!(r#"{field} = {value}"#)
371                }
372            }
373            "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
374                if let Some(value) = value.as_str() {
375                    if value == "nonempty" {
376                        format!(r#"json_array_length({field}) > 0"#)
377                    } else {
378                        let exprs = value
379                            .split(',')
380                            .map(|v| {
381                                let value = Query::escape_string(v);
382                                format!(r#"json_each.value = {value}"#)
383                            })
384                            .collect::<Vec<_>>();
385                        format!("({})", exprs.join(" OR "))
386                    }
387                } else if let Some(values) = value.as_array() {
388                    let exprs = values
389                        .iter()
390                        .map(|v| {
391                            let value = self.encode_value(Some(v));
392                            format!(r#"json_each.value = {value}"#)
393                        })
394                        .collect::<Vec<_>>();
395                    format!("({})", exprs.join(" OR "))
396                } else {
397                    let value = self.encode_value(Some(value));
398                    format!(r#"{field} = {value}"#)
399                }
400            }
401            _ => {
402                let value = self.encode_value(Some(value));
403                format!(r#"{field} = {value}"#)
404            }
405        }
406    }
407}
408
409#[cfg(feature = "orm-sqlx")]
410impl DecodeRow<DatabaseRow> for Map {
411    type Error = Error;
412
413    fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
414        let mut map = Map::new();
415        for col in row.columns() {
416            let field = col.name();
417            let index = col.ordinal();
418            let raw_value = row.try_get_raw(index)?;
419            let value = if raw_value.is_null() {
420                JsonValue::Null
421            } else {
422                use super::decode::decode_raw;
423                match col.type_info().name() {
424                    "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
425                    "INTEGER" | "BIGINT" => decode_raw::<i64>(field, raw_value)?.into(),
426                    "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
427                    "TEXT" => {
428                        let value = decode_raw::<String>(field, raw_value)?;
429                        if value.starts_with('[') && value.ends_with(']')
430                            || value.starts_with('{') && value.ends_with('}')
431                        {
432                            serde_json::from_str(&value)?
433                        } else {
434                            value.into()
435                        }
436                    }
437                    "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.into(),
438                    "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
439                    "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
440                    "BLOB" => {
441                        let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
442                        if bytes.starts_with(b"[") && bytes.ends_with(b"]")
443                            || bytes.starts_with(b"{") && bytes.ends_with(b"}")
444                        {
445                            serde_json::from_slice::<JsonValue>(&bytes)
446                                .unwrap_or_else(|_| bytes.into())
447                        } else if bytes.len() == 16 {
448                            if let Ok(value) = Uuid::from_slice(&bytes) {
449                                value.to_string().into()
450                            } else {
451                                bytes.into()
452                            }
453                        } else {
454                            bytes.into()
455                        }
456                    }
457                    _ => decode_raw::<String>(field, raw_value)?.into(),
458                }
459            };
460            if !value.is_ignorable() {
461                map.insert(field.to_owned(), value);
462            }
463        }
464        Ok(map)
465    }
466}
467
468#[cfg(feature = "orm-sqlx")]
469impl DecodeRow<DatabaseRow> for Record {
470    type Error = Error;
471
472    fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
473        let columns = row.columns();
474        let mut record = Record::with_capacity(columns.len());
475        for col in columns {
476            let field = col.name();
477            let index = col.ordinal();
478            let raw_value = row.try_get_raw(index)?;
479            let value = if raw_value.is_null() {
480                AvroValue::Null
481            } else {
482                use super::decode::decode_raw;
483                match col.type_info().name() {
484                    "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
485                    "INTEGER" | "BIGINT" => decode_raw::<i64>(field, raw_value)?.into(),
486                    "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
487                    "TEXT" => {
488                        let value = decode_raw::<String>(field, raw_value)?;
489                        if value.starts_with('[') && value.ends_with(']')
490                            || value.starts_with('{') && value.ends_with('}')
491                        {
492                            serde_json::from_str::<JsonValue>(&value)?.into()
493                        } else {
494                            value.into()
495                        }
496                    }
497                    "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.to_string().into(),
498                    "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
499                    "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
500                    "BLOB" => {
501                        let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
502                        if bytes.starts_with(b"[") && bytes.ends_with(b"]")
503                            || bytes.starts_with(b"{") && bytes.ends_with(b"}")
504                        {
505                            serde_json::from_slice::<JsonValue>(&bytes)
506                                .map(|value| value.into())
507                                .unwrap_or_else(|_| bytes.into())
508                        } else if bytes.len() == 16 {
509                            if let Ok(value) = Uuid::from_slice(&bytes) {
510                                value.into()
511                            } else {
512                                bytes.into()
513                            }
514                        } else {
515                            bytes.into()
516                        }
517                    }
518                    _ => decode_raw::<String>(field, raw_value)?.into(),
519                }
520            };
521            record.push((field.to_owned(), value));
522        }
523        Ok(record)
524    }
525}
526
527#[cfg(feature = "orm-sqlx")]
528impl QueryExt<DatabaseDriver> for Query {
529    type QueryResult = sqlx::sqlite::SqliteQueryResult;
530
531    #[inline]
532    fn parse_query_result(query_result: Self::QueryResult) -> (Option<i64>, u64) {
533        let last_insert_id = query_result.last_insert_rowid();
534        let rows_affected = query_result.rows_affected();
535        (Some(last_insert_id), rows_affected)
536    }
537
538    #[inline]
539    fn query_fields(&self) -> &[String] {
540        self.fields()
541    }
542
543    #[inline]
544    fn query_filters(&self) -> &Map {
545        self.filters()
546    }
547
548    #[inline]
549    fn query_order(&self) -> &[QueryOrder] {
550        self.sort_order()
551    }
552
553    #[inline]
554    fn query_offset(&self) -> usize {
555        self.offset()
556    }
557
558    #[inline]
559    fn query_limit(&self) -> usize {
560        self.limit()
561    }
562
563    #[inline]
564    fn placeholder(_n: usize) -> SharedString {
565        "?".into()
566    }
567
568    #[inline]
569    fn prepare_query<'a>(
570        query: &'a str,
571        params: Option<&'a Map>,
572    ) -> (Cow<'a, str>, Vec<&'a JsonValue>) {
573        crate::query::prepare_sql_query(query, params, '?')
574    }
575
576    fn format_field(field: &str) -> Cow<'_, str> {
577        if field.contains('`') {
578            field.into()
579        } else if field.contains('.') {
580            field
581                .split('.')
582                .map(|s| ["`", s, "`"].concat())
583                .collect::<Vec<_>>()
584                .join(".")
585                .into()
586        } else {
587            ["`", field, "`"].concat().into()
588        }
589    }
590
591    fn format_table_fields<M: Schema>(&self) -> Cow<'_, str> {
592        let model_name = M::model_name();
593        let fields = self.query_fields();
594        if fields.is_empty() {
595            "*".into()
596        } else {
597            fields
598                .iter()
599                .map(|field| {
600                    if let Some((alias, expr)) = field.split_once(':') {
601                        let alias = Self::format_field(alias.trim());
602                        format!(r#"{expr} AS {alias}"#)
603                    } else if field.contains('.') {
604                        field
605                            .split('.')
606                            .map(|s| ["`", s, "`"].concat())
607                            .collect::<Vec<_>>()
608                            .join(".")
609                    } else {
610                        format!(r#"`{model_name}`.`{field}`"#)
611                    }
612                })
613                .collect::<Vec<_>>()
614                .join(", ")
615                .into()
616        }
617    }
618
619    fn format_table_name<M: Schema>(&self) -> String {
620        let table_name = M::table_name();
621        let model_name = M::model_name();
622        let filters = self.query_filters();
623        let mut virtual_tables = Vec::new();
624        for col in M::columns() {
625            let col_name = col.name();
626            if filters.contains_key(col_name) {
627                match col.type_name() {
628                    "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>"
629                    | "Vec<i32>" => {
630                        let virtual_table = format!("json_each(`{model_name}`.`{col_name}`)");
631                        virtual_tables.push(virtual_table);
632                    }
633                    "Map" => {
634                        let virtual_table = format!("json_tree(`{model_name}`.`{col_name}`)");
635                        virtual_tables.push(virtual_table);
636                    }
637                    _ => (),
638                }
639            }
640        }
641
642        let table_name = if table_name.contains('.') {
643            table_name
644                .split('.')
645                .map(|s| ["`", s, "`"].concat())
646                .collect::<Vec<_>>()
647                .join(".")
648        } else {
649            ["`", table_name, "`"].concat()
650        };
651        if virtual_tables.is_empty() {
652            format!(r#"{table_name} AS `{model_name}`"#)
653        } else {
654            format!(
655                r#"{table_name} AS `{model_name}`, {}"#,
656                virtual_tables.join(", ")
657            )
658        }
659    }
660
661    fn table_name_escaped<M: Schema>() -> String {
662        let table_name = M::table_name();
663        if table_name.contains('.') {
664            table_name
665                .split('.')
666                .map(|s| ["`", s, "`"].concat())
667                .collect::<Vec<_>>()
668                .join(".")
669        } else {
670            ["`", table_name, "`"].concat()
671        }
672    }
673
674    fn parse_text_search(filter: &Map) -> Option<String> {
675        let fields = filter.parse_str_array("$fields")?;
676        filter.parse_string("$search").map(|search| {
677            let fields = fields.join(", ");
678            let search = Query::escape_string(search.as_ref());
679            format!("{fields} MATCH {search}")
680        })
681    }
682}