1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use serde_json::Value as JsonValue;
5use sqlx::postgres::{PgArguments, PgRow};
6use sqlx::{Postgres, Row as _, ValueRef as _};
7
8pub use futures_util::future::BoxFuture;
9use rust_decimal::Decimal;
10use uuid::Uuid;
11
12use crate::PgExecutor;
13use crate::filter::{FilterBuilder, compile_filter_sql, schema_model as resolve_schema_model};
14use crate::schema::{FieldType, Model, ScalarType, Schema};
15
16pub trait QuerySpec: Send + Sync {
18 type Output: Send + 'static;
19
20 #[doc(hidden)]
21 fn fetch_many<'a>(
22 &'a self,
23 executor: &'a dyn PgExecutor,
24 ) -> BoxFuture<'a, Result<Vec<Self::Output>, sqlx::Error>>;
25
26 #[doc(hidden)]
27 fn fetch_optional<'a>(
28 &'a self,
29 executor: &'a dyn PgExecutor,
30 ) -> BoxFuture<'a, Result<Option<Self::Output>, sqlx::Error>> {
31 Box::pin(async move { Ok(self.fetch_many(executor).await?.into_iter().next()) })
32 }
33
34 #[doc(hidden)]
35 fn fetch_first<'a>(
36 &'a self,
37 executor: &'a dyn PgExecutor,
38 ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
39 Box::pin(async move {
40 self.fetch_optional(executor)
41 .await?
42 .ok_or(sqlx::Error::RowNotFound)
43 })
44 }
45}
46
47pub trait SchemaAccess: Send + Sync + 'static {
48 fn schema() -> &'static Schema;
49}
50
51#[derive(Clone, Debug, PartialEq)]
52pub struct QuerySelection {
53 pub model: &'static str,
54 pub scalar_fields: Vec<&'static str>,
55 pub relations: Vec<QueryRelationSelection>,
56 pub filter: Option<QueryFilter>,
57 pub order_by: Vec<QueryOrder>,
58 pub skip: Option<QueryPagination>,
59 pub limit: Option<QueryPagination>,
60}
61
62#[derive(Clone, Debug, PartialEq)]
63pub struct QueryRelationSelection {
64 pub field: &'static str,
65 pub selection: QuerySelection,
66}
67
68#[derive(Clone, Copy, Debug, Eq, PartialEq)]
69pub enum QueryOrderDirection {
70 Asc,
71 Desc,
72}
73
74#[derive(Clone, Debug, Eq, PartialEq)]
75pub enum QueryOrder {
76 Scalar {
77 field: &'static str,
78 direction: QueryOrderDirection,
79 },
80 Relation {
81 field: &'static str,
82 orders: Vec<QueryOrder>,
83 },
84}
85
86impl QueryOrder {
87 pub fn scalar(field: &'static str, direction: QueryOrderDirection) -> Self {
88 Self::Scalar { field, direction }
89 }
90
91 pub fn relation(field: &'static str, orders: Vec<QueryOrder>) -> Self {
92 Self::Relation { field, orders }
93 }
94}
95
96#[derive(Clone, Debug, Eq, PartialEq)]
97pub enum QueryPagination {
98 Value(i64),
99 Variable(&'static str),
100}
101
102impl QueryPagination {
103 pub fn value(value: i64) -> Self {
104 Self::Value(value)
105 }
106
107 pub fn variable(name: &'static str) -> Self {
108 Self::Variable(name)
109 }
110}
111
112#[derive(Clone, Debug, Default, PartialEq)]
113pub struct QueryVariables {
114 values: Vec<QueryVariableValue>,
115 value_indices: HashMap<String, usize>,
116}
117
118impl QueryVariables {
119 pub fn new() -> Self {
120 Self {
121 values: Vec::new(),
122 value_indices: HashMap::new(),
123 }
124 }
125
126 pub fn from_values(values: Vec<(impl Into<String>, QueryVariableValue)>) -> Self {
127 let mut query_variables = Self::new();
128
129 for (name, value) in values {
130 query_variables
131 .push(name, value)
132 .expect("query variable names must be unique");
133 }
134
135 query_variables
136 }
137
138 pub fn push(
139 &mut self,
140 name: impl Into<String>,
141 value: QueryVariableValue,
142 ) -> Result<usize, sqlx::Error> {
143 let name = name.into();
144
145 if self.value_indices.contains_key(&name) {
146 return Err(schema_error(format!("duplicate query variable `{name}`")));
147 }
148
149 let index = self.values.len();
150 self.values.push(value);
151 self.value_indices.insert(name, index);
152 Ok(index)
153 }
154
155 pub fn get(&self, name: &str) -> Option<&QueryVariableValue> {
156 self.value_indices
157 .get(name)
158 .and_then(|index| self.values.get(*index))
159 }
160
161 pub fn len(&self) -> usize {
162 self.values.len()
163 }
164
165 pub fn is_empty(&self) -> bool {
166 self.values.is_empty()
167 }
168}
169
170pub trait QueryVariableSet: Send + 'static {
171 fn into_query_variables(self) -> QueryVariables;
172}
173
174impl QueryVariableSet for QueryVariables {
175 fn into_query_variables(self) -> QueryVariables {
176 self
177 }
178}
179
180impl QueryVariableSet for () {
181 fn into_query_variables(self) -> QueryVariables {
182 QueryVariables::new()
183 }
184}
185
186pub trait StringValueType: Sized + Send + 'static {
187 fn from_db_string(value: String) -> Result<Self, sqlx::Error>;
188
189 fn into_db_string(self) -> String;
190}
191
192impl StringValueType for String {
193 fn from_db_string(value: String) -> Result<Self, sqlx::Error> {
194 Ok(value)
195 }
196
197 fn into_db_string(self) -> String {
198 self
199 }
200}
201
202pub trait QueryScalar: Send {
203 fn into_query_variable_value(self) -> QueryVariableValue;
204}
205
206pub trait QueryResultValue: Sized + Send + 'static {
207 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error>;
208
209 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error>;
210}
211
212#[derive(Clone, Debug, PartialEq)]
213pub enum QueryVariableValue {
214 Null,
215 Int(i64),
216 String(String),
217 Bool(bool),
218 Float(f64),
219 Decimal(Decimal),
220 Bytes(Vec<u8>),
221 DateTime(chrono::DateTime<chrono::Utc>),
222 Uuid(Uuid),
223 List(Vec<QueryVariableValue>),
224}
225
226pub trait QueryListScalar: Send {
227 fn into_list_query_variable_value(self) -> QueryVariableValue;
228}
229
230impl From<i64> for QueryVariableValue {
231 fn from(value: i64) -> Self {
232 Self::Int(value)
233 }
234}
235
236impl QueryListScalar for i64 {
237 fn into_list_query_variable_value(self) -> QueryVariableValue {
238 self.into()
239 }
240}
241
242impl From<String> for QueryVariableValue {
243 fn from(value: String) -> Self {
244 Self::String(value)
245 }
246}
247
248impl From<&str> for QueryVariableValue {
249 fn from(value: &str) -> Self {
250 Self::String(value.to_owned())
251 }
252}
253
254impl QueryListScalar for &str {
255 fn into_list_query_variable_value(self) -> QueryVariableValue {
256 self.into()
257 }
258}
259
260impl<T> QueryListScalar for T
261where
262 T: StringValueType,
263{
264 fn into_list_query_variable_value(self) -> QueryVariableValue {
265 QueryVariableValue::String(self.into_db_string())
266 }
267}
268
269impl From<bool> for QueryVariableValue {
270 fn from(value: bool) -> Self {
271 Self::Bool(value)
272 }
273}
274
275impl QueryListScalar for bool {
276 fn into_list_query_variable_value(self) -> QueryVariableValue {
277 self.into()
278 }
279}
280
281impl From<f64> for QueryVariableValue {
282 fn from(value: f64) -> Self {
283 Self::Float(value)
284 }
285}
286
287impl QueryListScalar for f64 {
288 fn into_list_query_variable_value(self) -> QueryVariableValue {
289 self.into()
290 }
291}
292
293impl From<Decimal> for QueryVariableValue {
294 fn from(value: Decimal) -> Self {
295 Self::Decimal(value)
296 }
297}
298
299impl QueryListScalar for Decimal {
300 fn into_list_query_variable_value(self) -> QueryVariableValue {
301 self.into()
302 }
303}
304
305impl From<Vec<u8>> for QueryVariableValue {
306 fn from(value: Vec<u8>) -> Self {
307 Self::Bytes(value)
308 }
309}
310
311impl QueryListScalar for Vec<u8> {
312 fn into_list_query_variable_value(self) -> QueryVariableValue {
313 self.into()
314 }
315}
316
317impl From<&[u8]> for QueryVariableValue {
318 fn from(value: &[u8]) -> Self {
319 Self::Bytes(value.to_vec())
320 }
321}
322
323impl From<chrono::DateTime<chrono::Utc>> for QueryVariableValue {
324 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
325 Self::DateTime(value)
326 }
327}
328
329impl QueryListScalar for chrono::DateTime<chrono::Utc> {
330 fn into_list_query_variable_value(self) -> QueryVariableValue {
331 self.into()
332 }
333}
334
335impl From<Uuid> for QueryVariableValue {
336 fn from(value: Uuid) -> Self {
337 Self::Uuid(value)
338 }
339}
340
341impl QueryListScalar for Uuid {
342 fn into_list_query_variable_value(self) -> QueryVariableValue {
343 self.into()
344 }
345}
346
347impl<T> From<Vec<T>> for QueryVariableValue
348where
349 T: QueryListScalar,
350{
351 fn from(values: Vec<T>) -> Self {
352 Self::List(
353 values
354 .into_iter()
355 .map(QueryListScalar::into_list_query_variable_value)
356 .collect(),
357 )
358 }
359}
360
361impl<T> From<Option<T>> for QueryVariableValue
362where
363 T: Into<QueryVariableValue>,
364{
365 fn from(value: Option<T>) -> Self {
366 match value {
367 Some(value) => value.into(),
368 None => Self::Null,
369 }
370 }
371}
372
373impl QueryScalar for i64 {
374 fn into_query_variable_value(self) -> QueryVariableValue {
375 self.into()
376 }
377}
378
379impl QueryScalar for &str {
380 fn into_query_variable_value(self) -> QueryVariableValue {
381 self.into()
382 }
383}
384
385impl QueryScalar for bool {
386 fn into_query_variable_value(self) -> QueryVariableValue {
387 self.into()
388 }
389}
390
391impl QueryScalar for f64 {
392 fn into_query_variable_value(self) -> QueryVariableValue {
393 self.into()
394 }
395}
396
397impl QueryScalar for Decimal {
398 fn into_query_variable_value(self) -> QueryVariableValue {
399 self.into()
400 }
401}
402
403impl QueryScalar for Vec<u8> {
404 fn into_query_variable_value(self) -> QueryVariableValue {
405 self.into()
406 }
407}
408
409impl QueryScalar for &[u8] {
410 fn into_query_variable_value(self) -> QueryVariableValue {
411 self.into()
412 }
413}
414
415impl QueryScalar for chrono::DateTime<chrono::Utc> {
416 fn into_query_variable_value(self) -> QueryVariableValue {
417 self.into()
418 }
419}
420
421impl QueryScalar for Uuid {
422 fn into_query_variable_value(self) -> QueryVariableValue {
423 self.into()
424 }
425}
426
427impl<T> QueryScalar for T
428where
429 T: StringValueType,
430{
431 fn into_query_variable_value(self) -> QueryVariableValue {
432 QueryVariableValue::String(self.into_db_string())
433 }
434}
435
436impl<T> QueryScalar for Vec<T>
437where
438 T: QueryListScalar,
439{
440 fn into_query_variable_value(self) -> QueryVariableValue {
441 self.into()
442 }
443}
444
445impl<T, const N: usize> QueryScalar for [T; N]
446where
447 T: QueryListScalar,
448{
449 fn into_query_variable_value(self) -> QueryVariableValue {
450 QueryVariableValue::List(
451 self.into_iter()
452 .map(QueryListScalar::into_list_query_variable_value)
453 .collect(),
454 )
455 }
456}
457
458impl<T> QueryScalar for Option<T>
459where
460 T: QueryScalar,
461{
462 fn into_query_variable_value(self) -> QueryVariableValue {
463 match self {
464 Some(value) => value.into_query_variable_value(),
465 None => QueryVariableValue::Null,
466 }
467 }
468}
469
470#[derive(Clone, Debug, PartialEq)]
471pub enum QueryFilterValue {
472 Variable(String),
473 Value(QueryVariableValue),
474}
475
476impl QueryFilterValue {
477 pub fn variable(name: impl Into<String>) -> Self {
478 Self::Variable(name.into())
479 }
480
481 pub fn value<T>(value: T) -> Self
482 where
483 T: QueryScalar,
484 {
485 Self::Value(value.into_query_variable_value())
486 }
487}
488
489impl<T> From<T> for QueryFilterValue
490where
491 T: QueryScalar,
492{
493 fn from(value: T) -> Self {
494 Self::Value(value.into_query_variable_value())
495 }
496}
497
498#[derive(Clone, Debug, PartialEq)]
499pub enum QueryFilterValues {
500 Variable(String),
501 Values(Vec<QueryFilterValue>),
502}
503
504impl QueryFilterValues {
505 pub fn variable(name: impl Into<String>) -> Self {
506 Self::Variable(name.into())
507 }
508
509 pub fn values<T>(values: impl IntoIterator<Item = T>) -> Self
510 where
511 T: QueryListScalar,
512 {
513 Self::Values(
514 values
515 .into_iter()
516 .map(|value| QueryFilterValue::Value(value.into_list_query_variable_value()))
517 .collect::<Vec<_>>(),
518 )
519 }
520}
521
522impl<T> From<Vec<T>> for QueryFilterValues
523where
524 T: QueryListScalar,
525{
526 fn from(values: Vec<T>) -> Self {
527 Self::values(values)
528 }
529}
530
531impl<T, const N: usize> From<[T; N]> for QueryFilterValues
532where
533 T: QueryListScalar,
534{
535 fn from(values: [T; N]) -> Self {
536 Self::values(values)
537 }
538}
539
540#[derive(Clone, Debug, PartialEq)]
541pub enum QueryFilter {
542 And(Vec<QueryFilter>),
543 Or(Vec<QueryFilter>),
544 Not(Box<QueryFilter>),
545 Eq {
546 field: &'static str,
547 value: QueryFilterValue,
548 },
549 Ne {
550 field: &'static str,
551 value: QueryFilterValue,
552 },
553 In {
554 field: &'static str,
555 values: QueryFilterValues,
556 },
557 Relation {
558 field: &'static str,
559 filter: Box<QueryFilter>,
560 },
561}
562
563impl QueryFilter {
564 pub fn eq(field: &'static str, value: impl Into<QueryFilterValue>) -> Self {
565 Self::Eq {
566 field,
567 value: value.into(),
568 }
569 }
570
571 pub fn ne(field: &'static str, value: impl Into<QueryFilterValue>) -> Self {
572 Self::Ne {
573 field,
574 value: value.into(),
575 }
576 }
577
578 pub fn r#in(field: &'static str, values: impl Into<QueryFilterValues>) -> Self {
579 Self::In {
580 field,
581 values: values.into(),
582 }
583 }
584
585 pub fn is_null(field: &'static str) -> Self {
586 Self::Eq {
587 field,
588 value: QueryFilterValue::Value(QueryVariableValue::Null),
589 }
590 }
591
592 pub fn is_not_null(field: &'static str) -> Self {
593 Self::Ne {
594 field,
595 value: QueryFilterValue::Value(QueryVariableValue::Null),
596 }
597 }
598
599 pub fn relation(field: &'static str, filter: QueryFilter) -> Self {
600 Self::Relation {
601 field,
602 filter: Box::new(filter),
603 }
604 }
605}
606
607pub trait QueryValue: Sized + Send + 'static {
608 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error>;
609}
610
611pub trait QueryModel: Sized + Send + 'static {
612 type Schema: SchemaAccess;
613 type Variables: QueryVariableSet;
614
615 fn model_name() -> &'static str;
616
617 fn selection() -> QuerySelection;
618
619 fn selection_with_variables(_variables: &QueryVariables) -> QuerySelection {
620 Self::selection()
621 }
622
623 fn from_row(row: &PgRow, prefix: &str) -> Result<Self, sqlx::Error>;
624}
625
626#[derive(Clone, Debug)]
627pub struct Query<S, T, V = ()> {
628 selection: Option<QuerySelection>,
629 variables: QueryVariables,
630 _marker: PhantomData<(S, T, V)>,
631}
632
633impl<S, T> Query<S, T, ()>
634where
635 T: QueryModel<Variables = ()>,
636{
637 pub fn new() -> Self {
638 Self {
639 selection: None,
640 variables: QueryVariables::new(),
641 _marker: PhantomData,
642 }
643 }
644
645 pub fn with_selection(selection: QuerySelection) -> Self {
646 Self {
647 selection: Some(selection),
648 variables: QueryVariables::new(),
649 _marker: PhantomData,
650 }
651 }
652}
653
654impl<S, T> Default for Query<S, T, ()>
655where
656 T: QueryModel<Variables = ()>,
657{
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663impl<S, T> Query<S, T, ()>
664where
665 T: QueryModel,
666{
667 pub fn new_with_variables(variables: T::Variables) -> Query<S, T, T::Variables> {
668 Query {
669 selection: None,
670 variables: variables.into_query_variables(),
671 _marker: PhantomData,
672 }
673 }
674
675 pub fn with_selection_and_variables(
676 selection: QuerySelection,
677 variables: T::Variables,
678 ) -> Query<S, T, T::Variables> {
679 Query {
680 selection: Some(selection),
681 ..Self::new_with_variables(variables)
682 }
683 }
684
685 pub fn with_variables(self, variables: T::Variables) -> Query<S, T, T::Variables> {
686 Query {
687 selection: self.selection,
688 variables: variables.into_query_variables(),
689 _marker: PhantomData,
690 }
691 }
692}
693
694impl<S, T, V> Query<S, T, V>
695where
696 S: SchemaAccess,
697 T: QueryModel<Schema = S, Variables = V>,
698 V: QueryVariableSet,
699{
700 fn selection(&self) -> QuerySelection {
701 self.selection
702 .clone()
703 .unwrap_or_else(|| T::selection_with_variables(&self.variables))
704 }
705
706 pub fn to_sql(&self) -> Result<String, sqlx::Error> {
707 let selection = self.selection();
708 let (sql, _) = build_query_sql(S::schema(), &selection, &self.variables)?;
709 Ok(sql)
710 }
711}
712
713impl<S, T, V> QuerySpec for Query<S, T, V>
714where
715 S: SchemaAccess,
716 T: QueryModel<Schema = S, Variables = V> + Sync,
717 V: QueryVariableSet + Sync,
718{
719 type Output = T;
720
721 fn fetch_many<'a>(
722 &'a self,
723 executor: &'a dyn PgExecutor,
724 ) -> BoxFuture<'a, Result<Vec<Self::Output>, sqlx::Error>> {
725 Box::pin(async move {
726 let selection = self.selection();
727 let (sql, bindings) = build_query_sql(S::schema(), &selection, &self.variables)?;
728 let rows = executor
729 .fetch_all(bind_query(sqlx::query(&sql), &bindings))
730 .await?;
731 let mut values = Vec::with_capacity(rows.len());
732 let root_prefix = selection.model;
733
734 for row in rows {
735 values.push(T::from_row(&row, root_prefix)?);
736 }
737
738 Ok(values)
739 })
740 }
741
742 fn fetch_optional<'a>(
743 &'a self,
744 executor: &'a dyn PgExecutor,
745 ) -> BoxFuture<'a, Result<Option<Self::Output>, sqlx::Error>> {
746 Box::pin(async move {
747 let mut selection = self.selection();
748 selection.limit = Some(QueryPagination::value(1));
749
750 let (sql, bindings) = build_query_sql(S::schema(), &selection, &self.variables)?;
751 let row = executor
752 .fetch_optional(bind_query(sqlx::query(&sql), &bindings))
753 .await?;
754 let root_prefix = selection.model;
755
756 row.map(|row| T::from_row(&row, root_prefix)).transpose()
757 })
758 }
759}
760
761pub fn query_model_is_null<T: QueryModel>(row: &PgRow, prefix: &str) -> Result<bool, sqlx::Error> {
762 selection_is_null(row, prefix, &T::selection())
763}
764
765fn selection_is_null(
766 row: &PgRow,
767 prefix: &str,
768 selection: &QuerySelection,
769) -> Result<bool, sqlx::Error> {
770 for field in &selection.scalar_fields {
771 let alias = alias_name(prefix, field);
772 if !row.try_get_raw(alias.as_str())?.is_null() {
773 return Ok(false);
774 }
775 }
776
777 for relation in &selection.relations {
778 let alias = alias_name(prefix, relation.field);
779 if !row.try_get_raw(alias.as_str())?.is_null() {
780 return Ok(false);
781 }
782 }
783
784 Ok(true)
785}
786
787pub fn alias_name(prefix: &str, field: &str) -> String {
788 format!("{prefix}__{field}")
789}
790
791pub fn json_array_field(value: &JsonValue, index: usize) -> Result<&JsonValue, sqlx::Error> {
792 value.get(index).ok_or_else(|| {
793 schema_error(format!(
794 "missing JSON array index `{index}` in query result"
795 ))
796 })
797}
798
799pub fn json_as_i64(value: &JsonValue) -> Result<i64, sqlx::Error> {
800 value
801 .as_i64()
802 .ok_or_else(|| schema_error("expected JSON integer in query result".to_owned()))
803}
804
805pub fn json_as_string(value: &JsonValue) -> Result<String, sqlx::Error> {
806 value
807 .as_str()
808 .map(ToOwned::to_owned)
809 .ok_or_else(|| schema_error("expected JSON string in query result".to_owned()))
810}
811
812pub fn json_as_bool(value: &JsonValue) -> Result<bool, sqlx::Error> {
813 value
814 .as_bool()
815 .ok_or_else(|| schema_error("expected JSON boolean in query result".to_owned()))
816}
817
818pub fn json_value<T>(value: &JsonValue) -> Result<T, sqlx::Error>
819where
820 T: QueryResultValue,
821{
822 T::from_json(value)
823}
824
825pub fn json_string_value<T>(value: &JsonValue) -> Result<T, sqlx::Error>
826where
827 T: StringValueType,
828{
829 T::from_db_string(json_as_string(value)?)
830}
831
832pub fn json_as_f64(value: &JsonValue) -> Result<f64, sqlx::Error> {
833 value
834 .as_f64()
835 .ok_or_else(|| schema_error("expected JSON float in query result".to_owned()))
836}
837
838pub fn json_as_bytes(value: &JsonValue) -> Result<Vec<u8>, sqlx::Error> {
839 match value {
840 JsonValue::String(value) => decode_hex_bytes(value),
841 JsonValue::Array(values) => values
842 .iter()
843 .map(|value| {
844 let byte = value.as_u64().ok_or_else(|| {
845 schema_error("expected JSON byte array in query result".to_owned())
846 })?;
847
848 u8::try_from(byte).map_err(|_| {
849 schema_error(format!(
850 "expected JSON byte array values in range 0..=255, got `{byte}`"
851 ))
852 })
853 })
854 .collect(),
855 _ => Err(schema_error(
856 "expected JSON byte string or byte array in query result".to_owned(),
857 )),
858 }
859}
860
861fn decode_hex_bytes(value: &str) -> Result<Vec<u8>, sqlx::Error> {
862 let value = value.strip_prefix("\\x").unwrap_or(value);
863
864 if !value.len().is_multiple_of(2) {
865 return Err(schema_error(format!(
866 "invalid bytes in query result `{value}`: hex string must have an even length"
867 )));
868 }
869
870 let mut bytes = Vec::with_capacity(value.len() / 2);
871 let mut index = 0;
872
873 while index < value.len() {
874 let chunk = &value[index..index + 2];
875 let byte = u8::from_str_radix(chunk, 16).map_err(|error| {
876 schema_error(format!("invalid bytes in query result `{value}`: {error}"))
877 })?;
878 bytes.push(byte);
879 index += 2;
880 }
881
882 Ok(bytes)
883}
884
885pub fn json_as_decimal(value: &JsonValue) -> Result<Decimal, sqlx::Error> {
886 match value {
887 JsonValue::String(value) => parse_decimal(value),
888 JsonValue::Number(value) => parse_decimal(&value.to_string()),
889 _ => Err(schema_error(
890 "expected JSON decimal string or number in query result".to_owned(),
891 )),
892 }
893}
894
895pub fn parse_decimal(value: &str) -> Result<Decimal, sqlx::Error> {
896 Decimal::from_str_exact(value)
897 .or_else(|_| Decimal::from_scientific(value))
898 .or_else(|_| {
899 let normalized = normalize_decimal_string(value);
900 Decimal::from_str_exact(&normalized).or_else(|_| Decimal::from_scientific(&normalized))
901 })
902 .map_err(|error| {
903 schema_error(format!(
904 "invalid decimal in query result `{value}`: {error}"
905 ))
906 })
907}
908
909fn normalize_decimal_string(value: &str) -> String {
910 if let Some((integer, fractional)) = value.split_once('.') {
911 let fractional = fractional.trim_end_matches('0');
912 if fractional.is_empty() {
913 integer.to_owned()
914 } else {
915 format!("{integer}.{fractional}")
916 }
917 } else {
918 value.to_owned()
919 }
920}
921
922pub fn row_as_decimal(row: &PgRow, alias: &str) -> Result<Decimal, sqlx::Error> {
923 row.try_get(alias)
924}
925
926pub fn json_as_datetime_utc(
927 value: &JsonValue,
928) -> Result<chrono::DateTime<chrono::Utc>, sqlx::Error> {
929 let value = value
930 .as_str()
931 .ok_or_else(|| schema_error("expected JSON datetime string in query result".to_owned()))?;
932
933 if let Ok(datetime) = chrono::DateTime::parse_from_rfc3339(value) {
934 return Ok(datetime.with_timezone(&chrono::Utc));
935 }
936
937 for format in ["%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%d %H:%M:%S%.f"] {
938 if let Ok(datetime) = chrono::NaiveDateTime::parse_from_str(value, format) {
939 return Ok(datetime.and_utc());
940 }
941 }
942
943 Err(schema_error(format!(
944 "invalid JSON datetime in query result: unsupported format `{value}`"
945 )))
946}
947
948pub fn row_as_datetime_utc(
949 row: &PgRow,
950 alias: &str,
951) -> Result<chrono::DateTime<chrono::Utc>, sqlx::Error> {
952 if let Ok(value) = row.try_get::<chrono::DateTime<chrono::Utc>, _>(alias) {
953 return Ok(value);
954 }
955
956 let value: chrono::NaiveDateTime = row.try_get(alias)?;
957 Ok(value.and_utc())
958}
959
960pub fn row_as_bytes(row: &PgRow, alias: &str) -> Result<Vec<u8>, sqlx::Error> {
961 row.try_get(alias)
962}
963
964pub fn row_value<T>(row: &PgRow, alias: &str) -> Result<T, sqlx::Error>
965where
966 T: QueryResultValue,
967{
968 T::from_row(row, alias)
969}
970
971pub fn row_string_value<T>(row: &PgRow, alias: &str) -> Result<T, sqlx::Error>
972where
973 T: StringValueType,
974{
975 T::from_db_string(row.try_get::<String, _>(alias)?)
976}
977
978impl QueryResultValue for i64 {
979 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
980 row.try_get(alias)
981 }
982
983 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
984 json_as_i64(value)
985 }
986}
987
988impl QueryResultValue for bool {
989 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
990 row.try_get(alias)
991 }
992
993 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
994 json_as_bool(value)
995 }
996}
997
998impl QueryResultValue for f64 {
999 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1000 row.try_get(alias)
1001 }
1002
1003 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1004 json_as_f64(value)
1005 }
1006}
1007
1008impl QueryResultValue for Decimal {
1009 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1010 row_as_decimal(row, alias)
1011 }
1012
1013 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1014 json_as_decimal(value)
1015 }
1016}
1017
1018impl QueryResultValue for chrono::DateTime<chrono::Utc> {
1019 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1020 row_as_datetime_utc(row, alias)
1021 }
1022
1023 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1024 json_as_datetime_utc(value)
1025 }
1026}
1027
1028impl QueryResultValue for Vec<u8> {
1029 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1030 row_as_bytes(row, alias)
1031 }
1032
1033 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1034 json_as_bytes(value)
1035 }
1036}
1037
1038impl QueryResultValue for Uuid {
1039 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1040 row.try_get(alias)
1041 }
1042
1043 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1044 let value = json_as_string(value)?;
1045 Uuid::parse_str(&value)
1046 .map_err(|error| schema_error(format!("invalid JSON UUID in query result: {error}")))
1047 }
1048}
1049
1050impl<T> QueryResultValue for T
1051where
1052 T: StringValueType,
1053{
1054 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1055 row_string_value(row, alias)
1056 }
1057
1058 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1059 json_string_value(value)
1060 }
1061}
1062
1063impl<T> QueryResultValue for Option<T>
1064where
1065 T: QueryResultValue,
1066{
1067 fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1068 if row.try_get_raw(alias)?.is_null() {
1069 Ok(None)
1070 } else {
1071 T::from_row(row, alias).map(Some)
1072 }
1073 }
1074
1075 fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1076 if value.is_null() {
1077 Ok(None)
1078 } else {
1079 T::from_json(value).map(Some)
1080 }
1081 }
1082}
1083
1084fn build_query_sql(
1085 schema: &Schema,
1086 selection: &QuerySelection,
1087 variables: &QueryVariables,
1088) -> Result<(String, Vec<QueryVariableValue>), sqlx::Error> {
1089 let root_model = resolve_schema_model(schema, selection.model, "query")?;
1090
1091 let mut builder = SqlBuilder {
1092 schema,
1093 variables,
1094 bindings: Vec::new(),
1095 joins: Vec::new(),
1096 next_alias: 1,
1097 };
1098
1099 let selects = builder.root_selects(root_model, selection, selection.model, "t0")?;
1100 let where_clause = selection
1101 .filter
1102 .as_ref()
1103 .map(|filter| builder.filter_sql(root_model, filter, "t0"))
1104 .transpose()?;
1105
1106 let mut order_joins = Vec::new();
1107 let order_by_clause =
1108 builder.order_by_sql(root_model, &selection.order_by, "t0", &mut order_joins)?;
1109 builder.joins.extend(order_joins);
1110
1111 let pagination_clause =
1112 builder.pagination_clause(selection.skip.as_ref(), selection.limit.as_ref())?;
1113
1114 let sql = format!(
1115 "SELECT {} FROM {} AS \"t0\"{}{}{}{}",
1116 selects.join(", "),
1117 quoted_ident(root_model.name()),
1118 if builder.joins.is_empty() {
1119 String::new()
1120 } else {
1121 format!(" {}", builder.joins.join(" "))
1122 },
1123 where_clause
1124 .map(|where_clause| format!(" WHERE {where_clause}"))
1125 .unwrap_or_default(),
1126 order_by_clause
1127 .map(|order_by_clause| format!(" ORDER BY {order_by_clause}"))
1128 .unwrap_or_default(),
1129 pagination_clause,
1130 );
1131
1132 Ok((sql, builder.bindings))
1133}
1134
1135fn model_names_match(left: &str, right: &str) -> bool {
1136 left.eq_ignore_ascii_case(right)
1137}
1138
1139fn infer_relation_fields<'a>(
1140 model: &'a Model,
1141 field: &'a crate::Field,
1142 target_model: &'a Model,
1143) -> Result<(Vec<&'a str>, Vec<&'a str>), sqlx::Error> {
1144 let reverse_relation = target_model
1145 .fields()
1146 .iter()
1147 .find(|candidate| {
1148 model_names_match(candidate.ty().name(), model.name()) && candidate.relation().is_some()
1149 })
1150 .ok_or_else(|| {
1151 schema_error(format!(
1152 "could not infer relation metadata for `{}.{}`",
1153 model.name(),
1154 field.name()
1155 ))
1156 })?;
1157
1158 let reverse_relation = reverse_relation
1159 .relation()
1160 .expect("reverse relation existence checked above");
1161
1162 Ok((
1163 reverse_relation
1164 .fields()
1165 .iter()
1166 .map(String::as_str)
1167 .collect(),
1168 reverse_relation
1169 .references()
1170 .iter()
1171 .map(String::as_str)
1172 .collect(),
1173 ))
1174}
1175
1176struct SqlBuilder<'a> {
1177 schema: &'a Schema,
1178 variables: &'a QueryVariables,
1179 bindings: Vec<QueryVariableValue>,
1180 joins: Vec<String>,
1181 next_alias: usize,
1182}
1183
1184struct RelationSql<'a> {
1185 many: bool,
1186 source_model_name: &'a str,
1187 relation_field_name: &'a str,
1188 target_model: &'a Model,
1189 selection: QuerySelection,
1190 parent_table_alias: String,
1191 nested_alias: String,
1192 nested_fields: Vec<&'a str>,
1193 parent_fields: Vec<&'a str>,
1194}
1195
1196impl<'a> SqlBuilder<'a> {
1197 fn root_selects(
1198 &mut self,
1199 model: &'a Model,
1200 selection: &QuerySelection,
1201 prefix: &str,
1202 table_alias: &str,
1203 ) -> Result<Vec<String>, sqlx::Error> {
1204 let mut selects = Vec::new();
1205
1206 for field_name in &selection.scalar_fields {
1207 let field = model.field_named(field_name).ok_or_else(|| {
1208 schema_error(format!(
1209 "unknown field `{}.{}` in query selection",
1210 model.name(),
1211 field_name
1212 ))
1213 })?;
1214
1215 let scalar = match field.ty() {
1216 FieldType::Scalar(scalar) => scalar.scalar(),
1217 FieldType::Relation { .. } => {
1218 return Err(schema_error(format!(
1219 "field `{}.{}` is not scalar and cannot appear in `select`",
1220 model.name(),
1221 field_name
1222 )));
1223 }
1224 };
1225
1226 selects.push(select_expr(
1227 table_alias,
1228 field.name(),
1229 scalar,
1230 &alias_name(prefix, field.name()),
1231 ));
1232 }
1233
1234 for relation in &selection.relations {
1235 selects.push(self.relation_select(model, relation, prefix, table_alias)?);
1236 }
1237
1238 Ok(selects)
1239 }
1240
1241 fn relation_select(
1242 &mut self,
1243 model: &'a Model,
1244 relation: &QueryRelationSelection,
1245 prefix: &str,
1246 table_alias: &str,
1247 ) -> Result<String, sqlx::Error> {
1248 let field = model.field_named(relation.field).ok_or_else(|| {
1249 schema_error(format!(
1250 "unknown relation `{}.{}` in query include",
1251 model.name(),
1252 relation.field
1253 ))
1254 })?;
1255
1256 if field.kind().is_scalar() {
1257 return Err(schema_error(format!(
1258 "field `{}.{}` is not a relation and cannot appear in `include`",
1259 model.name(),
1260 relation.field
1261 )));
1262 }
1263
1264 let target_model =
1265 resolve_schema_model(self.schema, field.ty().name(), "query").map_err(|_| {
1266 schema_error(format!(
1267 "relation `{}.{}` points at unknown model `{}`",
1268 model.name(),
1269 relation.field,
1270 field.ty().name()
1271 ))
1272 })?;
1273
1274 let (nested_fields, parent_fields) = self.relation_fields(model, field, target_model)?;
1275
1276 let join_alias = format!("t{}", self.next_alias);
1277 self.next_alias += 1;
1278
1279 let nested_alias = format!("t{}", self.next_alias);
1280 self.next_alias += 1;
1281
1282 let subquery = self.relation_subquery_sql(RelationSql {
1283 many: field.ty().is_many(),
1284 source_model_name: model.name(),
1285 relation_field_name: relation.field,
1286 target_model,
1287 selection: relation.selection.clone(),
1288 parent_table_alias: table_alias.to_owned(),
1289 nested_alias: nested_alias.clone(),
1290 nested_fields,
1291 parent_fields,
1292 })?;
1293
1294 self.joins.push(format!(
1295 "LEFT JOIN LATERAL ({subquery}) AS \"{join_alias}\" ON TRUE"
1296 ));
1297
1298 let alias = alias_name(prefix, relation.field);
1299 Ok(format!("\"{join_alias}\".\"data\" AS \"{alias}\""))
1300 }
1301
1302 fn relation_subquery_sql(&mut self, relation: RelationSql<'a>) -> Result<String, sqlx::Error> {
1303 let mut where_clauses = vec![relation_predicates(
1304 &relation.nested_alias,
1305 &relation.nested_fields,
1306 &relation.parent_table_alias,
1307 &relation.parent_fields,
1308 )];
1309 let mut joins = Vec::new();
1310 let row_expr = self.json_row_expr(
1311 relation.target_model,
1312 &relation.selection,
1313 &relation.nested_alias,
1314 &mut joins,
1315 )?;
1316 let order_by_clause = self.order_by_sql(
1317 relation.target_model,
1318 &relation.selection.order_by,
1319 &relation.nested_alias,
1320 &mut joins,
1321 )?;
1322 let joins_sql = if joins.is_empty() {
1323 String::new()
1324 } else {
1325 format!(" {}", joins.join(" "))
1326 };
1327
1328 if let Some(filter) = relation.selection.filter.as_ref() {
1329 where_clauses.push(self.filter_sql(
1330 relation.target_model,
1331 filter,
1332 &relation.nested_alias,
1333 )?);
1334 }
1335
1336 let where_clause = where_clauses.join(" AND ");
1337 let select_order_by_clause = order_by_clause
1338 .clone()
1339 .map(|order_by_clause| format!(" ORDER BY {order_by_clause}"))
1340 .unwrap_or_default();
1341
1342 if relation.many {
1343 if relation.selection.skip.is_some() || relation.selection.limit.is_some() {
1344 let pagination_clause = self.pagination_clause(
1345 relation.selection.skip.as_ref(),
1346 relation.selection.limit.as_ref(),
1347 )?;
1348 let aggregate_table_alias = "__vitrail_nested_rows";
1349 let aggregate_order_column = "__vitrail_nested_order";
1350 let select_order_by_clause = if select_order_by_clause.is_empty() {
1351 aggregate_order_by(relation.target_model, &relation.nested_alias)
1352 } else {
1353 select_order_by_clause
1354 };
1355
1356 Ok(format!(
1357 "SELECT COALESCE(json_agg(\"{aggregate_table_alias}\".\"data\" ORDER BY \"{aggregate_table_alias}\".\"{aggregate_order_column}\"), '[]'::json) AS \"data\" FROM (SELECT {row_expr} AS \"data\", row_number() OVER ({select_order_by_clause}) AS \"{aggregate_order_column}\" FROM {} AS \"{}\"{} WHERE {where_clause}{select_order_by_clause}{pagination_clause}) AS \"{aggregate_table_alias}\"",
1358 quoted_ident(relation.target_model.name()),
1359 relation.nested_alias,
1360 joins_sql,
1361 ))
1362 } else {
1363 let aggregate_order_by_clause =
1364 if let Some(order_by_clause) = order_by_clause.as_ref() {
1365 format!(" ORDER BY {order_by_clause}")
1366 } else {
1367 aggregate_order_by(relation.target_model, &relation.nested_alias)
1368 };
1369
1370 Ok(format!(
1371 "SELECT COALESCE(json_agg({row_expr}{aggregate_order_by_clause}), '[]'::json) AS \"data\" FROM {} AS \"{}\"{} WHERE {where_clause}",
1372 quoted_ident(relation.target_model.name()),
1373 relation.nested_alias,
1374 joins_sql,
1375 ))
1376 }
1377 } else {
1378 if relation.selection.skip.is_some() || relation.selection.limit.is_some() {
1379 return Err(schema_error(format!(
1380 "relation `{}.{}` is to-one and cannot use `skip` or `limit`",
1381 relation.source_model_name, relation.relation_field_name
1382 )));
1383 }
1384
1385 Ok(format!(
1386 "SELECT {row_expr} AS \"data\" FROM {} AS \"{}\"{} WHERE {where_clause}{select_order_by_clause} LIMIT 1",
1387 quoted_ident(relation.target_model.name()),
1388 relation.nested_alias,
1389 joins_sql,
1390 ))
1391 }
1392 }
1393
1394 fn filter_sql(
1395 &mut self,
1396 model: &'a Model,
1397 filter: &QueryFilter,
1398 table_alias: &str,
1399 ) -> Result<String, sqlx::Error> {
1400 compile_filter_sql(self, model, filter, table_alias)
1401 }
1402
1403 fn pagination_clause(
1404 &mut self,
1405 skip: Option<&QueryPagination>,
1406 limit: Option<&QueryPagination>,
1407 ) -> Result<String, sqlx::Error> {
1408 let mut clause = String::new();
1409
1410 if let Some(limit) = limit {
1411 let limit = self.pagination_placeholder(limit, "limit")?;
1412 clause.push_str(&format!(" LIMIT {limit}"));
1413 }
1414
1415 if let Some(skip) = skip {
1416 let skip = self.pagination_placeholder(skip, "skip")?;
1417 clause.push_str(&format!(" OFFSET {skip}"));
1418 }
1419
1420 Ok(clause)
1421 }
1422
1423 fn pagination_placeholder(
1424 &mut self,
1425 pagination: &QueryPagination,
1426 kind: &str,
1427 ) -> Result<String, sqlx::Error> {
1428 let value = match pagination {
1429 QueryPagination::Value(value) => QueryVariableValue::Int(*value),
1430 QueryPagination::Variable(name) => {
1431 let value = self.variables.get(name).ok_or_else(|| {
1432 schema_error(format!("missing query variable `{name}` for `{kind}`"))
1433 })?;
1434
1435 match value {
1436 QueryVariableValue::Int(value) => QueryVariableValue::Int(*value),
1437 other => {
1438 return Err(schema_error(format!(
1439 "query `{kind}` variable `{name}` must be an integer, got {other:?}"
1440 )));
1441 }
1442 }
1443 }
1444 };
1445
1446 let QueryVariableValue::Int(value) = value else {
1447 unreachable!("pagination values must be integers")
1448 };
1449
1450 if value < 0 {
1451 return Err(schema_error(format!(
1452 "query `{kind}` must be greater than or equal to 0"
1453 )));
1454 }
1455
1456 self.push_binding(QueryVariableValue::Int(value), ScalarType::Int)
1457 }
1458
1459 fn order_by_sql(
1460 &mut self,
1461 model: &'a Model,
1462 orders: &[QueryOrder],
1463 table_alias: &str,
1464 joins: &mut Vec<String>,
1465 ) -> Result<Option<String>, sqlx::Error> {
1466 if orders.is_empty() {
1467 return Ok(None);
1468 }
1469
1470 let mut items = Vec::new();
1471 let mut relation_join_aliases = HashMap::new();
1472
1473 for order in orders {
1474 self.push_order_sql(
1475 model,
1476 order,
1477 table_alias,
1478 joins,
1479 &mut relation_join_aliases,
1480 &mut items,
1481 )?;
1482 }
1483
1484 Ok(Some(items.join(", ")))
1485 }
1486
1487 fn push_order_sql(
1488 &mut self,
1489 model: &'a Model,
1490 order: &QueryOrder,
1491 table_alias: &str,
1492 joins: &mut Vec<String>,
1493 relation_join_aliases: &mut HashMap<String, String>,
1494 items: &mut Vec<String>,
1495 ) -> Result<(), sqlx::Error> {
1496 match order {
1497 QueryOrder::Scalar { field, direction } => {
1498 let field = model.field_named(field).ok_or_else(|| {
1499 schema_error(format!(
1500 "unknown field `{}.{}` in query ordering",
1501 model.name(),
1502 field
1503 ))
1504 })?;
1505
1506 let scalar = match field.ty() {
1507 FieldType::Scalar(scalar) => scalar.scalar(),
1508 FieldType::Relation { .. } => {
1509 return Err(schema_error(format!(
1510 "field `{}.{}` is not scalar and cannot terminate `order_by`",
1511 model.name(),
1512 field.name()
1513 )));
1514 }
1515 };
1516
1517 items.push(format!(
1518 "{} {}",
1519 column_expr(table_alias, field.name(), scalar),
1520 match direction {
1521 QueryOrderDirection::Asc => "ASC",
1522 QueryOrderDirection::Desc => "DESC",
1523 }
1524 ));
1525 Ok(())
1526 }
1527 QueryOrder::Relation { field, orders } => {
1528 let field = model.field_named(field).ok_or_else(|| {
1529 schema_error(format!(
1530 "unknown relation `{}.{}` in query ordering",
1531 model.name(),
1532 field
1533 ))
1534 })?;
1535
1536 if field.kind().is_scalar() {
1537 return Err(schema_error(format!(
1538 "field `{}.{}` is not a relation and cannot be traversed in `order_by`",
1539 model.name(),
1540 field.name()
1541 )));
1542 }
1543
1544 if field.ty().is_many() {
1545 return Err(schema_error(format!(
1546 "relation `{}.{}` is to-many and cannot be used in `order_by`",
1547 model.name(),
1548 field.name()
1549 )));
1550 }
1551
1552 if orders.is_empty() {
1553 return Err(schema_error(format!(
1554 "relation `{}.{}` must contain at least one nested `order_by` entry",
1555 model.name(),
1556 field.name()
1557 )));
1558 }
1559
1560 let target_model = resolve_schema_model(self.schema, field.ty().name(), "query")
1561 .map_err(|_| {
1562 schema_error(format!(
1563 "relation `{}.{}` points at unknown model `{}`",
1564 model.name(),
1565 field.name(),
1566 field.ty().name()
1567 ))
1568 })?;
1569
1570 let (nested_fields, parent_fields) =
1571 self.relation_fields(model, field, target_model)?;
1572 let predicate_template = relation_predicates(
1573 "__vitrail_order_join__",
1574 &nested_fields,
1575 table_alias,
1576 &parent_fields,
1577 );
1578 let join_key = format!("{}::{predicate_template}", target_model.name());
1579 let join_alias = if let Some(join_alias) = relation_join_aliases.get(&join_key) {
1580 join_alias.clone()
1581 } else {
1582 let join_alias = format!("t{}", self.next_alias);
1583 self.next_alias += 1;
1584 joins.push(format!(
1585 "LEFT JOIN {} AS \"{join_alias}\" ON {}",
1586 quoted_ident(target_model.name()),
1587 relation_predicates(
1588 &join_alias,
1589 &nested_fields,
1590 table_alias,
1591 &parent_fields,
1592 ),
1593 ));
1594 relation_join_aliases.insert(join_key, join_alias.clone());
1595 join_alias
1596 };
1597
1598 for nested_order in orders {
1599 self.push_order_sql(
1600 target_model,
1601 nested_order,
1602 &join_alias,
1603 joins,
1604 relation_join_aliases,
1605 items,
1606 )?;
1607 }
1608
1609 Ok(())
1610 }
1611 }
1612 }
1613
1614 fn push_binding(
1615 &mut self,
1616 value: QueryVariableValue,
1617 _scalar: ScalarType,
1618 ) -> Result<String, sqlx::Error> {
1619 self.bindings.push(value);
1620 Ok(format!("${}", self.bindings.len()))
1621 }
1622
1623 fn json_row_expr(
1624 &mut self,
1625 model: &'a Model,
1626 selection: &QuerySelection,
1627 table_alias: &str,
1628 joins: &mut Vec<String>,
1629 ) -> Result<String, sqlx::Error> {
1630 let mut items = Vec::new();
1631
1632 for field_name in &selection.scalar_fields {
1633 let field = model.field_named(field_name).ok_or_else(|| {
1634 schema_error(format!(
1635 "unknown field `{}.{}` in query selection",
1636 model.name(),
1637 field_name
1638 ))
1639 })?;
1640
1641 let scalar = match field.ty() {
1642 FieldType::Scalar(scalar) => scalar.scalar(),
1643 FieldType::Relation { .. } => {
1644 return Err(schema_error(format!(
1645 "field `{}.{}` is not scalar and cannot appear in `select`",
1646 model.name(),
1647 field_name
1648 )));
1649 }
1650 };
1651
1652 items.push(json_column_expr(table_alias, field.name(), scalar));
1653 }
1654
1655 for relation in &selection.relations {
1656 items.push(self.nested_relation_json_expr(model, relation, table_alias, joins)?);
1657 }
1658
1659 Ok(format!("json_build_array({})", items.join(", ")))
1660 }
1661
1662 fn nested_relation_json_expr(
1663 &mut self,
1664 model: &'a Model,
1665 relation: &QueryRelationSelection,
1666 table_alias: &str,
1667 joins: &mut Vec<String>,
1668 ) -> Result<String, sqlx::Error> {
1669 let field = model.field_named(relation.field).ok_or_else(|| {
1670 schema_error(format!(
1671 "unknown relation `{}.{}` in query include",
1672 model.name(),
1673 relation.field
1674 ))
1675 })?;
1676
1677 if field.kind().is_scalar() {
1678 return Err(schema_error(format!(
1679 "field `{}.{}` is not a relation and cannot appear in `include`",
1680 model.name(),
1681 relation.field
1682 )));
1683 }
1684
1685 let target_model =
1686 resolve_schema_model(self.schema, field.ty().name(), "query").map_err(|_| {
1687 schema_error(format!(
1688 "relation `{}.{}` points at unknown model `{}`",
1689 model.name(),
1690 relation.field,
1691 field.ty().name()
1692 ))
1693 })?;
1694
1695 let (nested_fields, parent_fields) = self.relation_fields(model, field, target_model)?;
1696
1697 let join_alias = format!("t{}", self.next_alias);
1698 self.next_alias += 1;
1699
1700 let nested_alias = format!("t{}", self.next_alias);
1701 self.next_alias += 1;
1702
1703 let subquery = self.relation_subquery_sql(RelationSql {
1704 many: field.ty().is_many(),
1705 source_model_name: model.name(),
1706 relation_field_name: relation.field,
1707 target_model,
1708 selection: relation.selection.clone(),
1709 parent_table_alias: table_alias.to_owned(),
1710 nested_alias: nested_alias.clone(),
1711 nested_fields,
1712 parent_fields,
1713 })?;
1714
1715 joins.push(format!(
1716 "LEFT JOIN LATERAL ({subquery}) AS \"{join_alias}\" ON TRUE"
1717 ));
1718
1719 Ok(format!("\"{join_alias}\".\"data\""))
1720 }
1721
1722 fn relation_fields(
1723 &self,
1724 model: &'a Model,
1725 field: &'a crate::Field,
1726 target_model: &'a Model,
1727 ) -> Result<(Vec<&'a str>, Vec<&'a str>), sqlx::Error> {
1728 match field.relation() {
1729 Some(relation_info) => Ok((
1730 relation_info
1731 .references()
1732 .iter()
1733 .map(String::as_str)
1734 .collect::<Vec<_>>(),
1735 relation_info
1736 .fields()
1737 .iter()
1738 .map(String::as_str)
1739 .collect::<Vec<_>>(),
1740 )),
1741 None => infer_relation_fields(model, field, target_model),
1742 }
1743 }
1744}
1745
1746impl<'a> FilterBuilder<'a> for SqlBuilder<'a> {
1747 fn schema(&self) -> &'a Schema {
1748 self.schema
1749 }
1750
1751 fn variables(&self) -> &'a QueryVariables {
1752 self.variables
1753 }
1754
1755 fn push_filter_binding(
1756 &mut self,
1757 value: QueryVariableValue,
1758 scalar: ScalarType,
1759 ) -> Result<String, sqlx::Error> {
1760 self.push_binding(value, scalar)
1761 }
1762
1763 fn next_filter_alias(&mut self) -> String {
1764 let alias = format!("t{}", self.next_alias);
1765 self.next_alias += 1;
1766 alias
1767 }
1768
1769 fn operation_name(&self) -> &'static str {
1770 "query"
1771 }
1772}
1773
1774fn aggregate_order_by(model: &Model, table_alias: &str) -> String {
1775 let primary_key_columns = model.primary_key_columns();
1776 let field_names = if primary_key_columns.is_empty() {
1777 model
1778 .field_named("id")
1779 .map(|field| vec![field.name()])
1780 .or_else(|| {
1781 model
1782 .fields()
1783 .iter()
1784 .find(|field| field.kind().is_scalar())
1785 .map(|field| vec![field.name()])
1786 })
1787 .unwrap_or_else(|| vec!["id"])
1788 } else {
1789 primary_key_columns
1790 };
1791
1792 format!(
1793 " ORDER BY {}",
1794 field_names
1795 .into_iter()
1796 .map(|field_name| format!("\"{table_alias}\".{}", quoted_ident(field_name)))
1797 .collect::<Vec<_>>()
1798 .join(", ")
1799 )
1800}
1801
1802fn relation_predicates(
1803 nested_alias: &str,
1804 nested_fields: &[&str],
1805 parent_alias: &str,
1806 parent_fields: &[&str],
1807) -> String {
1808 nested_fields
1809 .iter()
1810 .zip(parent_fields)
1811 .map(|(nested_field, parent_field)| {
1812 format!(
1813 "\"{nested_alias}\".{} = \"{parent_alias}\".{}",
1814 quoted_ident(nested_field),
1815 quoted_ident(parent_field),
1816 )
1817 })
1818 .collect::<Vec<_>>()
1819 .join(" AND ")
1820}
1821
1822pub(crate) fn quoted_ident(ident: &str) -> String {
1823 format!("\"{}\"", ident.replace('"', "\"\""))
1824}
1825
1826pub(crate) fn column_expr(table_alias: &str, field_name: &str, scalar: ScalarType) -> String {
1827 let column_sql = format!("\"{table_alias}\".{}", quoted_ident(field_name));
1828 match scalar {
1829 ScalarType::Int => format!("({column_sql})::bigint"),
1830 ScalarType::DateTime => format!("({column_sql} AT TIME ZONE 'UTC')"),
1831 _ => column_sql,
1832 }
1833}
1834
1835pub(crate) fn json_column_expr(table_alias: &str, field_name: &str, scalar: ScalarType) -> String {
1836 let column_sql = format!("\"{table_alias}\".{}", quoted_ident(field_name));
1837 match scalar {
1838 ScalarType::Int => format!("({column_sql})::bigint"),
1839 ScalarType::DateTime => format!("({column_sql} AT TIME ZONE 'UTC')"),
1840 ScalarType::Decimal => format!("({column_sql})::text"),
1841 ScalarType::Bytes => format!("encode({column_sql}, 'hex')"),
1842 _ => column_sql,
1843 }
1844}
1845
1846pub(crate) fn select_expr(
1847 table_alias: &str,
1848 field_name: &str,
1849 scalar: ScalarType,
1850 alias: &str,
1851) -> String {
1852 let expr = column_expr(table_alias, field_name, scalar);
1853 format!("{expr} AS \"{alias}\"")
1854}
1855
1856fn bind_query<'q>(
1857 mut query: sqlx::query::Query<'q, Postgres, PgArguments>,
1858 bindings: &'q [QueryVariableValue],
1859) -> sqlx::query::Query<'q, Postgres, PgArguments> {
1860 for binding in bindings {
1861 query = match binding {
1862 QueryVariableValue::Null => query.bind(Option::<i64>::None),
1863 QueryVariableValue::Int(value) => query.bind(*value),
1864 QueryVariableValue::String(value) => query.bind(value),
1865 QueryVariableValue::Bool(value) => query.bind(*value),
1866 QueryVariableValue::Float(value) => query.bind(*value),
1867 QueryVariableValue::Decimal(value) => query.bind(*value),
1868 QueryVariableValue::Bytes(value) => query.bind(value),
1869 QueryVariableValue::DateTime(value) => query.bind(*value),
1870 QueryVariableValue::Uuid(value) => query.bind(*value),
1871 QueryVariableValue::List(values) => {
1872 let first = values
1873 .first()
1874 .expect("list-valued query variables must not be empty when bound");
1875
1876 match first {
1877 QueryVariableValue::Null => {
1878 unreachable!("list-valued query variables must not contain null items")
1879 }
1880 QueryVariableValue::Int(_) => query.bind(
1881 values
1882 .iter()
1883 .map(|value| match value {
1884 QueryVariableValue::Int(value) => *value,
1885 _ => unreachable!("list-valued query variables must be homogenous"),
1886 })
1887 .collect::<Vec<_>>(),
1888 ),
1889 QueryVariableValue::String(_) => query.bind(
1890 values
1891 .iter()
1892 .map(|value| match value {
1893 QueryVariableValue::String(value) => value.clone(),
1894 _ => unreachable!("list-valued query variables must be homogenous"),
1895 })
1896 .collect::<Vec<_>>(),
1897 ),
1898 QueryVariableValue::Bool(_) => query.bind(
1899 values
1900 .iter()
1901 .map(|value| match value {
1902 QueryVariableValue::Bool(value) => *value,
1903 _ => unreachable!("list-valued query variables must be homogenous"),
1904 })
1905 .collect::<Vec<_>>(),
1906 ),
1907 QueryVariableValue::Float(_) => query.bind(
1908 values
1909 .iter()
1910 .map(|value| match value {
1911 QueryVariableValue::Float(value) => *value,
1912 _ => unreachable!("list-valued query variables must be homogenous"),
1913 })
1914 .collect::<Vec<_>>(),
1915 ),
1916 QueryVariableValue::Decimal(_) => query.bind(
1917 values
1918 .iter()
1919 .map(|value| match value {
1920 QueryVariableValue::Decimal(value) => *value,
1921 _ => unreachable!("list-valued query variables must be homogenous"),
1922 })
1923 .collect::<Vec<_>>(),
1924 ),
1925 QueryVariableValue::Bytes(_) => query.bind(
1926 values
1927 .iter()
1928 .map(|value| match value {
1929 QueryVariableValue::Bytes(value) => value.clone(),
1930 _ => unreachable!("list-valued query variables must be homogenous"),
1931 })
1932 .collect::<Vec<_>>(),
1933 ),
1934 QueryVariableValue::DateTime(_) => query.bind(
1935 values
1936 .iter()
1937 .map(|value| match value {
1938 QueryVariableValue::DateTime(value) => *value,
1939 _ => unreachable!("list-valued query variables must be homogenous"),
1940 })
1941 .collect::<Vec<_>>(),
1942 ),
1943 QueryVariableValue::Uuid(_) => query.bind(
1944 values
1945 .iter()
1946 .map(|value| match value {
1947 QueryVariableValue::Uuid(value) => *value,
1948 _ => unreachable!("list-valued query variables must be homogenous"),
1949 })
1950 .collect::<Vec<_>>(),
1951 ),
1952 QueryVariableValue::List(_) => {
1953 unreachable!("list-valued query variables must not contain nested lists")
1954 }
1955 }
1956 }
1957 };
1958 }
1959
1960 query
1961}
1962
1963pub fn schema_error(message: String) -> sqlx::Error {
1964 sqlx::Error::Protocol(message)
1965}