1use crate::errors::MigrationError;
4use crate::{IntoDomain, MigratesTo, Versioned, VersionedWrapper};
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::collections::HashMap;
8use std::marker::PhantomData;
9
10type MigrationFn = Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, MigrationError>>;
11
12struct EntityMigrationPath {
14 steps: HashMap<String, MigrationFn>,
16 finalize: Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, MigrationError>>,
18}
19
20pub struct Migrator {
22 paths: HashMap<String, EntityMigrationPath>,
23}
24
25impl Migrator {
26 pub fn new() -> Self {
28 Self {
29 paths: HashMap::new(),
30 }
31 }
32
33 pub fn define(entity: &str) -> MigrationPathBuilder<Start> {
35 MigrationPathBuilder::new(entity.to_string())
36 }
37
38 pub fn register<D>(&mut self, path: MigrationPath<D>) -> Result<(), MigrationError> {
48 Self::validate_migration_path(&path.entity, &path.versions)?;
49 self.paths.insert(path.entity, path.inner);
50 Ok(())
51 }
52
53 fn validate_migration_path(entity: &str, versions: &[String]) -> Result<(), MigrationError> {
55 Self::check_circular_path(entity, versions)?;
57
58 Self::check_version_ordering(entity, versions)?;
60
61 Ok(())
62 }
63
64 fn check_circular_path(entity: &str, versions: &[String]) -> Result<(), MigrationError> {
66 let mut seen = std::collections::HashSet::new();
67
68 for version in versions {
69 if !seen.insert(version) {
70 let path = versions.join(" -> ");
72 return Err(MigrationError::CircularMigrationPath {
73 entity: entity.to_string(),
74 path,
75 });
76 }
77 }
78
79 Ok(())
80 }
81
82 fn check_version_ordering(entity: &str, versions: &[String]) -> Result<(), MigrationError> {
84 for i in 0..versions.len().saturating_sub(1) {
85 let current = &versions[i];
86 let next = &versions[i + 1];
87
88 let current_ver = semver::Version::parse(current).map_err(|e| {
90 MigrationError::DeserializationError(format!("Invalid semver '{}': {}", current, e))
91 })?;
92
93 let next_ver = semver::Version::parse(next).map_err(|e| {
94 MigrationError::DeserializationError(format!("Invalid semver '{}': {}", next, e))
95 })?;
96
97 if next_ver <= current_ver {
99 return Err(MigrationError::InvalidVersionOrder {
100 entity: entity.to_string(),
101 from: current.clone(),
102 to: next.clone(),
103 });
104 }
105 }
106
107 Ok(())
108 }
109
110 pub fn load_from<D, T>(&self, entity: &str, data: T) -> Result<D, MigrationError>
143 where
144 D: DeserializeOwned,
145 T: Serialize,
146 {
147 let value = serde_json::to_value(data).map_err(|e| {
149 MigrationError::DeserializationError(format!(
150 "Failed to convert input data to internal format: {}",
151 e
152 ))
153 })?;
154
155 let wrapper: VersionedWrapper<serde_json::Value> =
157 serde_json::from_value(value).map_err(|e| {
158 MigrationError::DeserializationError(format!(
159 "Failed to parse VersionedWrapper: {}",
160 e
161 ))
162 })?;
163
164 let path = self
166 .paths
167 .get(entity)
168 .ok_or_else(|| MigrationError::EntityNotFound(entity.to_string()))?;
169
170 let mut current_version = wrapper.version.clone();
172 let mut current_data = wrapper.data;
173
174 while let Some(migrate_fn) = path.steps.get(¤t_version) {
176 current_data = migrate_fn(current_data)?;
177 if let Ok(wrapped) =
180 serde_json::from_value::<VersionedWrapper<serde_json::Value>>(current_data.clone())
181 {
182 current_version = wrapped.version;
183 current_data = wrapped.data;
184 } else {
185 break;
186 }
187 }
188
189 let domain_value = (path.finalize)(current_data)?;
191
192 serde_json::from_value(domain_value).map_err(|e| {
193 MigrationError::DeserializationError(format!("Failed to convert to domain: {}", e))
194 })
195 }
196
197 pub fn load<D: DeserializeOwned>(&self, entity: &str, json: &str) -> Result<D, MigrationError> {
225 let data: serde_json::Value = serde_json::from_str(json).map_err(|e| {
226 MigrationError::DeserializationError(format!("Failed to parse JSON: {}", e))
227 })?;
228 self.load_from(entity, data)
229 }
230
231 pub fn save<T: Versioned + Serialize>(&self, data: T) -> Result<String, MigrationError> {
262 let wrapper = VersionedWrapper::from_versioned(data);
263
264 serde_json::to_string(&wrapper).map_err(|e| {
265 MigrationError::SerializationError(format!("Failed to serialize data: {}", e))
266 })
267 }
268}
269
270impl Default for Migrator {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276pub struct Start;
278
279pub struct HasFrom<V>(PhantomData<V>);
281
282pub struct HasSteps<V>(PhantomData<V>);
284
285pub struct MigrationPathBuilder<State> {
287 entity: String,
288 steps: HashMap<String, MigrationFn>,
289 versions: Vec<String>,
290 _state: PhantomData<State>,
291}
292
293impl MigrationPathBuilder<Start> {
294 fn new(entity: String) -> Self {
295 Self {
296 entity,
297 steps: HashMap::new(),
298 versions: Vec::new(),
299 _state: PhantomData,
300 }
301 }
302
303 pub fn from<V: Versioned + DeserializeOwned>(self) -> MigrationPathBuilder<HasFrom<V>> {
305 let mut versions = self.versions;
306 versions.push(V::VERSION.to_string());
307
308 MigrationPathBuilder {
309 entity: self.entity,
310 steps: self.steps,
311 versions,
312 _state: PhantomData,
313 }
314 }
315}
316
317impl<V> MigrationPathBuilder<HasFrom<V>>
318where
319 V: Versioned + DeserializeOwned,
320{
321 pub fn step<Next>(mut self) -> MigrationPathBuilder<HasSteps<Next>>
323 where
324 V: MigratesTo<Next>,
325 Next: Versioned + DeserializeOwned + Serialize,
326 {
327 let from_version = V::VERSION.to_string();
328 let migration_fn: MigrationFn = Box::new(move |value| {
329 let from_value: V = serde_json::from_value(value).map_err(|e| {
330 MigrationError::DeserializationError(format!(
331 "Failed to deserialize version {}: {}",
332 V::VERSION,
333 e
334 ))
335 })?;
336
337 let to_value = from_value.migrate();
338 let wrapped = VersionedWrapper::from_versioned(to_value);
339
340 serde_json::to_value(wrapped).map_err(|e| MigrationError::MigrationStepFailed {
341 from: V::VERSION.to_string(),
342 to: Next::VERSION.to_string(),
343 error: e.to_string(),
344 })
345 });
346
347 self.steps.insert(from_version, migration_fn);
348 self.versions.push(Next::VERSION.to_string());
349
350 MigrationPathBuilder {
351 entity: self.entity,
352 steps: self.steps,
353 versions: self.versions,
354 _state: PhantomData,
355 }
356 }
357
358 pub fn into<D: DeserializeOwned + Serialize>(self) -> MigrationPath<D>
360 where
361 V: IntoDomain<D>,
362 {
363 let finalize: Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, MigrationError>> =
364 Box::new(move |value| {
365 let versioned: V = serde_json::from_value(value).map_err(|e| {
366 MigrationError::DeserializationError(format!(
367 "Failed to deserialize final version: {}",
368 e
369 ))
370 })?;
371
372 let domain = versioned.into_domain();
373
374 serde_json::to_value(domain).map_err(|e| MigrationError::MigrationStepFailed {
375 from: V::VERSION.to_string(),
376 to: "domain".to_string(),
377 error: e.to_string(),
378 })
379 });
380
381 MigrationPath {
382 entity: self.entity,
383 inner: EntityMigrationPath {
384 steps: self.steps,
385 finalize,
386 },
387 versions: self.versions,
388 _phantom: PhantomData,
389 }
390 }
391}
392
393impl<V> MigrationPathBuilder<HasSteps<V>>
394where
395 V: Versioned + DeserializeOwned,
396{
397 pub fn step<Next>(mut self) -> MigrationPathBuilder<HasSteps<Next>>
399 where
400 V: MigratesTo<Next>,
401 Next: Versioned + DeserializeOwned + Serialize,
402 {
403 let from_version = V::VERSION.to_string();
404 let migration_fn: MigrationFn = Box::new(move |value| {
405 let from_value: V = serde_json::from_value(value).map_err(|e| {
406 MigrationError::DeserializationError(format!(
407 "Failed to deserialize version {}: {}",
408 V::VERSION,
409 e
410 ))
411 })?;
412
413 let to_value = from_value.migrate();
414 let wrapped = VersionedWrapper::from_versioned(to_value);
415
416 serde_json::to_value(wrapped).map_err(|e| MigrationError::MigrationStepFailed {
417 from: V::VERSION.to_string(),
418 to: Next::VERSION.to_string(),
419 error: e.to_string(),
420 })
421 });
422
423 self.steps.insert(from_version, migration_fn);
424 self.versions.push(Next::VERSION.to_string());
425
426 MigrationPathBuilder {
427 entity: self.entity,
428 steps: self.steps,
429 versions: self.versions,
430 _state: PhantomData,
431 }
432 }
433
434 pub fn into<D: DeserializeOwned + Serialize>(self) -> MigrationPath<D>
436 where
437 V: IntoDomain<D>,
438 {
439 let finalize: Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, MigrationError>> =
440 Box::new(move |value| {
441 let versioned: V = serde_json::from_value(value).map_err(|e| {
442 MigrationError::DeserializationError(format!(
443 "Failed to deserialize final version: {}",
444 e
445 ))
446 })?;
447
448 let domain = versioned.into_domain();
449
450 serde_json::to_value(domain).map_err(|e| MigrationError::MigrationStepFailed {
451 from: V::VERSION.to_string(),
452 to: "domain".to_string(),
453 error: e.to_string(),
454 })
455 });
456
457 MigrationPath {
458 entity: self.entity,
459 inner: EntityMigrationPath {
460 steps: self.steps,
461 finalize,
462 },
463 versions: self.versions,
464 _phantom: PhantomData,
465 }
466 }
467}
468
469pub struct MigrationPath<D> {
471 entity: String,
472 inner: EntityMigrationPath,
473 versions: Vec<String>,
475 _phantom: PhantomData<D>,
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::{IntoDomain, MigratesTo, Versioned, VersionedWrapper};
482 use serde::{Deserialize, Serialize};
483
484 #[derive(Serialize, Deserialize, Debug, PartialEq)]
486 struct V1 {
487 value: String,
488 }
489
490 impl Versioned for V1 {
491 const VERSION: &'static str = "1.0.0";
492 }
493
494 #[derive(Serialize, Deserialize, Debug, PartialEq)]
495 struct V2 {
496 value: String,
497 count: u32,
498 }
499
500 impl Versioned for V2 {
501 const VERSION: &'static str = "2.0.0";
502 }
503
504 #[derive(Serialize, Deserialize, Debug, PartialEq)]
505 struct V3 {
506 value: String,
507 count: u32,
508 enabled: bool,
509 }
510
511 impl Versioned for V3 {
512 const VERSION: &'static str = "3.0.0";
513 }
514
515 #[derive(Serialize, Deserialize, Debug, PartialEq)]
516 struct Domain {
517 value: String,
518 count: u32,
519 enabled: bool,
520 }
521
522 impl MigratesTo<V2> for V1 {
523 fn migrate(self) -> V2 {
524 V2 {
525 value: self.value,
526 count: 0,
527 }
528 }
529 }
530
531 impl MigratesTo<V3> for V2 {
532 fn migrate(self) -> V3 {
533 V3 {
534 value: self.value,
535 count: self.count,
536 enabled: true,
537 }
538 }
539 }
540
541 impl IntoDomain<Domain> for V3 {
542 fn into_domain(self) -> Domain {
543 Domain {
544 value: self.value,
545 count: self.count,
546 enabled: self.enabled,
547 }
548 }
549 }
550
551 #[test]
552 fn test_migrator_new() {
553 let migrator = Migrator::new();
554 assert_eq!(migrator.paths.len(), 0);
555 }
556
557 #[test]
558 fn test_migrator_default() {
559 let migrator = Migrator::default();
560 assert_eq!(migrator.paths.len(), 0);
561 }
562
563 #[test]
564 fn test_single_step_migration() {
565 let path = Migrator::define("test")
566 .from::<V2>()
567 .step::<V3>()
568 .into::<Domain>();
569
570 let mut migrator = Migrator::new();
571 migrator.register(path).unwrap();
572
573 let v2 = V2 {
574 value: "test".to_string(),
575 count: 42,
576 };
577 let wrapper = VersionedWrapper::from_versioned(v2);
578 let json = serde_json::to_string(&wrapper).unwrap();
579
580 let result: Domain = migrator.load("test", &json).unwrap();
581 assert_eq!(result.value, "test");
582 assert_eq!(result.count, 42);
583 assert!(result.enabled);
584 }
585
586 #[test]
587 fn test_multi_step_migration() {
588 let path = Migrator::define("test")
589 .from::<V1>()
590 .step::<V2>()
591 .step::<V3>()
592 .into::<Domain>();
593
594 let mut migrator = Migrator::new();
595 migrator.register(path).unwrap();
596
597 let v1 = V1 {
598 value: "multi_step".to_string(),
599 };
600 let wrapper = VersionedWrapper::from_versioned(v1);
601 let json = serde_json::to_string(&wrapper).unwrap();
602
603 let result: Domain = migrator.load("test", &json).unwrap();
604 assert_eq!(result.value, "multi_step");
605 assert_eq!(result.count, 0);
606 assert!(result.enabled);
607 }
608
609 #[test]
610 fn test_no_migration_needed() {
611 let path = Migrator::define("test").from::<V3>().into::<Domain>();
612
613 let mut migrator = Migrator::new();
614 migrator.register(path).unwrap();
615
616 let v3 = V3 {
617 value: "latest".to_string(),
618 count: 100,
619 enabled: false,
620 };
621 let wrapper = VersionedWrapper::from_versioned(v3);
622 let json = serde_json::to_string(&wrapper).unwrap();
623
624 let result: Domain = migrator.load("test", &json).unwrap();
625 assert_eq!(result.value, "latest");
626 assert_eq!(result.count, 100);
627 assert!(!result.enabled);
628 }
629
630 #[test]
631 fn test_entity_not_found() {
632 let migrator = Migrator::new();
633
634 let v1 = V1 {
635 value: "test".to_string(),
636 };
637 let wrapper = VersionedWrapper::from_versioned(v1);
638 let json = serde_json::to_string(&wrapper).unwrap();
639
640 let result: Result<Domain, MigrationError> = migrator.load("unknown", &json);
641 assert!(matches!(result, Err(MigrationError::EntityNotFound(_))));
642
643 if let Err(MigrationError::EntityNotFound(entity)) = result {
644 assert_eq!(entity, "unknown");
645 }
646 }
647
648 #[test]
649 fn test_invalid_json() {
650 let path = Migrator::define("test").from::<V3>().into::<Domain>();
651
652 let mut migrator = Migrator::new();
653 migrator.register(path).unwrap();
654
655 let invalid_json = "{ invalid json }";
656 let result: Result<Domain, MigrationError> = migrator.load("test", invalid_json);
657
658 assert!(matches!(
659 result,
660 Err(MigrationError::DeserializationError(_))
661 ));
662 }
663
664 #[test]
665 fn test_multiple_entities() {
666 #[derive(Serialize, Deserialize, Debug, PartialEq)]
667 struct OtherDomain {
668 value: String,
669 }
670
671 impl IntoDomain<OtherDomain> for V1 {
672 fn into_domain(self) -> OtherDomain {
673 OtherDomain { value: self.value }
674 }
675 }
676
677 let path1 = Migrator::define("entity1")
678 .from::<V1>()
679 .step::<V2>()
680 .step::<V3>()
681 .into::<Domain>();
682
683 let path2 = Migrator::define("entity2")
684 .from::<V1>()
685 .into::<OtherDomain>();
686
687 let mut migrator = Migrator::new();
688 migrator.register(path1).unwrap();
689 migrator.register(path2).unwrap();
690
691 let v1 = V1 {
693 value: "entity1".to_string(),
694 };
695 let wrapper = VersionedWrapper::from_versioned(v1);
696 let json = serde_json::to_string(&wrapper).unwrap();
697 let result: Domain = migrator.load("entity1", &json).unwrap();
698 assert_eq!(result.value, "entity1");
699
700 let v1 = V1 {
702 value: "entity2".to_string(),
703 };
704 let wrapper = VersionedWrapper::from_versioned(v1);
705 let json = serde_json::to_string(&wrapper).unwrap();
706 let result: OtherDomain = migrator.load("entity2", &json).unwrap();
707 assert_eq!(result.value, "entity2");
708 }
709
710 #[test]
711 fn test_save() {
712 let migrator = Migrator::new();
713
714 let v1 = V1 {
715 value: "test_save".to_string(),
716 };
717
718 let json = migrator.save(v1).unwrap();
719
720 assert!(json.contains("\"version\""));
722 assert!(json.contains("\"1.0.0\""));
723 assert!(json.contains("\"data\""));
724 assert!(json.contains("\"test_save\""));
725
726 let parsed: VersionedWrapper<serde_json::Value> = serde_json::from_str(&json).unwrap();
728 assert_eq!(parsed.version, "1.0.0");
729 }
730
731 #[test]
732 fn test_save_and_load_roundtrip() {
733 let path = Migrator::define("test")
734 .from::<V1>()
735 .step::<V2>()
736 .step::<V3>()
737 .into::<Domain>();
738
739 let mut migrator = Migrator::new();
740 migrator.register(path).unwrap();
741
742 let v1 = V1 {
744 value: "roundtrip".to_string(),
745 };
746 let json = migrator.save(v1).unwrap();
747
748 let domain: Domain = migrator.load("test", &json).unwrap();
750
751 assert_eq!(domain.value, "roundtrip");
752 assert_eq!(domain.count, 0); assert!(domain.enabled); }
755
756 #[test]
757 fn test_save_latest_version() {
758 let migrator = Migrator::new();
759
760 let v3 = V3 {
761 value: "latest".to_string(),
762 count: 42,
763 enabled: false,
764 };
765
766 let json = migrator.save(v3).unwrap();
767
768 assert!(json.contains("\"version\":\"3.0.0\""));
770 assert!(json.contains("\"value\":\"latest\""));
771 assert!(json.contains("\"count\":42"));
772 assert!(json.contains("\"enabled\":false"));
773 }
774
775 #[test]
776 fn test_save_pretty() {
777 let migrator = Migrator::new();
778
779 let v2 = V2 {
780 value: "pretty".to_string(),
781 count: 10,
782 };
783
784 let json = migrator.save(v2).unwrap();
785
786 assert!(!json.contains('\n'));
788 assert!(json.contains("\"version\":\"2.0.0\""));
789 }
790
791 #[test]
792 fn test_validation_invalid_version_order() {
793 let entity = "test".to_string();
795 let versions = vec!["2.0.0".to_string(), "1.0.0".to_string()]; let result = Migrator::validate_migration_path(&entity, &versions);
798 assert!(matches!(
799 result,
800 Err(MigrationError::InvalidVersionOrder { .. })
801 ));
802
803 if let Err(MigrationError::InvalidVersionOrder {
804 entity: e,
805 from,
806 to,
807 }) = result
808 {
809 assert_eq!(e, "test");
810 assert_eq!(from, "2.0.0");
811 assert_eq!(to, "1.0.0");
812 }
813 }
814
815 #[test]
816 fn test_validation_circular_path() {
817 let entity = "test".to_string();
819 let versions = vec![
820 "1.0.0".to_string(),
821 "2.0.0".to_string(),
822 "1.0.0".to_string(), ];
824
825 let result = Migrator::validate_migration_path(&entity, &versions);
826 assert!(matches!(
827 result,
828 Err(MigrationError::CircularMigrationPath { .. })
829 ));
830
831 if let Err(MigrationError::CircularMigrationPath { entity: e, path }) = result {
832 assert_eq!(e, "test");
833 assert!(path.contains("1.0.0"));
834 assert!(path.contains("2.0.0"));
835 }
836 }
837
838 #[test]
839 fn test_validation_valid_path() {
840 let entity = "test".to_string();
842 let versions = vec![
843 "1.0.0".to_string(),
844 "1.1.0".to_string(),
845 "2.0.0".to_string(),
846 ];
847
848 let result = Migrator::validate_migration_path(&entity, &versions);
849 assert!(result.is_ok());
850 }
851
852 #[test]
853 fn test_validation_empty_path() {
854 let entity = "test".to_string();
856 let versions = vec![];
857
858 let result = Migrator::validate_migration_path(&entity, &versions);
859 assert!(result.is_ok());
860 }
861
862 #[test]
863 fn test_validation_single_version() {
864 let entity = "test".to_string();
866 let versions = vec!["1.0.0".to_string()];
867
868 let result = Migrator::validate_migration_path(&entity, &versions);
869 assert!(result.is_ok());
870 }
871}