Skip to main content

stygian_proxy/storage/
mod.rs

1//! Storage port and in-memory adapter for proxy records.
2
3use async_trait::async_trait;
4use uuid::Uuid;
5
6use crate::error::ProxyResult;
7use crate::types::{Proxy, ProxyRecord};
8
9/// Abstract storage interface for persisting and querying proxy records.
10///
11/// Implementors must be `Send + Sync + 'static` to support concurrent access
12/// across async tasks. The trait is object-safe via [`async_trait`].
13///
14/// # Example
15/// ```rust,no_run
16/// use stygian_proxy::storage::ProxyStoragePort;
17/// use stygian_proxy::types::{Proxy, ProxyType};
18/// use uuid::Uuid;
19///
20/// async fn demo(store: &dyn ProxyStoragePort) {
21///     let proxy = Proxy {
22///         url: "http://proxy.example.com:8080".into(),
23///         proxy_type: ProxyType::Http,
24///         username: None,
25///         password: None,
26///         weight: 1,
27///         tags: vec![],
28///     };
29///     let record = store.add(proxy).await.unwrap();
30///     let _ = store.get(record.id).await.unwrap();
31/// }
32/// ```
33#[async_trait]
34pub trait ProxyStoragePort: Send + Sync + 'static {
35    /// Add a new proxy to the store and return its [`ProxyRecord`].
36    async fn add(&self, proxy: Proxy) -> ProxyResult<ProxyRecord>;
37
38    /// Remove a proxy by its UUID. Returns an error if the ID is not found.
39    async fn remove(&self, id: Uuid) -> ProxyResult<()>;
40
41    /// Return all stored proxy records.
42    async fn list(&self) -> ProxyResult<Vec<ProxyRecord>>;
43
44    /// Fetch a single proxy record by UUID.
45    async fn get(&self, id: Uuid) -> ProxyResult<ProxyRecord>;
46
47    /// Record the outcome of a request through a proxy.
48    ///
49    /// - `success`: whether the request succeeded.
50    /// - `latency_ms`: elapsed time in milliseconds.
51    async fn update_metrics(&self, id: Uuid, success: bool, latency_ms: u64) -> ProxyResult<()>;
52
53    /// Return all stored proxy records paired with their live metrics reference.
54    ///
55    /// Used by [`ProxyManager`](crate::manager::ProxyManager) when building
56    /// [`ProxyCandidate`](crate::strategy::ProxyCandidate) slices so that
57    /// latency-aware strategies (e.g. least-used) see up-to-date counters.
58    async fn list_with_metrics(&self) -> ProxyResult<Vec<(ProxyRecord, Arc<ProxyMetrics>)>>;
59}
60
61/// Convenience alias for a heap-allocated, type-erased [`ProxyStoragePort`].
62pub type BoxedProxyStorage = Box<dyn ProxyStoragePort>;
63
64// ─────────────────────────────────────────────────────────────────────────────
65// URL validation helper
66// ─────────────────────────────────────────────────────────────────────────────
67
68/// Validate a proxy URL: scheme must be recognised, host must be non-empty,
69/// and the explicit port (if present) must be in [1, 65535].
70fn validate_proxy_url(url: &str) -> ProxyResult<()> {
71    use crate::error::ProxyError;
72
73    let (scheme, rest) = url
74        .split_once("://")
75        .ok_or_else(|| ProxyError::InvalidProxyUrl {
76            url: url.to_owned(),
77            reason: "missing scheme separator '://'".into(),
78        })?;
79
80    match scheme {
81        "http" | "https" => {}
82        #[cfg(feature = "socks")]
83        "socks4" | "socks5" => {}
84        other => {
85            return Err(ProxyError::InvalidProxyUrl {
86                url: url.to_owned(),
87                reason: format!("unsupported scheme '{other}'"),
88            });
89        }
90    }
91
92    // Strip any path/query, then strip user:pass@ if present.
93    let authority = rest.split('/').next().unwrap_or("");
94    let host_and_port = authority.split('@').next_back().unwrap_or("");
95
96    // Split host from port, handling IPv6 brackets.
97    let (host, port_str) = if host_and_port.starts_with('[') {
98        let close = host_and_port.find(']').unwrap_or(host_and_port.len());
99        let after = &host_and_port[close + 1..];
100        let port = after.strip_prefix(':').unwrap_or("");
101        (&host_and_port[..=close], port)
102    } else {
103        match host_and_port.rsplit_once(':') {
104            Some((h, p)) => (h, p),
105            None => (host_and_port, ""),
106        }
107    };
108
109    if host.is_empty() || host == "[]" {
110        return Err(ProxyError::InvalidProxyUrl {
111            url: url.to_owned(),
112            reason: "empty host".into(),
113        });
114    }
115
116    if !port_str.is_empty() {
117        let port: u32 = port_str.parse().map_err(|_| ProxyError::InvalidProxyUrl {
118            url: url.to_owned(),
119            reason: format!("non-numeric port '{port_str}'"),
120        })?;
121        if port == 0 || port > 65535 {
122            return Err(ProxyError::InvalidProxyUrl {
123                url: url.to_owned(),
124                reason: format!("port {port} is out of range [1, 65535]"),
125            });
126        }
127    }
128
129    Ok(())
130}
131
132// ─────────────────────────────────────────────────────────────────────────────
133// MemoryProxyStore
134// ─────────────────────────────────────────────────────────────────────────────
135
136use std::collections::HashMap;
137use tokio::sync::RwLock;
138
139use crate::types::ProxyMetrics;
140use std::sync::Arc;
141
142type StoreMap = HashMap<Uuid, (ProxyRecord, Arc<ProxyMetrics>)>;
143
144/// In-memory implementation of [`ProxyStoragePort`].
145///
146/// Uses a `tokio::sync::RwLock`-guarded `HashMap` for thread-safe access.
147/// Metrics are updated via atomic operations, so only a **read** lock is
148/// needed for [`update_metrics`](MemoryProxyStore::update_metrics) calls —
149/// write contention stays low even under heavy concurrent load.
150///
151/// # Example
152/// ```
153/// # tokio_test::block_on(async {
154/// use stygian_proxy::storage::{MemoryProxyStore, ProxyStoragePort};
155/// use stygian_proxy::types::{Proxy, ProxyType};
156///
157/// let store = MemoryProxyStore::default();
158/// let proxy = Proxy { url: "http://proxy.example.com:8080".into(), proxy_type: ProxyType::Http,
159///                     username: None, password: None, weight: 1, tags: vec![] };
160/// let record = store.add(proxy).await.unwrap();
161/// assert_eq!(store.list().await.unwrap().len(), 1);
162/// store.remove(record.id).await.unwrap();
163/// assert!(store.list().await.unwrap().is_empty());
164/// # })
165/// ```
166#[derive(Debug, Default, Clone)]
167pub struct MemoryProxyStore {
168    inner: Arc<RwLock<StoreMap>>,
169}
170
171impl MemoryProxyStore {
172    /// Build a store pre-populated with `proxies`, validating each URL.
173    ///
174    /// Returns an error on the first invalid URL encountered.
175    pub async fn with_proxies(proxies: Vec<Proxy>) -> ProxyResult<Self> {
176        let store = Self::default();
177        for proxy in proxies {
178            store.add(proxy).await?;
179        }
180        Ok(store)
181    }
182}
183
184#[async_trait]
185impl ProxyStoragePort for MemoryProxyStore {
186    async fn add(&self, proxy: Proxy) -> ProxyResult<ProxyRecord> {
187        validate_proxy_url(&proxy.url)?;
188        let record = ProxyRecord::new(proxy);
189        let metrics = Arc::new(ProxyMetrics::default());
190        self.inner
191            .write()
192            .await
193            .insert(record.id, (record.clone(), metrics));
194        Ok(record)
195    }
196
197    async fn remove(&self, id: Uuid) -> ProxyResult<()> {
198        self.inner
199            .write()
200            .await
201            .remove(&id)
202            .map(|_| ())
203            .ok_or_else(|| crate::error::ProxyError::StorageError(format!("proxy {id} not found")))
204    }
205
206    async fn list(&self) -> ProxyResult<Vec<ProxyRecord>> {
207        Ok(self
208            .inner
209            .read()
210            .await
211            .values()
212            .map(|(r, _)| r.clone())
213            .collect())
214    }
215
216    async fn get(&self, id: Uuid) -> ProxyResult<ProxyRecord> {
217        self.inner
218            .read()
219            .await
220            .get(&id)
221            .map(|(r, _)| r.clone())
222            .ok_or_else(|| crate::error::ProxyError::StorageError(format!("proxy {id} not found")))
223    }
224
225    async fn list_with_metrics(&self) -> ProxyResult<Vec<(ProxyRecord, Arc<ProxyMetrics>)>> {
226        Ok(self
227            .inner
228            .read()
229            .await
230            .values()
231            .map(|(r, m)| (r.clone(), Arc::clone(m)))
232            .collect())
233    }
234
235    async fn update_metrics(&self, id: Uuid, success: bool, latency_ms: u64) -> ProxyResult<()> {
236        use std::sync::atomic::Ordering;
237
238        let metrics = self
239            .inner
240            .read()
241            .await
242            .get(&id)
243            .map(|(_, m)| Arc::clone(m))
244            .ok_or_else(|| {
245                crate::error::ProxyError::StorageError(format!("proxy {id} not found"))
246            })?;
247
248        // Lock released before the atomic updates — no long critical section.
249        metrics.requests_total.fetch_add(1, Ordering::Relaxed);
250        if success {
251            metrics.successes.fetch_add(1, Ordering::Relaxed);
252        } else {
253            metrics.failures.fetch_add(1, Ordering::Relaxed);
254        }
255        metrics
256            .total_latency_ms
257            .fetch_add(latency_ms, Ordering::Relaxed);
258        Ok(())
259    }
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Tests
264// ─────────────────────────────────────────────────────────────────────────────
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::types::ProxyType;
270    use std::sync::atomic::Ordering;
271
272    fn make_proxy(url: &str) -> Proxy {
273        Proxy {
274            url: url.into(),
275            proxy_type: ProxyType::Http,
276            username: None,
277            password: None,
278            weight: 1,
279            tags: vec![],
280        }
281    }
282
283    #[tokio::test]
284    async fn add_list_remove() {
285        let store = MemoryProxyStore::default();
286        let r1 = store.add(make_proxy("http://a.test:8080")).await.unwrap();
287        let r2 = store.add(make_proxy("http://b.test:8080")).await.unwrap();
288        let r3 = store.add(make_proxy("http://c.test:8080")).await.unwrap();
289        assert_eq!(store.list().await.unwrap().len(), 3);
290        store.remove(r2.id).await.unwrap();
291        let remaining = store.list().await.unwrap();
292        assert_eq!(remaining.len(), 2);
293        let ids: Vec<_> = remaining.iter().map(|r| r.id).collect();
294        assert!(ids.contains(&r1.id));
295        assert!(ids.contains(&r3.id));
296    }
297
298    #[tokio::test]
299    async fn invalid_url_rejected() {
300        let store = MemoryProxyStore::default();
301        let err = store.add(make_proxy("not-a-url")).await.unwrap_err();
302        assert!(matches!(
303            err,
304            crate::error::ProxyError::InvalidProxyUrl { .. }
305        ));
306    }
307
308    #[tokio::test]
309    async fn invalid_url_empty_host() {
310        let store = MemoryProxyStore::default();
311        let err = store.add(make_proxy("http://:8080")).await.unwrap_err();
312        assert!(matches!(
313            err,
314            crate::error::ProxyError::InvalidProxyUrl { .. }
315        ));
316    }
317
318    #[tokio::test]
319    async fn concurrent_metrics_updates() {
320        use tokio::task::JoinSet;
321
322        let store = Arc::new(MemoryProxyStore::default());
323        let record = store
324            .add(make_proxy("http://proxy.test:3128"))
325            .await
326            .unwrap();
327        let id = record.id;
328
329        let mut tasks = JoinSet::new();
330        for i in 0u64..50 {
331            let s = Arc::clone(&store);
332            tasks.spawn(async move {
333                s.update_metrics(id, i % 2 == 0, i * 10).await.unwrap();
334            });
335        }
336        while let Some(res) = tasks.join_next().await {
337            res.unwrap();
338        }
339
340        // Verify totals are internally consistent.
341        let guard = store.inner.read().await;
342        let metrics = guard.get(&id).map(|(_, m)| Arc::clone(m)).unwrap();
343        drop(guard);
344
345        let total = metrics.requests_total.load(Ordering::Relaxed);
346        let successes = metrics.successes.load(Ordering::Relaxed);
347        let failures = metrics.failures.load(Ordering::Relaxed);
348        assert_eq!(total, 50);
349        assert_eq!(successes + failures, 50);
350    }
351}