shaperail_runtime/handlers/
controller.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use shaperail_core::{ShaperailError, WASM_HOOK_PREFIX};
7
8use crate::auth::extractor::AuthenticatedUser;
9use crate::plugins::{PluginContext, PluginUser, WasmRuntime};
10
11pub struct Context {
28 pub input: serde_json::Map<String, serde_json::Value>,
30 pub data: Option<serde_json::Value>,
32 pub user: Option<AuthenticatedUser>,
34 pub pool: sqlx::PgPool,
36 pub headers: HashMap<String, String>,
38 pub response_headers: Vec<(String, String)>,
40 pub tenant_id: Option<String>,
43}
44
45pub type ControllerResult = Result<(), ShaperailError>;
47
48pub trait ControllerHandler: Send + Sync {
50 fn call<'a>(
51 &'a self,
52 ctx: &'a mut Context,
53 ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>>;
54}
55
56impl<F> ControllerHandler for F
68where
69 F: for<'a> AsyncControllerFn<'a> + Send + Sync,
70{
71 fn call<'a>(
72 &'a self,
73 ctx: &'a mut Context,
74 ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>> {
75 Box::pin(self.call_async(ctx))
76 }
77}
78
79pub trait AsyncControllerFn<'a> {
81 type Fut: Future<Output = ControllerResult> + Send + 'a;
82 fn call_async(&self, ctx: &'a mut Context) -> Self::Fut;
83}
84
85impl<'a, F, Fut> AsyncControllerFn<'a> for F
86where
87 F: Fn(&'a mut Context) -> Fut + Send + Sync,
88 Fut: Future<Output = ControllerResult> + Send + 'a,
89{
90 type Fut = Fut;
91 fn call_async(&self, ctx: &'a mut Context) -> Self::Fut {
92 (self)(ctx)
93 }
94}
95
96pub struct ControllerMap {
101 fns: HashMap<(String, String), Arc<dyn ControllerHandler>>,
102}
103
104impl ControllerMap {
105 pub fn new() -> Self {
107 Self {
108 fns: HashMap::new(),
109 }
110 }
111
112 pub fn register<F>(&mut self, resource: &str, name: &str, f: F)
114 where
115 F: ControllerHandler + 'static,
116 {
117 self.fns
118 .insert((resource.to_string(), name.to_string()), Arc::new(f));
119 }
120
121 pub async fn call(&self, resource: &str, name: &str, ctx: &mut Context) -> ControllerResult {
125 if let Some(f) = self.fns.get(&(resource.to_string(), name.to_string())) {
126 f.call(ctx).await
127 } else {
128 Err(ShaperailError::Internal(format!(
129 "Controller '{name}' not found for resource '{resource}'. \
130 Ensure the function exists in resources/{resource}.controller.rs"
131 )))
132 }
133 }
134
135 pub fn has(&self, resource: &str, name: &str) -> bool {
137 self.fns
138 .contains_key(&(resource.to_string(), name.to_string()))
139 }
140}
141
142impl Default for ControllerMap {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148pub async fn dispatch_controller(
153 name: &str,
154 resource: &str,
155 ctx: &mut Context,
156 controllers: Option<&ControllerMap>,
157 wasm_runtime: Option<&WasmRuntime>,
158) -> ControllerResult {
159 if let Some(wasm_path) = name.strip_prefix(WASM_HOOK_PREFIX) {
160 let runtime = wasm_runtime.ok_or_else(|| {
162 ShaperailError::Internal(
163 "WASM plugin declared but no WasmRuntime configured".to_string(),
164 )
165 })?;
166
167 let hook_name = if ctx.data.is_none() {
170 "before_hook"
171 } else {
172 "after_hook"
173 };
174
175 let plugin_ctx = PluginContext {
176 input: ctx.input.clone(),
177 data: ctx.data.clone(),
178 user: ctx.user.as_ref().map(|u| PluginUser {
179 id: u.id.to_string(),
180 role: u.role.clone(),
181 }),
182 headers: ctx.headers.clone(),
183 tenant_id: ctx.tenant_id.clone(),
184 };
185
186 let result = runtime.call_hook(wasm_path, hook_name, &plugin_ctx).await?;
187
188 if !result.ok {
189 let msg = result
190 .error
191 .unwrap_or_else(|| "WASM plugin returned error".to_string());
192 return Err(ShaperailError::Validation(vec![
193 shaperail_core::FieldError {
194 field: "wasm_plugin".to_string(),
195 message: msg,
196 code: "wasm_error".to_string(),
197 },
198 ]));
199 }
200
201 if let Some(modified_ctx) = result.ctx {
203 ctx.input = modified_ctx.input;
204 if modified_ctx.data.is_some() {
205 ctx.data = modified_ctx.data;
206 }
207 }
208
209 Ok(())
210 } else {
211 let map = controllers.ok_or_else(|| {
213 ShaperailError::Internal(format!(
214 "Controller '{name}' declared for '{resource}' but no ControllerMap configured"
215 ))
216 })?;
217 map.call(resource, name, ctx).await
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 async fn normalize_email(ctx: &mut Context) -> ControllerResult {
226 if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
227 let lower = email.to_lowercase();
228 ctx.input["email"] = serde_json::json!(lower);
229 }
230 Ok(())
231 }
232
233 async fn noop(_ctx: &mut Context) -> ControllerResult {
234 Ok(())
235 }
236
237 fn test_pool() -> sqlx::PgPool {
238 sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap()
239 }
240
241 #[tokio::test]
242 async fn controller_map_register_and_call() {
243 let mut map = ControllerMap::new();
244 map.register("users", "normalize_email", normalize_email);
245
246 let mut input = serde_json::Map::new();
247 input.insert("email".to_string(), serde_json::json!("USER@EXAMPLE.COM"));
248
249 let mut ctx = Context {
250 input,
251 data: None,
252 user: None,
253 pool: test_pool(),
254 headers: HashMap::new(),
255 response_headers: vec![],
256 tenant_id: None,
257 };
258
259 map.call("users", "normalize_email", &mut ctx)
260 .await
261 .unwrap();
262 assert_eq!(ctx.input["email"], serde_json::json!("user@example.com"));
263 }
264
265 #[tokio::test]
266 async fn controller_map_missing_returns_error() {
267 let map = ControllerMap::new();
268 let mut ctx = Context {
269 input: serde_json::Map::new(),
270 data: None,
271 user: None,
272 pool: test_pool(),
273 headers: HashMap::new(),
274 response_headers: vec![],
275 tenant_id: None,
276 };
277
278 let result = map.call("users", "nonexistent", &mut ctx).await;
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn controller_map_has() {
284 let mut map = ControllerMap::new();
285 assert!(!map.has("users", "check"));
286 map.register("users", "check", noop);
287 assert!(map.has("users", "check"));
288 }
289}