Skip to main content

vitrail_pg_core/
query.rs

1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use serde_json::Value as JsonValue;
5use sqlx::postgres::{PgArguments, PgRow};
6use sqlx::{Postgres, Row as _, ValueRef as _};
7
8pub use futures_util::future::BoxFuture;
9use rust_decimal::Decimal;
10use uuid::Uuid;
11
12use crate::PgExecutor;
13use crate::filter::{FilterBuilder, compile_filter_sql, schema_model as resolve_schema_model};
14use crate::schema::{FieldType, Model, ScalarType, Schema};
15
16/// Runtime contract implemented by executable query values.
17pub trait QuerySpec: Send + Sync {
18    type Output: Send + 'static;
19
20    #[doc(hidden)]
21    fn fetch_many<'a>(
22        &'a self,
23        executor: &'a dyn PgExecutor,
24    ) -> BoxFuture<'a, Result<Vec<Self::Output>, sqlx::Error>>;
25
26    #[doc(hidden)]
27    fn fetch_optional<'a>(
28        &'a self,
29        executor: &'a dyn PgExecutor,
30    ) -> BoxFuture<'a, Result<Option<Self::Output>, sqlx::Error>> {
31        Box::pin(async move { Ok(self.fetch_many(executor).await?.into_iter().next()) })
32    }
33
34    #[doc(hidden)]
35    fn fetch_first<'a>(
36        &'a self,
37        executor: &'a dyn PgExecutor,
38    ) -> BoxFuture<'a, Result<Self::Output, sqlx::Error>> {
39        Box::pin(async move {
40            self.fetch_optional(executor)
41                .await?
42                .ok_or(sqlx::Error::RowNotFound)
43        })
44    }
45}
46
47pub trait SchemaAccess: Send + Sync + 'static {
48    fn schema() -> &'static Schema;
49}
50
51#[derive(Clone, Debug, PartialEq)]
52pub struct QuerySelection {
53    pub model: &'static str,
54    pub scalar_fields: Vec<&'static str>,
55    pub relations: Vec<QueryRelationSelection>,
56    pub filter: Option<QueryFilter>,
57    pub order_by: Vec<QueryOrder>,
58    pub skip: Option<QueryPagination>,
59    pub limit: Option<QueryPagination>,
60}
61
62#[derive(Clone, Debug, PartialEq)]
63pub struct QueryRelationSelection {
64    pub field: &'static str,
65    pub selection: QuerySelection,
66}
67
68#[derive(Clone, Copy, Debug, Eq, PartialEq)]
69pub enum QueryOrderDirection {
70    Asc,
71    Desc,
72}
73
74#[derive(Clone, Debug, Eq, PartialEq)]
75pub enum QueryOrder {
76    Scalar {
77        field: &'static str,
78        direction: QueryOrderDirection,
79    },
80    Relation {
81        field: &'static str,
82        orders: Vec<QueryOrder>,
83    },
84}
85
86impl QueryOrder {
87    pub fn scalar(field: &'static str, direction: QueryOrderDirection) -> Self {
88        Self::Scalar { field, direction }
89    }
90
91    pub fn relation(field: &'static str, orders: Vec<QueryOrder>) -> Self {
92        Self::Relation { field, orders }
93    }
94}
95
96#[derive(Clone, Debug, Eq, PartialEq)]
97pub enum QueryPagination {
98    Value(i64),
99    Variable(&'static str),
100}
101
102impl QueryPagination {
103    pub fn value(value: i64) -> Self {
104        Self::Value(value)
105    }
106
107    pub fn variable(name: &'static str) -> Self {
108        Self::Variable(name)
109    }
110}
111
112#[derive(Clone, Debug, Default, PartialEq)]
113pub struct QueryVariables {
114    values: Vec<QueryVariableValue>,
115    value_indices: HashMap<String, usize>,
116}
117
118impl QueryVariables {
119    pub fn new() -> Self {
120        Self {
121            values: Vec::new(),
122            value_indices: HashMap::new(),
123        }
124    }
125
126    pub fn from_values(values: Vec<(impl Into<String>, QueryVariableValue)>) -> Self {
127        let mut query_variables = Self::new();
128
129        for (name, value) in values {
130            query_variables
131                .push(name, value)
132                .expect("query variable names must be unique");
133        }
134
135        query_variables
136    }
137
138    pub fn push(
139        &mut self,
140        name: impl Into<String>,
141        value: QueryVariableValue,
142    ) -> Result<usize, sqlx::Error> {
143        let name = name.into();
144
145        if self.value_indices.contains_key(&name) {
146            return Err(schema_error(format!("duplicate query variable `{name}`")));
147        }
148
149        let index = self.values.len();
150        self.values.push(value);
151        self.value_indices.insert(name, index);
152        Ok(index)
153    }
154
155    pub fn get(&self, name: &str) -> Option<&QueryVariableValue> {
156        self.value_indices
157            .get(name)
158            .and_then(|index| self.values.get(*index))
159    }
160
161    pub fn len(&self) -> usize {
162        self.values.len()
163    }
164
165    pub fn is_empty(&self) -> bool {
166        self.values.is_empty()
167    }
168}
169
170pub trait QueryVariableSet: Send + 'static {
171    fn into_query_variables(self) -> QueryVariables;
172}
173
174impl QueryVariableSet for QueryVariables {
175    fn into_query_variables(self) -> QueryVariables {
176        self
177    }
178}
179
180impl QueryVariableSet for () {
181    fn into_query_variables(self) -> QueryVariables {
182        QueryVariables::new()
183    }
184}
185
186pub trait StringValueType: Sized + Send + 'static {
187    fn from_db_string(value: String) -> Result<Self, sqlx::Error>;
188
189    fn into_db_string(self) -> String;
190}
191
192impl StringValueType for String {
193    fn from_db_string(value: String) -> Result<Self, sqlx::Error> {
194        Ok(value)
195    }
196
197    fn into_db_string(self) -> String {
198        self
199    }
200}
201
202pub trait QueryScalar: Send {
203    fn into_query_variable_value(self) -> QueryVariableValue;
204}
205
206pub trait QueryResultValue: Sized + Send + 'static {
207    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error>;
208
209    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error>;
210}
211
212#[derive(Clone, Debug, PartialEq)]
213pub enum QueryVariableValue {
214    Null,
215    Int(i64),
216    String(String),
217    Bool(bool),
218    Float(f64),
219    Decimal(Decimal),
220    Bytes(Vec<u8>),
221    DateTime(chrono::DateTime<chrono::Utc>),
222    Uuid(Uuid),
223    List(Vec<QueryVariableValue>),
224}
225
226pub trait QueryListScalar: Send {
227    fn into_list_query_variable_value(self) -> QueryVariableValue;
228}
229
230impl From<i64> for QueryVariableValue {
231    fn from(value: i64) -> Self {
232        Self::Int(value)
233    }
234}
235
236impl QueryListScalar for i64 {
237    fn into_list_query_variable_value(self) -> QueryVariableValue {
238        self.into()
239    }
240}
241
242impl From<String> for QueryVariableValue {
243    fn from(value: String) -> Self {
244        Self::String(value)
245    }
246}
247
248impl From<&str> for QueryVariableValue {
249    fn from(value: &str) -> Self {
250        Self::String(value.to_owned())
251    }
252}
253
254impl QueryListScalar for &str {
255    fn into_list_query_variable_value(self) -> QueryVariableValue {
256        self.into()
257    }
258}
259
260impl<T> QueryListScalar for T
261where
262    T: StringValueType,
263{
264    fn into_list_query_variable_value(self) -> QueryVariableValue {
265        QueryVariableValue::String(self.into_db_string())
266    }
267}
268
269impl From<bool> for QueryVariableValue {
270    fn from(value: bool) -> Self {
271        Self::Bool(value)
272    }
273}
274
275impl QueryListScalar for bool {
276    fn into_list_query_variable_value(self) -> QueryVariableValue {
277        self.into()
278    }
279}
280
281impl From<f64> for QueryVariableValue {
282    fn from(value: f64) -> Self {
283        Self::Float(value)
284    }
285}
286
287impl QueryListScalar for f64 {
288    fn into_list_query_variable_value(self) -> QueryVariableValue {
289        self.into()
290    }
291}
292
293impl From<Decimal> for QueryVariableValue {
294    fn from(value: Decimal) -> Self {
295        Self::Decimal(value)
296    }
297}
298
299impl QueryListScalar for Decimal {
300    fn into_list_query_variable_value(self) -> QueryVariableValue {
301        self.into()
302    }
303}
304
305impl From<Vec<u8>> for QueryVariableValue {
306    fn from(value: Vec<u8>) -> Self {
307        Self::Bytes(value)
308    }
309}
310
311impl QueryListScalar for Vec<u8> {
312    fn into_list_query_variable_value(self) -> QueryVariableValue {
313        self.into()
314    }
315}
316
317impl From<&[u8]> for QueryVariableValue {
318    fn from(value: &[u8]) -> Self {
319        Self::Bytes(value.to_vec())
320    }
321}
322
323impl From<chrono::DateTime<chrono::Utc>> for QueryVariableValue {
324    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
325        Self::DateTime(value)
326    }
327}
328
329impl QueryListScalar for chrono::DateTime<chrono::Utc> {
330    fn into_list_query_variable_value(self) -> QueryVariableValue {
331        self.into()
332    }
333}
334
335impl From<Uuid> for QueryVariableValue {
336    fn from(value: Uuid) -> Self {
337        Self::Uuid(value)
338    }
339}
340
341impl QueryListScalar for Uuid {
342    fn into_list_query_variable_value(self) -> QueryVariableValue {
343        self.into()
344    }
345}
346
347impl<T> From<Vec<T>> for QueryVariableValue
348where
349    T: QueryListScalar,
350{
351    fn from(values: Vec<T>) -> Self {
352        Self::List(
353            values
354                .into_iter()
355                .map(QueryListScalar::into_list_query_variable_value)
356                .collect(),
357        )
358    }
359}
360
361impl<T> From<Option<T>> for QueryVariableValue
362where
363    T: Into<QueryVariableValue>,
364{
365    fn from(value: Option<T>) -> Self {
366        match value {
367            Some(value) => value.into(),
368            None => Self::Null,
369        }
370    }
371}
372
373impl QueryScalar for i64 {
374    fn into_query_variable_value(self) -> QueryVariableValue {
375        self.into()
376    }
377}
378
379impl QueryScalar for &str {
380    fn into_query_variable_value(self) -> QueryVariableValue {
381        self.into()
382    }
383}
384
385impl QueryScalar for bool {
386    fn into_query_variable_value(self) -> QueryVariableValue {
387        self.into()
388    }
389}
390
391impl QueryScalar for f64 {
392    fn into_query_variable_value(self) -> QueryVariableValue {
393        self.into()
394    }
395}
396
397impl QueryScalar for Decimal {
398    fn into_query_variable_value(self) -> QueryVariableValue {
399        self.into()
400    }
401}
402
403impl QueryScalar for Vec<u8> {
404    fn into_query_variable_value(self) -> QueryVariableValue {
405        self.into()
406    }
407}
408
409impl QueryScalar for &[u8] {
410    fn into_query_variable_value(self) -> QueryVariableValue {
411        self.into()
412    }
413}
414
415impl QueryScalar for chrono::DateTime<chrono::Utc> {
416    fn into_query_variable_value(self) -> QueryVariableValue {
417        self.into()
418    }
419}
420
421impl QueryScalar for Uuid {
422    fn into_query_variable_value(self) -> QueryVariableValue {
423        self.into()
424    }
425}
426
427impl<T> QueryScalar for T
428where
429    T: StringValueType,
430{
431    fn into_query_variable_value(self) -> QueryVariableValue {
432        QueryVariableValue::String(self.into_db_string())
433    }
434}
435
436impl<T> QueryScalar for Vec<T>
437where
438    T: QueryListScalar,
439{
440    fn into_query_variable_value(self) -> QueryVariableValue {
441        self.into()
442    }
443}
444
445impl<T, const N: usize> QueryScalar for [T; N]
446where
447    T: QueryListScalar,
448{
449    fn into_query_variable_value(self) -> QueryVariableValue {
450        QueryVariableValue::List(
451            self.into_iter()
452                .map(QueryListScalar::into_list_query_variable_value)
453                .collect(),
454        )
455    }
456}
457
458impl<T> QueryScalar for Option<T>
459where
460    T: QueryScalar,
461{
462    fn into_query_variable_value(self) -> QueryVariableValue {
463        match self {
464            Some(value) => value.into_query_variable_value(),
465            None => QueryVariableValue::Null,
466        }
467    }
468}
469
470#[derive(Clone, Debug, PartialEq)]
471pub enum QueryFilterValue {
472    Variable(String),
473    Value(QueryVariableValue),
474}
475
476impl QueryFilterValue {
477    pub fn variable(name: impl Into<String>) -> Self {
478        Self::Variable(name.into())
479    }
480
481    pub fn value<T>(value: T) -> Self
482    where
483        T: QueryScalar,
484    {
485        Self::Value(value.into_query_variable_value())
486    }
487}
488
489impl<T> From<T> for QueryFilterValue
490where
491    T: QueryScalar,
492{
493    fn from(value: T) -> Self {
494        Self::Value(value.into_query_variable_value())
495    }
496}
497
498#[derive(Clone, Debug, PartialEq)]
499pub enum QueryFilterValues {
500    Variable(String),
501    Values(Vec<QueryFilterValue>),
502}
503
504impl QueryFilterValues {
505    pub fn variable(name: impl Into<String>) -> Self {
506        Self::Variable(name.into())
507    }
508
509    pub fn values<T>(values: impl IntoIterator<Item = T>) -> Self
510    where
511        T: QueryListScalar,
512    {
513        Self::Values(
514            values
515                .into_iter()
516                .map(|value| QueryFilterValue::Value(value.into_list_query_variable_value()))
517                .collect::<Vec<_>>(),
518        )
519    }
520}
521
522impl<T> From<Vec<T>> for QueryFilterValues
523where
524    T: QueryListScalar,
525{
526    fn from(values: Vec<T>) -> Self {
527        Self::values(values)
528    }
529}
530
531impl<T, const N: usize> From<[T; N]> for QueryFilterValues
532where
533    T: QueryListScalar,
534{
535    fn from(values: [T; N]) -> Self {
536        Self::values(values)
537    }
538}
539
540#[derive(Clone, Debug, PartialEq)]
541pub enum QueryFilter {
542    And(Vec<QueryFilter>),
543    Or(Vec<QueryFilter>),
544    Not(Box<QueryFilter>),
545    Eq {
546        field: &'static str,
547        value: QueryFilterValue,
548    },
549    Ne {
550        field: &'static str,
551        value: QueryFilterValue,
552    },
553    In {
554        field: &'static str,
555        values: QueryFilterValues,
556    },
557    Relation {
558        field: &'static str,
559        filter: Box<QueryFilter>,
560    },
561}
562
563impl QueryFilter {
564    pub fn eq(field: &'static str, value: impl Into<QueryFilterValue>) -> Self {
565        Self::Eq {
566            field,
567            value: value.into(),
568        }
569    }
570
571    pub fn ne(field: &'static str, value: impl Into<QueryFilterValue>) -> Self {
572        Self::Ne {
573            field,
574            value: value.into(),
575        }
576    }
577
578    pub fn r#in(field: &'static str, values: impl Into<QueryFilterValues>) -> Self {
579        Self::In {
580            field,
581            values: values.into(),
582        }
583    }
584
585    pub fn is_null(field: &'static str) -> Self {
586        Self::Eq {
587            field,
588            value: QueryFilterValue::Value(QueryVariableValue::Null),
589        }
590    }
591
592    pub fn is_not_null(field: &'static str) -> Self {
593        Self::Ne {
594            field,
595            value: QueryFilterValue::Value(QueryVariableValue::Null),
596        }
597    }
598
599    pub fn relation(field: &'static str, filter: QueryFilter) -> Self {
600        Self::Relation {
601            field,
602            filter: Box::new(filter),
603        }
604    }
605}
606
607pub trait QueryValue: Sized + Send + 'static {
608    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error>;
609}
610
611pub trait QueryModel: Sized + Send + 'static {
612    type Schema: SchemaAccess;
613    type Variables: QueryVariableSet;
614
615    fn model_name() -> &'static str;
616
617    fn selection() -> QuerySelection;
618
619    fn selection_with_variables(_variables: &QueryVariables) -> QuerySelection {
620        Self::selection()
621    }
622
623    fn from_row(row: &PgRow, prefix: &str) -> Result<Self, sqlx::Error>;
624}
625
626#[derive(Clone, Debug)]
627pub struct Query<S, T, V = ()> {
628    selection: Option<QuerySelection>,
629    variables: QueryVariables,
630    _marker: PhantomData<(S, T, V)>,
631}
632
633impl<S, T> Query<S, T, ()>
634where
635    T: QueryModel<Variables = ()>,
636{
637    pub fn new() -> Self {
638        Self {
639            selection: None,
640            variables: QueryVariables::new(),
641            _marker: PhantomData,
642        }
643    }
644
645    pub fn with_selection(selection: QuerySelection) -> Self {
646        Self {
647            selection: Some(selection),
648            variables: QueryVariables::new(),
649            _marker: PhantomData,
650        }
651    }
652}
653
654impl<S, T> Default for Query<S, T, ()>
655where
656    T: QueryModel<Variables = ()>,
657{
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663impl<S, T> Query<S, T, ()>
664where
665    T: QueryModel,
666{
667    pub fn new_with_variables(variables: T::Variables) -> Query<S, T, T::Variables> {
668        Query {
669            selection: None,
670            variables: variables.into_query_variables(),
671            _marker: PhantomData,
672        }
673    }
674
675    pub fn with_selection_and_variables(
676        selection: QuerySelection,
677        variables: T::Variables,
678    ) -> Query<S, T, T::Variables> {
679        Query {
680            selection: Some(selection),
681            ..Self::new_with_variables(variables)
682        }
683    }
684
685    pub fn with_variables(self, variables: T::Variables) -> Query<S, T, T::Variables> {
686        Query {
687            selection: self.selection,
688            variables: variables.into_query_variables(),
689            _marker: PhantomData,
690        }
691    }
692}
693
694impl<S, T, V> Query<S, T, V>
695where
696    S: SchemaAccess,
697    T: QueryModel<Schema = S, Variables = V>,
698    V: QueryVariableSet,
699{
700    fn selection(&self) -> QuerySelection {
701        self.selection
702            .clone()
703            .unwrap_or_else(|| T::selection_with_variables(&self.variables))
704    }
705
706    pub fn to_sql(&self) -> Result<String, sqlx::Error> {
707        let selection = self.selection();
708        let (sql, _) = build_query_sql(S::schema(), &selection, &self.variables)?;
709        Ok(sql)
710    }
711}
712
713impl<S, T, V> QuerySpec for Query<S, T, V>
714where
715    S: SchemaAccess,
716    T: QueryModel<Schema = S, Variables = V> + Sync,
717    V: QueryVariableSet + Sync,
718{
719    type Output = T;
720
721    fn fetch_many<'a>(
722        &'a self,
723        executor: &'a dyn PgExecutor,
724    ) -> BoxFuture<'a, Result<Vec<Self::Output>, sqlx::Error>> {
725        Box::pin(async move {
726            let selection = self.selection();
727            let (sql, bindings) = build_query_sql(S::schema(), &selection, &self.variables)?;
728            let rows = executor
729                .fetch_all(bind_query(sqlx::query(&sql), &bindings))
730                .await?;
731            let mut values = Vec::with_capacity(rows.len());
732            let root_prefix = selection.model;
733
734            for row in rows {
735                values.push(T::from_row(&row, root_prefix)?);
736            }
737
738            Ok(values)
739        })
740    }
741
742    fn fetch_optional<'a>(
743        &'a self,
744        executor: &'a dyn PgExecutor,
745    ) -> BoxFuture<'a, Result<Option<Self::Output>, sqlx::Error>> {
746        Box::pin(async move {
747            let mut selection = self.selection();
748            selection.limit = Some(QueryPagination::value(1));
749
750            let (sql, bindings) = build_query_sql(S::schema(), &selection, &self.variables)?;
751            let row = executor
752                .fetch_optional(bind_query(sqlx::query(&sql), &bindings))
753                .await?;
754            let root_prefix = selection.model;
755
756            row.map(|row| T::from_row(&row, root_prefix)).transpose()
757        })
758    }
759}
760
761pub fn query_model_is_null<T: QueryModel>(row: &PgRow, prefix: &str) -> Result<bool, sqlx::Error> {
762    selection_is_null(row, prefix, &T::selection())
763}
764
765fn selection_is_null(
766    row: &PgRow,
767    prefix: &str,
768    selection: &QuerySelection,
769) -> Result<bool, sqlx::Error> {
770    for field in &selection.scalar_fields {
771        let alias = alias_name(prefix, field);
772        if !row.try_get_raw(alias.as_str())?.is_null() {
773            return Ok(false);
774        }
775    }
776
777    for relation in &selection.relations {
778        let alias = alias_name(prefix, relation.field);
779        if !row.try_get_raw(alias.as_str())?.is_null() {
780            return Ok(false);
781        }
782    }
783
784    Ok(true)
785}
786
787pub fn alias_name(prefix: &str, field: &str) -> String {
788    format!("{prefix}__{field}")
789}
790
791pub fn json_array_field(value: &JsonValue, index: usize) -> Result<&JsonValue, sqlx::Error> {
792    value.get(index).ok_or_else(|| {
793        schema_error(format!(
794            "missing JSON array index `{index}` in query result"
795        ))
796    })
797}
798
799pub fn json_as_i64(value: &JsonValue) -> Result<i64, sqlx::Error> {
800    value
801        .as_i64()
802        .ok_or_else(|| schema_error("expected JSON integer in query result".to_owned()))
803}
804
805pub fn json_as_string(value: &JsonValue) -> Result<String, sqlx::Error> {
806    value
807        .as_str()
808        .map(ToOwned::to_owned)
809        .ok_or_else(|| schema_error("expected JSON string in query result".to_owned()))
810}
811
812pub fn json_as_bool(value: &JsonValue) -> Result<bool, sqlx::Error> {
813    value
814        .as_bool()
815        .ok_or_else(|| schema_error("expected JSON boolean in query result".to_owned()))
816}
817
818pub fn json_value<T>(value: &JsonValue) -> Result<T, sqlx::Error>
819where
820    T: QueryResultValue,
821{
822    T::from_json(value)
823}
824
825pub fn json_string_value<T>(value: &JsonValue) -> Result<T, sqlx::Error>
826where
827    T: StringValueType,
828{
829    T::from_db_string(json_as_string(value)?)
830}
831
832pub fn json_as_f64(value: &JsonValue) -> Result<f64, sqlx::Error> {
833    value
834        .as_f64()
835        .ok_or_else(|| schema_error("expected JSON float in query result".to_owned()))
836}
837
838pub fn json_as_bytes(value: &JsonValue) -> Result<Vec<u8>, sqlx::Error> {
839    match value {
840        JsonValue::String(value) => decode_hex_bytes(value),
841        JsonValue::Array(values) => values
842            .iter()
843            .map(|value| {
844                let byte = value.as_u64().ok_or_else(|| {
845                    schema_error("expected JSON byte array in query result".to_owned())
846                })?;
847
848                u8::try_from(byte).map_err(|_| {
849                    schema_error(format!(
850                        "expected JSON byte array values in range 0..=255, got `{byte}`"
851                    ))
852                })
853            })
854            .collect(),
855        _ => Err(schema_error(
856            "expected JSON byte string or byte array in query result".to_owned(),
857        )),
858    }
859}
860
861fn decode_hex_bytes(value: &str) -> Result<Vec<u8>, sqlx::Error> {
862    let value = value.strip_prefix("\\x").unwrap_or(value);
863
864    if !value.len().is_multiple_of(2) {
865        return Err(schema_error(format!(
866            "invalid bytes in query result `{value}`: hex string must have an even length"
867        )));
868    }
869
870    let mut bytes = Vec::with_capacity(value.len() / 2);
871    let mut index = 0;
872
873    while index < value.len() {
874        let chunk = &value[index..index + 2];
875        let byte = u8::from_str_radix(chunk, 16).map_err(|error| {
876            schema_error(format!("invalid bytes in query result `{value}`: {error}"))
877        })?;
878        bytes.push(byte);
879        index += 2;
880    }
881
882    Ok(bytes)
883}
884
885pub fn json_as_decimal(value: &JsonValue) -> Result<Decimal, sqlx::Error> {
886    match value {
887        JsonValue::String(value) => parse_decimal(value),
888        JsonValue::Number(value) => parse_decimal(&value.to_string()),
889        _ => Err(schema_error(
890            "expected JSON decimal string or number in query result".to_owned(),
891        )),
892    }
893}
894
895pub fn parse_decimal(value: &str) -> Result<Decimal, sqlx::Error> {
896    Decimal::from_str_exact(value)
897        .or_else(|_| Decimal::from_scientific(value))
898        .or_else(|_| {
899            let normalized = normalize_decimal_string(value);
900            Decimal::from_str_exact(&normalized).or_else(|_| Decimal::from_scientific(&normalized))
901        })
902        .map_err(|error| {
903            schema_error(format!(
904                "invalid decimal in query result `{value}`: {error}"
905            ))
906        })
907}
908
909fn normalize_decimal_string(value: &str) -> String {
910    if let Some((integer, fractional)) = value.split_once('.') {
911        let fractional = fractional.trim_end_matches('0');
912        if fractional.is_empty() {
913            integer.to_owned()
914        } else {
915            format!("{integer}.{fractional}")
916        }
917    } else {
918        value.to_owned()
919    }
920}
921
922pub fn row_as_decimal(row: &PgRow, alias: &str) -> Result<Decimal, sqlx::Error> {
923    row.try_get(alias)
924}
925
926pub fn json_as_datetime_utc(
927    value: &JsonValue,
928) -> Result<chrono::DateTime<chrono::Utc>, sqlx::Error> {
929    let value = value
930        .as_str()
931        .ok_or_else(|| schema_error("expected JSON datetime string in query result".to_owned()))?;
932
933    if let Ok(datetime) = chrono::DateTime::parse_from_rfc3339(value) {
934        return Ok(datetime.with_timezone(&chrono::Utc));
935    }
936
937    for format in ["%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%d %H:%M:%S%.f"] {
938        if let Ok(datetime) = chrono::NaiveDateTime::parse_from_str(value, format) {
939            return Ok(datetime.and_utc());
940        }
941    }
942
943    Err(schema_error(format!(
944        "invalid JSON datetime in query result: unsupported format `{value}`"
945    )))
946}
947
948pub fn row_as_datetime_utc(
949    row: &PgRow,
950    alias: &str,
951) -> Result<chrono::DateTime<chrono::Utc>, sqlx::Error> {
952    if let Ok(value) = row.try_get::<chrono::DateTime<chrono::Utc>, _>(alias) {
953        return Ok(value);
954    }
955
956    let value: chrono::NaiveDateTime = row.try_get(alias)?;
957    Ok(value.and_utc())
958}
959
960pub fn row_as_bytes(row: &PgRow, alias: &str) -> Result<Vec<u8>, sqlx::Error> {
961    row.try_get(alias)
962}
963
964pub fn row_value<T>(row: &PgRow, alias: &str) -> Result<T, sqlx::Error>
965where
966    T: QueryResultValue,
967{
968    T::from_row(row, alias)
969}
970
971pub fn row_string_value<T>(row: &PgRow, alias: &str) -> Result<T, sqlx::Error>
972where
973    T: StringValueType,
974{
975    T::from_db_string(row.try_get::<String, _>(alias)?)
976}
977
978impl QueryResultValue for i64 {
979    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
980        row.try_get(alias)
981    }
982
983    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
984        json_as_i64(value)
985    }
986}
987
988impl QueryResultValue for bool {
989    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
990        row.try_get(alias)
991    }
992
993    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
994        json_as_bool(value)
995    }
996}
997
998impl QueryResultValue for f64 {
999    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1000        row.try_get(alias)
1001    }
1002
1003    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1004        json_as_f64(value)
1005    }
1006}
1007
1008impl QueryResultValue for Decimal {
1009    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1010        row_as_decimal(row, alias)
1011    }
1012
1013    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1014        json_as_decimal(value)
1015    }
1016}
1017
1018impl QueryResultValue for chrono::DateTime<chrono::Utc> {
1019    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1020        row_as_datetime_utc(row, alias)
1021    }
1022
1023    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1024        json_as_datetime_utc(value)
1025    }
1026}
1027
1028impl QueryResultValue for Vec<u8> {
1029    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1030        row_as_bytes(row, alias)
1031    }
1032
1033    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1034        json_as_bytes(value)
1035    }
1036}
1037
1038impl QueryResultValue for Uuid {
1039    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1040        row.try_get(alias)
1041    }
1042
1043    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1044        let value = json_as_string(value)?;
1045        Uuid::parse_str(&value)
1046            .map_err(|error| schema_error(format!("invalid JSON UUID in query result: {error}")))
1047    }
1048}
1049
1050impl<T> QueryResultValue for T
1051where
1052    T: StringValueType,
1053{
1054    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1055        row_string_value(row, alias)
1056    }
1057
1058    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1059        json_string_value(value)
1060    }
1061}
1062
1063impl<T> QueryResultValue for Option<T>
1064where
1065    T: QueryResultValue,
1066{
1067    fn from_row(row: &PgRow, alias: &str) -> Result<Self, sqlx::Error> {
1068        if row.try_get_raw(alias)?.is_null() {
1069            Ok(None)
1070        } else {
1071            T::from_row(row, alias).map(Some)
1072        }
1073    }
1074
1075    fn from_json(value: &JsonValue) -> Result<Self, sqlx::Error> {
1076        if value.is_null() {
1077            Ok(None)
1078        } else {
1079            T::from_json(value).map(Some)
1080        }
1081    }
1082}
1083
1084fn build_query_sql(
1085    schema: &Schema,
1086    selection: &QuerySelection,
1087    variables: &QueryVariables,
1088) -> Result<(String, Vec<QueryVariableValue>), sqlx::Error> {
1089    let root_model = resolve_schema_model(schema, selection.model, "query")?;
1090
1091    let mut builder = SqlBuilder {
1092        schema,
1093        variables,
1094        bindings: Vec::new(),
1095        joins: Vec::new(),
1096        next_alias: 1,
1097    };
1098
1099    let selects = builder.root_selects(root_model, selection, selection.model, "t0")?;
1100    let where_clause = selection
1101        .filter
1102        .as_ref()
1103        .map(|filter| builder.filter_sql(root_model, filter, "t0"))
1104        .transpose()?;
1105
1106    let mut order_joins = Vec::new();
1107    let order_by_clause =
1108        builder.order_by_sql(root_model, &selection.order_by, "t0", &mut order_joins)?;
1109    builder.joins.extend(order_joins);
1110
1111    let pagination_clause =
1112        builder.pagination_clause(selection.skip.as_ref(), selection.limit.as_ref())?;
1113
1114    let sql = format!(
1115        "SELECT {} FROM {} AS \"t0\"{}{}{}{}",
1116        selects.join(", "),
1117        quoted_ident(root_model.name()),
1118        if builder.joins.is_empty() {
1119            String::new()
1120        } else {
1121            format!(" {}", builder.joins.join(" "))
1122        },
1123        where_clause
1124            .map(|where_clause| format!(" WHERE {where_clause}"))
1125            .unwrap_or_default(),
1126        order_by_clause
1127            .map(|order_by_clause| format!(" ORDER BY {order_by_clause}"))
1128            .unwrap_or_default(),
1129        pagination_clause,
1130    );
1131
1132    Ok((sql, builder.bindings))
1133}
1134
1135fn model_names_match(left: &str, right: &str) -> bool {
1136    left.eq_ignore_ascii_case(right)
1137}
1138
1139fn infer_relation_fields<'a>(
1140    model: &'a Model,
1141    field: &'a crate::Field,
1142    target_model: &'a Model,
1143) -> Result<(Vec<&'a str>, Vec<&'a str>), sqlx::Error> {
1144    let reverse_relation = target_model
1145        .fields()
1146        .iter()
1147        .find(|candidate| {
1148            model_names_match(candidate.ty().name(), model.name()) && candidate.relation().is_some()
1149        })
1150        .ok_or_else(|| {
1151            schema_error(format!(
1152                "could not infer relation metadata for `{}.{}`",
1153                model.name(),
1154                field.name()
1155            ))
1156        })?;
1157
1158    let reverse_relation = reverse_relation
1159        .relation()
1160        .expect("reverse relation existence checked above");
1161
1162    Ok((
1163        reverse_relation
1164            .fields()
1165            .iter()
1166            .map(String::as_str)
1167            .collect(),
1168        reverse_relation
1169            .references()
1170            .iter()
1171            .map(String::as_str)
1172            .collect(),
1173    ))
1174}
1175
1176struct SqlBuilder<'a> {
1177    schema: &'a Schema,
1178    variables: &'a QueryVariables,
1179    bindings: Vec<QueryVariableValue>,
1180    joins: Vec<String>,
1181    next_alias: usize,
1182}
1183
1184struct RelationSql<'a> {
1185    many: bool,
1186    source_model_name: &'a str,
1187    relation_field_name: &'a str,
1188    target_model: &'a Model,
1189    selection: QuerySelection,
1190    parent_table_alias: String,
1191    nested_alias: String,
1192    nested_fields: Vec<&'a str>,
1193    parent_fields: Vec<&'a str>,
1194}
1195
1196impl<'a> SqlBuilder<'a> {
1197    fn root_selects(
1198        &mut self,
1199        model: &'a Model,
1200        selection: &QuerySelection,
1201        prefix: &str,
1202        table_alias: &str,
1203    ) -> Result<Vec<String>, sqlx::Error> {
1204        let mut selects = Vec::new();
1205
1206        for field_name in &selection.scalar_fields {
1207            let field = model.field_named(field_name).ok_or_else(|| {
1208                schema_error(format!(
1209                    "unknown field `{}.{}` in query selection",
1210                    model.name(),
1211                    field_name
1212                ))
1213            })?;
1214
1215            let scalar = match field.ty() {
1216                FieldType::Scalar(scalar) => scalar.scalar(),
1217                FieldType::Relation { .. } => {
1218                    return Err(schema_error(format!(
1219                        "field `{}.{}` is not scalar and cannot appear in `select`",
1220                        model.name(),
1221                        field_name
1222                    )));
1223                }
1224            };
1225
1226            selects.push(select_expr(
1227                table_alias,
1228                field.name(),
1229                scalar,
1230                &alias_name(prefix, field.name()),
1231            ));
1232        }
1233
1234        for relation in &selection.relations {
1235            selects.push(self.relation_select(model, relation, prefix, table_alias)?);
1236        }
1237
1238        Ok(selects)
1239    }
1240
1241    fn relation_select(
1242        &mut self,
1243        model: &'a Model,
1244        relation: &QueryRelationSelection,
1245        prefix: &str,
1246        table_alias: &str,
1247    ) -> Result<String, sqlx::Error> {
1248        let field = model.field_named(relation.field).ok_or_else(|| {
1249            schema_error(format!(
1250                "unknown relation `{}.{}` in query include",
1251                model.name(),
1252                relation.field
1253            ))
1254        })?;
1255
1256        if field.kind().is_scalar() {
1257            return Err(schema_error(format!(
1258                "field `{}.{}` is not a relation and cannot appear in `include`",
1259                model.name(),
1260                relation.field
1261            )));
1262        }
1263
1264        let target_model =
1265            resolve_schema_model(self.schema, field.ty().name(), "query").map_err(|_| {
1266                schema_error(format!(
1267                    "relation `{}.{}` points at unknown model `{}`",
1268                    model.name(),
1269                    relation.field,
1270                    field.ty().name()
1271                ))
1272            })?;
1273
1274        let (nested_fields, parent_fields) = self.relation_fields(model, field, target_model)?;
1275
1276        let join_alias = format!("t{}", self.next_alias);
1277        self.next_alias += 1;
1278
1279        let nested_alias = format!("t{}", self.next_alias);
1280        self.next_alias += 1;
1281
1282        let subquery = self.relation_subquery_sql(RelationSql {
1283            many: field.ty().is_many(),
1284            source_model_name: model.name(),
1285            relation_field_name: relation.field,
1286            target_model,
1287            selection: relation.selection.clone(),
1288            parent_table_alias: table_alias.to_owned(),
1289            nested_alias: nested_alias.clone(),
1290            nested_fields,
1291            parent_fields,
1292        })?;
1293
1294        self.joins.push(format!(
1295            "LEFT JOIN LATERAL ({subquery}) AS \"{join_alias}\" ON TRUE"
1296        ));
1297
1298        let alias = alias_name(prefix, relation.field);
1299        Ok(format!("\"{join_alias}\".\"data\" AS \"{alias}\""))
1300    }
1301
1302    fn relation_subquery_sql(&mut self, relation: RelationSql<'a>) -> Result<String, sqlx::Error> {
1303        let mut where_clauses = vec![relation_predicates(
1304            &relation.nested_alias,
1305            &relation.nested_fields,
1306            &relation.parent_table_alias,
1307            &relation.parent_fields,
1308        )];
1309        let mut joins = Vec::new();
1310        let row_expr = self.json_row_expr(
1311            relation.target_model,
1312            &relation.selection,
1313            &relation.nested_alias,
1314            &mut joins,
1315        )?;
1316        let order_by_clause = self.order_by_sql(
1317            relation.target_model,
1318            &relation.selection.order_by,
1319            &relation.nested_alias,
1320            &mut joins,
1321        )?;
1322        let joins_sql = if joins.is_empty() {
1323            String::new()
1324        } else {
1325            format!(" {}", joins.join(" "))
1326        };
1327
1328        if let Some(filter) = relation.selection.filter.as_ref() {
1329            where_clauses.push(self.filter_sql(
1330                relation.target_model,
1331                filter,
1332                &relation.nested_alias,
1333            )?);
1334        }
1335
1336        let where_clause = where_clauses.join(" AND ");
1337        let select_order_by_clause = order_by_clause
1338            .clone()
1339            .map(|order_by_clause| format!(" ORDER BY {order_by_clause}"))
1340            .unwrap_or_default();
1341
1342        if relation.many {
1343            if relation.selection.skip.is_some() || relation.selection.limit.is_some() {
1344                let pagination_clause = self.pagination_clause(
1345                    relation.selection.skip.as_ref(),
1346                    relation.selection.limit.as_ref(),
1347                )?;
1348                let aggregate_table_alias = "__vitrail_nested_rows";
1349                let aggregate_order_column = "__vitrail_nested_order";
1350                let select_order_by_clause = if select_order_by_clause.is_empty() {
1351                    aggregate_order_by(relation.target_model, &relation.nested_alias)
1352                } else {
1353                    select_order_by_clause
1354                };
1355
1356                Ok(format!(
1357                    "SELECT COALESCE(json_agg(\"{aggregate_table_alias}\".\"data\" ORDER BY \"{aggregate_table_alias}\".\"{aggregate_order_column}\"), '[]'::json) AS \"data\" FROM (SELECT {row_expr} AS \"data\", row_number() OVER ({select_order_by_clause}) AS \"{aggregate_order_column}\" FROM {} AS \"{}\"{} WHERE {where_clause}{select_order_by_clause}{pagination_clause}) AS \"{aggregate_table_alias}\"",
1358                    quoted_ident(relation.target_model.name()),
1359                    relation.nested_alias,
1360                    joins_sql,
1361                ))
1362            } else {
1363                let aggregate_order_by_clause =
1364                    if let Some(order_by_clause) = order_by_clause.as_ref() {
1365                        format!(" ORDER BY {order_by_clause}")
1366                    } else {
1367                        aggregate_order_by(relation.target_model, &relation.nested_alias)
1368                    };
1369
1370                Ok(format!(
1371                    "SELECT COALESCE(json_agg({row_expr}{aggregate_order_by_clause}), '[]'::json) AS \"data\" FROM {} AS \"{}\"{} WHERE {where_clause}",
1372                    quoted_ident(relation.target_model.name()),
1373                    relation.nested_alias,
1374                    joins_sql,
1375                ))
1376            }
1377        } else {
1378            if relation.selection.skip.is_some() || relation.selection.limit.is_some() {
1379                return Err(schema_error(format!(
1380                    "relation `{}.{}` is to-one and cannot use `skip` or `limit`",
1381                    relation.source_model_name, relation.relation_field_name
1382                )));
1383            }
1384
1385            Ok(format!(
1386                "SELECT {row_expr} AS \"data\" FROM {} AS \"{}\"{} WHERE {where_clause}{select_order_by_clause} LIMIT 1",
1387                quoted_ident(relation.target_model.name()),
1388                relation.nested_alias,
1389                joins_sql,
1390            ))
1391        }
1392    }
1393
1394    fn filter_sql(
1395        &mut self,
1396        model: &'a Model,
1397        filter: &QueryFilter,
1398        table_alias: &str,
1399    ) -> Result<String, sqlx::Error> {
1400        compile_filter_sql(self, model, filter, table_alias)
1401    }
1402
1403    fn pagination_clause(
1404        &mut self,
1405        skip: Option<&QueryPagination>,
1406        limit: Option<&QueryPagination>,
1407    ) -> Result<String, sqlx::Error> {
1408        let mut clause = String::new();
1409
1410        if let Some(limit) = limit {
1411            let limit = self.pagination_placeholder(limit, "limit")?;
1412            clause.push_str(&format!(" LIMIT {limit}"));
1413        }
1414
1415        if let Some(skip) = skip {
1416            let skip = self.pagination_placeholder(skip, "skip")?;
1417            clause.push_str(&format!(" OFFSET {skip}"));
1418        }
1419
1420        Ok(clause)
1421    }
1422
1423    fn pagination_placeholder(
1424        &mut self,
1425        pagination: &QueryPagination,
1426        kind: &str,
1427    ) -> Result<String, sqlx::Error> {
1428        let value = match pagination {
1429            QueryPagination::Value(value) => QueryVariableValue::Int(*value),
1430            QueryPagination::Variable(name) => {
1431                let value = self.variables.get(name).ok_or_else(|| {
1432                    schema_error(format!("missing query variable `{name}` for `{kind}`"))
1433                })?;
1434
1435                match value {
1436                    QueryVariableValue::Int(value) => QueryVariableValue::Int(*value),
1437                    other => {
1438                        return Err(schema_error(format!(
1439                            "query `{kind}` variable `{name}` must be an integer, got {other:?}"
1440                        )));
1441                    }
1442                }
1443            }
1444        };
1445
1446        let QueryVariableValue::Int(value) = value else {
1447            unreachable!("pagination values must be integers")
1448        };
1449
1450        if value < 0 {
1451            return Err(schema_error(format!(
1452                "query `{kind}` must be greater than or equal to 0"
1453            )));
1454        }
1455
1456        self.push_binding(QueryVariableValue::Int(value), ScalarType::Int)
1457    }
1458
1459    fn order_by_sql(
1460        &mut self,
1461        model: &'a Model,
1462        orders: &[QueryOrder],
1463        table_alias: &str,
1464        joins: &mut Vec<String>,
1465    ) -> Result<Option<String>, sqlx::Error> {
1466        if orders.is_empty() {
1467            return Ok(None);
1468        }
1469
1470        let mut items = Vec::new();
1471        let mut relation_join_aliases = HashMap::new();
1472
1473        for order in orders {
1474            self.push_order_sql(
1475                model,
1476                order,
1477                table_alias,
1478                joins,
1479                &mut relation_join_aliases,
1480                &mut items,
1481            )?;
1482        }
1483
1484        Ok(Some(items.join(", ")))
1485    }
1486
1487    fn push_order_sql(
1488        &mut self,
1489        model: &'a Model,
1490        order: &QueryOrder,
1491        table_alias: &str,
1492        joins: &mut Vec<String>,
1493        relation_join_aliases: &mut HashMap<String, String>,
1494        items: &mut Vec<String>,
1495    ) -> Result<(), sqlx::Error> {
1496        match order {
1497            QueryOrder::Scalar { field, direction } => {
1498                let field = model.field_named(field).ok_or_else(|| {
1499                    schema_error(format!(
1500                        "unknown field `{}.{}` in query ordering",
1501                        model.name(),
1502                        field
1503                    ))
1504                })?;
1505
1506                let scalar = match field.ty() {
1507                    FieldType::Scalar(scalar) => scalar.scalar(),
1508                    FieldType::Relation { .. } => {
1509                        return Err(schema_error(format!(
1510                            "field `{}.{}` is not scalar and cannot terminate `order_by`",
1511                            model.name(),
1512                            field.name()
1513                        )));
1514                    }
1515                };
1516
1517                items.push(format!(
1518                    "{} {}",
1519                    column_expr(table_alias, field.name(), scalar),
1520                    match direction {
1521                        QueryOrderDirection::Asc => "ASC",
1522                        QueryOrderDirection::Desc => "DESC",
1523                    }
1524                ));
1525                Ok(())
1526            }
1527            QueryOrder::Relation { field, orders } => {
1528                let field = model.field_named(field).ok_or_else(|| {
1529                    schema_error(format!(
1530                        "unknown relation `{}.{}` in query ordering",
1531                        model.name(),
1532                        field
1533                    ))
1534                })?;
1535
1536                if field.kind().is_scalar() {
1537                    return Err(schema_error(format!(
1538                        "field `{}.{}` is not a relation and cannot be traversed in `order_by`",
1539                        model.name(),
1540                        field.name()
1541                    )));
1542                }
1543
1544                if field.ty().is_many() {
1545                    return Err(schema_error(format!(
1546                        "relation `{}.{}` is to-many and cannot be used in `order_by`",
1547                        model.name(),
1548                        field.name()
1549                    )));
1550                }
1551
1552                if orders.is_empty() {
1553                    return Err(schema_error(format!(
1554                        "relation `{}.{}` must contain at least one nested `order_by` entry",
1555                        model.name(),
1556                        field.name()
1557                    )));
1558                }
1559
1560                let target_model = resolve_schema_model(self.schema, field.ty().name(), "query")
1561                    .map_err(|_| {
1562                        schema_error(format!(
1563                            "relation `{}.{}` points at unknown model `{}`",
1564                            model.name(),
1565                            field.name(),
1566                            field.ty().name()
1567                        ))
1568                    })?;
1569
1570                let (nested_fields, parent_fields) =
1571                    self.relation_fields(model, field, target_model)?;
1572                let predicate_template = relation_predicates(
1573                    "__vitrail_order_join__",
1574                    &nested_fields,
1575                    table_alias,
1576                    &parent_fields,
1577                );
1578                let join_key = format!("{}::{predicate_template}", target_model.name());
1579                let join_alias = if let Some(join_alias) = relation_join_aliases.get(&join_key) {
1580                    join_alias.clone()
1581                } else {
1582                    let join_alias = format!("t{}", self.next_alias);
1583                    self.next_alias += 1;
1584                    joins.push(format!(
1585                        "LEFT JOIN {} AS \"{join_alias}\" ON {}",
1586                        quoted_ident(target_model.name()),
1587                        relation_predicates(
1588                            &join_alias,
1589                            &nested_fields,
1590                            table_alias,
1591                            &parent_fields,
1592                        ),
1593                    ));
1594                    relation_join_aliases.insert(join_key, join_alias.clone());
1595                    join_alias
1596                };
1597
1598                for nested_order in orders {
1599                    self.push_order_sql(
1600                        target_model,
1601                        nested_order,
1602                        &join_alias,
1603                        joins,
1604                        relation_join_aliases,
1605                        items,
1606                    )?;
1607                }
1608
1609                Ok(())
1610            }
1611        }
1612    }
1613
1614    fn push_binding(
1615        &mut self,
1616        value: QueryVariableValue,
1617        _scalar: ScalarType,
1618    ) -> Result<String, sqlx::Error> {
1619        self.bindings.push(value);
1620        Ok(format!("${}", self.bindings.len()))
1621    }
1622
1623    fn json_row_expr(
1624        &mut self,
1625        model: &'a Model,
1626        selection: &QuerySelection,
1627        table_alias: &str,
1628        joins: &mut Vec<String>,
1629    ) -> Result<String, sqlx::Error> {
1630        let mut items = Vec::new();
1631
1632        for field_name in &selection.scalar_fields {
1633            let field = model.field_named(field_name).ok_or_else(|| {
1634                schema_error(format!(
1635                    "unknown field `{}.{}` in query selection",
1636                    model.name(),
1637                    field_name
1638                ))
1639            })?;
1640
1641            let scalar = match field.ty() {
1642                FieldType::Scalar(scalar) => scalar.scalar(),
1643                FieldType::Relation { .. } => {
1644                    return Err(schema_error(format!(
1645                        "field `{}.{}` is not scalar and cannot appear in `select`",
1646                        model.name(),
1647                        field_name
1648                    )));
1649                }
1650            };
1651
1652            items.push(json_column_expr(table_alias, field.name(), scalar));
1653        }
1654
1655        for relation in &selection.relations {
1656            items.push(self.nested_relation_json_expr(model, relation, table_alias, joins)?);
1657        }
1658
1659        Ok(format!("json_build_array({})", items.join(", ")))
1660    }
1661
1662    fn nested_relation_json_expr(
1663        &mut self,
1664        model: &'a Model,
1665        relation: &QueryRelationSelection,
1666        table_alias: &str,
1667        joins: &mut Vec<String>,
1668    ) -> Result<String, sqlx::Error> {
1669        let field = model.field_named(relation.field).ok_or_else(|| {
1670            schema_error(format!(
1671                "unknown relation `{}.{}` in query include",
1672                model.name(),
1673                relation.field
1674            ))
1675        })?;
1676
1677        if field.kind().is_scalar() {
1678            return Err(schema_error(format!(
1679                "field `{}.{}` is not a relation and cannot appear in `include`",
1680                model.name(),
1681                relation.field
1682            )));
1683        }
1684
1685        let target_model =
1686            resolve_schema_model(self.schema, field.ty().name(), "query").map_err(|_| {
1687                schema_error(format!(
1688                    "relation `{}.{}` points at unknown model `{}`",
1689                    model.name(),
1690                    relation.field,
1691                    field.ty().name()
1692                ))
1693            })?;
1694
1695        let (nested_fields, parent_fields) = self.relation_fields(model, field, target_model)?;
1696
1697        let join_alias = format!("t{}", self.next_alias);
1698        self.next_alias += 1;
1699
1700        let nested_alias = format!("t{}", self.next_alias);
1701        self.next_alias += 1;
1702
1703        let subquery = self.relation_subquery_sql(RelationSql {
1704            many: field.ty().is_many(),
1705            source_model_name: model.name(),
1706            relation_field_name: relation.field,
1707            target_model,
1708            selection: relation.selection.clone(),
1709            parent_table_alias: table_alias.to_owned(),
1710            nested_alias: nested_alias.clone(),
1711            nested_fields,
1712            parent_fields,
1713        })?;
1714
1715        joins.push(format!(
1716            "LEFT JOIN LATERAL ({subquery}) AS \"{join_alias}\" ON TRUE"
1717        ));
1718
1719        Ok(format!("\"{join_alias}\".\"data\""))
1720    }
1721
1722    fn relation_fields(
1723        &self,
1724        model: &'a Model,
1725        field: &'a crate::Field,
1726        target_model: &'a Model,
1727    ) -> Result<(Vec<&'a str>, Vec<&'a str>), sqlx::Error> {
1728        match field.relation() {
1729            Some(relation_info) => Ok((
1730                relation_info
1731                    .references()
1732                    .iter()
1733                    .map(String::as_str)
1734                    .collect::<Vec<_>>(),
1735                relation_info
1736                    .fields()
1737                    .iter()
1738                    .map(String::as_str)
1739                    .collect::<Vec<_>>(),
1740            )),
1741            None => infer_relation_fields(model, field, target_model),
1742        }
1743    }
1744}
1745
1746impl<'a> FilterBuilder<'a> for SqlBuilder<'a> {
1747    fn schema(&self) -> &'a Schema {
1748        self.schema
1749    }
1750
1751    fn variables(&self) -> &'a QueryVariables {
1752        self.variables
1753    }
1754
1755    fn push_filter_binding(
1756        &mut self,
1757        value: QueryVariableValue,
1758        scalar: ScalarType,
1759    ) -> Result<String, sqlx::Error> {
1760        self.push_binding(value, scalar)
1761    }
1762
1763    fn next_filter_alias(&mut self) -> String {
1764        let alias = format!("t{}", self.next_alias);
1765        self.next_alias += 1;
1766        alias
1767    }
1768
1769    fn operation_name(&self) -> &'static str {
1770        "query"
1771    }
1772}
1773
1774fn aggregate_order_by(model: &Model, table_alias: &str) -> String {
1775    let primary_key_columns = model.primary_key_columns();
1776    let field_names = if primary_key_columns.is_empty() {
1777        model
1778            .field_named("id")
1779            .map(|field| vec![field.name()])
1780            .or_else(|| {
1781                model
1782                    .fields()
1783                    .iter()
1784                    .find(|field| field.kind().is_scalar())
1785                    .map(|field| vec![field.name()])
1786            })
1787            .unwrap_or_else(|| vec!["id"])
1788    } else {
1789        primary_key_columns
1790    };
1791
1792    format!(
1793        " ORDER BY {}",
1794        field_names
1795            .into_iter()
1796            .map(|field_name| format!("\"{table_alias}\".{}", quoted_ident(field_name)))
1797            .collect::<Vec<_>>()
1798            .join(", ")
1799    )
1800}
1801
1802fn relation_predicates(
1803    nested_alias: &str,
1804    nested_fields: &[&str],
1805    parent_alias: &str,
1806    parent_fields: &[&str],
1807) -> String {
1808    nested_fields
1809        .iter()
1810        .zip(parent_fields)
1811        .map(|(nested_field, parent_field)| {
1812            format!(
1813                "\"{nested_alias}\".{} = \"{parent_alias}\".{}",
1814                quoted_ident(nested_field),
1815                quoted_ident(parent_field),
1816            )
1817        })
1818        .collect::<Vec<_>>()
1819        .join(" AND ")
1820}
1821
1822pub(crate) fn quoted_ident(ident: &str) -> String {
1823    format!("\"{}\"", ident.replace('"', "\"\""))
1824}
1825
1826pub(crate) fn column_expr(table_alias: &str, field_name: &str, scalar: ScalarType) -> String {
1827    let column_sql = format!("\"{table_alias}\".{}", quoted_ident(field_name));
1828    match scalar {
1829        ScalarType::Int => format!("({column_sql})::bigint"),
1830        ScalarType::DateTime => format!("({column_sql} AT TIME ZONE 'UTC')"),
1831        _ => column_sql,
1832    }
1833}
1834
1835pub(crate) fn json_column_expr(table_alias: &str, field_name: &str, scalar: ScalarType) -> String {
1836    let column_sql = format!("\"{table_alias}\".{}", quoted_ident(field_name));
1837    match scalar {
1838        ScalarType::Int => format!("({column_sql})::bigint"),
1839        ScalarType::DateTime => format!("({column_sql} AT TIME ZONE 'UTC')"),
1840        ScalarType::Decimal => format!("({column_sql})::text"),
1841        ScalarType::Bytes => format!("encode({column_sql}, 'hex')"),
1842        _ => column_sql,
1843    }
1844}
1845
1846pub(crate) fn select_expr(
1847    table_alias: &str,
1848    field_name: &str,
1849    scalar: ScalarType,
1850    alias: &str,
1851) -> String {
1852    let expr = column_expr(table_alias, field_name, scalar);
1853    format!("{expr} AS \"{alias}\"")
1854}
1855
1856fn bind_query<'q>(
1857    mut query: sqlx::query::Query<'q, Postgres, PgArguments>,
1858    bindings: &'q [QueryVariableValue],
1859) -> sqlx::query::Query<'q, Postgres, PgArguments> {
1860    for binding in bindings {
1861        query = match binding {
1862            QueryVariableValue::Null => query.bind(Option::<i64>::None),
1863            QueryVariableValue::Int(value) => query.bind(*value),
1864            QueryVariableValue::String(value) => query.bind(value),
1865            QueryVariableValue::Bool(value) => query.bind(*value),
1866            QueryVariableValue::Float(value) => query.bind(*value),
1867            QueryVariableValue::Decimal(value) => query.bind(*value),
1868            QueryVariableValue::Bytes(value) => query.bind(value),
1869            QueryVariableValue::DateTime(value) => query.bind(*value),
1870            QueryVariableValue::Uuid(value) => query.bind(*value),
1871            QueryVariableValue::List(values) => {
1872                let first = values
1873                    .first()
1874                    .expect("list-valued query variables must not be empty when bound");
1875
1876                match first {
1877                    QueryVariableValue::Null => {
1878                        unreachable!("list-valued query variables must not contain null items")
1879                    }
1880                    QueryVariableValue::Int(_) => query.bind(
1881                        values
1882                            .iter()
1883                            .map(|value| match value {
1884                                QueryVariableValue::Int(value) => *value,
1885                                _ => unreachable!("list-valued query variables must be homogenous"),
1886                            })
1887                            .collect::<Vec<_>>(),
1888                    ),
1889                    QueryVariableValue::String(_) => query.bind(
1890                        values
1891                            .iter()
1892                            .map(|value| match value {
1893                                QueryVariableValue::String(value) => value.clone(),
1894                                _ => unreachable!("list-valued query variables must be homogenous"),
1895                            })
1896                            .collect::<Vec<_>>(),
1897                    ),
1898                    QueryVariableValue::Bool(_) => query.bind(
1899                        values
1900                            .iter()
1901                            .map(|value| match value {
1902                                QueryVariableValue::Bool(value) => *value,
1903                                _ => unreachable!("list-valued query variables must be homogenous"),
1904                            })
1905                            .collect::<Vec<_>>(),
1906                    ),
1907                    QueryVariableValue::Float(_) => query.bind(
1908                        values
1909                            .iter()
1910                            .map(|value| match value {
1911                                QueryVariableValue::Float(value) => *value,
1912                                _ => unreachable!("list-valued query variables must be homogenous"),
1913                            })
1914                            .collect::<Vec<_>>(),
1915                    ),
1916                    QueryVariableValue::Decimal(_) => query.bind(
1917                        values
1918                            .iter()
1919                            .map(|value| match value {
1920                                QueryVariableValue::Decimal(value) => *value,
1921                                _ => unreachable!("list-valued query variables must be homogenous"),
1922                            })
1923                            .collect::<Vec<_>>(),
1924                    ),
1925                    QueryVariableValue::Bytes(_) => query.bind(
1926                        values
1927                            .iter()
1928                            .map(|value| match value {
1929                                QueryVariableValue::Bytes(value) => value.clone(),
1930                                _ => unreachable!("list-valued query variables must be homogenous"),
1931                            })
1932                            .collect::<Vec<_>>(),
1933                    ),
1934                    QueryVariableValue::DateTime(_) => query.bind(
1935                        values
1936                            .iter()
1937                            .map(|value| match value {
1938                                QueryVariableValue::DateTime(value) => *value,
1939                                _ => unreachable!("list-valued query variables must be homogenous"),
1940                            })
1941                            .collect::<Vec<_>>(),
1942                    ),
1943                    QueryVariableValue::Uuid(_) => query.bind(
1944                        values
1945                            .iter()
1946                            .map(|value| match value {
1947                                QueryVariableValue::Uuid(value) => *value,
1948                                _ => unreachable!("list-valued query variables must be homogenous"),
1949                            })
1950                            .collect::<Vec<_>>(),
1951                    ),
1952                    QueryVariableValue::List(_) => {
1953                        unreachable!("list-valued query variables must not contain nested lists")
1954                    }
1955                }
1956            }
1957        };
1958    }
1959
1960    query
1961}
1962
1963pub fn schema_error(message: String) -> sqlx::Error {
1964    sqlx::Error::Protocol(message)
1965}