Skip to main content

winterbaume_core/
state.rs

1//! Per-account, per-region backend state management.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6/// Default mock account ID.
7pub const DEFAULT_ACCOUNT_ID: &str = "123456789012";
8
9/// Manages per-account, per-region state for a service backend.
10///
11/// Modeled after moto's `BackendDict` pattern: backends are lazily
12/// initialized on first access for each (account_id, region) pair.
13///
14/// Uses `tokio::sync::RwLock` for per-region state so locks can be held
15/// across `.await` points (e.g. during blob-backed snapshot/restore).
16pub struct BackendState<B: Default + Send + Sync> {
17    inner: RwLock<HashMap<(String, String), Arc<tokio::sync::RwLock<B>>>>,
18}
19
20impl<B: Default + Send + Sync> BackendState<B> {
21    pub fn new() -> Self {
22        Self {
23            inner: RwLock::new(HashMap::new()),
24        }
25    }
26
27    /// Get or create the backend state for the given account and region.
28    pub fn get(&self, account_id: &str, region: &str) -> Arc<tokio::sync::RwLock<B>> {
29        let key = (account_id.to_string(), region.to_string());
30
31        // Fast path: read lock on the outer map (std sync — brief, no await)
32        {
33            let map = self.inner.read().unwrap();
34            if let Some(backend) = map.get(&key) {
35                return Arc::clone(backend);
36            }
37        }
38
39        // Slow path: write lock on the outer map, create if missing
40        let mut map = self.inner.write().unwrap();
41        Arc::clone(
42            map.entry(key)
43                .or_insert_with(|| Arc::new(tokio::sync::RwLock::new(B::default()))),
44        )
45    }
46
47    /// Returns sorted `(account_id, region)` pairs that have state.
48    ///
49    /// Read-only: does not create empty backends (unlike [`get()`](Self::get)).
50    pub fn scopes_with_state(&self) -> Vec<(String, String)> {
51        let map = self.inner.read().unwrap();
52        let mut scopes: Vec<(String, String)> = map.keys().cloned().collect();
53        scopes.sort();
54        scopes
55    }
56
57    /// Reset all state (clear all backends).
58    pub fn reset(&self) {
59        let mut map = self.inner.write().unwrap();
60        map.clear();
61    }
62}
63
64impl<B: Default + Send + Sync> Default for BackendState<B> {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl<B: Default + Send + Sync> FromIterator<((String, String), B)> for BackendState<B> {
71    fn from_iter<T>(iter: T) -> Self
72    where
73        T: IntoIterator<Item = ((String, String), B)>,
74    {
75        Self {
76            inner: RwLock::new(HashMap::from_iter(
77                iter.into_iter()
78                    .map(|pair| (pair.0, Arc::new(tokio::sync::RwLock::new(pair.1)))),
79            )),
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[derive(Default)]
89    struct TestState {
90        counter: u32,
91    }
92
93    #[tokio::test]
94    async fn test_get_creates_default() {
95        let state = BackendState::<TestState>::new();
96        let backend = state.get("123456789012", "us-east-1");
97        assert_eq!(backend.read().await.counter, 0);
98    }
99
100    #[tokio::test]
101    async fn test_get_returns_same_instance() {
102        let state = BackendState::<TestState>::new();
103        let b1 = state.get("123456789012", "us-east-1");
104        b1.write().await.counter = 42;
105        let b2 = state.get("123456789012", "us-east-1");
106        assert_eq!(b2.read().await.counter, 42);
107    }
108
109    #[tokio::test]
110    async fn test_different_regions_different_state() {
111        let state = BackendState::<TestState>::new();
112        let b1 = state.get("123456789012", "us-east-1");
113        b1.write().await.counter = 10;
114        let b2 = state.get("123456789012", "eu-west-1");
115        assert_eq!(b2.read().await.counter, 0);
116    }
117
118    #[tokio::test]
119    async fn test_reset_clears_all() {
120        let state = BackendState::<TestState>::new();
121        let b = state.get("123456789012", "us-east-1");
122        b.write().await.counter = 99;
123        state.reset();
124        let b2 = state.get("123456789012", "us-east-1");
125        assert_eq!(b2.read().await.counter, 0);
126    }
127}