Skip to main content

upstream_rs/services/storage/
rollback_storage.rs

1use std::collections::HashMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use anyhow::{Context, Result, anyhow};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use crate::models::upstream::Package;
10use crate::utils::filesystem::atomic_ops::write_atomic;
11
12const ROLLBACK_STORAGE_VERSION: u32 = 1;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum RollbackSource {
16    Upgrade,
17    Reinstall,
18    Remove,
19}
20
21#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
22#[serde(rename_all = "snake_case")]
23pub enum RollbackArtifactFormat {
24    #[default]
25    Raw,
26    Tgz,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RollbackRecord {
31    pub package_snapshot: Package,
32    pub artifact_relative_path: PathBuf,
33    #[serde(default)]
34    pub icon_relative_path: Option<PathBuf>,
35    #[serde(default)]
36    pub artifact_format: RollbackArtifactFormat,
37    #[serde(default)]
38    pub artifact_entry_path: Option<PathBuf>,
39    #[serde(default)]
40    pub icon_entry_path: Option<PathBuf>,
41    pub source: RollbackSource,
42    pub created_at: DateTime<Utc>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46struct RollbackStorageFile {
47    version: u32,
48    records: HashMap<String, Vec<RollbackRecord>>,
49}
50
51impl Default for RollbackStorageFile {
52    fn default() -> Self {
53        Self {
54            version: ROLLBACK_STORAGE_VERSION,
55            records: HashMap::new(),
56        }
57    }
58}
59
60pub struct RollbackStorage {
61    file: RollbackStorageFile,
62    rollback_file: PathBuf,
63}
64
65impl RollbackStorage {
66    pub fn new(rollback_file: &Path) -> Result<Self> {
67        let mut storage = Self {
68            file: RollbackStorageFile::default(),
69            rollback_file: rollback_file.to_path_buf(),
70        };
71        storage.load()?;
72        Ok(storage)
73    }
74
75    pub fn load(&mut self) -> Result<()> {
76        if !self.rollback_file.exists() {
77            self.file = RollbackStorageFile::default();
78            return Ok(());
79        }
80
81        let json = fs::read_to_string(&self.rollback_file).with_context(|| {
82            format!(
83                "Failed to read rollback storage '{}'",
84                self.rollback_file.display()
85            )
86        })?;
87
88        if json.trim().is_empty() {
89            self.file = RollbackStorageFile::default();
90            return Ok(());
91        }
92
93        let parsed: RollbackStorageFile = serde_json::from_str(&json)
94            .or_else(|_| parse_legacy_storage_file(&json))
95            .with_context(|| {
96                format!(
97                    "Failed to parse rollback storage '{}'",
98                    self.rollback_file.display()
99                )
100            })?;
101        if parsed.version != ROLLBACK_STORAGE_VERSION {
102            return Err(anyhow!(
103                "Unsupported rollback storage version {} in '{}'. Expected version {}.",
104                parsed.version,
105                self.rollback_file.display(),
106                ROLLBACK_STORAGE_VERSION
107            ));
108        }
109        self.file = parsed;
110        Ok(())
111    }
112
113    pub fn save(&self) -> Result<()> {
114        let json = serde_json::to_string_pretty(&self.file)
115            .context("Failed to serialize rollback storage")?;
116        write_atomic(&self.rollback_file, json.as_bytes()).with_context(|| {
117            format!(
118                "Failed to write rollback storage to '{}'",
119                self.rollback_file.display()
120            )
121        })
122    }
123
124    pub fn get_record(&self, package_name: &str) -> Option<&RollbackRecord> {
125        self.file
126            .records
127            .get(package_name)
128            .and_then(|records| records.last())
129    }
130
131    pub fn get_records(&self, package_name: &str) -> &[RollbackRecord] {
132        self.file
133            .records
134            .get(package_name)
135            .map(Vec::as_slice)
136            .unwrap_or(&[])
137    }
138
139    pub fn list_records(&self) -> &HashMap<String, Vec<RollbackRecord>> {
140        &self.file.records
141    }
142
143    pub fn upsert_record(&mut self, package_name: &str, record: RollbackRecord) -> Result<()> {
144        self.push_record(package_name, record, 1).map(|_| ())
145    }
146
147    pub fn push_record(
148        &mut self,
149        package_name: &str,
150        record: RollbackRecord,
151        max_records: usize,
152    ) -> Result<Vec<RollbackRecord>> {
153        let records = self
154            .file
155            .records
156            .entry(package_name.to_string())
157            .or_default();
158        records.push(record);
159        let pruned = if max_records > 0 && records.len() > max_records {
160            let remove_count = records.len() - max_records;
161            records.drain(0..remove_count).collect()
162        } else {
163            Vec::new()
164        };
165        self.save()?;
166        Ok(pruned)
167    }
168
169    pub fn remove_record(&mut self, package_name: &str) -> Result<Option<RollbackRecord>> {
170        let removed = self.file.records.get_mut(package_name).and_then(Vec::pop);
171        if self
172            .file
173            .records
174            .get(package_name)
175            .is_some_and(Vec::is_empty)
176        {
177            self.file.records.remove(package_name);
178        }
179        self.save()?;
180        Ok(removed)
181    }
182
183    pub fn remove_all_records(&mut self, package_name: &str) -> Result<Vec<RollbackRecord>> {
184        let removed = self.file.records.remove(package_name).unwrap_or_default();
185        self.save()?;
186        Ok(removed)
187    }
188}
189
190#[derive(Debug, Clone, Deserialize)]
191struct LegacyRollbackStorageFile {
192    version: u32,
193    records: HashMap<String, RollbackRecord>,
194}
195
196fn parse_legacy_storage_file(json: &str) -> serde_json::Result<RollbackStorageFile> {
197    let legacy: LegacyRollbackStorageFile = serde_json::from_str(json)?;
198    Ok(RollbackStorageFile {
199        version: legacy.version,
200        records: legacy
201            .records
202            .into_iter()
203            .map(|(name, record)| (name, vec![record]))
204            .collect(),
205    })
206}
207
208#[cfg(test)]
209mod tests {
210    use super::{RollbackArtifactFormat, RollbackRecord, RollbackSource, RollbackStorage};
211    use crate::models::common::enums::{Channel, Filetype, Provider};
212    use crate::models::upstream::Package;
213    use chrono::Utc;
214    use std::path::{Path, PathBuf};
215    use std::time::{SystemTime, UNIX_EPOCH};
216    use std::{fs, io};
217
218    fn temp_rollback_file(name: &str) -> PathBuf {
219        let nanos = SystemTime::now()
220            .duration_since(UNIX_EPOCH)
221            .map(|d| d.as_nanos())
222            .unwrap_or(0);
223        std::env::temp_dir()
224            .join(format!("upstream-rollback-storage-test-{name}-{nanos}"))
225            .join("rollback.json")
226    }
227
228    fn test_package(name: &str) -> Package {
229        Package::with_defaults(
230            name.to_string(),
231            format!("owner/{name}"),
232            Filetype::Binary,
233            None,
234            None,
235            Channel::Stable,
236            Provider::Github,
237            None,
238        )
239    }
240
241    fn test_record(name: &str, source: RollbackSource) -> RollbackRecord {
242        RollbackRecord {
243            package_snapshot: test_package(name),
244            artifact_relative_path: PathBuf::from(format!("{name}/{name}.old")),
245            icon_relative_path: None,
246            artifact_format: RollbackArtifactFormat::Raw,
247            artifact_entry_path: None,
248            icon_entry_path: None,
249            source,
250            created_at: Utc::now(),
251        }
252    }
253
254    fn cleanup(path: &Path) -> io::Result<()> {
255        if let Some(parent) = path.parent() {
256            fs::remove_dir_all(parent)?;
257        }
258        Ok(())
259    }
260
261    #[test]
262    fn upsert_and_reload_record_round_trips() {
263        let path = temp_rollback_file("roundtrip");
264        let mut storage = RollbackStorage::new(&path).expect("create storage");
265        let mut record = test_record("tool", RollbackSource::Upgrade);
266        record.icon_relative_path = Some(PathBuf::from("tool/icon.png"));
267        storage
268            .upsert_record("tool", record.clone())
269            .expect("upsert");
270
271        let reloaded = RollbackStorage::new(&path).expect("reload");
272        let loaded = reloaded.get_record("tool").expect("record");
273        assert_eq!(loaded.package_snapshot.name, "tool");
274        assert_eq!(loaded.artifact_relative_path, record.artifact_relative_path);
275        assert!(loaded.icon_relative_path.is_some());
276
277        cleanup(&path).expect("cleanup");
278    }
279
280    #[test]
281    fn remove_record_returns_removed_value() {
282        let path = temp_rollback_file("remove");
283        let mut storage = RollbackStorage::new(&path).expect("create storage");
284        storage
285            .upsert_record("tool", test_record("tool", RollbackSource::Remove))
286            .expect("upsert");
287
288        let removed = storage.remove_record("tool").expect("remove");
289        assert!(removed.is_some());
290        assert!(storage.get_record("tool").is_none());
291
292        cleanup(&path).expect("cleanup");
293    }
294
295    #[test]
296    fn push_record_keeps_latest_records_with_limit() {
297        let path = temp_rollback_file("multiple");
298        let mut storage = RollbackStorage::new(&path).expect("create storage");
299        storage
300            .push_record("tool", test_record("tool", RollbackSource::Upgrade), 2)
301            .expect("push first");
302        storage
303            .push_record("tool", test_record("tool", RollbackSource::Remove), 2)
304            .expect("push second");
305        storage
306            .push_record("tool", test_record("tool", RollbackSource::Reinstall), 2)
307            .expect("push third");
308
309        let records = storage.get_records("tool");
310        assert_eq!(records.len(), 2);
311        assert!(matches!(records[0].source, RollbackSource::Remove));
312        assert!(matches!(records[1].source, RollbackSource::Reinstall));
313        assert!(matches!(
314            storage.get_record("tool").expect("latest").source,
315            RollbackSource::Reinstall
316        ));
317
318        cleanup(&path).expect("cleanup");
319    }
320}