1use smallvec::SmallVec;
2use sqlx::{
3 postgres::PgRow,
4 prelude::Type,
5 Postgres,
6};
7
8pub trait QueryModel =
9 Send + Clone + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>;
10
11pub trait FilterQuery = Filterable + Clone;
12
13pub trait QuerySort = Sortable + Clone;
14
15pub trait Filterable {
16 type Entity: QueryModel;
17
18 fn write<W: SqlWrite>(&self, w: &mut W);
19}
20
21pub trait SqlWrite {
22 fn push(&mut self, s: &str);
23
24 fn bind<T>(&mut self, value: T)
25 where
26 T: sqlx::Encode<'static, Postgres> + Send + 'static,
27 T: Type<Postgres>;
28}
29
30pub trait QueryContext: Send + Sync + 'static {
31 const TABLE: &'static str;
32
33 type Model: QueryModel
34 + Send
35 + Sync
36 + JoinNavigationModel
37 + WebJoinGraph
38 + PrimaryKey;
39 type Query: FilterQuery + Send + Sync;
40 type Sort: QuerySort + Send + Sync;
41 type Join: SqlJoin + Send + Sync;
42}
43
44pub trait Sortable {
45 type Entity: QueryModel;
46
47 fn sort_clause(&self) -> String;
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum JoinKind {
52 Left,
53 Inner,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum RelationFetchMode {
58 Eager,
59 Lazy,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct JoinThroughDescriptor {
64 pub table: &'static str,
65 pub alias_segment: &'static str,
66 pub left_field: &'static str,
67 pub right_field: &'static str,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub struct JoinDescriptor {
72 pub left_table: &'static str,
73 pub left_field: &'static str,
74 pub right_table: &'static str,
75 pub right_field: &'static str,
76 pub alias_segment: &'static str,
77 pub identifier: &'static str,
78 pub through: Option<JoinThroughDescriptor>,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct JoinSegment {
83 pub descriptor: JoinDescriptor,
84 pub kind: JoinKind,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct JoinPath {
89 segments: Vec<JoinSegment>,
90 start: usize,
91}
92
93impl JoinPath {
94 pub fn from_join<J: SqlJoin>(join: J, kind: JoinKind) -> Self {
95 Self::new(join.descriptor(), kind)
96 }
97
98 pub fn new(descriptor: JoinDescriptor, kind: JoinKind) -> Self {
99 Self {
100 segments: vec![JoinSegment { descriptor, kind }],
101 start: 0,
102 }
103 }
104
105 pub fn then<J: SqlJoin>(mut self, join: J, kind: JoinKind) -> Self {
106 let descriptor = join.descriptor();
107
108 if let Some(prev) = self.segments.last() {
109 assert_eq!(
110 prev.descriptor.right_table, descriptor.left_table,
111 "Invalid join path: expected next hop to start at `{}` but \
112 found `{}`",
113 prev.descriptor.right_table, descriptor.left_table,
114 );
115 }
116
117 self.segments.push(JoinSegment { descriptor, kind });
118 self
119 }
120
121 pub fn segments(&self) -> &[JoinSegment] {
122 &self.segments[self.start..]
123 }
124
125 pub fn len(&self) -> usize {
126 self.segments().len()
127 }
128
129 pub fn append(&mut self, tail: &JoinPath) {
130 if tail.is_empty() {
131 return;
132 }
133
134 let Some(last) = self.segments().last() else {
135 return;
136 };
137 let Some(first) = tail.segments().first() else {
138 return;
139 };
140
141 assert_eq!(
142 last.descriptor.right_table, first.descriptor.left_table,
143 "Invalid join append: left table `{}` does not match `{}`",
144 last.descriptor.right_table, first.descriptor.left_table,
145 );
146
147 self.segments.extend_from_slice(tail.segments());
148 }
149
150 pub fn strip_prefix(&self, len: usize) -> Option<Self> {
151 let new_start = self.start + len;
152 if new_start > self.segments.len() {
153 return None;
154 }
155
156 Some(Self {
157 segments: self.segments.clone(),
158 start: new_start,
159 })
160 }
161
162 pub fn tail(&self) -> Option<Self> {
163 if self.segments.len() - self.start <= 1 {
164 None
165 } else {
166 Some(Self {
167 segments: self.segments.clone(),
168 start: self.start + 1,
169 })
170 }
171 }
172
173 pub fn is_empty(&self) -> bool {
174 self.len() == 0
175 }
176
177 pub fn first_table(&self) -> Option<&'static str> {
178 self.segments().first().map(|seg| seg.descriptor.left_table)
179 }
180
181 pub fn alias(&self) -> String {
182 self.alias_prefix(self.len())
183 }
184
185 pub fn alias_prefix(&self, len: usize) -> String {
186 assert!(len <= self.len());
187 let mut alias = String::new();
188 let end = self.start + len;
189 for segment in &self.segments[..end] {
190 alias.push_str(segment.descriptor.alias_segment);
191 }
192 alias
193 }
194}
195
196pub trait SqlJoin {
197 fn descriptor(&self) -> JoinDescriptor;
198}
199
200#[derive(Debug, Clone, PartialEq, Eq)]
201pub struct AliasedColumn {
202 pub table_alias: String,
203 pub column: &'static str,
204 pub alias: String,
205}
206
207impl AliasedColumn {
208 pub fn new(
209 table_alias: impl Into<String>,
210 column: &'static str,
211 alias: impl Into<String>,
212 ) -> Self {
213 Self {
214 table_alias: table_alias.into(),
215 column,
216 alias: alias.into(),
217 }
218 }
219}
220
221#[derive(PartialEq, Default, Eq)]
222pub enum Relation<T> {
223 #[default]
224 NotLoaded,
225 Missing,
226 Loaded(T),
227}
228
229impl<T: Clone> Clone for Relation<T> {
230 fn clone(&self) -> Self {
231 match self {
232 Self::NotLoaded => Self::NotLoaded,
233 Self::Missing => Self::Missing,
234 Self::Loaded(v) => Self::Loaded(v.clone()),
235 }
236 }
237}
238
239impl<T: std::fmt::Debug> std::fmt::Debug for Relation<T> {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 match self {
242 Self::NotLoaded => f.write_str("Relation::NotLoaded"),
243 Self::Missing => f.write_str("Relation::Missing"),
244 Self::Loaded(v) => {
245 f.debug_tuple("Relation::Loaded").field(v).finish()
246 }
247 }
248 }
249}
250
251impl<T> Relation<T> {
252 pub fn is_loaded(&self) -> bool {
253 !self.is_not_loaded()
254 }
255
256 pub fn is_not_loaded(&self) -> bool {
257 matches!(self, Self::NotLoaded)
258 }
259
260 pub fn is_missing(&self) -> bool {
261 matches!(self, Self::Missing)
262 }
263
264 pub fn is_present(&self) -> bool {
265 matches!(self, Self::Loaded(_))
266 }
267
268 pub fn is_some(&self) -> bool {
269 self.is_present()
270 }
271
272 pub fn is_none(&self) -> bool {
273 !self.is_some()
274 }
275
276 pub fn as_ref(&self) -> Relation<&T> {
277 match self {
278 Self::NotLoaded => Relation::NotLoaded,
279 Self::Missing => Relation::Missing,
280 Self::Loaded(value) => Relation::Loaded(value),
281 }
282 }
283
284 pub fn as_mut(&mut self) -> Relation<&mut T> {
285 match self {
286 Self::NotLoaded => Relation::NotLoaded,
287 Self::Missing => Relation::Missing,
288 Self::Loaded(value) => Relation::Loaded(value),
289 }
290 }
291
292 pub fn as_option(&self) -> Option<&T> {
293 match self {
294 Self::Loaded(value) => Some(value),
295 Self::NotLoaded | Self::Missing => None,
296 }
297 }
298
299 pub fn as_option_mut(&mut self) -> Option<&mut T> {
300 match self {
301 Self::Loaded(value) => Some(value),
302 Self::NotLoaded | Self::Missing => None,
303 }
304 }
305
306 pub fn into_option(self) -> Option<T> {
307 match self {
308 Self::Loaded(value) => Some(value),
309 Self::NotLoaded | Self::Missing => None,
310 }
311 }
312
313 pub fn map<U, F>(self, f: F) -> Relation<U>
314 where
315 F: FnOnce(T) -> U,
316 {
317 match self {
318 Self::Loaded(value) => Relation::Loaded(f(value)),
319 Self::NotLoaded => Relation::NotLoaded,
320 Self::Missing => Relation::Missing,
321 }
322 }
323
324 pub fn map_or<U, F>(self, default: U, f: F) -> U
325 where
326 F: FnOnce(T) -> U,
327 {
328 match self {
329 Self::Loaded(value) => f(value),
330 Self::NotLoaded | Self::Missing => default,
331 }
332 }
333
334 pub fn map_or_else<U, D, F>(self, default: D, f: F) -> U
335 where
336 D: FnOnce() -> U,
337 F: FnOnce(T) -> U,
338 {
339 match self {
340 Self::Loaded(value) => f(value),
341 Self::NotLoaded | Self::Missing => default(),
342 }
343 }
344
345 pub fn ok_or<E>(self, err: E) -> Result<T, E> {
346 match self {
347 Self::Loaded(value) => Ok(value),
348 Self::NotLoaded | Self::Missing => Err(err),
349 }
350 }
351
352 pub fn ok_or_else<E, F>(self, err: F) -> Result<T, E>
353 where
354 F: FnOnce() -> E,
355 {
356 match self {
357 Self::Loaded(value) => Ok(value),
358 Self::NotLoaded | Self::Missing => Err(err()),
359 }
360 }
361
362 pub fn unwrap(self) -> T {
363 match self {
364 Self::Loaded(value) => value,
365 Self::NotLoaded => {
366 panic!("called `Relation::unwrap()` on a `NotLoaded` relation")
367 }
368 Self::Missing => {
369 panic!("called `Relation::unwrap()` on a `Missing` relation")
370 }
371 }
372 }
373
374 pub fn expect(self, msg: &str) -> T {
375 match self {
376 Self::Loaded(value) => value,
377 Self::NotLoaded | Self::Missing => panic!("{msg}"),
378 }
379 }
380
381 pub fn unwrap_or(self, default: T) -> T {
382 match self {
383 Self::Loaded(value) => value,
384 Self::NotLoaded | Self::Missing => default,
385 }
386 }
387
388 pub fn unwrap_or_else<F>(self, f: F) -> T
389 where
390 F: FnOnce() -> T,
391 {
392 match self {
393 Self::Loaded(value) => value,
394 Self::NotLoaded | Self::Missing => f(),
395 }
396 }
397
398 pub fn unwrap_or_default(self) -> T
399 where
400 T: Default,
401 {
402 self.unwrap_or_else(T::default)
403 }
404}
405
406pub type JoinValue<T> = Relation<T>;
407
408pub fn merge_join_collections<T>(
409 target: &mut Relation<Vec<T>>,
410 incoming: Relation<Vec<T>>,
411) where
412 T: JoinIdentifiable,
413{
414 match target {
415 Relation::NotLoaded => {
416 *target = incoming;
417 }
418 Relation::Missing => {
419 if let Relation::Loaded(values) = incoming {
420 *target = Relation::Loaded(values);
421 }
422 }
423 Relation::Loaded(existing) => {
424 if let Relation::Loaded(mut values) = incoming {
425 let mut keys: Vec<T::Key> =
426 existing.iter().map(|item| item.join_key()).collect();
427 for value in values.drain(..) {
428 let key = value.join_key();
429 if keys.iter().all(|existing_key| existing_key != &key) {
430 keys.push(key);
431 existing.push(value);
432 }
433 }
434 }
435 }
436 }
437}
438
439pub trait JoinLoadable: Sized {
440 fn project_join_columns(
441 alias: &str,
442 out: &mut SmallVec<[AliasedColumn; 4]>,
443 );
444
445 fn hydrate_from_join(
446 row: &PgRow,
447 alias: &str,
448 ) -> Result<Option<Self>, sqlx::Error>;
449}
450
451pub trait JoinIdentifiable {
452 type Key: PartialEq;
453
454 fn join_key(&self) -> Self::Key;
455}
456
457pub trait JoinNavigationModel {
458 fn collect_join_columns(
459 joins: Option<&[JoinPath]>,
460 base_alias: &str,
461 ) -> SmallVec<[AliasedColumn; 4]>;
462
463 fn hydrate_navigations(
464 &mut self,
465 joins: Option<&[JoinPath]>,
466 row: &PgRow,
467 base_alias: &str,
468 ) -> Result<(), sqlx::Error>;
469
470 fn has_collection_joins(_joins: Option<&[JoinPath]>) -> bool {
471 false
472 }
473
474 fn merge_collection_rows(
475 rows: Vec<Self>,
476 _joins: Option<&[JoinPath]>,
477 ) -> Vec<Self>
478 where
479 Self: Sized,
480 {
481 rows
482 }
483
484 fn relation_fetch_mode(_identifier: &str) -> Option<RelationFetchMode> {
485 None
486 }
487
488 fn collect_default_join_paths(
489 _include_lazy: bool,
490 _visiting_tables: &mut Vec<&'static str>,
491 ) -> SmallVec<[JoinPath; 4]> {
492 SmallVec::new()
493 }
494
495 fn default_join_paths(include_lazy: bool) -> SmallVec<[JoinPath; 4]>
496 where
497 Self: Sized,
498 {
499 let mut visiting_tables = Vec::new();
500 Self::collect_default_join_paths(include_lazy, &mut visiting_tables)
501 }
502}
503
504pub trait WebJoinGraph {
505 fn resolve_join_path(segments: &[&str], kind: JoinKind)
506 -> Option<JoinPath>;
507}
508
509pub trait PrimaryKey {
510 const PRIMARY_KEY: &'static [&'static str];
511}
512
513pub trait Model {}
514
515pub trait Deletable {
516 const IS_SOFT_DELETE: bool;
517 const DELETE_MARKER_FIELD: Option<&'static str>;
518}
519
520pub trait GetDeleteMarker {
521 fn delete_marker_field() -> Option<&'static str>;
522}
523
524impl<T> GetDeleteMarker for T {
525 default fn delete_marker_field() -> Option<&'static str> {
526 None
527 }
528}
529
530impl<T: Deletable> GetDeleteMarker for T {
531 fn delete_marker_field() -> Option<&'static str> {
532 T::DELETE_MARKER_FIELD
533 }
534}
535
536pub trait Updatable {
537 type UpdateModel: UpdateModel<Entity = Self>;
538 const UPDATE_MARKER_FIELD: Option<&'static str>;
539}
540
541pub trait UpdateModel: Clone + Send + Sync {
542 type Entity: QueryModel;
543
544 fn apply_updates(
545 &self,
546 qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
547 has_previous: bool,
548 ) -> Vec<&'static str>;
549
550 fn append_relation_ctes(
551 &self,
552 _qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
553 _affected_alias: &str,
554 ) {
555 }
556
557 fn append_relation_dependency_columns(
558 &self,
559 _qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
560 ) {
561 }
562
563 fn apply_relation_payload(&self, _entity: &mut Self::Entity) {}
564}
565
566pub trait Creatable {
567 type CreateModel: CreateModel<Entity = Self>;
568 const INSERT_MARKER_FIELD: Option<&'static str>;
569}
570
571pub trait CreateModel: Clone + Send + Sync {
572 type Entity: QueryModel;
573
574 fn apply_inserts(
575 &self,
576 qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
577 insert_marker_field: Option<&'static str>,
578 );
579
580 fn append_relation_ctes(
581 &self,
582 _qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
583 _affected_alias: &str,
584 ) {
585 }
586
587 fn append_relation_dependency_columns(
588 &self,
589 _qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
590 ) {
591 }
592
593 fn apply_relation_payload(&self, _entity: &mut Self::Entity) {}
594}
595
596#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
597pub enum SearchWeight {
598 #[default]
599 A,
600 B,
601 C,
602 D,
603}
604
605impl SearchWeight {
606 pub fn to_char(self) -> char {
607 match self {
608 Self::A => 'A',
609 Self::B => 'B',
610 Self::C => 'C',
611 Self::D => 'D',
612 }
613 }
614
615 pub fn sql_literal(self) -> &'static str {
616 match self {
617 Self::A => "'A'",
618 Self::B => "'B'",
619 Self::C => "'C'",
620 Self::D => "'D'",
621 }
622 }
623}
624
625pub trait FullTextSearchConfig {
626 fn include_rank(&self) -> bool;
627
628 fn fuzzy_threshold(&self) -> Option<f64> {
629 None
630 }
631
632 fn fuzzy_tokens(&self) -> Option<&[String]> {
633 None
634 }
635
636 fn min_term_match_ratio(&self) -> f64 {
637 0.5
638 }
639}
640
641pub trait FullTextSearchConfigBuilder: FullTextSearchConfig {
642 fn new_with_query(query: String) -> Self;
643 fn apply_language(self, language: Option<String>) -> Self;
644 fn apply_rank(self, include_rank: Option<bool>) -> Self;
645
646 fn apply_fuzzy(self, _enable_fuzzy: Option<bool>) -> Self
647 where
648 Self: Sized,
649 {
650 self
651 }
652
653 fn apply_fuzzy_threshold(self, _threshold: Option<f64>) -> Self
654 where
655 Self: Sized,
656 {
657 self
658 }
659}
660
661pub trait FullTextSearchJoinConfig {
662 type Join;
663
664 fn with_join(self, join: Self::Join) -> Self;
665}
666
667pub trait FullTextSearchable: Sized {
668 type FullTextSearchField: Copy + Eq;
669 type FullTextSearchConfig: FullTextSearchConfig + Send + Sync;
670 type FullTextSearchJoin: Copy + Eq;
671
672 fn write_tsvector<W>(
673 w: &mut W,
674 base_alias: &str,
675 joins: Option<&[JoinPath]>,
676 config: &Self::FullTextSearchConfig,
677 ) where
678 W: SqlWrite;
679
680 fn write_tsquery<W>(w: &mut W, config: &Self::FullTextSearchConfig)
681 where
682 W: SqlWrite;
683
684 fn write_rank<W>(
685 w: &mut W,
686 base_alias: &str,
687 joins: Option<&[JoinPath]>,
688 config: &Self::FullTextSearchConfig,
689 ) where
690 W: SqlWrite;
691
692 fn write_fuzzy<W>(
693 w: &mut W,
694 base_alias: &str,
695 joins: Option<&[JoinPath]>,
696 config: &Self::FullTextSearchConfig,
697 ) where
698 W: SqlWrite,
699 {
700 let _ = (base_alias, joins, config);
701 w.push("FALSE");
702 }
703
704 fn write_search_document<W>(
705 w: &mut W,
706 base_alias: &str,
707 joins: Option<&[JoinPath]>,
708 config: &Self::FullTextSearchConfig,
709 ) where
710 W: SqlWrite,
711 {
712 let _ = (base_alias, joins, config);
713 w.push("''");
714 }
715
716 fn write_search_predicate<W>(
717 w: &mut W,
718 base_alias: &str,
719 joins: Option<&[JoinPath]>,
720 config: &Self::FullTextSearchConfig,
721 ) where
722 W: SqlWrite,
723 {
724 w.push("(");
725 w.push("(");
726 Self::write_tsvector(w, base_alias, joins, config);
727 w.push(") @@ (");
728 Self::write_tsquery(w, config);
729 w.push(")");
730 if config.fuzzy_threshold().is_some() &&
731 config
732 .fuzzy_tokens()
733 .map(|tokens| !tokens.is_empty())
734 .unwrap_or(false)
735 {
736 w.push(" OR ");
737 Self::write_fuzzy(w, base_alias, joins, config);
738 }
739 w.push(")");
740 }
741
742 fn write_search_score<W>(
743 w: &mut W,
744 base_alias: &str,
745 joins: Option<&[JoinPath]>,
746 config: &Self::FullTextSearchConfig,
747 ) where
748 W: SqlWrite,
749 {
750 Self::write_rank(w, base_alias, joins, config);
751 }
752
753 fn resolve_search_join_path(
754 _segments: &[&str],
755 ) -> Option<Self::FullTextSearchJoin> {
756 None
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::Relation;
763
764 #[test]
765 fn relation_state_helpers_work() {
766 let not_loaded = Relation::<i32>::NotLoaded;
767 let missing = Relation::<i32>::Missing;
768 let loaded = Relation::Loaded(42);
769
770 assert!(not_loaded.is_not_loaded());
771 assert!(not_loaded.is_none());
772 assert!(!not_loaded.is_loaded());
773 assert!(!not_loaded.is_some());
774 assert!(!not_loaded.is_missing());
775 assert!(!not_loaded.is_present());
776
777 assert!(missing.is_loaded());
778 assert!(missing.is_missing());
779 assert!(missing.is_none());
780 assert!(!missing.is_some());
781 assert!(!missing.is_not_loaded());
782
783 assert!(loaded.is_loaded());
784 assert!(loaded.is_present());
785 assert!(loaded.is_some());
786 assert!(!loaded.is_none());
787 assert!(!loaded.is_missing());
788 }
789
790 #[test]
791 fn relation_option_like_helpers_work() {
792 assert_eq!(Relation::Loaded(7).unwrap_or(1), 7);
793 assert_eq!(Relation::<i32>::NotLoaded.unwrap_or(1), 1);
794 assert_eq!(Relation::<i32>::Missing.unwrap_or(1), 1);
795
796 assert_eq!(Relation::Loaded(7).map(|v| v * 2), Relation::Loaded(14));
797 assert_eq!(
798 Relation::<i32>::NotLoaded.map(|v| v * 2),
799 Relation::NotLoaded
800 );
801
802 assert_eq!(Relation::Loaded(7).map_or(3, |v| v * 2), 14);
803 assert_eq!(Relation::<i32>::Missing.map_or(3, |v| v * 2), 3);
804 assert_eq!(Relation::Loaded(7).ok_or("missing relation"), Ok(7));
805 assert_eq!(
806 Relation::<i32>::NotLoaded.ok_or("missing relation"),
807 Err("missing relation")
808 );
809 }
810}