1use crate::collection::CollectionMarker;
2use crate::error::Result;
3use crate::store::{StoreId, StoreState};
4use crate::ManagerExt;
5use itertools::Itertools;
6use semver::Version;
7use serde::{Deserialize, Serialize};
8use serde_json::{from_slice, to_vec};
9use std::cmp::Ordering;
10use std::collections::HashMap;
11use std::fs::{self, File};
12use std::io::Write;
13use std::path::PathBuf;
14use std::sync::{Mutex, OnceLock};
15use tauri::{AppHandle, Runtime};
16use tauri_store_utils::Semver;
17
18static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
20
21type MigrationFn = dyn Fn(&mut StoreState) -> Result<()> + Send + Sync;
22type BeforeEachMigrationFn = dyn Fn(MigrationContext) + Send + Sync;
23
24#[doc(hidden)]
25#[derive(Default)]
26pub struct Migrator {
27 migrations: HashMap<StoreId, Vec<Migration>>,
28 before_each: Option<Box<BeforeEachMigrationFn>>,
29 history: MigrationHistory,
30}
31
32impl Migrator {
33 pub fn add_migration(&mut self, id: StoreId, migration: Migration) {
34 self
35 .migrations
36 .entry(id)
37 .or_default()
38 .push(migration);
39 }
40
41 pub fn add_migrations<I>(&mut self, id: StoreId, migrations: I)
42 where
43 I: IntoIterator<Item = Migration>,
44 {
45 self
46 .migrations
47 .entry(id)
48 .or_default()
49 .extend(migrations);
50 }
51
52 pub fn migrate<R, C>(
53 &mut self,
54 app: &AppHandle<R>,
55 id: &StoreId,
56 state: &mut StoreState,
57 ) -> Result<()>
58 where
59 R: Runtime,
60 C: CollectionMarker,
61 {
62 let mut migrations = self
63 .migrations
64 .get(id)
65 .map(Vec::as_slice)
66 .unwrap_or_default()
67 .iter()
68 .sorted()
69 .collect_vec();
70
71 if let Some(last) = self.history.get(id) {
72 migrations.retain(|migration| migration.version > *last);
73 }
74
75 if migrations.is_empty() {
76 return Ok(());
77 }
78
79 let mut iter = migrations.iter().peekable();
80 let mut previous = None;
81 let mut done = 0;
82 let mut last_err = None;
83
84 while let Some(migration) = iter.next() {
85 let current = &migration.version;
86 if let Some(before_each) = &self.before_each {
87 let next = iter.peek().map(|it| &it.version);
88 let context = MigrationContext { id, state, current, previous, next };
89 before_each(context);
90 }
91
92 if let Err(err) = (migration.inner)(state) {
93 last_err = Some(err);
94 break;
95 }
96
97 self.history.set(id, current);
98 previous = Some(current);
99 done += 1;
100 }
101
102 if done > 0 {
103 self.write::<R, C>(app)?;
104 }
105
106 match last_err {
107 Some(err) => Err(err),
108 None => Ok(()),
109 }
110 }
111
112 #[doc(hidden)]
113 pub fn on_before_each<F>(&mut self, f: F)
114 where
115 F: Fn(MigrationContext) + Send + Sync + 'static,
116 {
117 self.before_each = Some(Box::new(f));
118 }
119
120 pub(crate) fn read<R, C>(&mut self, app: &AppHandle<R>) -> Result<()>
121 where
122 R: Runtime,
123 C: CollectionMarker,
124 {
125 let path = path::<R, C>(app);
126 if let Ok(bytes) = fs::read(&path) {
127 self.history = from_slice(&bytes)?;
128 }
129
130 Ok(())
131 }
132
133 fn write<R, C>(&self, app: &AppHandle<R>) -> Result<()>
134 where
135 R: Runtime,
136 C: CollectionMarker,
137 {
138 let path = path::<R, C>(app);
139 let lock = LOCK
140 .get_or_init(Mutex::default)
141 .lock()
142 .expect("migrator file lock is poisoned");
143
144 if let Some(parent) = path.parent() {
145 fs::create_dir_all(parent)?;
146 }
147
148 let bytes = to_vec(&self.history)?;
149 let mut file = File::create(path)?;
150 file.write_all(&bytes)?;
151 file.flush()?;
152
153 if cfg!(feature = "file-sync-all") {
154 file.sync_all()?;
155 }
156
157 drop(lock);
158
159 Ok(())
160 }
161}
162
163fn path<R, C>(app: &AppHandle<R>) -> PathBuf
164where
165 R: Runtime,
166 C: CollectionMarker,
167{
168 app
169 .store_collection_with_marker::<C>()
170 .path()
171 .join("migration.tauristore")
172}
173
174pub struct Migration {
176 inner: Box<MigrationFn>,
177 version: Version,
178}
179
180impl Migration {
181 #[allow(clippy::needless_pass_by_value)]
187 pub fn new<F>(version: impl Semver, up: F) -> Self
188 where
189 F: Fn(&mut StoreState) -> Result<()> + Send + Sync + 'static,
190 {
191 Self {
192 inner: Box::new(up),
193 version: version.semver(),
194 }
195 }
196
197 pub fn version(&self) -> &Version {
199 &self.version
200 }
201}
202
203impl PartialOrd for Migration {
204 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
205 Some(self.cmp(other))
206 }
207}
208
209impl Ord for Migration {
210 fn cmp(&self, other: &Self) -> Ordering {
211 self.version.cmp(&other.version)
212 }
213}
214
215impl PartialEq for Migration {
216 fn eq(&self, other: &Self) -> bool {
217 self.version == other.version
218 }
219}
220
221impl Eq for Migration {}
222
223#[derive(Debug)]
225pub struct MigrationContext<'a> {
226 pub id: &'a StoreId,
227 pub state: &'a StoreState,
228 pub current: &'a Version,
229 pub previous: Option<&'a Version>,
230 pub next: Option<&'a Version>,
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub(crate) struct MigrationHistory(HashMap<StoreId, Version>);
235
236impl MigrationHistory {
237 pub fn get(&self, id: &StoreId) -> Option<&Version> {
238 self.0.get(id)
239 }
240
241 pub fn set(&mut self, id: &StoreId, version: &Version) {
242 self.0.insert(id.clone(), version.clone());
243 }
244}