shuttle_turso/
lib.rs

1use async_trait::async_trait;
2use libsql::{Builder, Database};
3use serde::{Deserialize, Serialize};
4use shuttle_service::{
5    error::{CustomError, Error as ShuttleError},
6    Environment, IntoResource, ResourceFactory, ResourceInputBuilder,
7};
8use url::Url;
9
10#[derive(Serialize, Default)]
11pub struct Turso {
12    addr: String,
13    token: String,
14    local_addr: Option<String>,
15}
16
17#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
18pub struct TursoOutput {
19    conn_url: Url,
20    token: Option<String>,
21    remote: bool,
22}
23
24impl Turso {
25    pub fn addr(mut self, addr: &str) -> Self {
26        self.addr = addr.to_string();
27        self
28    }
29
30    pub fn token(mut self, token: &str) -> Self {
31        self.token = token.to_string();
32        self
33    }
34
35    pub fn local_addr(mut self, local_addr: &str) -> Self {
36        self.local_addr = Some(local_addr.to_string());
37        self
38    }
39}
40
41pub enum Error {
42    UrlParseError(url::ParseError),
43    LocateLocalDB(std::io::Error),
44}
45
46impl From<Error> for shuttle_service::Error {
47    fn from(error: Error) -> Self {
48        let msg = match error {
49            Error::UrlParseError(err) => format!("Failed to parse Turso Url: {}", err),
50            Error::LocateLocalDB(err) => format!("Failed to get path to local db file: {}", err),
51        };
52
53        ShuttleError::Custom(CustomError::msg(msg))
54    }
55}
56
57impl Turso {
58    async fn output_from_addr(
59        &self,
60        addr: &str,
61        remote: bool,
62    ) -> Result<TursoOutput, shuttle_service::Error> {
63        Ok(TursoOutput {
64            conn_url: Url::parse(addr).map_err(Error::UrlParseError)?,
65            token: if self.token.is_empty() {
66                None
67            } else {
68                Some(self.token.clone())
69            },
70            remote,
71        })
72    }
73}
74
75#[async_trait]
76impl ResourceInputBuilder for Turso {
77    type Input = TursoOutput;
78    type Output = TursoOutput;
79
80    async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, ShuttleError> {
81        let md = factory.get_metadata();
82        match md.env {
83            Environment::Deployment => {
84                if self.addr.is_empty() {
85                    Err(ShuttleError::Custom(CustomError::msg("missing addr")))
86                } else {
87                    if !self.addr.starts_with("libsql://") && !self.addr.starts_with("https://") {
88                        return Err(ShuttleError::Custom(CustomError::msg(
89                            "addr must start with either libsql:// or https://",
90                        )));
91                    }
92                    self.output_from_addr(&self.addr, true).await
93                }
94            }
95            Environment::Local => {
96                match self.local_addr {
97                    Some(ref local_addr) => self.output_from_addr(local_addr, true).await,
98                    None => {
99                        // Default to a local db of the name of the service.
100                        let db_file = std::env::current_dir() // Should be root of the project's workspace
101                            .and_then(dunce::canonicalize)
102                            .map(|cd| {
103                                let mut p = cd.join(md.project_name);
104                                p.set_extension("db");
105                                p
106                            })
107                            .map_err(Error::LocateLocalDB)?;
108                        let conn_url = format!("file:{}", db_file.display());
109                        Ok(TursoOutput {
110                            conn_url: Url::parse(&conn_url).map_err(Error::UrlParseError)?,
111                            // Nullify the token since we're using a file as database.
112                            token: None,
113                            remote: false,
114                        })
115                    }
116                }
117            }
118        }
119    }
120}
121
122#[async_trait]
123impl IntoResource<Database> for TursoOutput {
124    async fn into_resource(self) -> Result<Database, shuttle_service::Error> {
125        let database = if self.remote {
126            Builder::new_remote(
127                self.conn_url.to_string(),
128                self.token
129                    .clone()
130                    .ok_or(ShuttleError::Custom(CustomError::msg(
131                        "missing token for remote database",
132                    )))?,
133            )
134            .build()
135            .await
136        } else {
137            Builder::new_local(self.conn_url.to_string()).build().await
138        };
139
140        database.map_err(|err| ShuttleError::Custom(err.into()))
141    }
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147
148    #[tokio::test]
149    async fn local_database_user_supplied() {
150        let factory =
151            ResourceFactory::new(Default::default(), Default::default(), Default::default());
152
153        let mut turso = Turso::default();
154        let local_addr = "libsql://test-addr.turso.io";
155        turso = turso.local_addr(local_addr);
156
157        let output = turso.build(&factory).await.unwrap();
158        assert_eq!(
159            output,
160            TursoOutput {
161                conn_url: Url::parse(local_addr).unwrap(),
162                token: None,
163                remote: true,
164            }
165        )
166    }
167
168    #[tokio::test]
169    #[should_panic(expected = "missing addr")]
170    async fn remote_database_empty_addr() {
171        let factory = ResourceFactory::new(
172            Default::default(),
173            Default::default(),
174            Environment::Deployment,
175        );
176
177        let turso = Turso::default();
178        turso.build(&factory).await.unwrap();
179    }
180
181    #[tokio::test]
182    async fn remote_database() {
183        let factory = ResourceFactory::new(
184            Default::default(),
185            Default::default(),
186            Environment::Deployment,
187        );
188
189        let mut turso = Turso::default();
190        let addr = "libsql://my-turso-addr.turso.io".to_string();
191        turso.addr.clone_from(&addr);
192        turso.token = "token".to_string();
193        let output = turso.build(&factory).await.unwrap();
194
195        assert_eq!(
196            output,
197            TursoOutput {
198                conn_url: Url::parse(&addr).unwrap(),
199                token: Some("token".to_string()),
200                remote: true,
201            }
202        )
203    }
204}