1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub struct LayerDenial {
36 pub result: Result<Option<ToolOutput>, ToolError>,
38 pub reason: String,
40}
41
42pub type BeforeToolResult = Option<LayerDenial>;
44
45#[derive(Debug)]
47pub struct LayerContext<'a> {
48 pub conversation_id: Option<&'a str>,
50 pub turn_number: u32,
52}
53
54pub trait RuntimeLayer: Send + Sync {
68 fn before_chat<'a>(
73 &'a self,
74 _ctx: &'a LayerContext<'_>,
75 _messages: &'a [Message],
76 _tools: &'a [ToolDefinition],
77 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
78 Box::pin(std::future::ready(None))
79 }
80
81 fn after_chat<'a>(
83 &'a self,
84 _ctx: &'a LayerContext<'_>,
85 _response: &'a ChatResponse,
86 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
87 Box::pin(std::future::ready(()))
88 }
89
90 fn before_tool<'a>(
95 &'a self,
96 _ctx: &'a LayerContext<'_>,
97 _call: &'a ToolCall,
98 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
99 Box::pin(std::future::ready(None))
100 }
101
102 fn after_tool<'a>(
104 &'a self,
105 _ctx: &'a LayerContext<'_>,
106 _call: &'a ToolCall,
107 _result: &'a Result<Option<ToolOutput>, ToolError>,
108 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
109 Box::pin(std::future::ready(()))
110 }
111}
112
113pub struct NoopLayer;
118
119impl RuntimeLayer for NoopLayer {}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use zeph_llm::provider::Role;
125
126 struct CountingLayer {
127 before_chat_calls: std::sync::atomic::AtomicU32,
128 after_chat_calls: std::sync::atomic::AtomicU32,
129 }
130
131 impl CountingLayer {
132 fn new() -> Self {
133 Self {
134 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
135 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
136 }
137 }
138 }
139
140 impl RuntimeLayer for CountingLayer {
141 fn before_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _messages: &'a [Message],
145 _tools: &'a [ToolDefinition],
146 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
147 self.before_chat_calls
148 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 Box::pin(std::future::ready(None))
150 }
151
152 fn after_chat<'a>(
153 &'a self,
154 _ctx: &'a LayerContext<'_>,
155 _response: &'a ChatResponse,
156 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
157 self.after_chat_calls
158 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
159 Box::pin(std::future::ready(()))
160 }
161 }
162
163 #[test]
164 fn noop_layer_compiles_and_is_runtime_layer() {
165 fn assert_runtime_layer<T: RuntimeLayer>() {}
167 assert_runtime_layer::<NoopLayer>();
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_chat_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let result = layer.before_chat(&ctx, &[], &[]).await;
178 assert!(result.is_none());
179 }
180
181 #[tokio::test]
182 async fn noop_layer_before_tool_returns_none() {
183 let layer = NoopLayer;
184 let ctx = LayerContext {
185 conversation_id: None,
186 turn_number: 0,
187 };
188 let call = ToolCall {
189 tool_id: "shell".into(),
190 params: serde_json::Map::new(),
191 caller_id: None,
192 context: None,
193 };
194 let result = layer.before_tool(&ctx, &call).await;
195 assert!(result.is_none());
196 }
197
198 #[tokio::test]
199 async fn layer_hooks_are_called() {
200 use std::sync::Arc;
201 let layer = Arc::new(CountingLayer::new());
202 let ctx = LayerContext {
203 conversation_id: Some("conv-1"),
204 turn_number: 3,
205 };
206 let resp = ChatResponse::Text("hello".into());
207
208 let _ = layer.before_chat(&ctx, &[], &[]).await;
209 layer.after_chat(&ctx, &resp).await;
210
211 assert_eq!(
212 layer
213 .before_chat_calls
214 .load(std::sync::atomic::Ordering::Relaxed),
215 1
216 );
217 assert_eq!(
218 layer
219 .after_chat_calls
220 .load(std::sync::atomic::Ordering::Relaxed),
221 1
222 );
223 }
224
225 #[tokio::test]
226 async fn short_circuit_layer_returns_response() {
227 struct ShortCircuitLayer;
228 impl RuntimeLayer for ShortCircuitLayer {
229 fn before_chat<'a>(
230 &'a self,
231 _ctx: &'a LayerContext<'_>,
232 _messages: &'a [Message],
233 _tools: &'a [ToolDefinition],
234 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
235 Box::pin(std::future::ready(Some(ChatResponse::Text(
236 "short-circuited".into(),
237 ))))
238 }
239 }
240
241 let layer = ShortCircuitLayer;
242 let ctx = LayerContext {
243 conversation_id: None,
244 turn_number: 0,
245 };
246 let result = layer.before_chat(&ctx, &[], &[]).await;
247 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
248 }
249
250 #[test]
252 fn message_from_legacy_compiles() {
253 let _msg = Message::from_legacy(Role::User, "hello");
254 }
255
256 #[tokio::test]
259 async fn multiple_layers_called_in_registration_order() {
260 use std::sync::{Arc, Mutex};
261
262 struct OrderLayer {
263 id: u32,
264 log: Arc<Mutex<Vec<String>>>,
265 }
266 impl RuntimeLayer for OrderLayer {
267 fn before_chat<'a>(
268 &'a self,
269 _ctx: &'a LayerContext<'_>,
270 _messages: &'a [Message],
271 _tools: &'a [ToolDefinition],
272 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
273 let entry = format!("before_{}", self.id);
274 self.log.lock().unwrap().push(entry);
275 Box::pin(std::future::ready(None))
276 }
277
278 fn after_chat<'a>(
279 &'a self,
280 _ctx: &'a LayerContext<'_>,
281 _response: &'a ChatResponse,
282 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
283 let entry = format!("after_{}", self.id);
284 self.log.lock().unwrap().push(entry);
285 Box::pin(std::future::ready(()))
286 }
287 }
288
289 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
290 let layer_a = OrderLayer {
291 id: 1,
292 log: Arc::clone(&log),
293 };
294 let layer_b = OrderLayer {
295 id: 2,
296 log: Arc::clone(&log),
297 };
298
299 let ctx = LayerContext {
300 conversation_id: None,
301 turn_number: 0,
302 };
303 let resp = ChatResponse::Text("ok".into());
304
305 layer_a.before_chat(&ctx, &[], &[]).await;
306 layer_b.before_chat(&ctx, &[], &[]).await;
307 layer_a.after_chat(&ctx, &resp).await;
308 layer_b.after_chat(&ctx, &resp).await;
309
310 let events = log.lock().unwrap().clone();
311 assert_eq!(
312 events,
313 vec!["before_1", "before_2", "after_1", "after_2"],
314 "hooks must fire in registration order"
315 );
316 }
317
318 #[tokio::test]
320 async fn after_chat_receives_short_circuit_response() {
321 use std::sync::{Arc, Mutex};
322
323 struct CapturingAfter {
324 captured: Arc<Mutex<Option<String>>>,
325 }
326 impl RuntimeLayer for CapturingAfter {
327 fn after_chat<'a>(
328 &'a self,
329 _ctx: &'a LayerContext<'_>,
330 response: &'a ChatResponse,
331 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
332 if let ChatResponse::Text(t) = response {
333 *self.captured.lock().unwrap() = Some(t.clone());
334 }
335 Box::pin(std::future::ready(()))
336 }
337 }
338
339 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
340 let layer = CapturingAfter {
341 captured: Arc::clone(&captured),
342 };
343 let ctx = LayerContext {
344 conversation_id: None,
345 turn_number: 0,
346 };
347
348 let sc_response = ChatResponse::Text("short-circuit".into());
350 layer.after_chat(&ctx, &sc_response).await;
351
352 let got = captured.lock().unwrap().clone();
353 assert_eq!(
354 got.as_deref(),
355 Some("short-circuit"),
356 "after_chat must receive the short-circuit response"
357 );
358 }
359
360 #[tokio::test]
363 async fn multi_layer_before_after_tool_ordering() {
364 use std::sync::{Arc, Mutex};
365
366 struct ToolOrderLayer {
367 id: u32,
368 log: Arc<Mutex<Vec<String>>>,
369 }
370 impl RuntimeLayer for ToolOrderLayer {
371 fn before_tool<'a>(
372 &'a self,
373 _ctx: &'a LayerContext<'_>,
374 _call: &'a ToolCall,
375 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
376 self.log
377 .lock()
378 .unwrap()
379 .push(format!("before_tool_{}", self.id));
380 Box::pin(std::future::ready(None))
381 }
382
383 fn after_tool<'a>(
384 &'a self,
385 _ctx: &'a LayerContext<'_>,
386 _call: &'a ToolCall,
387 _result: &'a Result<Option<ToolOutput>, ToolError>,
388 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
389 self.log
390 .lock()
391 .unwrap()
392 .push(format!("after_tool_{}", self.id));
393 Box::pin(std::future::ready(()))
394 }
395 }
396
397 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
398 let layer_a = ToolOrderLayer {
399 id: 1,
400 log: Arc::clone(&log),
401 };
402 let layer_b = ToolOrderLayer {
403 id: 2,
404 log: Arc::clone(&log),
405 };
406
407 let ctx = LayerContext {
408 conversation_id: None,
409 turn_number: 0,
410 };
411 let call = ToolCall {
412 tool_id: "shell".into(),
413 params: serde_json::Map::new(),
414 caller_id: None,
415 context: None,
416 };
417 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
418
419 layer_a.before_tool(&ctx, &call).await;
420 layer_b.before_tool(&ctx, &call).await;
421 layer_a.after_tool(&ctx, &call, &result).await;
422 layer_b.after_tool(&ctx, &call, &result).await;
423
424 let events = log.lock().unwrap().clone();
425 assert_eq!(
426 events,
427 vec![
428 "before_tool_1",
429 "before_tool_2",
430 "after_tool_1",
431 "after_tool_2"
432 ],
433 "tool hooks must fire in registration order"
434 );
435 }
436
437 #[tokio::test]
439 async fn noop_layer_after_tool_returns_unit() {
440 use zeph_tools::executor::ToolOutput;
441
442 let layer = NoopLayer;
443 let ctx = LayerContext {
444 conversation_id: None,
445 turn_number: 0,
446 };
447 let call = ToolCall {
448 tool_id: "shell".into(),
449 params: serde_json::Map::new(),
450 caller_id: None,
451 context: None,
452 };
453 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
454 layer.after_tool(&ctx, &call, &result).await;
455 }
457}