1use std::collections::{HashMap, HashSet};
2
3use rustrails_macros::{BelongsToAssociation, HasManyAssociation, HasOneAssociation};
4use rustrails_support::{
5 database,
6 inflector::{foreign_key, singularize},
7 runtime,
8};
9use sea_orm::{ColumnTrait, EntityTrait, Iterable};
10use serde::Serialize;
11use serde_json::{Value, json};
12
13use crate::{Querying, Record, RecordError, Relation};
14
15pub mod belongs_to;
17pub mod has_and_belongs_to_many;
19pub mod has_many;
21pub mod has_one;
23
24pub use belongs_to::BelongsToBuilder;
25pub use has_and_belongs_to_many::HasAndBelongsToManyBuilder;
26pub use has_many::HasManyBuilder;
27pub use has_one::HasOneBuilder;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum AssociationType {
32 HasMany,
34 HasOne,
36 BelongsTo,
38 HasAndBelongsToMany,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum DependentAction {
45 Destroy,
47 Delete,
49 Nullify,
51 Restrict,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct AssociationMeta {
58 pub name: String,
60 pub association_type: AssociationType,
62 pub target_table: String,
64 pub foreign_key: String,
66 pub primary_key: String,
68 pub dependent: Option<DependentAction>,
70 pub through: Option<String>,
72 pub polymorphic: bool,
74}
75
76#[derive(Debug, Default)]
78pub struct AssociationRegistry {
79 associations: Vec<AssociationMeta>,
80}
81
82impl AssociationRegistry {
83 #[must_use]
85 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn add(&mut self, meta: AssociationMeta) {
91 self.associations.push(meta);
92 }
93
94 #[must_use]
96 pub fn get(&self, name: &str) -> Option<&AssociationMeta> {
97 self.associations.iter().find(|meta| meta.name == name)
98 }
99
100 #[must_use]
102 pub fn of_type(&self, assoc_type: AssociationType) -> Vec<&AssociationMeta> {
103 self.associations
104 .iter()
105 .filter(|meta| meta.association_type == assoc_type)
106 .collect()
107 }
108
109 #[must_use]
111 pub fn all(&self) -> &[AssociationMeta] {
112 &self.associations
113 }
114}
115
116pub trait HasAssociations: Record {
118 fn associations() -> &'static AssociationRegistry;
120}
121
122pub trait HasManyQuery<Target> {
124 fn has_many(&self) -> Result<Vec<Target>, RecordError>
126 where
127 Self: Record + Serialize + HasManyAssociation<Target>,
128 Target: Querying,
129 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
130 {
131 let Some(owner_id) = self.id() else {
132 return Ok(Vec::new());
133 };
134
135 let definition = <Self as HasManyAssociation<Target>>::association_definition();
136 let foreign_key = definition
137 .foreign_key
138 .map(str::to_owned)
139 .unwrap_or_else(|| default_owner_foreign_key::<Self>());
140
141 load_many_by_field::<Target>(&foreign_key, owner_id)
142 }
143}
144
145pub trait HasManyThroughQuery<Target, Join> {
147 fn has_many(&self) -> Result<Vec<Target>, RecordError>
149 where
150 Self: Record + Serialize + HasManyAssociation<Target>,
151 Target: Querying,
152 Join: Querying + Serialize,
153 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
154 <Join::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
155 {
156 let Some(owner_id) = self.id() else {
157 return Ok(Vec::new());
158 };
159
160 let join_owner_key = default_owner_foreign_key::<Self>();
161 let join_target_key = default_target_foreign_key::<Target>();
162 let join_rows = load_many_by_field::<Join>(&join_owner_key, owner_id)?;
163 let mut target_ids = HashSet::new();
164 let mut targets = Vec::new();
165
166 for join_row in join_rows {
167 let Some(target_id) = extract_serialized_id(&join_row, &join_target_key)? else {
168 continue;
169 };
170 if target_ids.insert(target_id) {
171 targets.push(Target::find_sync(target_id)?);
172 }
173 }
174
175 Ok(targets)
176 }
177}
178
179pub trait BelongsToQuery<Target> {
181 fn belongs_to(&self) -> Result<Target, RecordError>
183 where
184 Self: Record + Serialize + BelongsToAssociation<Target>,
185 Target: Querying,
186 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
187 {
188 let definition = <Self as BelongsToAssociation<Target>>::association_definition();
189 let foreign_key = definition
190 .foreign_key
191 .map(str::to_owned)
192 .unwrap_or_else(|| default_target_foreign_key::<Target>());
193 let target_id = extract_serialized_id(self, &foreign_key)?.ok_or(RecordError::NotFound)?;
194
195 Target::find_sync(target_id)
196 }
197}
198
199pub trait HasOneQuery<Target> {
201 fn has_one(&self) -> Result<Option<Target>, RecordError>
203 where
204 Self: Record + Serialize + HasOneAssociation<Target>,
205 Target: Querying,
206 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
207 {
208 let Some(owner_id) = self.id() else {
209 return Ok(None);
210 };
211
212 let definition = <Self as HasOneAssociation<Target>>::association_definition();
213 let foreign_key = definition
214 .foreign_key
215 .map(str::to_owned)
216 .unwrap_or_else(|| default_owner_foreign_key::<Self>());
217
218 load_one_by_field::<Target>(&foreign_key, owner_id)
219 }
220}
221
222fn default_owner_foreign_key<Model: Record>() -> String {
223 foreign_key(&singularize(Model::table_name()))
224}
225
226fn default_target_foreign_key<Target: Record>() -> String {
227 foreign_key(&singularize(Target::table_name()))
228}
229
230fn load_many_by_field<Target>(field: &str, value: i64) -> Result<Vec<Target>, RecordError>
231where
232 Target: Querying,
233 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
234{
235 database::with_db(|db| {
236 runtime::block_on(
237 Relation::<Target>::new()
238 .r#where(HashMap::from([(field.to_owned(), json!(value))]))
239 .load(db),
240 )
241 })
242}
243
244fn load_one_by_field<Target>(field: &str, value: i64) -> Result<Option<Target>, RecordError>
245where
246 Target: Querying,
247 <Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
248{
249 database::with_db(|db| {
250 runtime::block_on(
251 Relation::<Target>::new()
252 .r#where(HashMap::from([(field.to_owned(), json!(value))]))
253 .first(db),
254 )
255 })
256}
257
258fn extract_serialized_id<T: Serialize>(
259 record: &T,
260 field: &str,
261) -> Result<Option<i64>, RecordError> {
262 let value =
263 serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
264 let object = value.as_object().ok_or_else(|| {
265 RecordError::Invalid("associated record must serialize to a JSON object".to_owned())
266 })?;
267
268 match object.get(field) {
269 Some(Value::Null) => Ok(None),
270 Some(value) => Ok(Some(json_value_to_i64(value, field)?)),
271 None => Err(RecordError::Invalid(format!(
272 "missing association key `{field}` on serialized record"
273 ))),
274 }
275}
276
277fn json_value_to_i64(value: &Value, field: &str) -> Result<i64, RecordError> {
278 match value {
279 Value::Number(number) => {
280 if let Some(value) = number.as_i64() {
281 Ok(value)
282 } else if let Some(value) = number.as_u64() {
283 i64::try_from(value).map_err(|_| {
284 RecordError::Invalid(format!("association key `{field}` does not fit in i64"))
285 })
286 } else {
287 Err(RecordError::Invalid(format!(
288 "association key `{field}` must be an integer"
289 )))
290 }
291 }
292 _ => Err(RecordError::Invalid(format!(
293 "association key `{field}` must be numeric"
294 ))),
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use std::{collections::HashMap, sync::LazyLock};
301
302 use rustrails_macros::{
303 AssociationKind, BelongsToAssociation, HasManyAssociation, HasOneAssociation, belongs_to,
304 has_many, has_one, model,
305 };
306 use rustrails_support::{database, runtime};
307 use serde_json::json;
308
309 use crate::{
310 Persistence,
311 associations::{
312 AssociationRegistry, AssociationType, BelongsToBuilder, BelongsToQuery,
313 DependentAction, HasAndBelongsToManyBuilder, HasAssociations, HasManyBuilder,
314 HasManyQuery, HasOneBuilder, HasOneQuery,
315 },
316 base::test_support::TestUser,
317 };
318
319 model! {
320 QueryBlog {
321 title: String,
322 }
323 table_name: "query_blogs";
324 }
325
326 model! {
327 QueryPost {
328 query_blog_id: i64,
329 title: String,
330 }
331 table_name: "query_posts";
332 }
333
334 model! {
335 QueryProfile {
336 query_blog_id: i64,
337 bio: String,
338 }
339 table_name: "query_profiles";
340 }
341
342 has_many!(QueryBlog => QueryPost, foreign_key: query_blog_id);
343 belongs_to!(QueryPost => QueryBlog, foreign_key: query_blog_id);
344 has_one!(QueryBlog => QueryProfile);
345
346 struct DefaultHasManyAuthor;
347 struct DefaultHasManyPost;
348 struct ForeignKeyHasManyAuthor;
349 struct ForeignKeyHasManyPost;
350 struct ThroughHasManyAuthor;
351 struct ThroughHasManyTag;
352 struct PostTagging;
353 struct DefaultBelongsToComment;
354 struct DefaultBelongsToPost;
355 struct ForeignKeyBelongsToComment;
356 struct ForeignKeyBelongsToBlog;
357 struct DefaultHasOneUser;
358 struct DefaultHasOneProfile;
359
360 has_many!(DefaultHasManyAuthor => DefaultHasManyPost);
361 has_many!(ForeignKeyHasManyAuthor => ForeignKeyHasManyPost, foreign_key: author_id);
362 has_many!(ThroughHasManyAuthor => ThroughHasManyTag, through: PostTagging);
363 belongs_to!(DefaultBelongsToComment => DefaultBelongsToPost);
364 belongs_to!(ForeignKeyBelongsToComment => ForeignKeyBelongsToBlog, foreign_key: blog_id);
365 has_one!(DefaultHasOneUser => DefaultHasOneProfile);
366
367 static TEST_ASSOCIATIONS: LazyLock<AssociationRegistry> = LazyLock::new(|| {
368 let mut registry = AssociationRegistry::new();
369 registry.add(
370 HasManyBuilder::new("comments")
371 .dependent(DependentAction::Destroy)
372 .build(),
373 );
374 registry.add(HasOneBuilder::new("profile").build());
375 registry.add(BelongsToBuilder::new("account").build());
376 registry.add(
377 HasAndBelongsToManyBuilder::new("roles")
378 .through("accounts_roles")
379 .build(),
380 );
381 registry
382 });
383
384 impl HasAssociations for TestUser {
385 fn associations() -> &'static AssociationRegistry {
386 &TEST_ASSOCIATIONS
387 }
388 }
389
390 #[test]
391 fn registry_returns_named_association() {
392 let association = TestUser::associations()
393 .get("comments")
394 .expect("comments association should exist");
395
396 assert_eq!(association.association_type, AssociationType::HasMany);
397 assert_eq!(association.dependent, Some(DependentAction::Destroy));
398 }
399
400 #[test]
401 fn registry_filters_associations_by_type() {
402 let has_many = TestUser::associations().of_type(AssociationType::HasMany);
403 let belongs_to = TestUser::associations().of_type(AssociationType::BelongsTo);
404
405 assert_eq!(has_many.len(), 1);
406 assert_eq!(has_many[0].name, "comments");
407 assert_eq!(belongs_to.len(), 1);
408 assert_eq!(belongs_to[0].name, "account");
409 }
410
411 #[test]
412 fn registry_exposes_all_associations_in_order() {
413 let names = TestUser::associations()
414 .all()
415 .iter()
416 .map(|meta| meta.name.as_str())
417 .collect::<Vec<_>>();
418
419 assert_eq!(names, vec!["comments", "profile", "account", "roles"]);
420 }
421
422 #[test]
423 fn registry_returns_none_for_unknown_association() {
424 assert!(TestUser::associations().get("missing").is_none());
425 }
426 #[test]
427 fn new_registry_starts_empty() {
428 let registry = AssociationRegistry::new();
429
430 assert!(registry.all().is_empty());
431 }
432
433 #[test]
434 fn add_appends_associations() {
435 let mut registry = AssociationRegistry::new();
436 registry.add(HasManyBuilder::new("comments").build());
437 registry.add(HasOneBuilder::new("profile").build());
438
439 assert_eq!(registry.all().len(), 2);
440 }
441
442 #[test]
443 fn get_is_case_sensitive() {
444 assert!(TestUser::associations().get("Comments").is_none());
445 }
446
447 #[test]
448 fn get_returns_first_matching_name_when_duplicates_exist() {
449 let mut registry = AssociationRegistry::new();
450 registry.add(HasManyBuilder::new("comments").build());
451 registry.add(
452 HasManyBuilder::new("comments")
453 .foreign_key("owner_id")
454 .build(),
455 );
456
457 let association = registry.get("comments").expect("association should exist");
458
459 assert_eq!(association.foreign_key, "comment_id");
460 }
461
462 #[test]
463 fn of_type_returns_empty_when_no_associations_match() {
464 let registry = AssociationRegistry::new();
465
466 assert!(registry.of_type(AssociationType::HasMany).is_empty());
467 }
468
469 #[test]
470 fn of_type_preserves_declaration_order() {
471 let mut registry = AssociationRegistry::new();
472 registry.add(HasManyBuilder::new("comments").build());
473 registry.add(HasManyBuilder::new("tags").build());
474
475 let names = registry
476 .of_type(AssociationType::HasMany)
477 .into_iter()
478 .map(|meta| meta.name.as_str())
479 .collect::<Vec<_>>();
480
481 assert_eq!(names, vec!["comments", "tags"]);
482 }
483
484 #[test]
485 fn all_returns_empty_slice_for_new_registry() {
486 let registry = AssociationRegistry::new();
487
488 assert_eq!(registry.all(), &[]);
489 }
490
491 #[test]
492 fn associations_registry_is_stable_across_calls() {
493 assert!(std::ptr::eq(
494 TestUser::associations(),
495 TestUser::associations()
496 ));
497 }
498
499 fn association_kind_name(kind: AssociationKind) -> &'static str {
500 match kind {
501 AssociationKind::HasMany => "has_many",
502 AssociationKind::BelongsTo => "belongs_to",
503 AssociationKind::HasOne => "has_one",
504 }
505 }
506
507 #[test]
508 fn default_has_many_definition_sets_has_many_kind() {
509 let definition =
510 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
511
512 assert_eq!(definition.kind, AssociationKind::HasMany);
513 }
514
515 #[test]
516 fn default_has_many_definition_records_model_name() {
517 let definition =
518 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
519
520 assert_eq!(definition.model, "DefaultHasManyAuthor");
521 }
522
523 #[test]
524 fn default_has_many_definition_records_target_name() {
525 let definition =
526 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
527
528 assert_eq!(definition.target, "DefaultHasManyPost");
529 }
530
531 #[test]
532 fn default_has_many_definition_has_no_foreign_key_override() {
533 let definition =
534 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
535
536 assert_eq!(definition.foreign_key, None);
537 }
538
539 #[test]
540 fn default_has_many_definition_has_no_through_target() {
541 let definition =
542 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
543
544 assert_eq!(definition.through, None);
545 }
546
547 #[test]
548 fn foreign_key_has_many_definition_records_foreign_key_override() {
549 let definition = <ForeignKeyHasManyAuthor as HasManyAssociation<ForeignKeyHasManyPost>>::association_definition();
550
551 assert_eq!(definition.foreign_key, Some("author_id"));
552 }
553
554 #[test]
555 fn through_has_many_definition_records_through_target() {
556 let definition =
557 <ThroughHasManyAuthor as HasManyAssociation<ThroughHasManyTag>>::association_definition(
558 );
559
560 assert_eq!(definition.through, Some("PostTagging"));
561 }
562
563 #[test]
564 fn has_many_association_definition_is_stable_across_calls() {
565 assert_eq!(
566 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition(),
567 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition(),
568 );
569 }
570
571 #[test]
572 fn default_belongs_to_definition_sets_belongs_to_kind() {
573 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
574
575 assert_eq!(definition.kind, AssociationKind::BelongsTo);
576 }
577
578 #[test]
579 fn default_belongs_to_definition_records_model_name() {
580 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
581
582 assert_eq!(definition.model, "DefaultBelongsToComment");
583 }
584
585 #[test]
586 fn default_belongs_to_definition_records_target_name() {
587 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
588
589 assert_eq!(definition.target, "DefaultBelongsToPost");
590 }
591
592 #[test]
593 fn default_belongs_to_definition_has_no_foreign_key_override() {
594 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
595
596 assert_eq!(definition.foreign_key, None);
597 }
598
599 #[test]
600 fn default_belongs_to_definition_has_no_through_target() {
601 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
602
603 assert_eq!(definition.through, None);
604 }
605
606 #[test]
607 fn foreign_key_belongs_to_definition_records_foreign_key_override() {
608 let definition = <ForeignKeyBelongsToComment as BelongsToAssociation<
609 ForeignKeyBelongsToBlog,
610 >>::association_definition();
611
612 assert_eq!(definition.foreign_key, Some("blog_id"));
613 }
614
615 #[test]
616 fn belongs_to_association_definition_is_stable_across_calls() {
617 assert_eq!(
618 <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition(),
619 <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition(),
620 );
621 }
622
623 #[test]
624 fn default_has_one_definition_sets_has_one_kind() {
625 let definition =
626 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
627 );
628
629 assert_eq!(definition.kind, AssociationKind::HasOne);
630 }
631
632 #[test]
633 fn default_has_one_definition_records_model_name() {
634 let definition =
635 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
636 );
637
638 assert_eq!(definition.model, "DefaultHasOneUser");
639 }
640
641 #[test]
642 fn default_has_one_definition_records_target_name() {
643 let definition =
644 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
645 );
646
647 assert_eq!(definition.target, "DefaultHasOneProfile");
648 }
649
650 #[test]
651 fn default_has_one_definition_has_no_foreign_key_override() {
652 let definition =
653 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
654 );
655
656 assert_eq!(definition.foreign_key, None);
657 }
658
659 #[test]
660 fn default_has_one_definition_has_no_through_target() {
661 let definition =
662 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
663 );
664
665 assert_eq!(definition.through, None);
666 }
667
668 #[test]
669 fn has_one_association_definition_is_stable_across_calls() {
670 assert_eq!(
671 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
672 ),
673 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
674 ),
675 );
676 }
677
678 #[test]
679 fn has_many_definition_kind_matches_has_many_branch() {
680 let definition =
681 <DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
682
683 assert_eq!(association_kind_name(definition.kind), "has_many");
684 }
685
686 #[test]
687 fn belongs_to_definition_kind_matches_belongs_to_branch() {
688 let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
689
690 assert_eq!(association_kind_name(definition.kind), "belongs_to");
691 }
692
693 #[test]
694 fn has_one_definition_kind_matches_has_one_branch() {
695 let definition =
696 <DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
697 );
698
699 assert_eq!(association_kind_name(definition.kind), "has_one");
700 }
701
702 #[test]
703 fn association_query_traits_load_related_records() {
704 let _runtime = runtime::init_runtime();
705 database::establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
706
707 runtime::block_on(async {
708 use sea_orm::ConnectionTrait;
709
710 let db = database::db();
711 db.execute_unprepared(
712 "CREATE TABLE query_blogs (id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL)",
713 )
714 .await
715 .expect("query_blogs table should be created");
716 db.execute_unprepared(
717 "CREATE TABLE query_posts (id INTEGER PRIMARY KEY AUTOINCREMENT, query_blog_id INTEGER NOT NULL, title TEXT NOT NULL)",
718 )
719 .await
720 .expect("query_posts table should be created");
721 db.execute_unprepared(
722 "CREATE TABLE query_profiles (id INTEGER PRIMARY KEY AUTOINCREMENT, query_blog_id INTEGER NOT NULL, bio TEXT NOT NULL)",
723 )
724 .await
725 .expect("query_profiles table should be created");
726 });
727
728 let blog = QueryBlog::create_sync(HashMap::from([("title".to_owned(), json!("Main"))]))
729 .expect("blog should be created");
730 let blog_id = blog.id.expect("blog should have an id");
731
732 let post = QueryPost::create_sync(HashMap::from([
733 ("query_blog_id".to_owned(), json!(blog_id)),
734 ("title".to_owned(), json!("First")),
735 ]))
736 .expect("post should be created");
737 QueryProfile::create_sync(HashMap::from([
738 ("query_blog_id".to_owned(), json!(blog_id)),
739 ("bio".to_owned(), json!("About the blog")),
740 ]))
741 .expect("profile should be created");
742
743 let posts: Vec<QueryPost> = blog.has_many().expect("has_many should load related posts");
744 assert_eq!(posts.len(), 1);
745 assert_eq!(posts[0].title, "First");
746
747 let owner: QueryBlog = post.belongs_to().expect("belongs_to should load the owner");
748 assert_eq!(owner.title, "Main");
749
750 let profile = blog
751 .has_one()
752 .expect("has_one should query the related record");
753 assert_eq!(profile.expect("profile should exist").bio, "About the blog");
754 }
755}