1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub struct LayerDenial {
36 pub result: Result<Option<ToolOutput>, ToolError>,
38 pub reason: String,
40}
41
42pub type BeforeToolResult = Option<LayerDenial>;
44
45#[derive(Debug)]
47pub struct LayerContext<'a> {
48 pub conversation_id: Option<&'a str>,
50 pub turn_number: u32,
52}
53
54pub trait RuntimeLayer: Send + Sync {
68 fn before_chat<'a>(
73 &'a self,
74 _ctx: &'a LayerContext<'_>,
75 _messages: &'a [Message],
76 _tools: &'a [ToolDefinition],
77 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
78 Box::pin(std::future::ready(None))
79 }
80
81 fn after_chat<'a>(
83 &'a self,
84 _ctx: &'a LayerContext<'_>,
85 _response: &'a ChatResponse,
86 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
87 Box::pin(std::future::ready(()))
88 }
89
90 fn before_tool<'a>(
95 &'a self,
96 _ctx: &'a LayerContext<'_>,
97 _call: &'a ToolCall,
98 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
99 Box::pin(std::future::ready(None))
100 }
101
102 fn after_tool<'a>(
104 &'a self,
105 _ctx: &'a LayerContext<'_>,
106 _call: &'a ToolCall,
107 _result: &'a Result<Option<ToolOutput>, ToolError>,
108 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
109 Box::pin(std::future::ready(()))
110 }
111}
112
113pub struct NoopLayer;
118
119impl RuntimeLayer for NoopLayer {}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use zeph_llm::provider::Role;
125
126 struct CountingLayer {
127 before_chat_calls: std::sync::atomic::AtomicU32,
128 after_chat_calls: std::sync::atomic::AtomicU32,
129 }
130
131 impl CountingLayer {
132 fn new() -> Self {
133 Self {
134 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
135 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
136 }
137 }
138 }
139
140 impl RuntimeLayer for CountingLayer {
141 fn before_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _messages: &'a [Message],
145 _tools: &'a [ToolDefinition],
146 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
147 self.before_chat_calls
148 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 Box::pin(std::future::ready(None))
150 }
151
152 fn after_chat<'a>(
153 &'a self,
154 _ctx: &'a LayerContext<'_>,
155 _response: &'a ChatResponse,
156 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
157 self.after_chat_calls
158 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
159 Box::pin(std::future::ready(()))
160 }
161 }
162
163 #[test]
164 fn noop_layer_compiles_and_is_runtime_layer() {
165 fn assert_runtime_layer<T: RuntimeLayer>() {}
167 assert_runtime_layer::<NoopLayer>();
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_chat_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let result = layer.before_chat(&ctx, &[], &[]).await;
178 assert!(result.is_none());
179 }
180
181 #[tokio::test]
182 async fn noop_layer_before_tool_returns_none() {
183 let layer = NoopLayer;
184 let ctx = LayerContext {
185 conversation_id: None,
186 turn_number: 0,
187 };
188 let call = ToolCall {
189 tool_id: "shell".into(),
190 params: serde_json::Map::new(),
191 caller_id: None,
192 context: None,
193
194 tool_call_id: String::new(),
195 };
196 let result = layer.before_tool(&ctx, &call).await;
197 assert!(result.is_none());
198 }
199
200 #[tokio::test]
201 async fn layer_hooks_are_called() {
202 use std::sync::Arc;
203 let layer = Arc::new(CountingLayer::new());
204 let ctx = LayerContext {
205 conversation_id: Some("conv-1"),
206 turn_number: 3,
207 };
208 let resp = ChatResponse::Text("hello".into());
209
210 let _ = layer.before_chat(&ctx, &[], &[]).await;
211 layer.after_chat(&ctx, &resp).await;
212
213 assert_eq!(
214 layer
215 .before_chat_calls
216 .load(std::sync::atomic::Ordering::Relaxed),
217 1
218 );
219 assert_eq!(
220 layer
221 .after_chat_calls
222 .load(std::sync::atomic::Ordering::Relaxed),
223 1
224 );
225 }
226
227 #[tokio::test]
228 async fn short_circuit_layer_returns_response() {
229 struct ShortCircuitLayer;
230 impl RuntimeLayer for ShortCircuitLayer {
231 fn before_chat<'a>(
232 &'a self,
233 _ctx: &'a LayerContext<'_>,
234 _messages: &'a [Message],
235 _tools: &'a [ToolDefinition],
236 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
237 Box::pin(std::future::ready(Some(ChatResponse::Text(
238 "short-circuited".into(),
239 ))))
240 }
241 }
242
243 let layer = ShortCircuitLayer;
244 let ctx = LayerContext {
245 conversation_id: None,
246 turn_number: 0,
247 };
248 let result = layer.before_chat(&ctx, &[], &[]).await;
249 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
250 }
251
252 #[test]
254 fn message_from_legacy_compiles() {
255 let _msg = Message::from_legacy(Role::User, "hello");
256 }
257
258 #[tokio::test]
261 async fn multiple_layers_called_in_registration_order() {
262 use std::sync::{Arc, Mutex};
263
264 struct OrderLayer {
265 id: u32,
266 log: Arc<Mutex<Vec<String>>>,
267 }
268 impl RuntimeLayer for OrderLayer {
269 fn before_chat<'a>(
270 &'a self,
271 _ctx: &'a LayerContext<'_>,
272 _messages: &'a [Message],
273 _tools: &'a [ToolDefinition],
274 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
275 let entry = format!("before_{}", self.id);
276 self.log.lock().unwrap().push(entry);
277 Box::pin(std::future::ready(None))
278 }
279
280 fn after_chat<'a>(
281 &'a self,
282 _ctx: &'a LayerContext<'_>,
283 _response: &'a ChatResponse,
284 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
285 let entry = format!("after_{}", self.id);
286 self.log.lock().unwrap().push(entry);
287 Box::pin(std::future::ready(()))
288 }
289 }
290
291 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
292 let layer_a = OrderLayer {
293 id: 1,
294 log: Arc::clone(&log),
295 };
296 let layer_b = OrderLayer {
297 id: 2,
298 log: Arc::clone(&log),
299 };
300
301 let ctx = LayerContext {
302 conversation_id: None,
303 turn_number: 0,
304 };
305 let resp = ChatResponse::Text("ok".into());
306
307 layer_a.before_chat(&ctx, &[], &[]).await;
308 layer_b.before_chat(&ctx, &[], &[]).await;
309 layer_a.after_chat(&ctx, &resp).await;
310 layer_b.after_chat(&ctx, &resp).await;
311
312 let events = log.lock().unwrap().clone();
313 assert_eq!(
314 events,
315 vec!["before_1", "before_2", "after_1", "after_2"],
316 "hooks must fire in registration order"
317 );
318 }
319
320 #[tokio::test]
322 async fn after_chat_receives_short_circuit_response() {
323 use std::sync::{Arc, Mutex};
324
325 struct CapturingAfter {
326 captured: Arc<Mutex<Option<String>>>,
327 }
328 impl RuntimeLayer for CapturingAfter {
329 fn after_chat<'a>(
330 &'a self,
331 _ctx: &'a LayerContext<'_>,
332 response: &'a ChatResponse,
333 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
334 if let ChatResponse::Text(t) = response {
335 *self.captured.lock().unwrap() = Some(t.clone());
336 }
337 Box::pin(std::future::ready(()))
338 }
339 }
340
341 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
342 let layer = CapturingAfter {
343 captured: Arc::clone(&captured),
344 };
345 let ctx = LayerContext {
346 conversation_id: None,
347 turn_number: 0,
348 };
349
350 let sc_response = ChatResponse::Text("short-circuit".into());
352 layer.after_chat(&ctx, &sc_response).await;
353
354 let got = captured.lock().unwrap().clone();
355 assert_eq!(
356 got.as_deref(),
357 Some("short-circuit"),
358 "after_chat must receive the short-circuit response"
359 );
360 }
361
362 #[tokio::test]
365 async fn multi_layer_before_after_tool_ordering() {
366 use std::sync::{Arc, Mutex};
367
368 struct ToolOrderLayer {
369 id: u32,
370 log: Arc<Mutex<Vec<String>>>,
371 }
372 impl RuntimeLayer for ToolOrderLayer {
373 fn before_tool<'a>(
374 &'a self,
375 _ctx: &'a LayerContext<'_>,
376 _call: &'a ToolCall,
377 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
378 self.log
379 .lock()
380 .unwrap()
381 .push(format!("before_tool_{}", self.id));
382 Box::pin(std::future::ready(None))
383 }
384
385 fn after_tool<'a>(
386 &'a self,
387 _ctx: &'a LayerContext<'_>,
388 _call: &'a ToolCall,
389 _result: &'a Result<Option<ToolOutput>, ToolError>,
390 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
391 self.log
392 .lock()
393 .unwrap()
394 .push(format!("after_tool_{}", self.id));
395 Box::pin(std::future::ready(()))
396 }
397 }
398
399 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
400 let layer_a = ToolOrderLayer {
401 id: 1,
402 log: Arc::clone(&log),
403 };
404 let layer_b = ToolOrderLayer {
405 id: 2,
406 log: Arc::clone(&log),
407 };
408
409 let ctx = LayerContext {
410 conversation_id: None,
411 turn_number: 0,
412 };
413 let call = ToolCall {
414 tool_id: "shell".into(),
415 params: serde_json::Map::new(),
416 caller_id: None,
417 context: None,
418
419 tool_call_id: String::new(),
420 };
421 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
422
423 layer_a.before_tool(&ctx, &call).await;
424 layer_b.before_tool(&ctx, &call).await;
425 layer_a.after_tool(&ctx, &call, &result).await;
426 layer_b.after_tool(&ctx, &call, &result).await;
427
428 let events = log.lock().unwrap().clone();
429 assert_eq!(
430 events,
431 vec![
432 "before_tool_1",
433 "before_tool_2",
434 "after_tool_1",
435 "after_tool_2"
436 ],
437 "tool hooks must fire in registration order"
438 );
439 }
440
441 #[tokio::test]
443 async fn noop_layer_after_tool_returns_unit() {
444 use zeph_tools::executor::ToolOutput;
445
446 let layer = NoopLayer;
447 let ctx = LayerContext {
448 conversation_id: None,
449 turn_number: 0,
450 };
451 let call = ToolCall {
452 tool_id: "shell".into(),
453 params: serde_json::Map::new(),
454 caller_id: None,
455 context: None,
456
457 tool_call_id: String::new(),
458 };
459 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
460 layer.after_tool(&ctx, &call, &result).await;
461 }
463}