1use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use async_trait::async_trait;
18use uuid::Uuid;
19
20use crate::store::{RoutingStore, RoutingStoreError};
21use crate::{Route, RoutingEngine};
22
23pub const DEFAULT_TTL: Duration = Duration::from_secs(60);
25
26#[derive(Debug)]
27struct Cached {
28 engine: Arc<RoutingEngine>,
29 expires_at: Instant,
30}
31
32#[derive(Debug)]
38pub struct CachingRoutingStore {
39 inner: Arc<dyn RoutingStore>,
40 ttl: Duration,
41 cache: tokio::sync::RwLock<HashMap<Uuid, Cached>>,
42}
43
44impl CachingRoutingStore {
45 pub fn new(inner: Arc<dyn RoutingStore>) -> Self {
46 Self::with_ttl(inner, DEFAULT_TTL)
47 }
48
49 pub fn with_ttl(inner: Arc<dyn RoutingStore>, ttl: Duration) -> Self {
50 Self {
51 inner,
52 ttl,
53 cache: tokio::sync::RwLock::new(HashMap::new()),
54 }
55 }
56
57 pub async fn engine_for(&self, org_id: Uuid) -> Result<Arc<RoutingEngine>, RoutingStoreError> {
61 {
63 let g = self.cache.read().await;
64 if let Some(entry) = g.get(&org_id) {
65 if entry.expires_at > Instant::now() {
66 return Ok(Arc::clone(&entry.engine));
67 }
68 }
69 }
70
71 let routes = self.inner.list_for_org(org_id).await?;
73 let engine = Arc::new(RoutingEngine::with_routes(routes));
74 let mut g = self.cache.write().await;
75 g.insert(
76 org_id,
77 Cached {
78 engine: Arc::clone(&engine),
79 expires_at: Instant::now() + self.ttl,
80 },
81 );
82 Ok(engine)
83 }
84
85 pub async fn invalidate(&self, org_id: Uuid) {
88 let mut g = self.cache.write().await;
89 g.remove(&org_id);
90 }
91}
92
93#[async_trait]
94impl RoutingStore for CachingRoutingStore {
95 async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
96 let engine = self.engine_for(org_id).await?;
97 Ok(engine.routes().to_vec())
98 }
99
100 async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
101 self.inner.list_all_for_org(org_id).await
102 }
103
104 async fn create_route(
105 &self,
106 org_id: Uuid,
107 spec: crate::store::NewRoute,
108 ) -> Result<Route, RoutingStoreError> {
109 let created = self.inner.create_route(org_id, spec).await?;
110 self.invalidate(org_id).await;
111 Ok(created)
112 }
113
114 async fn get_route(&self, org_id: Uuid, id: Uuid) -> Result<Option<Route>, RoutingStoreError> {
115 self.inner.get_route(org_id, id).await
116 }
117
118 async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
119 let removed = self.inner.delete_route(org_id, id).await?;
120 if removed {
121 self.invalidate(org_id).await;
122 }
123 Ok(removed)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::store::InMemoryRoutingStore;
131 use crate::{RouteAction, RouteConditions};
132
133 fn route(name: &str, target: &str) -> Route {
134 Route {
135 id: Uuid::now_v7(),
136 name: name.into(),
137 priority: 10,
138 enabled: true,
139 when: RouteConditions::default(),
140 then: RouteAction {
141 target_model: target.into(),
142 fallbacks: Vec::new(),
143 disable_cache: false,
144 max_cost_usd: None,
145 },
146 }
147 }
148
149 #[tokio::test]
150 async fn caches_within_ttl() {
151 let backing = Arc::new(InMemoryRoutingStore::new());
152 let org = Uuid::now_v7();
153 backing.set_routes(org, vec![route("a", "m1")]);
154
155 let cache = CachingRoutingStore::with_ttl(
156 backing.clone() as Arc<dyn RoutingStore>,
157 Duration::from_secs(60),
158 );
159
160 let e1 = cache.engine_for(org).await.unwrap();
161 backing.set_routes(org, vec![route("b", "m2"), route("c", "m3")]);
163 let e2 = cache.engine_for(org).await.unwrap();
164 assert!(Arc::ptr_eq(&e1, &e2));
166 assert_eq!(e2.routes().len(), 1);
167 }
168
169 #[tokio::test]
170 async fn refreshes_after_ttl_expires() {
171 let backing = Arc::new(InMemoryRoutingStore::new());
172 let org = Uuid::now_v7();
173 backing.set_routes(org, vec![route("a", "m1")]);
174
175 let cache = CachingRoutingStore::with_ttl(
176 backing.clone() as Arc<dyn RoutingStore>,
177 Duration::from_millis(50),
178 );
179
180 let e1 = cache.engine_for(org).await.unwrap();
181 assert_eq!(e1.routes().len(), 1);
182
183 backing.set_routes(org, vec![route("b", "m2"), route("c", "m3")]);
185 tokio::time::sleep(Duration::from_millis(80)).await;
186 let e2 = cache.engine_for(org).await.unwrap();
187 assert_eq!(e2.routes().len(), 2);
188 }
189
190 #[tokio::test]
191 async fn invalidate_forces_refresh() {
192 let backing = Arc::new(InMemoryRoutingStore::new());
193 let org = Uuid::now_v7();
194 backing.set_routes(org, vec![route("a", "m1")]);
195
196 let cache = CachingRoutingStore::with_ttl(
197 backing.clone() as Arc<dyn RoutingStore>,
198 Duration::from_secs(3600),
199 );
200 let _ = cache.engine_for(org).await.unwrap();
201
202 backing.set_routes(org, vec![route("b", "m2")]);
203 cache.invalidate(org).await;
204 let e = cache.engine_for(org).await.unwrap();
205 assert_eq!(e.routes()[0].name, "b");
206 }
207
208 #[tokio::test]
209 async fn empty_org_caches_too() {
210 let backing = Arc::new(InMemoryRoutingStore::new());
211 let cache = CachingRoutingStore::with_ttl(
212 backing as Arc<dyn RoutingStore>,
213 Duration::from_secs(60),
214 );
215 let e = cache.engine_for(Uuid::now_v7()).await.unwrap();
216 assert!(e.routes().is_empty());
217 }
218
219 #[tokio::test]
220 async fn create_invalidates_so_engine_sees_it() {
221 let backing = Arc::new(InMemoryRoutingStore::new());
222 let org = Uuid::now_v7();
223 let cache = CachingRoutingStore::with_ttl(
224 backing as Arc<dyn RoutingStore>,
225 Duration::from_secs(3600), );
227 assert_eq!(cache.engine_for(org).await.unwrap().routes().len(), 0);
229 cache
231 .create_route(
232 org,
233 crate::store::NewRoute {
234 name: "x".into(),
235 priority: 10,
236 enabled: true,
237 when: RouteConditions::default(),
238 then: RouteAction {
239 target_model: "m".into(),
240 fallbacks: vec![],
241 disable_cache: false,
242 max_cost_usd: None,
243 },
244 },
245 )
246 .await
247 .unwrap();
248 assert_eq!(cache.engine_for(org).await.unwrap().routes().len(), 1);
250 }
251}