1use super::{DatabaseDriver, DatabaseRow, DecodeRow, EncodeColumn, Schema, query::QueryExt};
2use std::borrow::Cow;
3use zino_core::{
4 AvroValue, JsonValue, Map, Record, SharedString, Uuid,
5 datetime::{Date, DateTime, Time},
6 error::Error,
7 extension::{JsonObjectExt, JsonValueExt},
8 model::{Column, Query, QueryOrder},
9};
10
11#[cfg(feature = "orm-sqlx")]
12use sqlx::{Column as _, Row, TypeInfo, ValueRef};
13
14impl EncodeColumn<DatabaseDriver> for Column<'_> {
15 fn column_type(&self) -> &str {
16 if let Some(column_type) = self.extra().get_str("column_type") {
17 return column_type;
18 }
19 match self.type_name() {
20 "bool" => "BOOLEAN",
21 "u64" | "i64" | "usize" | "isize" | "Option<u64>" | "Option<i64>" | "u32" | "i32"
22 | "u16" | "i16" | "u8" | "i8" | "Option<u32>" | "Option<i32>" => "INTEGER",
23 "f64" | "f32" => "REAL",
24 "Date" | "NaiveDate" => "DATE",
25 "Time" | "NaiveTime" => "TIME",
26 "DateTime" | "NaiveDateTime" => "DATETIME",
27 "Vec<u8>" => "BLOB",
28 _ => "TEXT",
29 }
30 }
31
32 fn encode_value<'a>(&self, value: Option<&'a JsonValue>) -> Cow<'a, str> {
33 if let Some(value) = value {
34 match value {
35 JsonValue::Null => "NULL".into(),
36 JsonValue::Bool(b) => {
37 let value = if *b { "TRUE" } else { "FALSE" };
38 value.into()
39 }
40 JsonValue::Number(n) => n.to_string().into(),
41 JsonValue::String(s) => {
42 if s.is_empty() {
43 if let Some(value) = self.default_value() {
44 self.format_value(value).into_owned().into()
45 } else {
46 "''".into()
47 }
48 } else if s == "null" {
49 "NULL".into()
50 } else if s == "not_null" {
51 "NOT NULL".into()
52 } else {
53 self.format_value(s)
54 }
55 }
56 JsonValue::Array(vec) => {
57 let values = vec
58 .iter()
59 .map(|v| match v {
60 JsonValue::String(v) => Query::escape_string(v),
61 _ => self.encode_value(Some(v)).into_owned(),
62 })
63 .collect::<Vec<_>>();
64 format!(r#"json_array({})"#, values.join(",")).into()
65 }
66 JsonValue::Object(_) => Query::escape_string(value).into(),
67 }
68 } else if self.default_value().is_some() {
69 "DEFAULT".into()
70 } else {
71 "NULL".into()
72 }
73 }
74
75 fn format_value<'a>(&self, value: &'a str) -> Cow<'a, str> {
76 match self.type_name() {
77 "bool" => {
78 let value = if value == "true" { "TRUE" } else { "FALSE" };
79 value.into()
80 }
81 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
82 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
83 if value.parse::<i64>().is_ok() {
84 value.into()
85 } else {
86 "NULL".into()
87 }
88 }
89 "f64" | "f32" => {
90 if value.parse::<f64>().is_ok() {
91 value.into()
92 } else {
93 "NULL".into()
94 }
95 }
96 "DateTime" | "NaiveDateTime" => match value {
97 "epoch" => "datetime(0, 'unixepoch')".into(),
98 "now" => "datetime('now', 'localtime')".into(),
99 "today" => "datetime('now', 'start of day')".into(),
100 "tomorrow" => "datetime('now', 'start of day', '+1 day')".into(),
101 "yesterday" => "datetime('now', 'start of day', '-1 day')".into(),
102 _ => Query::escape_string(value).into(),
103 },
104 "Date" | "NaiveDate" => match value {
105 "epoch" => "'1970-01-01'".into(),
106 "today" => "date('now', 'localtime')".into(),
107 "tomorrow" => "date('now', '+1 day')".into(),
108 "yesterday" => "date('now', '-1 day')".into(),
109 _ => Query::escape_string(value).into(),
110 },
111 "Time" | "NaiveTime" => match value {
112 "now" => "time('now', 'localtime')".into(),
113 "midnight" => "'00:00:00'".into(),
114 _ => Query::escape_string(value).into(),
115 },
116 "Vec<u8>" => format!("'{value}'").into(),
117 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>"
118 | "Vec<f64>" | "Vec<f32>" => {
119 if value.contains(',') {
120 let values = value
121 .split(',')
122 .map(Query::escape_string)
123 .collect::<Vec<_>>();
124 format!(r#"json_array({})"#, values.join(",")).into()
125 } else {
126 let value = Query::escape_string(value);
127 format!(r#"json_array({value})"#).into()
128 }
129 }
130 _ => Query::escape_string(value).into(),
131 }
132 }
133
134 fn format_filter(&self, field: &str, value: &JsonValue) -> String {
135 let type_name = self.type_name();
136 let field = Query::format_field(field);
137 if let Some(filter) = value.as_object() {
138 let mut conditions = Vec::with_capacity(filter.len());
139 if type_name == "Map" {
140 for (key, value) in filter {
141 let key = Query::escape_string(key);
142 let value = self.encode_value(Some(value));
143 let condition =
144 format!(r#"json_tree.key = {key} AND json_tree.value = {value}"#);
145 conditions.push(condition);
146 }
147 return Query::join_conditions(conditions, " OR ");
148 } else {
149 for (name, value) in filter {
150 let name = name.as_str();
151 let operator = match name {
152 "$eq" => "=",
153 "$ne" => "<>",
154 "$lt" => "<",
155 "$le" => "<=",
156 "$gt" => ">",
157 "$ge" => ">=",
158 "$in" => "IN",
159 "$nin" => "NOT IN",
160 "$betw" => "BETWEEN",
161 "$like" => "LIKE",
162 "$ilike" => "ILIKE",
163 "$rlike" => "REGEXP",
164 "$is" => "IS",
165 "$size" => "json_array_length",
166 _ => {
167 if cfg!(debug_assertions) && name.starts_with('$') {
168 tracing::warn!("unsupported operator `{name}` for SQLite");
169 }
170 name
171 }
172 };
173 if let Some(subquery) = value.as_object().and_then(|m| m.get_str("$subquery")) {
174 let condition = format!(r#"{field} {operator} {subquery}"#);
175 conditions.push(condition);
176 } else if operator == "IN" || operator == "NOT IN" {
177 if let Some(values) = value.as_array() {
178 if values.is_empty() {
179 let condition = if operator == "IN" { "FALSE" } else { "TRUE" };
180 conditions.push(condition.to_owned());
181 } else {
182 let value = values
183 .iter()
184 .map(|v| self.encode_value(Some(v)))
185 .collect::<Vec<_>>()
186 .join(", ");
187 let condition = format!(r#"{field} {operator} ({value})"#);
188 conditions.push(condition);
189 }
190 }
191 } else if operator == "BETWEEN" {
192 if let Some(values) = value.as_array() {
193 if let [min_value, max_value] = values.as_slice() {
194 let min_value = self.encode_value(Some(min_value));
195 let max_value = self.encode_value(Some(max_value));
196 let condition =
197 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
198 conditions.push(condition);
199 }
200 } else if let Some(values) = value.parse_str_array()
201 && let [min_value, max_value] = values.as_slice()
202 {
203 let min_value = self.format_value(min_value);
204 let max_value = self.format_value(max_value);
205 let condition =
206 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
207 conditions.push(condition);
208 }
209 } else if operator == "ILIKE" {
210 let value = self.encode_value(Some(value));
211 let condition = format!(r#"LOWER({field}) LIKE LOWER({value})"#);
212 conditions.push(condition);
213 } else if operator == "json_array_length" {
214 if let Some(Ok(length)) = value.parse_usize() {
215 let condition = format!(r#"json_array_length({field}) = {length}"#);
216 conditions.push(condition);
217 }
218 } else {
219 let value = self.encode_value(Some(value));
220 let condition = format!(r#"{field} {operator} {value}"#);
221 conditions.push(condition);
222 }
223 }
224 if conditions.is_empty() {
225 return String::new();
226 } else {
227 return conditions.join(" AND ");
228 }
229 }
230 } else if value.is_null() {
231 return format!(r#"{field} IS NULL"#);
232 } else if self.has_attribute("exact_filter") {
233 let value = self.encode_value(Some(value));
234 return format!(r#"{field} = {value}"#);
235 } else if let Some(value) = value.as_str() {
236 if value == "null" {
237 return format!(r#"{field} IS NULL"#);
238 } else if value == "not_null" {
239 return format!(r#"{field} IS NOT NULL"#);
240 } else if let Some((min_value, max_value)) =
241 value.split_once(',').filter(|_| self.is_temporal_type())
242 {
243 let min_value = self.format_value(min_value);
244 let max_value = self.format_value(max_value);
245 return format!(r#"{field} >= {min_value} AND {field} < {max_value}"#);
246 }
247 }
248
249 match type_name {
250 "bool" => {
251 let value = self.encode_value(Some(value));
252 if value == "TRUE" {
253 format!(r#"{field} IS TRUE"#)
254 } else {
255 format!(r#"{field} IS NOT TRUE"#)
256 }
257 }
258 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
259 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
260 if let Some(value) = value.as_str() {
261 if value == "nonzero" {
262 format!(r#"{field} <> 0"#)
263 } else if value.contains(',') {
264 let value = value.split(',').collect::<Vec<_>>().join(",");
265 format!(r#"{field} IN ({value})"#)
266 } else {
267 let value = self.format_value(value);
268 format!(r#"{field} = {value}"#)
269 }
270 } else {
271 let value = self.encode_value(Some(value));
272 format!(r#"{field} = {value}"#)
273 }
274 }
275 "String" | "Option<String>" => {
276 if let Some(value) = value.as_str() {
277 if value == "empty" {
278 format!(r#"({field} = '') IS NOT FALSE"#)
280 } else if value == "nonempty" {
281 format!(r#"({field} = '') IS FALSE"#)
282 } else if self.fuzzy_search() {
283 if value.contains(',') {
284 let exprs = value
285 .split(',')
286 .map(|s| {
287 let value = Query::escape_string(format!("%{s}%"));
288 format!(r#"{field} LIKE {value}"#)
289 })
290 .collect::<Vec<_>>();
291 format!("({})", exprs.join(" OR "))
292 } else {
293 let value = Query::escape_string(format!("%{value}%"));
294 format!(r#"{field} LIKE {value}"#)
295 }
296 } else if value.contains(',') {
297 let value = value
298 .split(',')
299 .map(Query::escape_string)
300 .collect::<Vec<_>>()
301 .join(", ");
302 format!(r#"{field} IN ({value})"#)
303 } else {
304 let value = Query::escape_string(value);
305 format!(r#"{field} = {value}"#)
306 }
307 } else {
308 let value = self.encode_value(Some(value));
309 format!(r#"{field} = {value}"#)
310 }
311 }
312 "DateTime" | "NaiveDateTime" => {
313 if let Some(value) = value.as_str() {
314 let length = value.len();
315 let value = self.format_value(value);
316 match length {
317 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
318 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
319 10 => format!(r#"strftime('%Y-%m-%d', {field}) = {value}"#),
320 _ => format!(r#"{field} = {value}"#),
321 }
322 } else {
323 let value = self.encode_value(Some(value));
324 format!(r#"{field} = {value}"#)
325 }
326 }
327 "Date" | "NaiveDate" => {
328 if let Some(value) = value.as_str() {
329 let length = value.len();
330 let value = self.format_value(value);
331 match length {
332 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
333 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
334 _ => format!(r#"{field} = {value}"#),
335 }
336 } else {
337 let value = self.encode_value(Some(value));
338 format!(r#"{field} = {value}"#)
339 }
340 }
341 "Time" | "NaiveTime" => {
342 if let Some(value) = value.as_str() {
343 let length = value.len();
344 let value = self.format_value(value);
345 match length {
346 2 => format!(r#"strftime('%H', {field}) = {value}"#),
347 5 => format!(r#"strftime('%H:%M', {field}) = {value}"#),
348 8 => format!(r#"strftime('%H:%M:%S', {field}) = {value}"#),
349 _ => format!(r#"{field} = {value}"#),
350 }
351 } else {
352 let value = self.encode_value(Some(value));
353 format!(r#"{field} = {value}"#)
354 }
355 }
356 "Uuid" | "Option<Uuid>" => {
357 if let Some(value) = value.as_str() {
358 if value.contains(',') {
359 let value = value
360 .split(',')
361 .map(Query::escape_string)
362 .collect::<Vec<_>>()
363 .join(", ");
364 format!(r#"{field} IN ({value})"#)
365 } else {
366 let value = Query::escape_string(value);
367 format!(r#"{field} = {value}"#)
368 }
369 } else {
370 let value = self.encode_value(Some(value));
371 format!(r#"{field} = {value}"#)
372 }
373 }
374 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>"
375 | "Vec<f64>" | "Vec<f32>" => {
376 if let Some(value) = value.as_str() {
377 if value == "nonempty" {
378 format!(r#"json_array_length({field}) > 0"#)
379 } else {
380 let exprs = value
381 .split(',')
382 .map(|v| {
383 let value = Query::escape_string(v);
384 format!(r#"json_each.value = {value}"#)
385 })
386 .collect::<Vec<_>>();
387 format!("({})", exprs.join(" OR "))
388 }
389 } else if let Some(values) = value.as_array() {
390 let exprs = values
391 .iter()
392 .map(|v| {
393 let value = self.encode_value(Some(v));
394 format!(r#"json_each.value = {value}"#)
395 })
396 .collect::<Vec<_>>();
397 format!("({})", exprs.join(" OR "))
398 } else {
399 let value = self.encode_value(Some(value));
400 format!(r#"{field} = {value}"#)
401 }
402 }
403 _ => {
404 let value = self.encode_value(Some(value));
405 format!(r#"{field} = {value}"#)
406 }
407 }
408 }
409}
410
411#[cfg(feature = "orm-sqlx")]
412impl DecodeRow<DatabaseRow> for Map {
413 type Error = Error;
414
415 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
416 let mut map = Map::new();
417 for col in row.columns() {
418 let field = col.name();
419 let index = col.ordinal();
420 let raw_value = row.try_get_raw(index)?;
421 let value = if raw_value.is_null() {
422 JsonValue::Null
423 } else {
424 use super::decode::decode_raw;
425
426 let type_info = col.type_info();
427 let value_type_info = raw_value.type_info();
428 let column_type = if type_info.is_null() {
429 value_type_info.name()
430 } else {
431 type_info.name()
432 };
433 match column_type {
434 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
435 "INTEGER" => decode_raw::<i64>(field, raw_value)?.into(),
436 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
437 "TEXT" => {
438 let value = decode_raw::<String>(field, raw_value)?;
439 if value.starts_with('[') && value.ends_with(']')
440 || value.starts_with('{') && value.ends_with('}')
441 {
442 serde_json::from_str(&value)?
443 } else {
444 value.into()
445 }
446 }
447 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.into(),
448 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
449 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
450 "BLOB" => {
451 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
452 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
453 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
454 {
455 serde_json::from_slice::<JsonValue>(&bytes)
456 .unwrap_or_else(|_| bytes.into())
457 } else if bytes.len() == 16 {
458 if let Ok(value) = Uuid::from_slice(&bytes) {
459 value.to_string().into()
460 } else {
461 bytes.into()
462 }
463 } else {
464 bytes.into()
465 }
466 }
467 _ => decode_raw::<String>(field, raw_value)?.into(),
468 }
469 };
470 if !value.is_ignorable() {
471 map.insert(field.to_owned(), value);
472 }
473 }
474 Ok(map)
475 }
476}
477
478#[cfg(feature = "orm-sqlx")]
479impl DecodeRow<DatabaseRow> for Record {
480 type Error = Error;
481
482 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
483 let columns = row.columns();
484 let mut record = Record::with_capacity(columns.len());
485 for col in columns {
486 let field = col.name();
487 let index = col.ordinal();
488 let raw_value = row.try_get_raw(index)?;
489 let value = if raw_value.is_null() {
490 AvroValue::Null
491 } else {
492 use super::decode::decode_raw;
493
494 let type_info = col.type_info();
495 let value_type_info = raw_value.type_info();
496 let column_type = if type_info.is_null() {
497 value_type_info.name()
498 } else {
499 type_info.name()
500 };
501 match column_type {
502 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
503 "INTEGER" => decode_raw::<i64>(field, raw_value)?.into(),
504 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
505 "TEXT" => {
506 let value = decode_raw::<String>(field, raw_value)?;
507 if value.starts_with('[') && value.ends_with(']')
508 || value.starts_with('{') && value.ends_with('}')
509 {
510 serde_json::from_str::<JsonValue>(&value)?.into()
511 } else {
512 value.into()
513 }
514 }
515 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.to_string().into(),
516 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
517 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
518 "BLOB" => {
519 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
520 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
521 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
522 {
523 serde_json::from_slice::<JsonValue>(&bytes)
524 .map(|value| value.into())
525 .unwrap_or_else(|_| bytes.into())
526 } else if bytes.len() == 16 {
527 if let Ok(value) = Uuid::from_slice(&bytes) {
528 value.into()
529 } else {
530 bytes.into()
531 }
532 } else {
533 bytes.into()
534 }
535 }
536 _ => decode_raw::<String>(field, raw_value)?.into(),
537 }
538 };
539 record.push((field.to_owned(), value));
540 }
541 Ok(record)
542 }
543}
544
545#[cfg(feature = "orm-sqlx")]
546impl QueryExt<DatabaseDriver> for Query {
547 type QueryResult = sqlx::sqlite::SqliteQueryResult;
548
549 #[inline]
550 fn parse_query_result(query_result: Self::QueryResult) -> (Option<i64>, u64) {
551 let last_insert_id = query_result.last_insert_rowid();
552 let rows_affected = query_result.rows_affected();
553 (Some(last_insert_id), rows_affected)
554 }
555
556 #[inline]
557 fn query_fields(&self) -> &[String] {
558 self.fields()
559 }
560
561 #[inline]
562 fn query_filters(&self) -> &Map {
563 self.filters()
564 }
565
566 #[inline]
567 fn query_order(&self) -> &[QueryOrder] {
568 self.sort_order()
569 }
570
571 #[inline]
572 fn query_offset(&self) -> usize {
573 self.offset()
574 }
575
576 #[inline]
577 fn query_limit(&self) -> usize {
578 self.limit()
579 }
580
581 #[inline]
582 fn placeholder(_n: usize) -> SharedString {
583 "?".into()
584 }
585
586 #[inline]
587 fn prepare_query<'a>(
588 query: &'a str,
589 params: Option<&'a Map>,
590 ) -> (Cow<'a, str>, Vec<&'a JsonValue>) {
591 crate::query::prepare_sql_query(query, params, '?')
592 }
593
594 fn format_field(field: &str) -> Cow<'_, str> {
595 if field.contains('`') {
596 field.into()
597 } else if field.contains('.') {
598 field
599 .split('.')
600 .map(|s| ["`", s, "`"].concat())
601 .collect::<Vec<_>>()
602 .join(".")
603 .into()
604 } else {
605 ["`", field, "`"].concat().into()
606 }
607 }
608
609 fn format_table_fields<M: Schema>(&self) -> Cow<'_, str> {
610 let model_name = M::model_name();
611 let fields = self.query_fields();
612 if fields.is_empty() {
613 "*".into()
614 } else {
615 fields
616 .iter()
617 .map(|field| {
618 if let Some((alias, expr)) = field.split_once(':') {
619 let alias = Self::format_field(alias.trim());
620 format!(r#"{expr} AS {alias}"#)
621 } else if field.contains('"') {
622 field.into()
623 } else if field.contains('.') {
624 field
625 .split('.')
626 .map(|s| ["`", s, "`"].concat())
627 .collect::<Vec<_>>()
628 .join(".")
629 } else {
630 format!(r#"`{model_name}`.`{field}`"#)
631 }
632 })
633 .collect::<Vec<_>>()
634 .join(", ")
635 .into()
636 }
637 }
638
639 fn format_table_name<M: Schema>(&self) -> String {
640 let table_name = self
641 .extra()
642 .get_str("table_name")
643 .unwrap_or_else(|| M::table_name());
644 let model_name = M::model_name();
645 let filters = self.query_filters();
646 let mut virtual_tables = Vec::new();
647 for col in M::columns() {
648 let col_name = col.name();
649 if filters.contains_key(col_name) {
650 match col.type_name() {
651 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>"
652 | "Vec<i32>" => {
653 let virtual_table = format!("json_each(`{model_name}`.`{col_name}`)");
654 virtual_tables.push(virtual_table);
655 }
656 "Map" => {
657 let virtual_table = format!("json_tree(`{model_name}`.`{col_name}`)");
658 virtual_tables.push(virtual_table);
659 }
660 _ => (),
661 }
662 }
663 }
664
665 let table_name = if table_name.contains('.') {
666 table_name
667 .split('.')
668 .map(|s| ["`", s, "`"].concat())
669 .collect::<Vec<_>>()
670 .join(".")
671 } else {
672 ["`", table_name, "`"].concat()
673 };
674 if virtual_tables.is_empty() {
675 format!(r#"{table_name} AS `{model_name}`"#)
676 } else {
677 format!(
678 r#"{table_name} AS `{model_name}`, {}"#,
679 virtual_tables.join(", ")
680 )
681 }
682 }
683
684 fn escape_table_name(table_name: &str) -> String {
685 if table_name.contains('.') {
686 table_name
687 .split('.')
688 .map(|s| ["`", s, "`"].concat())
689 .collect::<Vec<_>>()
690 .join(".")
691 } else {
692 ["`", table_name, "`"].concat()
693 }
694 }
695
696 fn parse_text_search(filter: &Map) -> Option<String> {
697 let fields = filter.parse_str_array("$fields")?;
698 filter.parse_string("$search").map(|search| {
699 let fields = fields.join(", ");
700 let search = Query::escape_string(search.as_ref());
701 format!("{fields} MATCH {search}")
702 })
703 }
704}