Skip to main content

stygian_proxy/storage/
mod.rs

1//! Storage port and in-memory adapter for proxy records.
2
3use async_trait::async_trait;
4use uuid::Uuid;
5
6use crate::error::ProxyResult;
7use crate::types::{Proxy, ProxyRecord};
8
9/// Abstract storage interface for persisting and querying proxy records.
10///
11/// Implementors must be `Send + Sync + 'static` to support concurrent access
12/// across async tasks. The trait is object-safe via [`async_trait`].
13///
14/// # Example
15/// ```rust,no_run
16/// use stygian_proxy::storage::ProxyStoragePort;
17/// use stygian_proxy::types::{Proxy, ProxyType};
18/// use uuid::Uuid;
19///
20/// async fn demo(store: &dyn ProxyStoragePort) {
21///     let proxy = Proxy {
22///         url: "http://proxy.example.com:8080".into(),
23///         proxy_type: ProxyType::Http,
24///         username: None,
25///         password: None,
26///         weight: 1,
27///         tags: vec![],
28///     };
29///     let record = store.add(proxy).await.unwrap();
30///     let _ = store.get(record.id).await.unwrap();
31/// }
32/// ```
33#[async_trait]
34pub trait ProxyStoragePort: Send + Sync + 'static {
35    /// Add a new proxy to the store and return its [`ProxyRecord`].
36    async fn add(&self, proxy: Proxy) -> ProxyResult<ProxyRecord>;
37
38    /// Remove a proxy by its UUID. Returns an error if the ID is not found.
39    async fn remove(&self, id: Uuid) -> ProxyResult<()>;
40
41    /// Return all stored proxy records.
42    async fn list(&self) -> ProxyResult<Vec<ProxyRecord>>;
43
44    /// Fetch a single proxy record by UUID.
45    async fn get(&self, id: Uuid) -> ProxyResult<ProxyRecord>;
46
47    /// Record the outcome of a request through a proxy.
48    ///
49    /// - `success`: whether the request succeeded.
50    /// - `latency_ms`: elapsed time in milliseconds.
51    async fn update_metrics(&self, id: Uuid, success: bool, latency_ms: u64) -> ProxyResult<()>;
52
53    /// Return all stored proxy records paired with their live metrics reference.
54    ///
55    /// Used by [`ProxyManager`](crate::manager::ProxyManager) when building
56    /// [`ProxyCandidate`](crate::strategy::ProxyCandidate) slices so that
57    /// latency-aware strategies (e.g. least-used) see up-to-date counters.
58    async fn list_with_metrics(&self) -> ProxyResult<Vec<(ProxyRecord, Arc<ProxyMetrics>)>>;
59}
60
61/// Convenience alias for a heap-allocated, type-erased [`ProxyStoragePort`].
62pub type BoxedProxyStorage = Box<dyn ProxyStoragePort>;
63
64// ─────────────────────────────────────────────────────────────────────────────
65// URL validation helper
66// ─────────────────────────────────────────────────────────────────────────────
67
68/// Validate a proxy URL: scheme must be recognised, host must be non-empty,
69/// and the explicit port (if present) must be in [1, 65535].
70fn validate_proxy_url(url: &str) -> ProxyResult<()> {
71    use crate::error::ProxyError;
72
73    let (scheme, rest) = url.split_once("://").ok_or_else(|| ProxyError::InvalidProxyUrl {
74        url: url.to_owned(),
75        reason: "missing scheme separator '://'".into(),
76    })?;
77
78    match scheme {
79        "http" | "https" => {}
80        #[cfg(feature = "socks")]
81        "socks4" | "socks5" => {}
82        other => {
83            return Err(ProxyError::InvalidProxyUrl {
84                url: url.to_owned(),
85                reason: format!("unsupported scheme '{other}'"),
86            })
87        }
88    }
89
90    // Strip any path/query, then strip user:pass@ if present.
91    let authority = rest.split('/').next().unwrap_or("");
92    let host_and_port = authority.split('@').next_back().unwrap_or("");
93
94    // Split host from port, handling IPv6 brackets.
95    let (host, port_str) = if host_and_port.starts_with('[') {
96        let close = host_and_port.find(']').unwrap_or(host_and_port.len());
97        let after = &host_and_port[close + 1..];
98        let port = after.strip_prefix(':').unwrap_or("");
99        (&host_and_port[..=close], port)
100    } else {
101        match host_and_port.rsplit_once(':') {
102            Some((h, p)) => (h, p),
103            None => (host_and_port, ""),
104        }
105    };
106
107    if host.is_empty() || host == "[]" {
108        return Err(ProxyError::InvalidProxyUrl {
109            url: url.to_owned(),
110            reason: "empty host".into(),
111        });
112    }
113
114    if !port_str.is_empty() {
115        let port: u32 = port_str
116            .parse()
117            .map_err(|_| ProxyError::InvalidProxyUrl {
118                url: url.to_owned(),
119                reason: format!("non-numeric port '{port_str}'"),
120            })?;
121        if port == 0 || port > 65535 {
122            return Err(ProxyError::InvalidProxyUrl {
123                url: url.to_owned(),
124                reason: format!("port {port} is out of range [1, 65535]"),
125            });
126        }
127    }
128
129    Ok(())
130}
131
132// ─────────────────────────────────────────────────────────────────────────────
133// MemoryProxyStore
134// ─────────────────────────────────────────────────────────────────────────────
135
136use std::collections::HashMap;
137use tokio::sync::RwLock;
138
139use std::sync::Arc;
140use crate::types::ProxyMetrics;
141
142type StoreMap = HashMap<Uuid, (ProxyRecord, Arc<ProxyMetrics>)>;
143
144/// In-memory implementation of [`ProxyStoragePort`].
145///
146/// Uses a `tokio::sync::RwLock`-guarded `HashMap` for thread-safe access.
147/// Metrics are updated via atomic operations, so only a **read** lock is
148/// needed for [`update_metrics`](MemoryProxyStore::update_metrics) calls —
149/// write contention stays low even under heavy concurrent load.
150///
151/// # Example
152/// ```
153/// # tokio_test::block_on(async {
154/// use stygian_proxy::storage::{MemoryProxyStore, ProxyStoragePort};
155/// use stygian_proxy::types::{Proxy, ProxyType};
156///
157/// let store = MemoryProxyStore::default();
158/// let proxy = Proxy { url: "http://proxy.example.com:8080".into(), proxy_type: ProxyType::Http,
159///                     username: None, password: None, weight: 1, tags: vec![] };
160/// let record = store.add(proxy).await.unwrap();
161/// assert_eq!(store.list().await.unwrap().len(), 1);
162/// store.remove(record.id).await.unwrap();
163/// assert!(store.list().await.unwrap().is_empty());
164/// # })
165/// ```
166#[derive(Debug, Default, Clone)]
167pub struct MemoryProxyStore {
168    inner: Arc<RwLock<StoreMap>>,
169}
170
171impl MemoryProxyStore {
172    /// Build a store pre-populated with `proxies`, validating each URL.
173    ///
174    /// Returns an error on the first invalid URL encountered.
175    pub async fn with_proxies(proxies: Vec<Proxy>) -> ProxyResult<Self> {
176        let store = Self::default();
177        for proxy in proxies {
178            store.add(proxy).await?;
179        }
180        Ok(store)
181    }
182}
183
184#[async_trait]
185impl ProxyStoragePort for MemoryProxyStore {
186    async fn add(&self, proxy: Proxy) -> ProxyResult<ProxyRecord> {
187        validate_proxy_url(&proxy.url)?;
188        let record = ProxyRecord::new(proxy);
189        let metrics = Arc::new(ProxyMetrics::default());
190        self.inner.write().await.insert(record.id, (record.clone(), metrics));
191        Ok(record)
192    }
193
194    async fn remove(&self, id: Uuid) -> ProxyResult<()> {
195        self.inner
196            .write()
197            .await
198            .remove(&id)
199            .map(|_| ())
200            .ok_or_else(|| crate::error::ProxyError::StorageError(format!("proxy {id} not found")))
201    }
202
203    async fn list(&self) -> ProxyResult<Vec<ProxyRecord>> {
204        Ok(self.inner.read().await.values().map(|(r, _)| r.clone()).collect())
205    }
206
207    async fn get(&self, id: Uuid) -> ProxyResult<ProxyRecord> {
208        self.inner
209            .read()
210            .await
211            .get(&id)
212            .map(|(r, _)| r.clone())
213            .ok_or_else(|| crate::error::ProxyError::StorageError(format!("proxy {id} not found")))
214    }
215
216    async fn list_with_metrics(&self) -> ProxyResult<Vec<(ProxyRecord, Arc<ProxyMetrics>)>> {
217        Ok(self.inner.read().await.values().map(|(r, m)| (r.clone(), Arc::clone(m))).collect())
218    }
219
220    async fn update_metrics(&self, id: Uuid, success: bool, latency_ms: u64) -> ProxyResult<()> {
221        use std::sync::atomic::Ordering;
222
223        let metrics = self
224            .inner
225            .read()
226            .await
227            .get(&id)
228            .map(|(_, m)| Arc::clone(m))
229            .ok_or_else(|| crate::error::ProxyError::StorageError(format!("proxy {id} not found")))?;
230
231        // Lock released before the atomic updates — no long critical section.
232        metrics.requests_total.fetch_add(1, Ordering::Relaxed);
233        if success {
234            metrics.successes.fetch_add(1, Ordering::Relaxed);
235        } else {
236            metrics.failures.fetch_add(1, Ordering::Relaxed);
237        }
238        metrics.total_latency_ms.fetch_add(latency_ms, Ordering::Relaxed);
239        Ok(())
240    }
241}
242
243// ─────────────────────────────────────────────────────────────────────────────
244// Tests
245// ─────────────────────────────────────────────────────────────────────────────
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::types::ProxyType;
251    use std::sync::atomic::Ordering;
252
253    fn make_proxy(url: &str) -> Proxy {
254        Proxy {
255            url: url.into(),
256            proxy_type: ProxyType::Http,
257            username: None,
258            password: None,
259            weight: 1,
260            tags: vec![],
261        }
262    }
263
264    #[tokio::test]
265    async fn add_list_remove() {
266        let store = MemoryProxyStore::default();
267        let r1 = store.add(make_proxy("http://a.test:8080")).await.unwrap();
268        let r2 = store.add(make_proxy("http://b.test:8080")).await.unwrap();
269        let r3 = store.add(make_proxy("http://c.test:8080")).await.unwrap();
270        assert_eq!(store.list().await.unwrap().len(), 3);
271        store.remove(r2.id).await.unwrap();
272        let remaining = store.list().await.unwrap();
273        assert_eq!(remaining.len(), 2);
274        let ids: Vec<_> = remaining.iter().map(|r| r.id).collect();
275        assert!(ids.contains(&r1.id));
276        assert!(ids.contains(&r3.id));
277    }
278
279    #[tokio::test]
280    async fn invalid_url_rejected() {
281        let store = MemoryProxyStore::default();
282        let err = store.add(make_proxy("not-a-url")).await.unwrap_err();
283        assert!(matches!(err, crate::error::ProxyError::InvalidProxyUrl { .. }));
284    }
285
286    #[tokio::test]
287    async fn invalid_url_empty_host() {
288        let store = MemoryProxyStore::default();
289        let err = store.add(make_proxy("http://:8080")).await.unwrap_err();
290        assert!(matches!(err, crate::error::ProxyError::InvalidProxyUrl { .. }));
291    }
292
293    #[tokio::test]
294    async fn concurrent_metrics_updates() {
295        use tokio::task::JoinSet;
296
297        let store = Arc::new(MemoryProxyStore::default());
298        let record = store.add(make_proxy("http://proxy.test:3128")).await.unwrap();
299        let id = record.id;
300
301        let mut tasks = JoinSet::new();
302        for i in 0u64..50 {
303            let s = Arc::clone(&store);
304            tasks.spawn(async move {
305                s.update_metrics(id, i % 2 == 0, i * 10).await.unwrap();
306            });
307        }
308        while let Some(res) = tasks.join_next().await {
309            res.unwrap();
310        }
311
312        // Verify totals are internally consistent.
313        let guard = store.inner.read().await;
314        let metrics = guard.get(&id).map(|(_, m)| Arc::clone(m)).unwrap();
315        drop(guard);
316
317        let total = metrics.requests_total.load(Ordering::Relaxed);
318        let successes = metrics.successes.load(Ordering::Relaxed);
319        let failures = metrics.failures.load(Ordering::Relaxed);
320        assert_eq!(total, 50);
321        assert_eq!(successes + failures, 50);
322    }
323}