1use super::{query::QueryExt, DatabaseDriver, DatabaseRow, DecodeRow, EncodeColumn, Schema};
2use std::borrow::Cow;
3use zino_core::{
4 datetime::{Date, DateTime, Time},
5 error::Error,
6 extension::{JsonObjectExt, JsonValueExt},
7 model::{Column, Query, QueryOrder},
8 AvroValue, JsonValue, Map, Record, SharedString, Uuid,
9};
10
11#[cfg(feature = "orm-sqlx")]
12use sqlx::{Column as _, Row, TypeInfo, ValueRef};
13
14impl EncodeColumn<DatabaseDriver> for Column<'_> {
15 fn column_type(&self) -> &str {
16 if let Some(column_type) = self.extra().get_str("column_type") {
17 return column_type;
18 }
19 match self.type_name() {
20 "bool" => "BOOLEAN",
21 "u64" | "i64" | "usize" | "isize" | "Option<u64>" | "Option<i64>" | "u32" | "i32"
22 | "u16" | "i16" | "u8" | "i8" | "Option<u32>" | "Option<i32>" => "INTEGER",
23 "f64" | "f32" => "REAL",
24 "Date" | "NaiveDate" => "DATE",
25 "Time" | "NaiveTime" => "TIME",
26 "DateTime" | "NaiveDateTime" => "DATETIME",
27 "Vec<u8>" => "BLOB",
28 _ => "TEXT",
29 }
30 }
31
32 fn encode_value<'a>(&self, value: Option<&'a JsonValue>) -> Cow<'a, str> {
33 if let Some(value) = value {
34 match value {
35 JsonValue::Null => "NULL".into(),
36 JsonValue::Bool(b) => {
37 let value = if *b { "TRUE" } else { "FALSE" };
38 value.into()
39 }
40 JsonValue::Number(n) => n.to_string().into(),
41 JsonValue::String(s) => {
42 if s.is_empty() {
43 if let Some(value) = self.default_value() {
44 self.format_value(value).into_owned().into()
45 } else {
46 "''".into()
47 }
48 } else if s == "null" {
49 "NULL".into()
50 } else if s == "not_null" {
51 "NOT NULL".into()
52 } else {
53 self.format_value(s)
54 }
55 }
56 JsonValue::Array(vec) => {
57 let values = vec
58 .iter()
59 .map(|v| match v {
60 JsonValue::String(v) => Query::escape_string(v),
61 _ => self.encode_value(Some(v)).into_owned(),
62 })
63 .collect::<Vec<_>>();
64 format!(r#"json_array({})"#, values.join(",")).into()
65 }
66 JsonValue::Object(_) => Query::escape_string(value).into(),
67 }
68 } else if self.default_value().is_some() {
69 "DEFAULT".into()
70 } else {
71 "NULL".into()
72 }
73 }
74
75 fn format_value<'a>(&self, value: &'a str) -> Cow<'a, str> {
76 match self.type_name() {
77 "bool" => {
78 let value = if value == "true" { "TRUE" } else { "FALSE" };
79 value.into()
80 }
81 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
82 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
83 if value.parse::<i64>().is_ok() {
84 value.into()
85 } else {
86 "NULL".into()
87 }
88 }
89 "f64" | "f32" => {
90 if value.parse::<f64>().is_ok() {
91 value.into()
92 } else {
93 "NULL".into()
94 }
95 }
96 "DateTime" | "NaiveDateTime" => match value {
97 "epoch" => "datetime(0, 'unixepoch')".into(),
98 "now" => "datetime('now', 'localtime')".into(),
99 "today" => "datetime('now', 'start of day')".into(),
100 "tomorrow" => "datetime('now', 'start of day', '+1 day')".into(),
101 "yesterday" => "datetime('now', 'start of day', '-1 day')".into(),
102 _ => Query::escape_string(value).into(),
103 },
104 "Date" | "NaiveDate" => match value {
105 "epoch" => "'1970-01-01'".into(),
106 "today" => "date('now', 'localtime')".into(),
107 "tomorrow" => "date('now', '+1 day')".into(),
108 "yesterday" => "date('now', '-1 day')".into(),
109 _ => Query::escape_string(value).into(),
110 },
111 "Time" | "NaiveTime" => match value {
112 "now" => "time('now', 'localtime')".into(),
113 "midnight" => "'00:00:00'".into(),
114 _ => Query::escape_string(value).into(),
115 },
116 "Vec<u8>" => format!("'{value}'").into(),
117 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
118 if value.contains(',') {
119 let values = value
120 .split(',')
121 .map(Query::escape_string)
122 .collect::<Vec<_>>();
123 format!(r#"json_array({})"#, values.join(",")).into()
124 } else {
125 let value = Query::escape_string(value);
126 format!(r#"json_array({value})"#).into()
127 }
128 }
129 _ => Query::escape_string(value).into(),
130 }
131 }
132
133 fn format_filter(&self, field: &str, value: &JsonValue) -> String {
134 let type_name = self.type_name();
135 let field = Query::format_field(field);
136 if let Some(filter) = value.as_object() {
137 let mut conditions = Vec::with_capacity(filter.len());
138 if type_name == "Map" {
139 for (key, value) in filter {
140 let key = Query::escape_string(key);
141 let value = self.encode_value(Some(value));
142 let condition =
143 format!(r#"json_tree.key = {key} AND json_tree.value = {value}"#);
144 conditions.push(condition);
145 }
146 return Query::join_conditions(conditions, " OR ");
147 } else {
148 for (name, value) in filter {
149 let name = name.as_str();
150 let operator = match name {
151 "$eq" => "=",
152 "$ne" => "<>",
153 "$lt" => "<",
154 "$le" => "<=",
155 "$gt" => ">",
156 "$ge" => ">=",
157 "$in" => "IN",
158 "$nin" => "NOT IN",
159 "$betw" => "BETWEEN",
160 "$like" => "LIKE",
161 "$ilike" => "ILIKE",
162 "$rlike" => "REGEXP",
163 "$is" => "IS",
164 "$size" => "json_array_length",
165 _ => {
166 if cfg!(debug_assertions) && name.starts_with('$') {
167 tracing::warn!("unsupported operator `{name}` for SQLite");
168 }
169 name
170 }
171 };
172 if let Some(subquery) = value.as_object().and_then(|m| m.get_str("$subquery")) {
173 let condition = format!(r#"{field} {operator} {subquery}"#);
174 conditions.push(condition);
175 } else if operator == "IN" || operator == "NOT IN" {
176 if let Some(values) = value.as_array() {
177 if values.is_empty() {
178 let condition = if operator == "IN" { "FALSE" } else { "TRUE" };
179 conditions.push(condition.to_owned());
180 } else {
181 let value = values
182 .iter()
183 .map(|v| self.encode_value(Some(v)))
184 .collect::<Vec<_>>()
185 .join(", ");
186 let condition = format!(r#"{field} {operator} ({value})"#);
187 conditions.push(condition);
188 }
189 }
190 } else if operator == "BETWEEN" {
191 if let Some(values) = value.as_array() {
192 if let [min_value, max_value] = values.as_slice() {
193 let min_value = self.encode_value(Some(min_value));
194 let max_value = self.encode_value(Some(max_value));
195 let condition =
196 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
197 conditions.push(condition);
198 }
199 } else if let Some(values) = value.parse_str_array() {
200 if let [min_value, max_value] = values.as_slice() {
201 let min_value = self.format_value(min_value);
202 let max_value = self.format_value(max_value);
203 let condition =
204 format!(r#"({field} BETWEEN {min_value} AND {max_value})"#);
205 conditions.push(condition);
206 }
207 }
208 } else if operator == "ILIKE" {
209 let value = self.encode_value(Some(value));
210 let condition = format!(r#"LOWER({field}) LIKE LOWER({value})"#);
211 conditions.push(condition);
212 } else if operator == "json_array_length" {
213 if let Some(Ok(length)) = value.parse_usize() {
214 let condition = format!(r#"json_array_length({field}) = {length}"#);
215 conditions.push(condition);
216 }
217 } else {
218 let value = self.encode_value(Some(value));
219 let condition = format!(r#"{field} {operator} {value}"#);
220 conditions.push(condition);
221 }
222 }
223 if conditions.is_empty() {
224 return String::new();
225 } else {
226 return conditions.join(" AND ");
227 }
228 }
229 } else if value.is_null() {
230 return format!(r#"{field} IS NULL"#);
231 } else if self.has_attribute("exact_filter") {
232 let value = self.encode_value(Some(value));
233 return format!(r#"{field} = {value}"#);
234 } else if let Some(value) = value.as_str() {
235 if value == "null" {
236 return format!(r#"{field} IS NULL"#);
237 } else if value == "not_null" {
238 return format!(r#"{field} IS NOT NULL"#);
239 } else if let Some((min_value, max_value)) =
240 value.split_once(',').filter(|_| self.is_datetime_type())
241 {
242 let min_value = self.format_value(min_value);
243 let max_value = self.format_value(max_value);
244 return format!(r#"{field} >= {min_value} AND {field} < {max_value}"#);
245 }
246 }
247
248 match type_name {
249 "bool" => {
250 let value = self.encode_value(Some(value));
251 if value == "TRUE" {
252 format!(r#"{field} IS TRUE"#)
253 } else {
254 format!(r#"{field} IS NOT TRUE"#)
255 }
256 }
257 "u64" | "i64" | "u32" | "i32" | "u16" | "i16" | "u8" | "i8" | "usize" | "isize"
258 | "Option<u64>" | "Option<i64>" | "Option<u32>" | "Option<i32>" => {
259 if let Some(value) = value.as_str() {
260 if value == "nonzero" {
261 format!(r#"{field} <> 0"#)
262 } else if value.contains(',') {
263 let value = value.split(',').collect::<Vec<_>>().join(",");
264 format!(r#"{field} IN ({value})"#)
265 } else {
266 let value = self.format_value(value);
267 format!(r#"{field} = {value}"#)
268 }
269 } else {
270 let value = self.encode_value(Some(value));
271 format!(r#"{field} = {value}"#)
272 }
273 }
274 "String" | "Option<String>" => {
275 if let Some(value) = value.as_str() {
276 if value == "empty" {
277 format!(r#"({field} = '') IS NOT FALSE"#)
279 } else if value == "nonempty" {
280 format!(r#"({field} = '') IS FALSE"#)
281 } else if self.fuzzy_search() {
282 if value.contains(',') {
283 let exprs = value
284 .split(',')
285 .map(|s| {
286 let value = Query::escape_string(format!("%{s}%"));
287 format!(r#"{field} LIKE {value}"#)
288 })
289 .collect::<Vec<_>>();
290 format!("({})", exprs.join(" OR "))
291 } else {
292 let value = Query::escape_string(format!("%{value}%"));
293 format!(r#"{field} LIKE {value}"#)
294 }
295 } else if value.contains(',') {
296 let value = value
297 .split(',')
298 .map(Query::escape_string)
299 .collect::<Vec<_>>()
300 .join(", ");
301 format!(r#"{field} IN ({value})"#)
302 } else {
303 let value = Query::escape_string(value);
304 format!(r#"{field} = {value}"#)
305 }
306 } else {
307 let value = self.encode_value(Some(value));
308 format!(r#"{field} = {value}"#)
309 }
310 }
311 "DateTime" | "NaiveDateTime" => {
312 if let Some(value) = value.as_str() {
313 let length = value.len();
314 let value = self.format_value(value);
315 match length {
316 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
317 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
318 10 => format!(r#"strftime('%Y-%m-%d', {field}) = {value}"#),
319 _ => format!(r#"{field} = {value}"#),
320 }
321 } else {
322 let value = self.encode_value(Some(value));
323 format!(r#"{field} = {value}"#)
324 }
325 }
326 "Date" | "NaiveDate" => {
327 if let Some(value) = value.as_str() {
328 let length = value.len();
329 let value = self.format_value(value);
330 match length {
331 4 => format!(r#"strftime('%Y', {field}) = {value}"#),
332 7 => format!(r#"strftime('%Y-%m', {field}) = {value}"#),
333 _ => format!(r#"{field} = {value}"#),
334 }
335 } else {
336 let value = self.encode_value(Some(value));
337 format!(r#"{field} = {value}"#)
338 }
339 }
340 "Time" | "NaiveTime" => {
341 if let Some(value) = value.as_str() {
342 let length = value.len();
343 let value = self.format_value(value);
344 match length {
345 2 => format!(r#"strftime('%H', {field}) = {value}"#),
346 5 => format!(r#"strftime('%H:%M', {field}) = {value}"#),
347 8 => format!(r#"strftime('%H:%M:%S', {field}) = {value}"#),
348 _ => format!(r#"{field} = {value}"#),
349 }
350 } else {
351 let value = self.encode_value(Some(value));
352 format!(r#"{field} = {value}"#)
353 }
354 }
355 "Uuid" | "Option<Uuid>" => {
356 if let Some(value) = value.as_str() {
357 if value.contains(',') {
358 let value = value
359 .split(',')
360 .map(Query::escape_string)
361 .collect::<Vec<_>>()
362 .join(", ");
363 format!(r#"{field} IN ({value})"#)
364 } else {
365 let value = Query::escape_string(value);
366 format!(r#"{field} = {value}"#)
367 }
368 } else {
369 let value = self.encode_value(Some(value));
370 format!(r#"{field} = {value}"#)
371 }
372 }
373 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>" | "Vec<i32>" => {
374 if let Some(value) = value.as_str() {
375 if value == "nonempty" {
376 format!(r#"json_array_length({field}) > 0"#)
377 } else {
378 let exprs = value
379 .split(',')
380 .map(|v| {
381 let value = Query::escape_string(v);
382 format!(r#"json_each.value = {value}"#)
383 })
384 .collect::<Vec<_>>();
385 format!("({})", exprs.join(" OR "))
386 }
387 } else if let Some(values) = value.as_array() {
388 let exprs = values
389 .iter()
390 .map(|v| {
391 let value = self.encode_value(Some(v));
392 format!(r#"json_each.value = {value}"#)
393 })
394 .collect::<Vec<_>>();
395 format!("({})", exprs.join(" OR "))
396 } else {
397 let value = self.encode_value(Some(value));
398 format!(r#"{field} = {value}"#)
399 }
400 }
401 _ => {
402 let value = self.encode_value(Some(value));
403 format!(r#"{field} = {value}"#)
404 }
405 }
406 }
407}
408
409#[cfg(feature = "orm-sqlx")]
410impl DecodeRow<DatabaseRow> for Map {
411 type Error = Error;
412
413 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
414 let mut map = Map::new();
415 for col in row.columns() {
416 let field = col.name();
417 let index = col.ordinal();
418 let raw_value = row.try_get_raw(index)?;
419 let value = if raw_value.is_null() {
420 JsonValue::Null
421 } else {
422 use super::decode::decode_raw;
423 match col.type_info().name() {
424 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
425 "INTEGER" | "BIGINT" => decode_raw::<i64>(field, raw_value)?.into(),
426 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
427 "TEXT" => {
428 let value = decode_raw::<String>(field, raw_value)?;
429 if value.starts_with('[') && value.ends_with(']')
430 || value.starts_with('{') && value.ends_with('}')
431 {
432 serde_json::from_str(&value)?
433 } else {
434 value.into()
435 }
436 }
437 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.into(),
438 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
439 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
440 "BLOB" => {
441 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
442 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
443 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
444 {
445 serde_json::from_slice::<JsonValue>(&bytes)
446 .unwrap_or_else(|_| bytes.into())
447 } else if bytes.len() == 16 {
448 if let Ok(value) = Uuid::from_slice(&bytes) {
449 value.to_string().into()
450 } else {
451 bytes.into()
452 }
453 } else {
454 bytes.into()
455 }
456 }
457 _ => decode_raw::<String>(field, raw_value)?.into(),
458 }
459 };
460 if !value.is_ignorable() {
461 map.insert(field.to_owned(), value);
462 }
463 }
464 Ok(map)
465 }
466}
467
468#[cfg(feature = "orm-sqlx")]
469impl DecodeRow<DatabaseRow> for Record {
470 type Error = Error;
471
472 fn decode_row(row: &DatabaseRow) -> Result<Self, Self::Error> {
473 let columns = row.columns();
474 let mut record = Record::with_capacity(columns.len());
475 for col in columns {
476 let field = col.name();
477 let index = col.ordinal();
478 let raw_value = row.try_get_raw(index)?;
479 let value = if raw_value.is_null() {
480 AvroValue::Null
481 } else {
482 use super::decode::decode_raw;
483 match col.type_info().name() {
484 "BOOLEAN" => decode_raw::<bool>(field, raw_value)?.into(),
485 "INTEGER" | "BIGINT" => decode_raw::<i64>(field, raw_value)?.into(),
486 "REAL" => decode_raw::<f64>(field, raw_value)?.into(),
487 "TEXT" => {
488 let value = decode_raw::<String>(field, raw_value)?;
489 if value.starts_with('[') && value.ends_with(']')
490 || value.starts_with('{') && value.ends_with('}')
491 {
492 serde_json::from_str::<JsonValue>(&value)?.into()
493 } else {
494 value.into()
495 }
496 }
497 "DATETIME" => decode_raw::<DateTime>(field, raw_value)?.to_string().into(),
498 "DATE" => decode_raw::<Date>(field, raw_value)?.into(),
499 "TIME" => decode_raw::<Time>(field, raw_value)?.into(),
500 "BLOB" => {
501 let bytes = decode_raw::<Vec<u8>>(field, raw_value)?;
502 if bytes.starts_with(b"[") && bytes.ends_with(b"]")
503 || bytes.starts_with(b"{") && bytes.ends_with(b"}")
504 {
505 serde_json::from_slice::<JsonValue>(&bytes)
506 .map(|value| value.into())
507 .unwrap_or_else(|_| bytes.into())
508 } else if bytes.len() == 16 {
509 if let Ok(value) = Uuid::from_slice(&bytes) {
510 value.into()
511 } else {
512 bytes.into()
513 }
514 } else {
515 bytes.into()
516 }
517 }
518 _ => decode_raw::<String>(field, raw_value)?.into(),
519 }
520 };
521 record.push((field.to_owned(), value));
522 }
523 Ok(record)
524 }
525}
526
527#[cfg(feature = "orm-sqlx")]
528impl QueryExt<DatabaseDriver> for Query {
529 type QueryResult = sqlx::sqlite::SqliteQueryResult;
530
531 #[inline]
532 fn parse_query_result(query_result: Self::QueryResult) -> (Option<i64>, u64) {
533 let last_insert_id = query_result.last_insert_rowid();
534 let rows_affected = query_result.rows_affected();
535 (Some(last_insert_id), rows_affected)
536 }
537
538 #[inline]
539 fn query_fields(&self) -> &[String] {
540 self.fields()
541 }
542
543 #[inline]
544 fn query_filters(&self) -> &Map {
545 self.filters()
546 }
547
548 #[inline]
549 fn query_order(&self) -> &[QueryOrder] {
550 self.sort_order()
551 }
552
553 #[inline]
554 fn query_offset(&self) -> usize {
555 self.offset()
556 }
557
558 #[inline]
559 fn query_limit(&self) -> usize {
560 self.limit()
561 }
562
563 #[inline]
564 fn placeholder(_n: usize) -> SharedString {
565 "?".into()
566 }
567
568 #[inline]
569 fn prepare_query<'a>(
570 query: &'a str,
571 params: Option<&'a Map>,
572 ) -> (Cow<'a, str>, Vec<&'a JsonValue>) {
573 crate::query::prepare_sql_query(query, params, '?')
574 }
575
576 fn format_field(field: &str) -> Cow<'_, str> {
577 if field.contains('`') {
578 field.into()
579 } else if field.contains('.') {
580 field
581 .split('.')
582 .map(|s| ["`", s, "`"].concat())
583 .collect::<Vec<_>>()
584 .join(".")
585 .into()
586 } else {
587 ["`", field, "`"].concat().into()
588 }
589 }
590
591 fn format_table_fields<M: Schema>(&self) -> Cow<'_, str> {
592 let model_name = M::model_name();
593 let fields = self.query_fields();
594 if fields.is_empty() {
595 "*".into()
596 } else {
597 fields
598 .iter()
599 .map(|field| {
600 if let Some((alias, expr)) = field.split_once(':') {
601 let alias = Self::format_field(alias.trim());
602 format!(r#"{expr} AS {alias}"#)
603 } else if field.contains('.') {
604 field
605 .split('.')
606 .map(|s| ["`", s, "`"].concat())
607 .collect::<Vec<_>>()
608 .join(".")
609 } else {
610 format!(r#"`{model_name}`.`{field}`"#)
611 }
612 })
613 .collect::<Vec<_>>()
614 .join(", ")
615 .into()
616 }
617 }
618
619 fn format_table_name<M: Schema>(&self) -> String {
620 let table_name = M::table_name();
621 let model_name = M::model_name();
622 let filters = self.query_filters();
623 let mut virtual_tables = Vec::new();
624 for col in M::columns() {
625 let col_name = col.name();
626 if filters.contains_key(col_name) {
627 match col.type_name() {
628 "Vec<String>" | "Vec<Uuid>" | "Vec<u64>" | "Vec<i64>" | "Vec<u32>"
629 | "Vec<i32>" => {
630 let virtual_table = format!("json_each(`{model_name}`.`{col_name}`)");
631 virtual_tables.push(virtual_table);
632 }
633 "Map" => {
634 let virtual_table = format!("json_tree(`{model_name}`.`{col_name}`)");
635 virtual_tables.push(virtual_table);
636 }
637 _ => (),
638 }
639 }
640 }
641
642 let table_name = if table_name.contains('.') {
643 table_name
644 .split('.')
645 .map(|s| ["`", s, "`"].concat())
646 .collect::<Vec<_>>()
647 .join(".")
648 } else {
649 ["`", table_name, "`"].concat()
650 };
651 if virtual_tables.is_empty() {
652 format!(r#"{table_name} AS `{model_name}`"#)
653 } else {
654 format!(
655 r#"{table_name} AS `{model_name}`, {}"#,
656 virtual_tables.join(", ")
657 )
658 }
659 }
660
661 fn table_name_escaped<M: Schema>() -> String {
662 let table_name = M::table_name();
663 if table_name.contains('.') {
664 table_name
665 .split('.')
666 .map(|s| ["`", s, "`"].concat())
667 .collect::<Vec<_>>()
668 .join(".")
669 } else {
670 ["`", table_name, "`"].concat()
671 }
672 }
673
674 fn parse_text_search(filter: &Map) -> Option<String> {
675 let fields = filter.parse_str_array("$fields")?;
676 filter.parse_string("$search").map(|search| {
677 let fields = fields.join(", ");
678 let search = Query::escape_string(search.as_ref());
679 format!("{fields} MATCH {search}")
680 })
681 }
682}