1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub struct LayerDenial {
36 pub result: Result<Option<ToolOutput>, ToolError>,
38 pub reason: String,
40}
41
42pub type BeforeToolResult = Option<LayerDenial>;
44
45#[derive(Debug)]
47pub struct LayerContext<'a> {
48 pub conversation_id: Option<&'a str>,
50 pub turn_number: u32,
52}
53
54pub trait RuntimeLayer: Send + Sync {
68 fn before_chat<'a>(
73 &'a self,
74 _ctx: &'a LayerContext<'_>,
75 _messages: &'a [Message],
76 _tools: &'a [ToolDefinition],
77 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
78 Box::pin(std::future::ready(None))
79 }
80
81 fn after_chat<'a>(
83 &'a self,
84 _ctx: &'a LayerContext<'_>,
85 _response: &'a ChatResponse,
86 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
87 Box::pin(std::future::ready(()))
88 }
89
90 fn before_tool<'a>(
95 &'a self,
96 _ctx: &'a LayerContext<'_>,
97 _call: &'a ToolCall,
98 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
99 Box::pin(std::future::ready(None))
100 }
101
102 fn after_tool<'a>(
104 &'a self,
105 _ctx: &'a LayerContext<'_>,
106 _call: &'a ToolCall,
107 _result: &'a Result<Option<ToolOutput>, ToolError>,
108 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
109 Box::pin(std::future::ready(()))
110 }
111}
112
113pub struct NoopLayer;
118
119impl RuntimeLayer for NoopLayer {}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use zeph_llm::provider::Role;
125
126 struct CountingLayer {
127 before_chat_calls: std::sync::atomic::AtomicU32,
128 after_chat_calls: std::sync::atomic::AtomicU32,
129 }
130
131 impl CountingLayer {
132 fn new() -> Self {
133 Self {
134 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
135 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
136 }
137 }
138 }
139
140 impl RuntimeLayer for CountingLayer {
141 fn before_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _messages: &'a [Message],
145 _tools: &'a [ToolDefinition],
146 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
147 self.before_chat_calls
148 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 Box::pin(std::future::ready(None))
150 }
151
152 fn after_chat<'a>(
153 &'a self,
154 _ctx: &'a LayerContext<'_>,
155 _response: &'a ChatResponse,
156 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
157 self.after_chat_calls
158 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
159 Box::pin(std::future::ready(()))
160 }
161 }
162
163 #[test]
164 fn noop_layer_compiles_and_is_runtime_layer() {
165 fn assert_runtime_layer<T: RuntimeLayer>() {}
167 assert_runtime_layer::<NoopLayer>();
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_chat_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let result = layer.before_chat(&ctx, &[], &[]).await;
178 assert!(result.is_none());
179 }
180
181 #[tokio::test]
182 async fn noop_layer_before_tool_returns_none() {
183 let layer = NoopLayer;
184 let ctx = LayerContext {
185 conversation_id: None,
186 turn_number: 0,
187 };
188 let call = ToolCall {
189 tool_id: "shell".into(),
190 params: serde_json::Map::new(),
191 caller_id: None,
192 };
193 let result = layer.before_tool(&ctx, &call).await;
194 assert!(result.is_none());
195 }
196
197 #[tokio::test]
198 async fn layer_hooks_are_called() {
199 use std::sync::Arc;
200 let layer = Arc::new(CountingLayer::new());
201 let ctx = LayerContext {
202 conversation_id: Some("conv-1"),
203 turn_number: 3,
204 };
205 let resp = ChatResponse::Text("hello".into());
206
207 let _ = layer.before_chat(&ctx, &[], &[]).await;
208 layer.after_chat(&ctx, &resp).await;
209
210 assert_eq!(
211 layer
212 .before_chat_calls
213 .load(std::sync::atomic::Ordering::Relaxed),
214 1
215 );
216 assert_eq!(
217 layer
218 .after_chat_calls
219 .load(std::sync::atomic::Ordering::Relaxed),
220 1
221 );
222 }
223
224 #[tokio::test]
225 async fn short_circuit_layer_returns_response() {
226 struct ShortCircuitLayer;
227 impl RuntimeLayer for ShortCircuitLayer {
228 fn before_chat<'a>(
229 &'a self,
230 _ctx: &'a LayerContext<'_>,
231 _messages: &'a [Message],
232 _tools: &'a [ToolDefinition],
233 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
234 Box::pin(std::future::ready(Some(ChatResponse::Text(
235 "short-circuited".into(),
236 ))))
237 }
238 }
239
240 let layer = ShortCircuitLayer;
241 let ctx = LayerContext {
242 conversation_id: None,
243 turn_number: 0,
244 };
245 let result = layer.before_chat(&ctx, &[], &[]).await;
246 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
247 }
248
249 #[test]
251 fn message_from_legacy_compiles() {
252 let _msg = Message::from_legacy(Role::User, "hello");
253 }
254
255 #[tokio::test]
258 async fn multiple_layers_called_in_registration_order() {
259 use std::sync::{Arc, Mutex};
260
261 struct OrderLayer {
262 id: u32,
263 log: Arc<Mutex<Vec<String>>>,
264 }
265 impl RuntimeLayer for OrderLayer {
266 fn before_chat<'a>(
267 &'a self,
268 _ctx: &'a LayerContext<'_>,
269 _messages: &'a [Message],
270 _tools: &'a [ToolDefinition],
271 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
272 let entry = format!("before_{}", self.id);
273 self.log.lock().unwrap().push(entry);
274 Box::pin(std::future::ready(None))
275 }
276
277 fn after_chat<'a>(
278 &'a self,
279 _ctx: &'a LayerContext<'_>,
280 _response: &'a ChatResponse,
281 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
282 let entry = format!("after_{}", self.id);
283 self.log.lock().unwrap().push(entry);
284 Box::pin(std::future::ready(()))
285 }
286 }
287
288 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
289 let layer_a = OrderLayer {
290 id: 1,
291 log: Arc::clone(&log),
292 };
293 let layer_b = OrderLayer {
294 id: 2,
295 log: Arc::clone(&log),
296 };
297
298 let ctx = LayerContext {
299 conversation_id: None,
300 turn_number: 0,
301 };
302 let resp = ChatResponse::Text("ok".into());
303
304 layer_a.before_chat(&ctx, &[], &[]).await;
305 layer_b.before_chat(&ctx, &[], &[]).await;
306 layer_a.after_chat(&ctx, &resp).await;
307 layer_b.after_chat(&ctx, &resp).await;
308
309 let events = log.lock().unwrap().clone();
310 assert_eq!(
311 events,
312 vec!["before_1", "before_2", "after_1", "after_2"],
313 "hooks must fire in registration order"
314 );
315 }
316
317 #[tokio::test]
319 async fn after_chat_receives_short_circuit_response() {
320 use std::sync::{Arc, Mutex};
321
322 struct CapturingAfter {
323 captured: Arc<Mutex<Option<String>>>,
324 }
325 impl RuntimeLayer for CapturingAfter {
326 fn after_chat<'a>(
327 &'a self,
328 _ctx: &'a LayerContext<'_>,
329 response: &'a ChatResponse,
330 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
331 if let ChatResponse::Text(t) = response {
332 *self.captured.lock().unwrap() = Some(t.clone());
333 }
334 Box::pin(std::future::ready(()))
335 }
336 }
337
338 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
339 let layer = CapturingAfter {
340 captured: Arc::clone(&captured),
341 };
342 let ctx = LayerContext {
343 conversation_id: None,
344 turn_number: 0,
345 };
346
347 let sc_response = ChatResponse::Text("short-circuit".into());
349 layer.after_chat(&ctx, &sc_response).await;
350
351 let got = captured.lock().unwrap().clone();
352 assert_eq!(
353 got.as_deref(),
354 Some("short-circuit"),
355 "after_chat must receive the short-circuit response"
356 );
357 }
358
359 #[tokio::test]
362 async fn multi_layer_before_after_tool_ordering() {
363 use std::sync::{Arc, Mutex};
364
365 struct ToolOrderLayer {
366 id: u32,
367 log: Arc<Mutex<Vec<String>>>,
368 }
369 impl RuntimeLayer for ToolOrderLayer {
370 fn before_tool<'a>(
371 &'a self,
372 _ctx: &'a LayerContext<'_>,
373 _call: &'a ToolCall,
374 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
375 self.log
376 .lock()
377 .unwrap()
378 .push(format!("before_tool_{}", self.id));
379 Box::pin(std::future::ready(None))
380 }
381
382 fn after_tool<'a>(
383 &'a self,
384 _ctx: &'a LayerContext<'_>,
385 _call: &'a ToolCall,
386 _result: &'a Result<Option<ToolOutput>, ToolError>,
387 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
388 self.log
389 .lock()
390 .unwrap()
391 .push(format!("after_tool_{}", self.id));
392 Box::pin(std::future::ready(()))
393 }
394 }
395
396 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
397 let layer_a = ToolOrderLayer {
398 id: 1,
399 log: Arc::clone(&log),
400 };
401 let layer_b = ToolOrderLayer {
402 id: 2,
403 log: Arc::clone(&log),
404 };
405
406 let ctx = LayerContext {
407 conversation_id: None,
408 turn_number: 0,
409 };
410 let call = ToolCall {
411 tool_id: "shell".into(),
412 params: serde_json::Map::new(),
413 caller_id: None,
414 };
415 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
416
417 layer_a.before_tool(&ctx, &call).await;
418 layer_b.before_tool(&ctx, &call).await;
419 layer_a.after_tool(&ctx, &call, &result).await;
420 layer_b.after_tool(&ctx, &call, &result).await;
421
422 let events = log.lock().unwrap().clone();
423 assert_eq!(
424 events,
425 vec![
426 "before_tool_1",
427 "before_tool_2",
428 "after_tool_1",
429 "after_tool_2"
430 ],
431 "tool hooks must fire in registration order"
432 );
433 }
434
435 #[tokio::test]
437 async fn noop_layer_after_tool_returns_unit() {
438 use zeph_tools::executor::ToolOutput;
439
440 let layer = NoopLayer;
441 let ctx = LayerContext {
442 conversation_id: None,
443 turn_number: 0,
444 };
445 let call = ToolCall {
446 tool_id: "shell".into(),
447 params: serde_json::Map::new(),
448 caller_id: None,
449 };
450 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
451 layer.after_tool(&ctx, &call, &result).await;
452 }
454}