Skip to main content

sqlxo_traits/
core.rs

1use smallvec::SmallVec;
2use sqlx::{
3	postgres::PgRow,
4	prelude::Type,
5	Postgres,
6};
7
8pub trait QueryModel =
9	Send + Clone + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>;
10
11pub trait FilterQuery = Filterable + Clone;
12
13pub trait QuerySort = Sortable + Clone;
14
15pub trait Filterable {
16	type Entity: QueryModel;
17
18	fn write<W: SqlWrite>(&self, w: &mut W);
19}
20
21pub trait SqlWrite {
22	fn push(&mut self, s: &str);
23
24	fn bind<T>(&mut self, value: T)
25	where
26		T: sqlx::Encode<'static, Postgres> + Send + 'static,
27		T: Type<Postgres>;
28}
29
30pub trait QueryContext: Send + Sync + 'static {
31	const TABLE: &'static str;
32
33	type Model: QueryModel
34		+ Send
35		+ Sync
36		+ JoinNavigationModel
37		+ WebJoinGraph
38		+ PrimaryKey;
39	type Query: FilterQuery + Send + Sync;
40	type Sort: QuerySort + Send + Sync;
41	type Join: SqlJoin + Send + Sync;
42}
43
44pub trait Sortable {
45	type Entity: QueryModel;
46
47	fn sort_clause(&self) -> String;
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum JoinKind {
52	Left,
53	Inner,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum RelationFetchMode {
58	Eager,
59	Lazy,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct JoinThroughDescriptor {
64	pub table:         &'static str,
65	pub alias_segment: &'static str,
66	pub left_field:    &'static str,
67	pub right_field:   &'static str,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub struct JoinDescriptor {
72	pub left_table:    &'static str,
73	pub left_field:    &'static str,
74	pub right_table:   &'static str,
75	pub right_field:   &'static str,
76	pub alias_segment: &'static str,
77	pub identifier:    &'static str,
78	pub through:       Option<JoinThroughDescriptor>,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct JoinSegment {
83	pub descriptor: JoinDescriptor,
84	pub kind:       JoinKind,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct JoinPath {
89	segments: Vec<JoinSegment>,
90	start:    usize,
91}
92
93impl JoinPath {
94	pub fn from_join<J: SqlJoin>(join: J, kind: JoinKind) -> Self {
95		Self::new(join.descriptor(), kind)
96	}
97
98	pub fn new(descriptor: JoinDescriptor, kind: JoinKind) -> Self {
99		Self {
100			segments: vec![JoinSegment { descriptor, kind }],
101			start:    0,
102		}
103	}
104
105	pub fn then<J: SqlJoin>(mut self, join: J, kind: JoinKind) -> Self {
106		let descriptor = join.descriptor();
107
108		if let Some(prev) = self.segments.last() {
109			assert_eq!(
110				prev.descriptor.right_table, descriptor.left_table,
111				"Invalid join path: expected next hop to start at `{}` but \
112				 found `{}`",
113				prev.descriptor.right_table, descriptor.left_table,
114			);
115		}
116
117		self.segments.push(JoinSegment { descriptor, kind });
118		self
119	}
120
121	pub fn segments(&self) -> &[JoinSegment] {
122		&self.segments[self.start..]
123	}
124
125	pub fn len(&self) -> usize {
126		self.segments().len()
127	}
128
129	pub fn append(&mut self, tail: &JoinPath) {
130		if tail.is_empty() {
131			return;
132		}
133
134		let Some(last) = self.segments().last() else {
135			return;
136		};
137		let Some(first) = tail.segments().first() else {
138			return;
139		};
140
141		assert_eq!(
142			last.descriptor.right_table, first.descriptor.left_table,
143			"Invalid join append: left table `{}` does not match `{}`",
144			last.descriptor.right_table, first.descriptor.left_table,
145		);
146
147		self.segments.extend_from_slice(tail.segments());
148	}
149
150	pub fn strip_prefix(&self, len: usize) -> Option<Self> {
151		let new_start = self.start + len;
152		if new_start > self.segments.len() {
153			return None;
154		}
155
156		Some(Self {
157			segments: self.segments.clone(),
158			start:    new_start,
159		})
160	}
161
162	pub fn tail(&self) -> Option<Self> {
163		if self.segments.len() - self.start <= 1 {
164			None
165		} else {
166			Some(Self {
167				segments: self.segments.clone(),
168				start:    self.start + 1,
169			})
170		}
171	}
172
173	pub fn is_empty(&self) -> bool {
174		self.len() == 0
175	}
176
177	pub fn first_table(&self) -> Option<&'static str> {
178		self.segments().first().map(|seg| seg.descriptor.left_table)
179	}
180
181	pub fn alias(&self) -> String {
182		self.alias_prefix(self.len())
183	}
184
185	pub fn alias_prefix(&self, len: usize) -> String {
186		assert!(len <= self.len());
187		let mut alias = String::new();
188		let end = self.start + len;
189		for segment in &self.segments[..end] {
190			alias.push_str(segment.descriptor.alias_segment);
191		}
192		alias
193	}
194}
195
196pub trait SqlJoin {
197	fn descriptor(&self) -> JoinDescriptor;
198}
199
200#[derive(Debug, Clone, PartialEq, Eq)]
201pub struct AliasedColumn {
202	pub table_alias: String,
203	pub column:      &'static str,
204	pub alias:       String,
205}
206
207impl AliasedColumn {
208	pub fn new(
209		table_alias: impl Into<String>,
210		column: &'static str,
211		alias: impl Into<String>,
212	) -> Self {
213		Self {
214			table_alias: table_alias.into(),
215			column,
216			alias: alias.into(),
217		}
218	}
219}
220
221#[derive(PartialEq, Default, Eq)]
222pub enum Relation<T> {
223	#[default]
224	NotLoaded,
225	Missing,
226	Loaded(T),
227}
228
229impl<T: Clone> Clone for Relation<T> {
230	fn clone(&self) -> Self {
231		match self {
232			Self::NotLoaded => Self::NotLoaded,
233			Self::Missing => Self::Missing,
234			Self::Loaded(v) => Self::Loaded(v.clone()),
235		}
236	}
237}
238
239impl<T: std::fmt::Debug> std::fmt::Debug for Relation<T> {
240	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241		match self {
242			Self::NotLoaded => f.write_str("Relation::NotLoaded"),
243			Self::Missing => f.write_str("Relation::Missing"),
244			Self::Loaded(v) => {
245				f.debug_tuple("Relation::Loaded").field(v).finish()
246			}
247		}
248	}
249}
250
251impl<T> Relation<T> {
252	pub fn is_loaded(&self) -> bool {
253		!self.is_not_loaded()
254	}
255
256	pub fn is_not_loaded(&self) -> bool {
257		matches!(self, Self::NotLoaded)
258	}
259
260	pub fn is_missing(&self) -> bool {
261		matches!(self, Self::Missing)
262	}
263
264	pub fn is_present(&self) -> bool {
265		matches!(self, Self::Loaded(_))
266	}
267
268	pub fn is_some(&self) -> bool {
269		self.is_present()
270	}
271
272	pub fn is_none(&self) -> bool {
273		!self.is_some()
274	}
275
276	pub fn as_ref(&self) -> Relation<&T> {
277		match self {
278			Self::NotLoaded => Relation::NotLoaded,
279			Self::Missing => Relation::Missing,
280			Self::Loaded(value) => Relation::Loaded(value),
281		}
282	}
283
284	pub fn as_mut(&mut self) -> Relation<&mut T> {
285		match self {
286			Self::NotLoaded => Relation::NotLoaded,
287			Self::Missing => Relation::Missing,
288			Self::Loaded(value) => Relation::Loaded(value),
289		}
290	}
291
292	pub fn as_option(&self) -> Option<&T> {
293		match self {
294			Self::Loaded(value) => Some(value),
295			Self::NotLoaded | Self::Missing => None,
296		}
297	}
298
299	pub fn as_option_mut(&mut self) -> Option<&mut T> {
300		match self {
301			Self::Loaded(value) => Some(value),
302			Self::NotLoaded | Self::Missing => None,
303		}
304	}
305
306	pub fn into_option(self) -> Option<T> {
307		match self {
308			Self::Loaded(value) => Some(value),
309			Self::NotLoaded | Self::Missing => None,
310		}
311	}
312
313	pub fn map<U, F>(self, f: F) -> Relation<U>
314	where
315		F: FnOnce(T) -> U,
316	{
317		match self {
318			Self::Loaded(value) => Relation::Loaded(f(value)),
319			Self::NotLoaded => Relation::NotLoaded,
320			Self::Missing => Relation::Missing,
321		}
322	}
323
324	pub fn map_or<U, F>(self, default: U, f: F) -> U
325	where
326		F: FnOnce(T) -> U,
327	{
328		match self {
329			Self::Loaded(value) => f(value),
330			Self::NotLoaded | Self::Missing => default,
331		}
332	}
333
334	pub fn map_or_else<U, D, F>(self, default: D, f: F) -> U
335	where
336		D: FnOnce() -> U,
337		F: FnOnce(T) -> U,
338	{
339		match self {
340			Self::Loaded(value) => f(value),
341			Self::NotLoaded | Self::Missing => default(),
342		}
343	}
344
345	pub fn ok_or<E>(self, err: E) -> Result<T, E> {
346		match self {
347			Self::Loaded(value) => Ok(value),
348			Self::NotLoaded | Self::Missing => Err(err),
349		}
350	}
351
352	pub fn ok_or_else<E, F>(self, err: F) -> Result<T, E>
353	where
354		F: FnOnce() -> E,
355	{
356		match self {
357			Self::Loaded(value) => Ok(value),
358			Self::NotLoaded | Self::Missing => Err(err()),
359		}
360	}
361
362	pub fn unwrap(self) -> T {
363		match self {
364			Self::Loaded(value) => value,
365			Self::NotLoaded => {
366				panic!("called `Relation::unwrap()` on a `NotLoaded` relation")
367			}
368			Self::Missing => {
369				panic!("called `Relation::unwrap()` on a `Missing` relation")
370			}
371		}
372	}
373
374	pub fn expect(self, msg: &str) -> T {
375		match self {
376			Self::Loaded(value) => value,
377			Self::NotLoaded | Self::Missing => panic!("{msg}"),
378		}
379	}
380
381	pub fn unwrap_or(self, default: T) -> T {
382		match self {
383			Self::Loaded(value) => value,
384			Self::NotLoaded | Self::Missing => default,
385		}
386	}
387
388	pub fn unwrap_or_else<F>(self, f: F) -> T
389	where
390		F: FnOnce() -> T,
391	{
392		match self {
393			Self::Loaded(value) => value,
394			Self::NotLoaded | Self::Missing => f(),
395		}
396	}
397
398	pub fn unwrap_or_default(self) -> T
399	where
400		T: Default,
401	{
402		self.unwrap_or_else(T::default)
403	}
404}
405
406pub type JoinValue<T> = Relation<T>;
407
408pub fn merge_join_collections<T>(
409	target: &mut Relation<Vec<T>>,
410	incoming: Relation<Vec<T>>,
411) where
412	T: JoinIdentifiable,
413{
414	match target {
415		Relation::NotLoaded => {
416			*target = incoming;
417		}
418		Relation::Missing => {
419			if let Relation::Loaded(values) = incoming {
420				*target = Relation::Loaded(values);
421			}
422		}
423		Relation::Loaded(existing) => {
424			if let Relation::Loaded(mut values) = incoming {
425				let mut keys: Vec<T::Key> =
426					existing.iter().map(|item| item.join_key()).collect();
427				for value in values.drain(..) {
428					let key = value.join_key();
429					if keys.iter().all(|existing_key| existing_key != &key) {
430						keys.push(key);
431						existing.push(value);
432					}
433				}
434			}
435		}
436	}
437}
438
439pub trait JoinLoadable: Sized {
440	fn project_join_columns(
441		alias: &str,
442		out: &mut SmallVec<[AliasedColumn; 4]>,
443	);
444
445	fn hydrate_from_join(
446		row: &PgRow,
447		alias: &str,
448	) -> Result<Option<Self>, sqlx::Error>;
449}
450
451pub trait JoinIdentifiable {
452	type Key: PartialEq;
453
454	fn join_key(&self) -> Self::Key;
455}
456
457pub trait JoinNavigationModel {
458	fn collect_join_columns(
459		joins: Option<&[JoinPath]>,
460		base_alias: &str,
461	) -> SmallVec<[AliasedColumn; 4]>;
462
463	fn hydrate_navigations(
464		&mut self,
465		joins: Option<&[JoinPath]>,
466		row: &PgRow,
467		base_alias: &str,
468	) -> Result<(), sqlx::Error>;
469
470	fn has_collection_joins(_joins: Option<&[JoinPath]>) -> bool {
471		false
472	}
473
474	fn merge_collection_rows(
475		rows: Vec<Self>,
476		_joins: Option<&[JoinPath]>,
477	) -> Vec<Self>
478	where
479		Self: Sized,
480	{
481		rows
482	}
483
484	fn relation_fetch_mode(_identifier: &str) -> Option<RelationFetchMode> {
485		None
486	}
487
488	fn collect_default_join_paths(
489		_include_lazy: bool,
490		_visiting_tables: &mut Vec<&'static str>,
491	) -> SmallVec<[JoinPath; 4]> {
492		SmallVec::new()
493	}
494
495	fn default_join_paths(include_lazy: bool) -> SmallVec<[JoinPath; 4]>
496	where
497		Self: Sized,
498	{
499		let mut visiting_tables = Vec::new();
500		Self::collect_default_join_paths(include_lazy, &mut visiting_tables)
501	}
502}
503
504pub trait WebJoinGraph {
505	fn resolve_join_path(segments: &[&str], kind: JoinKind)
506		-> Option<JoinPath>;
507}
508
509pub trait PrimaryKey {
510	const PRIMARY_KEY: &'static [&'static str];
511}
512
513pub trait Model {}
514
515pub trait Deletable {
516	const IS_SOFT_DELETE: bool;
517	const DELETE_MARKER_FIELD: Option<&'static str>;
518}
519
520pub trait GetDeleteMarker {
521	fn delete_marker_field() -> Option<&'static str>;
522}
523
524impl<T> GetDeleteMarker for T {
525	default fn delete_marker_field() -> Option<&'static str> {
526		None
527	}
528}
529
530impl<T: Deletable> GetDeleteMarker for T {
531	fn delete_marker_field() -> Option<&'static str> {
532		T::DELETE_MARKER_FIELD
533	}
534}
535
536pub trait Updatable {
537	type UpdateModel: UpdateModel<Entity = Self>;
538	const UPDATE_MARKER_FIELD: Option<&'static str>;
539}
540
541pub trait UpdateModel: Clone + Send + Sync {
542	type Entity: QueryModel;
543
544	fn apply_updates(
545		&self,
546		qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
547		has_previous: bool,
548	) -> Vec<&'static str>;
549
550	fn append_relation_ctes(
551		&self,
552		_qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
553		_affected_alias: &str,
554	) {
555	}
556
557	fn append_relation_dependency_columns(
558		&self,
559		_qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
560	) {
561	}
562
563	fn apply_relation_payload(&self, _entity: &mut Self::Entity) {}
564}
565
566pub trait Creatable {
567	type CreateModel: CreateModel<Entity = Self>;
568	const INSERT_MARKER_FIELD: Option<&'static str>;
569}
570
571pub trait CreateModel: Clone + Send + Sync {
572	type Entity: QueryModel;
573
574	fn apply_inserts(
575		&self,
576		qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
577		insert_marker_field: Option<&'static str>,
578	);
579
580	fn append_relation_ctes(
581		&self,
582		_qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
583		_affected_alias: &str,
584	) {
585	}
586
587	fn append_relation_dependency_columns(
588		&self,
589		_qb: &mut sqlx::QueryBuilder<'static, sqlx::Postgres>,
590	) {
591	}
592
593	fn apply_relation_payload(&self, _entity: &mut Self::Entity) {}
594}
595
596#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
597pub enum SearchWeight {
598	#[default]
599	A,
600	B,
601	C,
602	D,
603}
604
605impl SearchWeight {
606	pub fn to_char(self) -> char {
607		match self {
608			Self::A => 'A',
609			Self::B => 'B',
610			Self::C => 'C',
611			Self::D => 'D',
612		}
613	}
614
615	pub fn sql_literal(self) -> &'static str {
616		match self {
617			Self::A => "'A'",
618			Self::B => "'B'",
619			Self::C => "'C'",
620			Self::D => "'D'",
621		}
622	}
623}
624
625pub trait FullTextSearchConfig {
626	fn include_rank(&self) -> bool;
627
628	fn fuzzy_threshold(&self) -> Option<f64> {
629		None
630	}
631
632	fn fuzzy_tokens(&self) -> Option<&[String]> {
633		None
634	}
635
636	fn min_term_match_ratio(&self) -> f64 {
637		0.5
638	}
639}
640
641pub trait FullTextSearchConfigBuilder: FullTextSearchConfig {
642	fn new_with_query(query: String) -> Self;
643	fn apply_language(self, language: Option<String>) -> Self;
644	fn apply_rank(self, include_rank: Option<bool>) -> Self;
645
646	fn apply_fuzzy(self, _enable_fuzzy: Option<bool>) -> Self
647	where
648		Self: Sized,
649	{
650		self
651	}
652
653	fn apply_fuzzy_threshold(self, _threshold: Option<f64>) -> Self
654	where
655		Self: Sized,
656	{
657		self
658	}
659}
660
661pub trait FullTextSearchJoinConfig {
662	type Join;
663
664	fn with_join(self, join: Self::Join) -> Self;
665}
666
667pub trait FullTextSearchable: Sized {
668	type FullTextSearchField: Copy + Eq;
669	type FullTextSearchConfig: FullTextSearchConfig + Send + Sync;
670	type FullTextSearchJoin: Copy + Eq;
671
672	fn write_tsvector<W>(
673		w: &mut W,
674		base_alias: &str,
675		joins: Option<&[JoinPath]>,
676		config: &Self::FullTextSearchConfig,
677	) where
678		W: SqlWrite;
679
680	fn write_tsquery<W>(w: &mut W, config: &Self::FullTextSearchConfig)
681	where
682		W: SqlWrite;
683
684	fn write_rank<W>(
685		w: &mut W,
686		base_alias: &str,
687		joins: Option<&[JoinPath]>,
688		config: &Self::FullTextSearchConfig,
689	) where
690		W: SqlWrite;
691
692	fn write_fuzzy<W>(
693		w: &mut W,
694		base_alias: &str,
695		joins: Option<&[JoinPath]>,
696		config: &Self::FullTextSearchConfig,
697	) where
698		W: SqlWrite,
699	{
700		let _ = (base_alias, joins, config);
701		w.push("FALSE");
702	}
703
704	fn write_search_document<W>(
705		w: &mut W,
706		base_alias: &str,
707		joins: Option<&[JoinPath]>,
708		config: &Self::FullTextSearchConfig,
709	) where
710		W: SqlWrite,
711	{
712		let _ = (base_alias, joins, config);
713		w.push("''");
714	}
715
716	fn write_search_predicate<W>(
717		w: &mut W,
718		base_alias: &str,
719		joins: Option<&[JoinPath]>,
720		config: &Self::FullTextSearchConfig,
721	) where
722		W: SqlWrite,
723	{
724		w.push("(");
725		w.push("(");
726		Self::write_tsvector(w, base_alias, joins, config);
727		w.push(") @@ (");
728		Self::write_tsquery(w, config);
729		w.push(")");
730		if config.fuzzy_threshold().is_some() &&
731			config
732				.fuzzy_tokens()
733				.map(|tokens| !tokens.is_empty())
734				.unwrap_or(false)
735		{
736			w.push(" OR ");
737			Self::write_fuzzy(w, base_alias, joins, config);
738		}
739		w.push(")");
740	}
741
742	fn write_search_score<W>(
743		w: &mut W,
744		base_alias: &str,
745		joins: Option<&[JoinPath]>,
746		config: &Self::FullTextSearchConfig,
747	) where
748		W: SqlWrite,
749	{
750		Self::write_rank(w, base_alias, joins, config);
751	}
752
753	fn resolve_search_join_path(
754		_segments: &[&str],
755	) -> Option<Self::FullTextSearchJoin> {
756		None
757	}
758}
759
760#[cfg(test)]
761mod tests {
762	use super::Relation;
763
764	#[test]
765	fn relation_state_helpers_work() {
766		let not_loaded = Relation::<i32>::NotLoaded;
767		let missing = Relation::<i32>::Missing;
768		let loaded = Relation::Loaded(42);
769
770		assert!(not_loaded.is_not_loaded());
771		assert!(not_loaded.is_none());
772		assert!(!not_loaded.is_loaded());
773		assert!(!not_loaded.is_some());
774		assert!(!not_loaded.is_missing());
775		assert!(!not_loaded.is_present());
776
777		assert!(missing.is_loaded());
778		assert!(missing.is_missing());
779		assert!(missing.is_none());
780		assert!(!missing.is_some());
781		assert!(!missing.is_not_loaded());
782
783		assert!(loaded.is_loaded());
784		assert!(loaded.is_present());
785		assert!(loaded.is_some());
786		assert!(!loaded.is_none());
787		assert!(!loaded.is_missing());
788	}
789
790	#[test]
791	fn relation_option_like_helpers_work() {
792		assert_eq!(Relation::Loaded(7).unwrap_or(1), 7);
793		assert_eq!(Relation::<i32>::NotLoaded.unwrap_or(1), 1);
794		assert_eq!(Relation::<i32>::Missing.unwrap_or(1), 1);
795
796		assert_eq!(Relation::Loaded(7).map(|v| v * 2), Relation::Loaded(14));
797		assert_eq!(
798			Relation::<i32>::NotLoaded.map(|v| v * 2),
799			Relation::NotLoaded
800		);
801
802		assert_eq!(Relation::Loaded(7).map_or(3, |v| v * 2), 14);
803		assert_eq!(Relation::<i32>::Missing.map_or(3, |v| v * 2), 3);
804		assert_eq!(Relation::Loaded(7).ok_or("missing relation"), Ok(7));
805		assert_eq!(
806			Relation::<i32>::NotLoaded.ok_or("missing relation"),
807			Err("missing relation")
808		);
809	}
810}