1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub type BeforeToolResult = Option<Result<Option<ToolOutput>, ToolError>>;
33
34#[derive(Debug)]
36pub struct LayerContext<'a> {
37 pub conversation_id: Option<&'a str>,
39 pub turn_number: u32,
41}
42
43pub trait RuntimeLayer: Send + Sync {
57 fn before_chat<'a>(
62 &'a self,
63 _ctx: &'a LayerContext<'_>,
64 _messages: &'a [Message],
65 _tools: &'a [ToolDefinition],
66 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
67 Box::pin(std::future::ready(None))
68 }
69
70 fn after_chat<'a>(
72 &'a self,
73 _ctx: &'a LayerContext<'_>,
74 _response: &'a ChatResponse,
75 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
76 Box::pin(std::future::ready(()))
77 }
78
79 fn before_tool<'a>(
84 &'a self,
85 _ctx: &'a LayerContext<'_>,
86 _call: &'a ToolCall,
87 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
88 Box::pin(std::future::ready(None))
89 }
90
91 fn after_tool<'a>(
93 &'a self,
94 _ctx: &'a LayerContext<'_>,
95 _call: &'a ToolCall,
96 _result: &'a Result<Option<ToolOutput>, ToolError>,
97 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
98 Box::pin(std::future::ready(()))
99 }
100}
101
102pub struct NoopLayer;
107
108impl RuntimeLayer for NoopLayer {}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use zeph_llm::provider::Role;
114
115 struct CountingLayer {
116 before_chat_calls: std::sync::atomic::AtomicU32,
117 after_chat_calls: std::sync::atomic::AtomicU32,
118 }
119
120 impl CountingLayer {
121 fn new() -> Self {
122 Self {
123 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
124 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
125 }
126 }
127 }
128
129 impl RuntimeLayer for CountingLayer {
130 fn before_chat<'a>(
131 &'a self,
132 _ctx: &'a LayerContext<'_>,
133 _messages: &'a [Message],
134 _tools: &'a [ToolDefinition],
135 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
136 self.before_chat_calls
137 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
138 Box::pin(std::future::ready(None))
139 }
140
141 fn after_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _response: &'a ChatResponse,
145 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
146 self.after_chat_calls
147 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148 Box::pin(std::future::ready(()))
149 }
150 }
151
152 #[test]
153 fn noop_layer_compiles_and_is_runtime_layer() {
154 fn assert_runtime_layer<T: RuntimeLayer>() {}
156 assert_runtime_layer::<NoopLayer>();
157 }
158
159 #[tokio::test]
160 async fn noop_layer_before_chat_returns_none() {
161 let layer = NoopLayer;
162 let ctx = LayerContext {
163 conversation_id: None,
164 turn_number: 0,
165 };
166 let result = layer.before_chat(&ctx, &[], &[]).await;
167 assert!(result.is_none());
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_tool_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let call = ToolCall {
178 tool_id: "shell".into(),
179 params: serde_json::Map::new(),
180 caller_id: None,
181 };
182 let result = layer.before_tool(&ctx, &call).await;
183 assert!(result.is_none());
184 }
185
186 #[tokio::test]
187 async fn layer_hooks_are_called() {
188 use std::sync::Arc;
189 let layer = Arc::new(CountingLayer::new());
190 let ctx = LayerContext {
191 conversation_id: Some("conv-1"),
192 turn_number: 3,
193 };
194 let resp = ChatResponse::Text("hello".into());
195
196 let _ = layer.before_chat(&ctx, &[], &[]).await;
197 layer.after_chat(&ctx, &resp).await;
198
199 assert_eq!(
200 layer
201 .before_chat_calls
202 .load(std::sync::atomic::Ordering::Relaxed),
203 1
204 );
205 assert_eq!(
206 layer
207 .after_chat_calls
208 .load(std::sync::atomic::Ordering::Relaxed),
209 1
210 );
211 }
212
213 #[tokio::test]
214 async fn short_circuit_layer_returns_response() {
215 struct ShortCircuitLayer;
216 impl RuntimeLayer for ShortCircuitLayer {
217 fn before_chat<'a>(
218 &'a self,
219 _ctx: &'a LayerContext<'_>,
220 _messages: &'a [Message],
221 _tools: &'a [ToolDefinition],
222 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
223 Box::pin(std::future::ready(Some(ChatResponse::Text(
224 "short-circuited".into(),
225 ))))
226 }
227 }
228
229 let layer = ShortCircuitLayer;
230 let ctx = LayerContext {
231 conversation_id: None,
232 turn_number: 0,
233 };
234 let result = layer.before_chat(&ctx, &[], &[]).await;
235 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
236 }
237
238 #[test]
240 fn message_from_legacy_compiles() {
241 let _msg = Message::from_legacy(Role::User, "hello");
242 }
243
244 #[tokio::test]
247 async fn multiple_layers_called_in_registration_order() {
248 use std::sync::{Arc, Mutex};
249
250 struct OrderLayer {
251 id: u32,
252 log: Arc<Mutex<Vec<String>>>,
253 }
254 impl RuntimeLayer for OrderLayer {
255 fn before_chat<'a>(
256 &'a self,
257 _ctx: &'a LayerContext<'_>,
258 _messages: &'a [Message],
259 _tools: &'a [ToolDefinition],
260 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
261 let entry = format!("before_{}", self.id);
262 self.log.lock().unwrap().push(entry);
263 Box::pin(std::future::ready(None))
264 }
265
266 fn after_chat<'a>(
267 &'a self,
268 _ctx: &'a LayerContext<'_>,
269 _response: &'a ChatResponse,
270 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
271 let entry = format!("after_{}", self.id);
272 self.log.lock().unwrap().push(entry);
273 Box::pin(std::future::ready(()))
274 }
275 }
276
277 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
278 let layer_a = OrderLayer {
279 id: 1,
280 log: Arc::clone(&log),
281 };
282 let layer_b = OrderLayer {
283 id: 2,
284 log: Arc::clone(&log),
285 };
286
287 let ctx = LayerContext {
288 conversation_id: None,
289 turn_number: 0,
290 };
291 let resp = ChatResponse::Text("ok".into());
292
293 layer_a.before_chat(&ctx, &[], &[]).await;
294 layer_b.before_chat(&ctx, &[], &[]).await;
295 layer_a.after_chat(&ctx, &resp).await;
296 layer_b.after_chat(&ctx, &resp).await;
297
298 let events = log.lock().unwrap().clone();
299 assert_eq!(
300 events,
301 vec!["before_1", "before_2", "after_1", "after_2"],
302 "hooks must fire in registration order"
303 );
304 }
305
306 #[tokio::test]
308 async fn after_chat_receives_short_circuit_response() {
309 use std::sync::{Arc, Mutex};
310
311 struct CapturingAfter {
312 captured: Arc<Mutex<Option<String>>>,
313 }
314 impl RuntimeLayer for CapturingAfter {
315 fn after_chat<'a>(
316 &'a self,
317 _ctx: &'a LayerContext<'_>,
318 response: &'a ChatResponse,
319 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
320 if let ChatResponse::Text(t) = response {
321 *self.captured.lock().unwrap() = Some(t.clone());
322 }
323 Box::pin(std::future::ready(()))
324 }
325 }
326
327 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
328 let layer = CapturingAfter {
329 captured: Arc::clone(&captured),
330 };
331 let ctx = LayerContext {
332 conversation_id: None,
333 turn_number: 0,
334 };
335
336 let sc_response = ChatResponse::Text("short-circuit".into());
338 layer.after_chat(&ctx, &sc_response).await;
339
340 let got = captured.lock().unwrap().clone();
341 assert_eq!(
342 got.as_deref(),
343 Some("short-circuit"),
344 "after_chat must receive the short-circuit response"
345 );
346 }
347
348 #[tokio::test]
351 async fn multi_layer_before_after_tool_ordering() {
352 use std::sync::{Arc, Mutex};
353
354 struct ToolOrderLayer {
355 id: u32,
356 log: Arc<Mutex<Vec<String>>>,
357 }
358 impl RuntimeLayer for ToolOrderLayer {
359 fn before_tool<'a>(
360 &'a self,
361 _ctx: &'a LayerContext<'_>,
362 _call: &'a ToolCall,
363 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
364 self.log
365 .lock()
366 .unwrap()
367 .push(format!("before_tool_{}", self.id));
368 Box::pin(std::future::ready(None))
369 }
370
371 fn after_tool<'a>(
372 &'a self,
373 _ctx: &'a LayerContext<'_>,
374 _call: &'a ToolCall,
375 _result: &'a Result<Option<ToolOutput>, ToolError>,
376 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
377 self.log
378 .lock()
379 .unwrap()
380 .push(format!("after_tool_{}", self.id));
381 Box::pin(std::future::ready(()))
382 }
383 }
384
385 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
386 let layer_a = ToolOrderLayer {
387 id: 1,
388 log: Arc::clone(&log),
389 };
390 let layer_b = ToolOrderLayer {
391 id: 2,
392 log: Arc::clone(&log),
393 };
394
395 let ctx = LayerContext {
396 conversation_id: None,
397 turn_number: 0,
398 };
399 let call = ToolCall {
400 tool_id: "shell".into(),
401 params: serde_json::Map::new(),
402 caller_id: None,
403 };
404 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
405
406 layer_a.before_tool(&ctx, &call).await;
407 layer_b.before_tool(&ctx, &call).await;
408 layer_a.after_tool(&ctx, &call, &result).await;
409 layer_b.after_tool(&ctx, &call, &result).await;
410
411 let events = log.lock().unwrap().clone();
412 assert_eq!(
413 events,
414 vec![
415 "before_tool_1",
416 "before_tool_2",
417 "after_tool_1",
418 "after_tool_2"
419 ],
420 "tool hooks must fire in registration order"
421 );
422 }
423
424 #[tokio::test]
426 async fn noop_layer_after_tool_returns_unit() {
427 use zeph_tools::executor::ToolOutput;
428
429 let layer = NoopLayer;
430 let ctx = LayerContext {
431 conversation_id: None,
432 turn_number: 0,
433 };
434 let call = ToolCall {
435 tool_id: "shell".into(),
436 params: serde_json::Map::new(),
437 caller_id: None,
438 };
439 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
440 layer.after_tool(&ctx, &call, &result).await;
441 }
443}