1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub struct LayerDenial {
36 pub result: Result<Option<ToolOutput>, ToolError>,
38 pub reason: String,
40}
41
42pub type BeforeToolResult = Option<LayerDenial>;
44
45#[derive(Debug)]
47pub struct LayerContext<'a> {
48 pub conversation_id: Option<&'a str>,
50 pub turn_number: u32,
52}
53
54pub trait RuntimeLayer: Send + Sync {
68 fn before_chat<'a>(
73 &'a self,
74 _ctx: &'a LayerContext<'_>,
75 _messages: &'a [Message],
76 _tools: &'a [ToolDefinition],
77 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
78 Box::pin(std::future::ready(None))
79 }
80
81 fn after_chat<'a>(
83 &'a self,
84 _ctx: &'a LayerContext<'_>,
85 _response: &'a ChatResponse,
86 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
87 Box::pin(std::future::ready(()))
88 }
89
90 fn before_tool<'a>(
95 &'a self,
96 _ctx: &'a LayerContext<'_>,
97 _call: &'a ToolCall,
98 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
99 Box::pin(std::future::ready(None))
100 }
101
102 fn after_tool<'a>(
104 &'a self,
105 _ctx: &'a LayerContext<'_>,
106 _call: &'a ToolCall,
107 _result: &'a Result<Option<ToolOutput>, ToolError>,
108 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
109 Box::pin(std::future::ready(()))
110 }
111}
112
113pub struct NoopLayer;
118
119impl RuntimeLayer for NoopLayer {}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use zeph_llm::provider::Role;
125
126 struct CountingLayer {
127 before_chat_calls: std::sync::atomic::AtomicU32,
128 after_chat_calls: std::sync::atomic::AtomicU32,
129 }
130
131 impl CountingLayer {
132 fn new() -> Self {
133 Self {
134 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
135 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
136 }
137 }
138 }
139
140 impl RuntimeLayer for CountingLayer {
141 fn before_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _messages: &'a [Message],
145 _tools: &'a [ToolDefinition],
146 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
147 self.before_chat_calls
148 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 Box::pin(std::future::ready(None))
150 }
151
152 fn after_chat<'a>(
153 &'a self,
154 _ctx: &'a LayerContext<'_>,
155 _response: &'a ChatResponse,
156 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
157 self.after_chat_calls
158 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
159 Box::pin(std::future::ready(()))
160 }
161 }
162
163 #[test]
164 fn noop_layer_compiles_and_is_runtime_layer() {
165 fn assert_runtime_layer<T: RuntimeLayer>() {}
167 assert_runtime_layer::<NoopLayer>();
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_chat_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let result = layer.before_chat(&ctx, &[], &[]).await;
178 assert!(result.is_none());
179 }
180
181 #[tokio::test]
182 async fn noop_layer_before_tool_returns_none() {
183 let layer = NoopLayer;
184 let ctx = LayerContext {
185 conversation_id: None,
186 turn_number: 0,
187 };
188 let call = ToolCall {
189 tool_id: "shell".into(),
190 params: serde_json::Map::new(),
191 caller_id: None,
192 context: None,
193
194 tool_call_id: String::new(),
195 skill_name: None,
196 };
197 let result = layer.before_tool(&ctx, &call).await;
198 assert!(result.is_none());
199 }
200
201 #[tokio::test]
202 async fn layer_hooks_are_called() {
203 use std::sync::Arc;
204 let layer = Arc::new(CountingLayer::new());
205 let ctx = LayerContext {
206 conversation_id: Some("conv-1"),
207 turn_number: 3,
208 };
209 let resp = ChatResponse::Text("hello".into());
210
211 let _ = layer.before_chat(&ctx, &[], &[]).await;
212 layer.after_chat(&ctx, &resp).await;
213
214 assert_eq!(
215 layer
216 .before_chat_calls
217 .load(std::sync::atomic::Ordering::Relaxed),
218 1
219 );
220 assert_eq!(
221 layer
222 .after_chat_calls
223 .load(std::sync::atomic::Ordering::Relaxed),
224 1
225 );
226 }
227
228 #[tokio::test]
229 async fn short_circuit_layer_returns_response() {
230 struct ShortCircuitLayer;
231 impl RuntimeLayer for ShortCircuitLayer {
232 fn before_chat<'a>(
233 &'a self,
234 _ctx: &'a LayerContext<'_>,
235 _messages: &'a [Message],
236 _tools: &'a [ToolDefinition],
237 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
238 Box::pin(std::future::ready(Some(ChatResponse::Text(
239 "short-circuited".into(),
240 ))))
241 }
242 }
243
244 let layer = ShortCircuitLayer;
245 let ctx = LayerContext {
246 conversation_id: None,
247 turn_number: 0,
248 };
249 let result = layer.before_chat(&ctx, &[], &[]).await;
250 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
251 }
252
253 #[test]
255 fn message_from_legacy_compiles() {
256 let _msg = Message::from_legacy(Role::User, "hello");
257 }
258
259 #[tokio::test]
262 async fn multiple_layers_called_in_registration_order() {
263 use std::sync::{Arc, Mutex};
264
265 struct OrderLayer {
266 id: u32,
267 log: Arc<Mutex<Vec<String>>>,
268 }
269 impl RuntimeLayer for OrderLayer {
270 fn before_chat<'a>(
271 &'a self,
272 _ctx: &'a LayerContext<'_>,
273 _messages: &'a [Message],
274 _tools: &'a [ToolDefinition],
275 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
276 let entry = format!("before_{}", self.id);
277 self.log.lock().unwrap().push(entry);
278 Box::pin(std::future::ready(None))
279 }
280
281 fn after_chat<'a>(
282 &'a self,
283 _ctx: &'a LayerContext<'_>,
284 _response: &'a ChatResponse,
285 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
286 let entry = format!("after_{}", self.id);
287 self.log.lock().unwrap().push(entry);
288 Box::pin(std::future::ready(()))
289 }
290 }
291
292 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
293 let layer_a = OrderLayer {
294 id: 1,
295 log: Arc::clone(&log),
296 };
297 let layer_b = OrderLayer {
298 id: 2,
299 log: Arc::clone(&log),
300 };
301
302 let ctx = LayerContext {
303 conversation_id: None,
304 turn_number: 0,
305 };
306 let resp = ChatResponse::Text("ok".into());
307
308 layer_a.before_chat(&ctx, &[], &[]).await;
309 layer_b.before_chat(&ctx, &[], &[]).await;
310 layer_a.after_chat(&ctx, &resp).await;
311 layer_b.after_chat(&ctx, &resp).await;
312
313 let events = log.lock().unwrap().clone();
314 assert_eq!(
315 events,
316 vec!["before_1", "before_2", "after_1", "after_2"],
317 "hooks must fire in registration order"
318 );
319 }
320
321 #[tokio::test]
323 async fn after_chat_receives_short_circuit_response() {
324 use std::sync::{Arc, Mutex};
325
326 struct CapturingAfter {
327 captured: Arc<Mutex<Option<String>>>,
328 }
329 impl RuntimeLayer for CapturingAfter {
330 fn after_chat<'a>(
331 &'a self,
332 _ctx: &'a LayerContext<'_>,
333 response: &'a ChatResponse,
334 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
335 if let ChatResponse::Text(t) = response {
336 *self.captured.lock().unwrap() = Some(t.clone());
337 }
338 Box::pin(std::future::ready(()))
339 }
340 }
341
342 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
343 let layer = CapturingAfter {
344 captured: Arc::clone(&captured),
345 };
346 let ctx = LayerContext {
347 conversation_id: None,
348 turn_number: 0,
349 };
350
351 let sc_response = ChatResponse::Text("short-circuit".into());
353 layer.after_chat(&ctx, &sc_response).await;
354
355 let got = captured.lock().unwrap().clone();
356 assert_eq!(
357 got.as_deref(),
358 Some("short-circuit"),
359 "after_chat must receive the short-circuit response"
360 );
361 }
362
363 #[tokio::test]
366 async fn multi_layer_before_after_tool_ordering() {
367 use std::sync::{Arc, Mutex};
368
369 struct ToolOrderLayer {
370 id: u32,
371 log: Arc<Mutex<Vec<String>>>,
372 }
373 impl RuntimeLayer for ToolOrderLayer {
374 fn before_tool<'a>(
375 &'a self,
376 _ctx: &'a LayerContext<'_>,
377 _call: &'a ToolCall,
378 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
379 self.log
380 .lock()
381 .unwrap()
382 .push(format!("before_tool_{}", self.id));
383 Box::pin(std::future::ready(None))
384 }
385
386 fn after_tool<'a>(
387 &'a self,
388 _ctx: &'a LayerContext<'_>,
389 _call: &'a ToolCall,
390 _result: &'a Result<Option<ToolOutput>, ToolError>,
391 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
392 self.log
393 .lock()
394 .unwrap()
395 .push(format!("after_tool_{}", self.id));
396 Box::pin(std::future::ready(()))
397 }
398 }
399
400 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
401 let layer_a = ToolOrderLayer {
402 id: 1,
403 log: Arc::clone(&log),
404 };
405 let layer_b = ToolOrderLayer {
406 id: 2,
407 log: Arc::clone(&log),
408 };
409
410 let ctx = LayerContext {
411 conversation_id: None,
412 turn_number: 0,
413 };
414 let call = ToolCall {
415 tool_id: "shell".into(),
416 params: serde_json::Map::new(),
417 caller_id: None,
418 context: None,
419
420 tool_call_id: String::new(),
421 skill_name: None,
422 };
423 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
424
425 layer_a.before_tool(&ctx, &call).await;
426 layer_b.before_tool(&ctx, &call).await;
427 layer_a.after_tool(&ctx, &call, &result).await;
428 layer_b.after_tool(&ctx, &call, &result).await;
429
430 let events = log.lock().unwrap().clone();
431 assert_eq!(
432 events,
433 vec![
434 "before_tool_1",
435 "before_tool_2",
436 "after_tool_1",
437 "after_tool_2"
438 ],
439 "tool hooks must fire in registration order"
440 );
441 }
442
443 #[tokio::test]
445 async fn noop_layer_after_tool_returns_unit() {
446 use zeph_tools::executor::ToolOutput;
447
448 let layer = NoopLayer;
449 let ctx = LayerContext {
450 conversation_id: None,
451 turn_number: 0,
452 };
453 let call = ToolCall {
454 tool_id: "shell".into(),
455 params: serde_json::Map::new(),
456 caller_id: None,
457 context: None,
458
459 tool_call_id: String::new(),
460 skill_name: None,
461 };
462 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
463 layer.after_tool(&ctx, &call, &result).await;
464 }
466}