1use std::collections::HashMap;
10use std::sync::RwLock;
11
12use async_trait::async_trait;
13use uuid::Uuid;
14
15use crate::{Route, RouteAction, RouteConditions};
16
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
19pub struct NewRoute {
20 pub name: String,
21 #[serde(default = "default_priority")]
22 pub priority: u32,
23 #[serde(default = "default_enabled")]
24 pub enabled: bool,
25 #[serde(default)]
26 pub when: RouteConditions,
27 pub then: RouteAction,
28}
29
30fn default_priority() -> u32 {
31 100
32}
33fn default_enabled() -> bool {
34 true
35}
36
37#[async_trait]
42pub trait RoutingStore: Send + Sync + std::fmt::Debug {
43 async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError>;
46
47 async fn list_all_for_org(&self, _org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
49 Err(RoutingStoreError::Backend(
50 "management unsupported by this store".into(),
51 ))
52 }
53 async fn create_route(
55 &self,
56 _org_id: Uuid,
57 _spec: NewRoute,
58 ) -> Result<Route, RoutingStoreError> {
59 Err(RoutingStoreError::Backend(
60 "management unsupported by this store".into(),
61 ))
62 }
63 async fn get_route(
65 &self,
66 _org_id: Uuid,
67 _id: Uuid,
68 ) -> Result<Option<Route>, RoutingStoreError> {
69 Err(RoutingStoreError::Backend(
70 "management unsupported by this store".into(),
71 ))
72 }
73 async fn delete_route(&self, _org_id: Uuid, _id: Uuid) -> Result<bool, RoutingStoreError> {
75 Err(RoutingStoreError::Backend(
76 "management unsupported by this store".into(),
77 ))
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum RoutingStoreError {
83 #[error("backend error: {0}")]
84 Backend(String),
85}
86
87#[derive(Debug, Default)]
90pub struct InMemoryRoutingStore {
91 inner: RwLock<HashMap<Uuid, Vec<Route>>>,
92}
93
94impl InMemoryRoutingStore {
95 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub fn set_routes(&self, org_id: Uuid, routes: Vec<Route>) {
101 let mut g = self.inner.write().expect("inmemory routing store poisoned");
102 g.insert(org_id, routes);
103 }
104}
105
106#[async_trait]
107impl RoutingStore for InMemoryRoutingStore {
108 async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
109 let g = self.inner.read().expect("inmemory routing store poisoned");
110 Ok(g.get(&org_id).cloned().unwrap_or_default())
111 }
112
113 async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
114 let g = self.inner.read().expect("inmemory routing store poisoned");
115 Ok(g.get(&org_id).cloned().unwrap_or_default())
116 }
117
118 async fn create_route(&self, org_id: Uuid, spec: NewRoute) -> Result<Route, RoutingStoreError> {
119 let route = Route {
120 id: Uuid::now_v7(),
121 name: spec.name,
122 priority: spec.priority,
123 enabled: spec.enabled,
124 when: spec.when,
125 then: spec.then,
126 };
127 let mut g = self.inner.write().expect("inmemory routing store poisoned");
128 g.entry(org_id).or_default().push(route.clone());
129 Ok(route)
130 }
131
132 async fn get_route(&self, org_id: Uuid, id: Uuid) -> Result<Option<Route>, RoutingStoreError> {
133 let g = self.inner.read().expect("inmemory routing store poisoned");
134 Ok(g.get(&org_id)
135 .and_then(|v| v.iter().find(|r| r.id == id).cloned()))
136 }
137
138 async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
139 let mut g = self.inner.write().expect("inmemory routing store poisoned");
140 let Some(v) = g.get_mut(&org_id) else {
141 return Ok(false);
142 };
143 let before = v.len();
144 v.retain(|r| r.id != id);
145 Ok(v.len() != before)
146 }
147}
148
149#[cfg(feature = "postgres")]
150mod pg {
151 use super::*;
152 use crate::{RouteAction, RouteConditions};
153 use sqlx::PgPool;
154
155 #[derive(Clone, Debug)]
177 pub struct PostgresRoutingStore {
178 pool: PgPool,
179 }
180
181 impl PostgresRoutingStore {
182 pub fn new(pool: PgPool) -> Self {
183 Self { pool }
184 }
185 }
186
187 #[async_trait]
188 impl RoutingStore for PostgresRoutingStore {
189 async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
190 let rows = sqlx::query_as::<_, RouteRow>(
191 "SELECT id, name, priority, conditions, target \
192 FROM routes \
193 WHERE org_id = $1 AND enabled = TRUE \
194 ORDER BY priority DESC, created_at ASC",
195 )
196 .bind(org_id)
197 .fetch_all(&self.pool)
198 .await
199 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
200
201 Ok(rows.into_iter().filter_map(RouteRow::into_route).collect())
202 }
203
204 async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
205 let rows = sqlx::query_as::<_, MgmtRouteRow>(
206 "SELECT id, name, priority, enabled, conditions, target \
207 FROM routes WHERE org_id = $1 ORDER BY priority DESC, created_at ASC",
208 )
209 .bind(org_id)
210 .fetch_all(&self.pool)
211 .await
212 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
213 Ok(rows
214 .into_iter()
215 .filter_map(MgmtRouteRow::into_route)
216 .collect())
217 }
218
219 async fn create_route(
220 &self,
221 org_id: Uuid,
222 spec: crate::store::NewRoute,
223 ) -> Result<Route, RoutingStoreError> {
224 let conditions = serde_json::to_value(&spec.when)
225 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
226 let target = serde_json::to_value(&spec.then)
227 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
228 let row = sqlx::query_as::<_, MgmtRouteRow>(
229 "INSERT INTO routes (org_id, name, priority, conditions, target, enabled) \
230 VALUES ($1, $2, $3, $4, $5, $6) \
231 RETURNING id, name, priority, enabled, conditions, target",
232 )
233 .bind(org_id)
234 .bind(&spec.name)
235 .bind(i32::try_from(spec.priority).unwrap_or(i32::MAX))
236 .bind(&conditions)
237 .bind(&target)
238 .bind(spec.enabled)
239 .fetch_one(&self.pool)
240 .await
241 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
242 row.into_route()
243 .ok_or_else(|| RoutingStoreError::Backend("created route failed to decode".into()))
244 }
245
246 async fn get_route(
247 &self,
248 org_id: Uuid,
249 id: Uuid,
250 ) -> Result<Option<Route>, RoutingStoreError> {
251 let row = sqlx::query_as::<_, MgmtRouteRow>(
252 "SELECT id, name, priority, enabled, conditions, target \
253 FROM routes WHERE org_id = $1 AND id = $2",
254 )
255 .bind(org_id)
256 .bind(id)
257 .fetch_optional(&self.pool)
258 .await
259 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
260 Ok(row.and_then(MgmtRouteRow::into_route))
261 }
262
263 async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
264 let res = sqlx::query("DELETE FROM routes WHERE org_id = $1 AND id = $2")
265 .bind(org_id)
266 .bind(id)
267 .execute(&self.pool)
268 .await
269 .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
270 Ok(res.rows_affected() > 0)
271 }
272 }
273
274 #[derive(sqlx::FromRow)]
275 struct RouteRow {
276 id: Uuid,
277 name: String,
278 priority: i32,
279 conditions: sqlx::types::Json<serde_json::Value>,
280 target: sqlx::types::Json<serde_json::Value>,
281 }
282
283 impl RouteRow {
284 fn into_route(self) -> Option<Route> {
285 let when = match serde_json::from_value::<RouteConditions>(self.conditions.0) {
286 Ok(c) => c,
287 Err(e) => {
288 tracing::warn!(route_id = %self.id, error = %e, "skipping route — conditions JSON failed to decode");
289 return None;
290 }
291 };
292 let then = match serde_json::from_value::<RouteAction>(self.target.0) {
293 Ok(t) => t,
294 Err(e) => {
295 tracing::warn!(route_id = %self.id, error = %e, "skipping route — target JSON failed to decode");
296 return None;
297 }
298 };
299 Some(Route {
300 id: self.id,
301 name: self.name,
302 priority: u32::try_from(self.priority).unwrap_or(0),
303 enabled: true,
304 when,
305 then,
306 })
307 }
308 }
309
310 #[derive(sqlx::FromRow)]
312 struct MgmtRouteRow {
313 id: Uuid,
314 name: String,
315 priority: i32,
316 enabled: bool,
317 conditions: sqlx::types::Json<serde_json::Value>,
318 target: sqlx::types::Json<serde_json::Value>,
319 }
320
321 impl MgmtRouteRow {
322 fn into_route(self) -> Option<Route> {
323 let when = serde_json::from_value::<RouteConditions>(self.conditions.0).ok()?;
324 let then = serde_json::from_value::<RouteAction>(self.target.0).ok()?;
325 Some(Route {
326 id: self.id,
327 name: self.name,
328 priority: u32::try_from(self.priority).unwrap_or(0),
329 enabled: self.enabled,
330 when,
331 then,
332 })
333 }
334 }
335}
336
337#[cfg(feature = "postgres")]
338pub use pg::PostgresRoutingStore;
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 #[allow(unused_imports)]
344 use crate::Route;
345 use crate::{RouteAction, RouteConditions};
346
347 fn route(name: &str, priority: u32, target: &str) -> Route {
348 Route {
349 id: Uuid::now_v7(),
350 name: name.into(),
351 priority,
352 enabled: true,
353 when: RouteConditions::default(),
354 then: RouteAction {
355 target_model: target.into(),
356 fallbacks: Vec::new(),
357 disable_cache: false,
358 max_cost_usd: None,
359 },
360 }
361 }
362
363 #[tokio::test]
364 async fn in_memory_returns_empty_for_unknown_org() {
365 let s = InMemoryRoutingStore::new();
366 let rs = s.list_for_org(Uuid::now_v7()).await.unwrap();
367 assert!(rs.is_empty());
368 }
369
370 #[tokio::test]
371 async fn in_memory_set_and_fetch_round_trips() {
372 let s = InMemoryRoutingStore::new();
373 let org = Uuid::now_v7();
374 s.set_routes(org, vec![route("a", 10, "m1"), route("b", 5, "m2")]);
375 let rs = s.list_for_org(org).await.unwrap();
376 assert_eq!(rs.len(), 2);
377 }
378
379 #[tokio::test]
380 async fn in_memory_create_list_get_delete() {
381 let s = InMemoryRoutingStore::new();
382 let org = Uuid::now_v7();
383 let spec = NewRoute {
384 name: "pin".into(),
385 priority: 100,
386 enabled: true,
387 when: RouteConditions::default(),
388 then: RouteAction {
389 target_model: "m1".into(),
390 fallbacks: vec![],
391 disable_cache: false,
392 max_cost_usd: None,
393 },
394 };
395 let created = s.create_route(org, spec).await.unwrap();
396 assert_eq!(created.name, "pin");
397
398 let all = s.list_all_for_org(org).await.unwrap();
399 assert_eq!(all.len(), 1);
400
401 let got = s.get_route(org, created.id).await.unwrap();
402 assert_eq!(got.unwrap().id, created.id);
403
404 assert!(s.delete_route(org, created.id).await.unwrap());
405 assert!(s.get_route(org, created.id).await.unwrap().is_none());
406 assert!(!s.delete_route(org, created.id).await.unwrap());
407 }
408
409 #[tokio::test]
410 async fn in_memory_management_is_org_scoped() {
411 let s = InMemoryRoutingStore::new();
412 let org_a = Uuid::now_v7();
413 let org_b = Uuid::now_v7();
414 let created = s
415 .create_route(
416 org_a,
417 NewRoute {
418 name: "a".into(),
419 priority: 1,
420 enabled: true,
421 when: RouteConditions::default(),
422 then: RouteAction {
423 target_model: "m".into(),
424 fallbacks: vec![],
425 disable_cache: false,
426 max_cost_usd: None,
427 },
428 },
429 )
430 .await
431 .unwrap();
432 assert!(s.get_route(org_b, created.id).await.unwrap().is_none());
433 assert!(!s.delete_route(org_b, created.id).await.unwrap());
434 assert_eq!(s.list_all_for_org(org_b).await.unwrap().len(), 0);
435 }
436}