1use super::{DatabaseDriver, DatabaseRow, DecodeRow, EncodeColumn, Schema, query::QueryExt};
2use std::borrow::Cow;
3use zino_core::{
4 AvroValue, JsonValue, Map, Record, SharedString, Uuid,
5 datetime::{Date, DateTime, Time},
6 error::Error,
7 extension::{JsonObjectExt, JsonValueExt},
8 model::{Column, Query, QueryOrder},
9};
10
11#[cfg(feature = "orm-sqlx")]
12use sqlx::{Column as _, Row, TypeInfo, ValueRef};
13
14impl EncodeColumn<DatabaseDriver> for Column<'_> {
15 fn column_type(&self) -> &str {
16 if let Some(column_type) = self.extra().get_str("column_type") {
17 return column_type;
18 }
19 match self.type_name() {
20 "bool" => "BOOLEAN",
21 "u64" | "i64" | "usize" | "isize" | "Option<u64>" | "Option<i64>" | "u32" | "i32"
22 | "u16" | "i16" | "u8" | "i8" | "Option<u32>" | "Option<i32>" => "INTEGER",
23 "f64" | "f32" => "REAL",
24 "Date" | "NaiveDate" => "DATE",
25 "Time" | "NaiveTime" => "TIME",
26 "DateTime" | "NaiveDateTime" => "DATETIME",
27 "Vec<u8>" => "BLOB",
28 _ => "TEXT",
29 }
30 }
31
32 fn encode_value<'a>(&self, value: Option<&'a JsonValue>) -> Cow<'a, str> {
33 if let Some(value) = value {
34 match value {
35 JsonValue::Null => "NULL".into(),
36 JsonValue::Bool(b) => {
37 let value = if *b { "TRUE" } else { "FALSE" };
38 value.into()
39 }
40 JsonValue::Number(n) => n.to_string().into(),
41 JsonValue::String(s) => {
42 if s.is_empty() {
43 if let Some(value) = self.default_value() {
44 self.format_value(value).into_owned().into()
45 } else {
46 "''".into()
47 }
48 } else if s == "null" {
49 "NULL".into()
50 } else if s == "not_null" {
51 "NOT NULL".into()
52 } else {
53 self.format_value(s)
54 }
55 }
56 JsonValue::Array(vec) => {
57 let values = vec
58 .iter()
59 .map(|v| match v {
60 JsonValue::String(v) => Query::escape_string(v),
61 _ => self.encode_value(Some(v)).into_owned(),
62 })
63 .collect::<Vec<_>>();
64 format!(r#"json_array({})"#, values.join(",")).into()
65 }
66 JsonValue::Object(_) => Query::escape_string(value).into(),
67 }
68 } else if self.default_value().is_some() {
69 "DEFAULT".into()
70 } else {
71 "NULL".into()
72 }
73 }
74
75 fn format_value<'a>(&self, value: &'a str) -> Cow<'a, str> {
76 match self.type_name() {
77 "bool" => {
78 let value = if value == "true" { "TRUE" } else { "FALSE" };
79 value.into()
80 }
81 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
82 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
83 if value.parse::<i64>().is_ok() {
84 value.into()
85 } else {
86 "NULL".into()
87 }
88 }
89 "f64" | "f32" => {
90 if value.parse::<f64>().is_ok() {
91 value.into()
92 } else {
93 "NULL".into()
94 }
95 }
96 "DateTime" | "NaiveDateTime" => match value {
97 "epoch" => "datetime(0, 'unixepoch')".into(),
98 "now" => "datetime('now', 'localtime')".into(),
99 "today" => "datetime('now', 'start of day')".into(),
100 "tomorrow" => "datetime('now', 'start of day', '+1 day')".into(),
101 "yesterday" => "datetime('now', 'start of day', '-1 day')".into(),
102 _ => Query::escape_string(value).into(),
103 },
104 "Date" | "NaiveDate" => match value {
105 "epoch" => "'1970-01-01'".into(),
106 "today" => "date('now', 'localtime')".into(),
107 "tomorrow" => "date('now', '+1 day')".into(),
108 "yesterday" => "date('now', '-1 day')".into(),
109 _ => Query::escape_string(value).into(),
110 },
111 "Time" | "NaiveTime" => match value {
112 "now" => "time('now', 'localtime')".into(),
113 "midnight" => "'00:00:00'".into(),
114 _ => Query::escape_string(value).into(),
115 },
116 "Vec<u8>" => format!("'{value}'").into(),
117 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
118 if value.contains(',') {
119 let values = value
120 .split(',')
121 .map(Query::escape_string)
122 .collect::<Vec<_>>();
123 format!(r#"json_array({})"#, values.join(",")).into()
124 } else {
125 let value = Query::escape_string(value);
126 format!(r#"json_array({value})"#).into()
127 }
128 }
129 _ => Query::escape_string(value).into(),
130 }
131 }
132
133 fn format_filter(&self, field: &str, value: &JsonValue) -> String {
134 let type_name = self.type_name();
135 let field = Query::format_field(field);
136 if let Some(filter) = value.as_object() {
137 let mut conditions = Vec::with_capacity(filter.len());
138 if type_name == "Map" {
139 for (key, value) in filter {
140 let key = Query::escape_string(key);
141 let value = self.encode_value(Some(value));
142 let condition =
143 format!(r#"json_tree.key = {key} AND json_tree.value = {value}"#);
144 conditions.push(condition);
145 }
146 return Query::join_conditions(conditions, " OR ");
147 } else {
148 for (name, value) in filter {
149 let name = name.as_str();
150 let operator = match name {
151 "$eq" => "=",
152 "$ne" => "<>",
153 "$lt" => "<",
154 "$le" => "<=",
155 "$gt" => ">",
156 "$ge" => ">=",
157 "$in" => "IN",
158 "$nin" => "NOT IN",
159 "$betw" => "BETWEEN",
160 "$like" => "LIKE",
161 "$ilike" => "ILIKE",
162 "$rlike" => "REGEXP",
163 "$is" => "IS",
164 "$size" => "json_array_length",
165 _ => {
166 if cfg!(debug_assertions) && name.starts_with('$') {
167 tracing::warn!("unsupported operator `{name}` for SQLite");
168 }
169 name
170 }
171 };
172 if let Some(subquery) = value.as_object().and_then(|m| m.get_str("$subquery")) {
173 let condition = format!(r#"{field} {operator} {subquery}"#);
174 conditions.push(condition);
175 } else if operator == "IN" || operator == "NOT IN" {
176 if let Some(values) = value.as_array() {
177 if values.is_empty() {
178 let condition = if operator == "IN" { "FALSE" } else { "TRUE" };
179 conditions.push(condition.to_owned());
180 } else {
181 let value = values
182 .iter()
183 .map(|v| self.encode_value(Some(v)))
184 .collect::<Vec<_>>()
185 .join(", ");
186 let condition = format!(r#"{field} {operator} ({value})"#);
187 conditions.push(condition);
188 }
189 }
190 } else if operator == "BETWEEN" {
191 if let Some(values) = value.as_array() {
192 if let [min_value, max_value] = values.as_slice() {
193 let min_value = self.encode_value(Some(min_value));
194 let max_value = self.encode_value(Some(max_value));
195 let condition =
196 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
197 conditions.push(condition);
198 }
199 } else if let Some(values) = value.parse_str_array() {
200 if let [min_value, max_value] = values.as_slice() {
201 let min_value = self.format_value(min_value);
202 let max_value = self.format_value(max_value);
203 let condition =
204 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
205 conditions.push(condition);
206 }
207 }
208 } else if operator == "ILIKE" {
209 let value = self.encode_value(Some(value));
210 let condition = format!(r#"LOWER({field}) LIKE LOWER({value})"#);
211 conditions.push(condition);
212 } else if operator == "json_array_length" {
213 if let Some(Ok(length)) = value.parse_usize() {
214 let condition = format!(r#"json_array_length({field}) = {length}"#);
215 conditions.push(condition);
216 }
217 } else {
218 let value = self.encode_value(Some(value));
219 let condition = format!(r#"{field} {operator} {value}"#);
220 conditions.push(condition);
221 }
222 }
223 if conditions.is_empty() {
224 return String::new();
225 } else {
226 return conditions.join(" AND ");
227 }
228 }
229 } else if value.is_null() {
230 return format!(r#"{field} IS NULL"#);
231 } else if self.has_attribute("exact_filter") {
232 let value = self.encode_value(Some(value));
233 return format!(r#"{field} = {value}"#);
234 } else if let Some(value) = value.as_str() {
235 if value == "null" {
236 return format!(r#"{field} IS NULL"#);
237 } else if value == "not_null" {
238 return format!(r#"{field} IS NOT NULL"#);
239 } else if let Some((min_value, max_value)) =
240 value.split_once(',').filter(|_| self.is_datetime_type())
241 {
242 let min_value = self.format_value(min_value);
243 let max_value = self.format_value(max_value);
244 return format!(r#"{field} >= {min_value} AND {field} < {max_value}"#);
245 }
246 }
247
248 match type_name {
249 "bool" => {
250 let value = self.encode_value(Some(value));
251 if value == "TRUE" {
252 format!(r#"{field} IS TRUE"#)
253 } else {
254 format!(r#"{field} IS NOT TRUE"#)
255 }
256 }
257 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
258 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
259 if let Some(value) = value.as_str() {
260 if value == "nonzero" {
261 format!(r#"{field} <> 0"#)
262 } else if value.contains(',') {
263 let value = value.split(',').collect::<Vec<_>>().join(",");
264 format!(r#"{field} IN ({value})"#)
265 } else {
266 let value = self.format_value(value);
267 format!(r#"{field} = {value}"#)
268 }
269 } else {
270 let value = self.encode_value(Some(value));
271 format!(r#"{field} = {value}"#)
272 }
273 }
274 "String" | "Option<String>" => {
275 if let Some(value) = value.as_str() {
276 if value == "empty" {
277 format!(r#"({field} = '') IS NOT FALSE"#)
279 } else if value == "nonempty" {
280 format!(r#"({field} = '') IS FALSE"#)
281 } else if self.fuzzy_search() {
282 if value.contains(',') {
283 let exprs = value
284 .split(',')
285 .map(|s| {
286 let value = Query::escape_string(format!("%{s}%"));
287 format!(r#"{field} LIKE {value}"#)
288 })
289 .collect::<Vec<_>>();
290 format!("({})", exprs.join(" OR "))
291 } else {
292 let value = Query::escape_string(format!("%{value}%"));
293 format!(r#"{field} LIKE {value}"#)
294 }
295 } else if value.contains(',') {
296 let value = value
297 .split(',')
298 .map(Query::escape_string)
299 .collect::<Vec<_>>()
300 .join(", ");
301 format!(r#"{field} IN ({value})"#)
302 } else {
303 let value = Query::escape_string(value);
304 format!(r#"{field} = {value}"#)
305 }
306 } else {
307 let value = self.encode_value(Some(value));
308 format!(r#"{field} = {value}"#)
309 }
310 }
311 "DateTime" | "NaiveDateTime" => {
312 if let Some(value) = value.as_str() {
313 let length = value.len();
314 let value = self.format_value(value);
315 match length {
316 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
317 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
318 10 => format!(r#"strftime('%Y-%m-%d', {field}) = {value}"#),
319 _ => format!(r#"{field} = {value}"#),
320 }
321 } else {
322 let value = self.encode_value(Some(value));
323 format!(r#"{field} = {value}"#)
324 }
325 }
326 "Date" | "NaiveDate" => {
327 if let Some(value) = value.as_str() {
328 let length = value.len();
329 let value = self.format_value(value);
330 match length {
331 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
332 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
333 _ => format!(r#"{field} = {value}"#),
334 }
335 } else {
336 let value = self.encode_value(Some(value));
337 format!(r#"{field} = {value}"#)
338 }
339 }
340 "Time" | "NaiveTime" => {
341 if let Some(value) = value.as_str() {
342 let length = value.len();
343 let value = self.format_value(value);
344 match length {
345 2 => format!(r#"strftime('%H', {field}) = {value}"#),
346 5 => format!(r#"strftime('%H:%M', {field}) = {value}"#),
347 8 => format!(r#"strftime('%H:%M:%S', {field}) = {value}"#),
348 _ => format!(r#"{field} = {value}"#),
349 }
350 } else {
351 let value = self.encode_value(Some(value));
352 format!(r#"{field} = {value}"#)
353 }
354 }
355 "Uuid" | "Option<Uuid>" => {
356 if let Some(value) = value.as_str() {
357 if value.contains(',') {
358 let value = value
359 .split(',')
360 .map(Query::escape_string)
361 .collect::<Vec<_>>()
362 .join(", ");
363 format!(r#"{field} IN ({value})"#)
364 } else {
365 let value = Query::escape_string(value);
366 format!(r#"{field} = {value}"#)
367 }
368 } else {
369 let value = self.encode_value(Some(value));
370 format!(r#"{field} = {value}"#)
371 }
372 }
373 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
374 if let Some(value) = value.as_str() {
375 if value == "nonempty" {
376 format!(r#"json_array_length({field}) > 0"#)
377 } else {
378 let exprs = value
379 .split(',')
380 .map(|v| {
381 let value = Query::escape_string(v);
382 format!(r#"json_each.value = {value}"#)
383 })
384 .collect::<Vec<_>>();
385 format!("({})", exprs.join(" OR "))
386 }
387 } else if let Some(values) = value.as_array() {
388 let exprs = values
389 .iter()
390 .map(|v| {
391 let value = self.encode_value(Some(v));
392 format!(r#"json_each.value = {value}"#)
393 })
394 .collect::<Vec<_>>();
395 format!("({})", exprs.join(" OR "))
396 } else {
397 let value = self.encode_value(Some(value));
398 format!(r#"{field} = {value}"#)
399 }
400 }
401 _ => {
402 let value = self.encode_value(Some(value));
403 format!(r#"{field} = {value}"#)
404 }
405 }
406 }
407}
408
409#[cfg(feature = "orm-sqlx")]
410impl DecodeRow<DatabaseRow> for Map {
411 type Error = Error;
412
413 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
414 let mut map = Map::new();
415 for col in row.columns() {
416 let field = col.name();
417 let index = col.ordinal();
418 let raw_value = row.try_get_raw(index)?;
419 let value = if raw_value.is_null() {
420 JsonValue::Null
421 } else {
422 use super::decode::decode_raw;
423
424 let type_info = col.type_info();
425 let value_type_info = raw_value.type_info();
426 let column_type = if type_info.is_null() {
427 value_type_info.name()
428 } else {
429 type_info.name()
430 };
431 match column_type {
432 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
433 "INTEGER" => decode_raw::<i64>(field, raw_value)?.into(),
434 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
435 "TEXT" => {
436 let value = decode_raw::<String>(field, raw_value)?;
437 if value.starts_with('[') && value.ends_with(']')
438 || value.starts_with('{') && value.ends_with('}')
439 {
440 serde_json::from_str(&value)?
441 } else {
442 value.into()
443 }
444 }
445 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.into(),
446 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
447 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
448 "BLOB" => {
449 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
450 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
451 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
452 {
453 serde_json::from_slice::<JsonValue>(&bytes)
454 .unwrap_or_else(|_| bytes.into())
455 } else if bytes.len() == 16 {
456 if let Ok(value) = Uuid::from_slice(&bytes) {
457 value.to_string().into()
458 } else {
459 bytes.into()
460 }
461 } else {
462 bytes.into()
463 }
464 }
465 _ => decode_raw::<String>(field, raw_value)?.into(),
466 }
467 };
468 if !value.is_ignorable() {
469 map.insert(field.to_owned(), value);
470 }
471 }
472 Ok(map)
473 }
474}
475
476#[cfg(feature = "orm-sqlx")]
477impl DecodeRow<DatabaseRow> for Record {
478 type Error = Error;
479
480 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
481 let columns = row.columns();
482 let mut record = Record::with_capacity(columns.len());
483 for col in columns {
484 let field = col.name();
485 let index = col.ordinal();
486 let raw_value = row.try_get_raw(index)?;
487 let value = if raw_value.is_null() {
488 AvroValue::Null
489 } else {
490 use super::decode::decode_raw;
491
492 let type_info = col.type_info();
493 let value_type_info = raw_value.type_info();
494 let column_type = if type_info.is_null() {
495 value_type_info.name()
496 } else {
497 type_info.name()
498 };
499 match column_type {
500 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
501 "INTEGER" => decode_raw::<i64>(field, raw_value)?.into(),
502 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
503 "TEXT" => {
504 let value = decode_raw::<String>(field, raw_value)?;
505 if value.starts_with('[') && value.ends_with(']')
506 || value.starts_with('{') && value.ends_with('}')
507 {
508 serde_json::from_str::<JsonValue>(&value)?.into()
509 } else {
510 value.into()
511 }
512 }
513 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.to_string().into(),
514 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
515 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
516 "BLOB" => {
517 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
518 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
519 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
520 {
521 serde_json::from_slice::<JsonValue>(&bytes)
522 .map(|value| value.into())
523 .unwrap_or_else(|_| bytes.into())
524 } else if bytes.len() == 16 {
525 if let Ok(value) = Uuid::from_slice(&bytes) {
526 value.into()
527 } else {
528 bytes.into()
529 }
530 } else {
531 bytes.into()
532 }
533 }
534 _ => decode_raw::<String>(field, raw_value)?.into(),
535 }
536 };
537 record.push((field.to_owned(), value));
538 }
539 Ok(record)
540 }
541}
542
543#[cfg(feature = "orm-sqlx")]
544impl QueryExt<DatabaseDriver> for Query {
545 type QueryResult = sqlx::sqlite::SqliteQueryResult;
546
547 #[inline]
548 fn parse_query_result(query_result: Self::QueryResult) -> (Option<i64>, u64) {
549 let last_insert_id = query_result.last_insert_rowid();
550 let rows_affected = query_result.rows_affected();
551 (Some(last_insert_id), rows_affected)
552 }
553
554 #[inline]
555 fn query_fields(&self) -> &[String] {
556 self.fields()
557 }
558
559 #[inline]
560 fn query_filters(&self) -> &Map {
561 self.filters()
562 }
563
564 #[inline]
565 fn query_order(&self) -> &[QueryOrder] {
566 self.sort_order()
567 }
568
569 #[inline]
570 fn query_offset(&self) -> usize {
571 self.offset()
572 }
573
574 #[inline]
575 fn query_limit(&self) -> usize {
576 self.limit()
577 }
578
579 #[inline]
580 fn placeholder(_n: usize) -> SharedString {
581 "?".into()
582 }
583
584 #[inline]
585 fn prepare_query<'a>(
586 query: &'a str,
587 params: Option<&'a Map>,
588 ) -> (Cow<'a, str>, Vec<&'a JsonValue>) {
589 crate::query::prepare_sql_query(query, params, '?')
590 }
591
592 fn format_field(field: &str) -> Cow<'_, str> {
593 if field.contains('`') {
594 field.into()
595 } else if field.contains('.') {
596 field
597 .split('.')
598 .map(|s| ["`", s, "`"].concat())
599 .collect::<Vec<_>>()
600 .join(".")
601 .into()
602 } else {
603 ["`", field, "`"].concat().into()
604 }
605 }
606
607 fn format_table_fields<M: Schema>(&self) -> Cow<'_, str> {
608 let model_name = M::model_name();
609 let fields = self.query_fields();
610 if fields.is_empty() {
611 "*".into()
612 } else {
613 fields
614 .iter()
615 .map(|field| {
616 if let Some((alias, expr)) = field.split_once(':') {
617 let alias = Self::format_field(alias.trim());
618 format!(r#"{expr} AS {alias}"#)
619 } else if field.contains('.') {
620 field
621 .split('.')
622 .map(|s| ["`", s, "`"].concat())
623 .collect::<Vec<_>>()
624 .join(".")
625 } else {
626 format!(r#"`{model_name}`.`{field}`"#)
627 }
628 })
629 .collect::<Vec<_>>()
630 .join(", ")
631 .into()
632 }
633 }
634
635 fn format_table_name<M: Schema>(&self) -> String {
636 let table_name = self
637 .extra()
638 .get_str("table_name")
639 .unwrap_or_else(|| M::table_name());
640 let model_name = M::model_name();
641 let filters = self.query_filters();
642 let mut virtual_tables = Vec::new();
643 for col in M::columns() {
644 let col_name = col.name();
645 if filters.contains_key(col_name) {
646 match col.type_name() {
647 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>"
648 | "Vec<i32>" => {
649 let virtual_table = format!("json_each(`{model_name}`.`{col_name}`)");
650 virtual_tables.push(virtual_table);
651 }
652 "Map" => {
653 let virtual_table = format!("json_tree(`{model_name}`.`{col_name}`)");
654 virtual_tables.push(virtual_table);
655 }
656 _ => (),
657 }
658 }
659 }
660
661 let table_name = if table_name.contains('.') {
662 table_name
663 .split('.')
664 .map(|s| ["`", s, "`"].concat())
665 .collect::<Vec<_>>()
666 .join(".")
667 } else {
668 ["`", table_name, "`"].concat()
669 };
670 if virtual_tables.is_empty() {
671 format!(r#"{table_name} AS `{model_name}`"#)
672 } else {
673 format!(
674 r#"{table_name} AS `{model_name}`, {}"#,
675 virtual_tables.join(", ")
676 )
677 }
678 }
679
680 fn escape_table_name(table_name: &str) -> String {
681 if table_name.contains('.') {
682 table_name
683 .split('.')
684 .map(|s| ["`", s, "`"].concat())
685 .collect::<Vec<_>>()
686 .join(".")
687 } else {
688 ["`", table_name, "`"].concat()
689 }
690 }
691
692 fn parse_text_search(filter: &Map) -> Option<String> {
693 let fields = filter.parse_str_array("$fields")?;
694 filter.parse_string("$search").map(|search| {
695 let fields = fields.join(", ");
696 let search = Query::escape_string(search.as_ref());
697 format!("{fields} MATCH {search}")
698 })
699 }
700}