Skip to main content

signet_cold_sql/
connector.rs

1//! SQL cold storage connector.
2
3use crate::{
4    SqlColdBackend, SqlColdError,
5    backend::{DEFAULT_READ_TIMEOUT, DEFAULT_WRITE_TIMEOUT},
6};
7use signet_cold::ColdConnect;
8use sqlx::pool::PoolOptions;
9use std::time::Duration;
10
11/// Errors that can occur when initializing SQL connectors.
12#[derive(Debug, thiserror::Error)]
13pub enum SqlConnectorError {
14    /// Missing environment variable.
15    #[error("missing environment variable: {0}")]
16    MissingEnvVar(&'static str),
17
18    /// Cold storage initialization failed.
19    #[error("cold storage initialization failed: {0}")]
20    ColdInit(#[from] SqlColdError),
21}
22
23/// Connector for SQL cold storage (PostgreSQL or SQLite).
24///
25/// Automatically detects the database type from the URL:
26/// - URLs starting with `postgres://` or `postgresql://` use PostgreSQL
27/// - URLs starting with `sqlite:` use SQLite
28///
29/// Pool behaviour is configured via builder methods that mirror
30/// [`sqlx::pool::PoolOptions`], or by passing a complete
31/// [`PoolOptions`] via [`with_pool_options`](Self::with_pool_options).
32/// For in-memory SQLite URLs, `max_connections` is forced to 1
33/// regardless of the provided options.
34///
35/// # Example
36///
37/// ```ignore
38/// use signet_cold_sql::SqlConnector;
39///
40/// // PostgreSQL with custom pool size
41/// let pg = SqlConnector::new("postgres://localhost/signet")
42///     .with_max_connections(20);
43/// let backend = pg.connect().await?;
44///
45/// // SQLite (defaults)
46/// let sqlite = SqlConnector::new("sqlite::memory:");
47/// let backend = sqlite.connect().await?;
48/// ```
49#[cfg(any(feature = "sqlite", feature = "postgres"))]
50#[derive(Debug, Clone)]
51pub struct SqlConnector {
52    url: String,
53    pool_opts: PoolOptions<sqlx::Any>,
54    read_timeout: Duration,
55    write_timeout: Duration,
56}
57
58#[cfg(any(feature = "sqlite", feature = "postgres"))]
59impl SqlConnector {
60    /// Create a new SQL connector with default pool options.
61    ///
62    /// The database type is detected from the URL prefix.
63    pub fn new(url: impl Into<String>) -> Self {
64        Self {
65            url: url.into(),
66            pool_opts: PoolOptions::new(),
67            read_timeout: DEFAULT_READ_TIMEOUT,
68            write_timeout: DEFAULT_WRITE_TIMEOUT,
69        }
70    }
71
72    /// Get a reference to the connection URL.
73    pub fn url(&self) -> &str {
74        &self.url
75    }
76
77    /// Replace the pool options entirely.
78    ///
79    /// For in-memory SQLite URLs, `max_connections` is forced to 1
80    /// regardless of the value set here.
81    pub fn with_pool_options(mut self, pool_opts: PoolOptions<sqlx::Any>) -> Self {
82        self.pool_opts = pool_opts;
83        self
84    }
85
86    /// Set the maximum number of pool connections.
87    ///
88    /// Ignored for in-memory SQLite URLs, which always use 1.
89    pub fn with_max_connections(mut self, n: u32) -> Self {
90        self.pool_opts = self.pool_opts.max_connections(n);
91        self
92    }
93
94    /// Set the minimum number of connections to maintain at all times.
95    pub fn with_min_connections(mut self, n: u32) -> Self {
96        self.pool_opts = self.pool_opts.min_connections(n);
97        self
98    }
99
100    /// Set the connection acquire timeout.
101    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
102        self.pool_opts = self.pool_opts.acquire_timeout(timeout);
103        self
104    }
105
106    /// Set the maximum lifetime of individual connections.
107    pub fn with_max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
108        self.pool_opts = self.pool_opts.max_lifetime(lifetime);
109        self
110    }
111
112    /// Set the idle timeout for connections.
113    pub fn with_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
114        self.pool_opts = self.pool_opts.idle_timeout(timeout);
115        self
116    }
117
118    /// Set the per-transaction read timeout (default 500 ms).
119    ///
120    /// On Postgres this is applied via `SET LOCAL statement_timeout`
121    /// at the start of every read transaction. On SQLite the value is
122    /// stored but not enforced.
123    ///
124    /// # Panics
125    ///
126    /// Panics if `d` rounds to 0 ms (Postgres treats `0` as "no
127    /// timeout", which would silently disable the trait contract).
128    #[must_use]
129    pub fn with_read_timeout(mut self, d: Duration) -> Self {
130        assert!(d.as_millis() >= 1, "read_timeout must be >= 1ms (got {d:?})");
131        self.read_timeout = d;
132        self
133    }
134
135    /// Set the per-transaction write timeout (default 2 s).
136    ///
137    /// On Postgres this is applied via `SET LOCAL statement_timeout`
138    /// at the start of every write transaction. On SQLite the value is
139    /// stored but not enforced.
140    ///
141    /// # Panics
142    ///
143    /// Panics if `d` rounds to 0 ms. See [`with_read_timeout`](Self::with_read_timeout).
144    #[must_use]
145    pub fn with_write_timeout(mut self, d: Duration) -> Self {
146        assert!(d.as_millis() >= 1, "write_timeout must be >= 1ms (got {d:?})");
147        self.write_timeout = d;
148        self
149    }
150
151    /// Create a connector from environment variables.
152    ///
153    /// Reads the SQL URL from the specified environment variable.
154    /// Uses default pool settings.
155    ///
156    /// # Example
157    ///
158    /// ```ignore
159    /// use signet_cold_sql::SqlConnector;
160    ///
161    /// let cold = SqlConnector::from_env("SIGNET_COLD_SQL_URL")?;
162    /// ```
163    pub fn from_env(env_var: &'static str) -> Result<Self, SqlConnectorError> {
164        let url = std::env::var(env_var).map_err(|_| SqlConnectorError::MissingEnvVar(env_var))?;
165        Ok(Self::new(url))
166    }
167}
168
169#[cfg(any(feature = "sqlite", feature = "postgres"))]
170impl ColdConnect for SqlConnector {
171    type Cold = SqlColdBackend;
172    type Error = SqlColdError;
173
174    fn connect(&self) -> impl std::future::Future<Output = Result<Self::Cold, Self::Error>> + Send {
175        let url = self.url.clone();
176        let pool_opts = self.pool_opts.clone();
177        let read_timeout = self.read_timeout;
178        let write_timeout = self.write_timeout;
179        async move {
180            let backend = SqlColdBackend::connect_with(&url, pool_opts).await?;
181            Ok(backend.with_read_timeout(read_timeout).with_write_timeout(write_timeout))
182        }
183    }
184}