sos_filesystem/
server_origins.rs

1use crate::{write_exclusive, Error};
2use async_fd_lock::LockRead;
3use async_trait::async_trait;
4use sos_core::{Origin, Paths, RemoteOrigins};
5use sos_vfs::{self as vfs, File};
6use std::{collections::HashSet, sync::Arc};
7use tokio::io::AsyncReadExt;
8
9/// Collection of server origins.
10pub struct ServerOrigins<E>
11where
12    E: std::error::Error
13        + std::fmt::Debug
14        + From<Error>
15        + From<std::io::Error>
16        + Send
17        + Sync
18        + 'static,
19{
20    paths: Arc<Paths>,
21    marker: std::marker::PhantomData<E>,
22}
23
24impl<E> ServerOrigins<E>
25where
26    E: std::error::Error
27        + std::fmt::Debug
28        + From<Error>
29        + From<std::io::Error>
30        + Send
31        + Sync
32        + 'static,
33{
34    /// Create new server origins.
35    pub fn new(paths: Arc<Paths>) -> Self {
36        Self {
37            paths,
38            marker: std::marker::PhantomData,
39        }
40    }
41
42    async fn list_origins(&self) -> Result<HashSet<Origin>, E> {
43        let remotes_file = self.paths.remote_origins();
44        if vfs::try_exists(&remotes_file).await? {
45            let file = File::open(&remotes_file).await?;
46            let mut guard = file.lock_read().await.map_err(|e| e.error)?;
47            let mut content = Vec::new();
48            guard.read_to_end(&mut content).await?;
49
50            let origins: HashSet<Origin> =
51                serde_json::from_slice(&content).map_err(Error::from)?;
52            Ok(origins)
53        } else {
54            Ok(Default::default())
55        }
56    }
57
58    async fn save_origins(&self, origins: &HashSet<Origin>) -> Result<(), E> {
59        let data =
60            serde_json::to_vec_pretty(&origins).map_err(Error::from)?;
61        let file = self.paths.remote_origins();
62        write_exclusive(file, data).await?;
63        Ok(())
64    }
65}
66
67#[async_trait]
68impl<E> RemoteOrigins for ServerOrigins<E>
69where
70    E: std::error::Error
71        + std::fmt::Debug
72        + From<Error>
73        + From<std::io::Error>
74        + Send
75        + Sync
76        + 'static,
77{
78    type Error = E;
79
80    async fn list_servers(&self) -> Result<HashSet<Origin>, Self::Error> {
81        self.list_origins().await
82    }
83
84    async fn add_server(
85        &mut self,
86        origin: Origin,
87    ) -> Result<(), Self::Error> {
88        let mut origins = self.list_origins().await?;
89        origins.insert(origin);
90        self.save_origins(&origins).await?;
91        Ok(())
92    }
93
94    async fn replace_server(
95        &mut self,
96        old_origin: &Origin,
97        new_origin: Origin,
98    ) -> Result<(), Self::Error> {
99        self.remove_server(old_origin).await?;
100        self.add_server(new_origin).await?;
101        Ok(())
102    }
103
104    async fn remove_server(
105        &mut self,
106        origin: &Origin,
107    ) -> Result<(), Self::Error> {
108        let mut origins = self.list_origins().await?;
109        origins.remove(origin);
110        self.save_origins(&origins).await?;
111        Ok(())
112    }
113}