winterbaume_core/
state.rs1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6pub const DEFAULT_ACCOUNT_ID: &str = "123456789012";
8
9pub struct BackendState<B: Default + Send + Sync> {
17 inner: RwLock<HashMap<(String, String), Arc<tokio::sync::RwLock<B>>>>,
18}
19
20impl<B: Default + Send + Sync> BackendState<B> {
21 pub fn new() -> Self {
22 Self {
23 inner: RwLock::new(HashMap::new()),
24 }
25 }
26
27 pub fn get(&self, account_id: &str, region: &str) -> Arc<tokio::sync::RwLock<B>> {
29 let key = (account_id.to_string(), region.to_string());
30
31 {
33 let map = self.inner.read().unwrap();
34 if let Some(backend) = map.get(&key) {
35 return Arc::clone(backend);
36 }
37 }
38
39 let mut map = self.inner.write().unwrap();
41 Arc::clone(
42 map.entry(key)
43 .or_insert_with(|| Arc::new(tokio::sync::RwLock::new(B::default()))),
44 )
45 }
46
47 pub fn scopes_with_state(&self) -> Vec<(String, String)> {
51 let map = self.inner.read().unwrap();
52 let mut scopes: Vec<(String, String)> = map.keys().cloned().collect();
53 scopes.sort();
54 scopes
55 }
56
57 pub fn reset(&self) {
59 let mut map = self.inner.write().unwrap();
60 map.clear();
61 }
62}
63
64impl<B: Default + Send + Sync> Default for BackendState<B> {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl<B: Default + Send + Sync> FromIterator<((String, String), B)> for BackendState<B> {
71 fn from_iter<T>(iter: T) -> Self
72 where
73 T: IntoIterator<Item = ((String, String), B)>,
74 {
75 Self {
76 inner: RwLock::new(HashMap::from_iter(
77 iter.into_iter()
78 .map(|pair| (pair.0, Arc::new(tokio::sync::RwLock::new(pair.1)))),
79 )),
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[derive(Default)]
89 struct TestState {
90 counter: u32,
91 }
92
93 #[tokio::test]
94 async fn test_get_creates_default() {
95 let state = BackendState::<TestState>::new();
96 let backend = state.get("123456789012", "us-east-1");
97 assert_eq!(backend.read().await.counter, 0);
98 }
99
100 #[tokio::test]
101 async fn test_get_returns_same_instance() {
102 let state = BackendState::<TestState>::new();
103 let b1 = state.get("123456789012", "us-east-1");
104 b1.write().await.counter = 42;
105 let b2 = state.get("123456789012", "us-east-1");
106 assert_eq!(b2.read().await.counter, 42);
107 }
108
109 #[tokio::test]
110 async fn test_different_regions_different_state() {
111 let state = BackendState::<TestState>::new();
112 let b1 = state.get("123456789012", "us-east-1");
113 b1.write().await.counter = 10;
114 let b2 = state.get("123456789012", "eu-west-1");
115 assert_eq!(b2.read().await.counter, 0);
116 }
117
118 #[tokio::test]
119 async fn test_reset_clears_all() {
120 let state = BackendState::<TestState>::new();
121 let b = state.get("123456789012", "us-east-1");
122 b.write().await.counter = 99;
123 state.reset();
124 let b2 = state.get("123456789012", "us-east-1");
125 assert_eq!(b2.read().await.counter, 0);
126 }
127}