1use std::future::Future;
25use std::pin::Pin;
26
27use zeph_llm::provider::{ChatResponse, Message, ToolDefinition};
28use zeph_tools::ToolError;
29use zeph_tools::executor::{ToolCall, ToolOutput};
30
31pub type BeforeToolResult = Option<Result<Option<ToolOutput>, ToolError>>;
33
34#[derive(Debug)]
36pub struct LayerContext<'a> {
37 pub conversation_id: Option<&'a str>,
39 pub turn_number: u32,
41}
42
43pub trait RuntimeLayer: Send + Sync {
57 fn before_chat<'a>(
62 &'a self,
63 _ctx: &'a LayerContext<'_>,
64 _messages: &'a [Message],
65 _tools: &'a [ToolDefinition],
66 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
67 Box::pin(std::future::ready(None))
68 }
69
70 fn after_chat<'a>(
72 &'a self,
73 _ctx: &'a LayerContext<'_>,
74 _response: &'a ChatResponse,
75 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
76 Box::pin(std::future::ready(()))
77 }
78
79 fn before_tool<'a>(
84 &'a self,
85 _ctx: &'a LayerContext<'_>,
86 _call: &'a ToolCall,
87 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
88 Box::pin(std::future::ready(None))
89 }
90
91 fn after_tool<'a>(
93 &'a self,
94 _ctx: &'a LayerContext<'_>,
95 _call: &'a ToolCall,
96 _result: &'a Result<Option<ToolOutput>, ToolError>,
97 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
98 Box::pin(std::future::ready(()))
99 }
100}
101
102pub struct NoopLayer;
107
108impl RuntimeLayer for NoopLayer {}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use zeph_llm::provider::Role;
114
115 struct CountingLayer {
116 before_chat_calls: std::sync::atomic::AtomicU32,
117 after_chat_calls: std::sync::atomic::AtomicU32,
118 }
119
120 impl CountingLayer {
121 fn new() -> Self {
122 Self {
123 before_chat_calls: std::sync::atomic::AtomicU32::new(0),
124 after_chat_calls: std::sync::atomic::AtomicU32::new(0),
125 }
126 }
127 }
128
129 impl RuntimeLayer for CountingLayer {
130 fn before_chat<'a>(
131 &'a self,
132 _ctx: &'a LayerContext<'_>,
133 _messages: &'a [Message],
134 _tools: &'a [ToolDefinition],
135 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
136 self.before_chat_calls
137 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
138 Box::pin(std::future::ready(None))
139 }
140
141 fn after_chat<'a>(
142 &'a self,
143 _ctx: &'a LayerContext<'_>,
144 _response: &'a ChatResponse,
145 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
146 self.after_chat_calls
147 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148 Box::pin(std::future::ready(()))
149 }
150 }
151
152 #[test]
153 fn noop_layer_compiles_and_is_runtime_layer() {
154 fn assert_runtime_layer<T: RuntimeLayer>() {}
156 assert_runtime_layer::<NoopLayer>();
157 }
158
159 #[tokio::test]
160 async fn noop_layer_before_chat_returns_none() {
161 let layer = NoopLayer;
162 let ctx = LayerContext {
163 conversation_id: None,
164 turn_number: 0,
165 };
166 let result = layer.before_chat(&ctx, &[], &[]).await;
167 assert!(result.is_none());
168 }
169
170 #[tokio::test]
171 async fn noop_layer_before_tool_returns_none() {
172 let layer = NoopLayer;
173 let ctx = LayerContext {
174 conversation_id: None,
175 turn_number: 0,
176 };
177 let call = ToolCall {
178 tool_id: "shell".into(),
179 params: serde_json::Map::new(),
180 };
181 let result = layer.before_tool(&ctx, &call).await;
182 assert!(result.is_none());
183 }
184
185 #[tokio::test]
186 async fn layer_hooks_are_called() {
187 use std::sync::Arc;
188 let layer = Arc::new(CountingLayer::new());
189 let ctx = LayerContext {
190 conversation_id: Some("conv-1"),
191 turn_number: 3,
192 };
193 let resp = ChatResponse::Text("hello".into());
194
195 let _ = layer.before_chat(&ctx, &[], &[]).await;
196 layer.after_chat(&ctx, &resp).await;
197
198 assert_eq!(
199 layer
200 .before_chat_calls
201 .load(std::sync::atomic::Ordering::Relaxed),
202 1
203 );
204 assert_eq!(
205 layer
206 .after_chat_calls
207 .load(std::sync::atomic::Ordering::Relaxed),
208 1
209 );
210 }
211
212 #[tokio::test]
213 async fn short_circuit_layer_returns_response() {
214 struct ShortCircuitLayer;
215 impl RuntimeLayer for ShortCircuitLayer {
216 fn before_chat<'a>(
217 &'a self,
218 _ctx: &'a LayerContext<'_>,
219 _messages: &'a [Message],
220 _tools: &'a [ToolDefinition],
221 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
222 Box::pin(std::future::ready(Some(ChatResponse::Text(
223 "short-circuited".into(),
224 ))))
225 }
226 }
227
228 let layer = ShortCircuitLayer;
229 let ctx = LayerContext {
230 conversation_id: None,
231 turn_number: 0,
232 };
233 let result = layer.before_chat(&ctx, &[], &[]).await;
234 assert!(matches!(result, Some(ChatResponse::Text(ref s)) if s == "short-circuited"));
235 }
236
237 #[test]
239 fn message_from_legacy_compiles() {
240 let _msg = Message::from_legacy(Role::User, "hello");
241 }
242
243 #[tokio::test]
246 async fn multiple_layers_called_in_registration_order() {
247 use std::sync::{Arc, Mutex};
248
249 struct OrderLayer {
250 id: u32,
251 log: Arc<Mutex<Vec<String>>>,
252 }
253 impl RuntimeLayer for OrderLayer {
254 fn before_chat<'a>(
255 &'a self,
256 _ctx: &'a LayerContext<'_>,
257 _messages: &'a [Message],
258 _tools: &'a [ToolDefinition],
259 ) -> Pin<Box<dyn Future<Output = Option<ChatResponse>> + Send + 'a>> {
260 let entry = format!("before_{}", self.id);
261 self.log.lock().unwrap().push(entry);
262 Box::pin(std::future::ready(None))
263 }
264
265 fn after_chat<'a>(
266 &'a self,
267 _ctx: &'a LayerContext<'_>,
268 _response: &'a ChatResponse,
269 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
270 let entry = format!("after_{}", self.id);
271 self.log.lock().unwrap().push(entry);
272 Box::pin(std::future::ready(()))
273 }
274 }
275
276 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
277 let layer_a = OrderLayer {
278 id: 1,
279 log: Arc::clone(&log),
280 };
281 let layer_b = OrderLayer {
282 id: 2,
283 log: Arc::clone(&log),
284 };
285
286 let ctx = LayerContext {
287 conversation_id: None,
288 turn_number: 0,
289 };
290 let resp = ChatResponse::Text("ok".into());
291
292 layer_a.before_chat(&ctx, &[], &[]).await;
293 layer_b.before_chat(&ctx, &[], &[]).await;
294 layer_a.after_chat(&ctx, &resp).await;
295 layer_b.after_chat(&ctx, &resp).await;
296
297 let events = log.lock().unwrap().clone();
298 assert_eq!(
299 events,
300 vec!["before_1", "before_2", "after_1", "after_2"],
301 "hooks must fire in registration order"
302 );
303 }
304
305 #[tokio::test]
307 async fn after_chat_receives_short_circuit_response() {
308 use std::sync::{Arc, Mutex};
309
310 struct CapturingAfter {
311 captured: Arc<Mutex<Option<String>>>,
312 }
313 impl RuntimeLayer for CapturingAfter {
314 fn after_chat<'a>(
315 &'a self,
316 _ctx: &'a LayerContext<'_>,
317 response: &'a ChatResponse,
318 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
319 if let ChatResponse::Text(t) = response {
320 *self.captured.lock().unwrap() = Some(t.clone());
321 }
322 Box::pin(std::future::ready(()))
323 }
324 }
325
326 let captured: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
327 let layer = CapturingAfter {
328 captured: Arc::clone(&captured),
329 };
330 let ctx = LayerContext {
331 conversation_id: None,
332 turn_number: 0,
333 };
334
335 let sc_response = ChatResponse::Text("short-circuit".into());
337 layer.after_chat(&ctx, &sc_response).await;
338
339 let got = captured.lock().unwrap().clone();
340 assert_eq!(
341 got.as_deref(),
342 Some("short-circuit"),
343 "after_chat must receive the short-circuit response"
344 );
345 }
346
347 #[tokio::test]
350 async fn multi_layer_before_after_tool_ordering() {
351 use std::sync::{Arc, Mutex};
352
353 struct ToolOrderLayer {
354 id: u32,
355 log: Arc<Mutex<Vec<String>>>,
356 }
357 impl RuntimeLayer for ToolOrderLayer {
358 fn before_tool<'a>(
359 &'a self,
360 _ctx: &'a LayerContext<'_>,
361 _call: &'a ToolCall,
362 ) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + 'a>> {
363 self.log
364 .lock()
365 .unwrap()
366 .push(format!("before_tool_{}", self.id));
367 Box::pin(std::future::ready(None))
368 }
369
370 fn after_tool<'a>(
371 &'a self,
372 _ctx: &'a LayerContext<'_>,
373 _call: &'a ToolCall,
374 _result: &'a Result<Option<ToolOutput>, ToolError>,
375 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
376 self.log
377 .lock()
378 .unwrap()
379 .push(format!("after_tool_{}", self.id));
380 Box::pin(std::future::ready(()))
381 }
382 }
383
384 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
385 let layer_a = ToolOrderLayer {
386 id: 1,
387 log: Arc::clone(&log),
388 };
389 let layer_b = ToolOrderLayer {
390 id: 2,
391 log: Arc::clone(&log),
392 };
393
394 let ctx = LayerContext {
395 conversation_id: None,
396 turn_number: 0,
397 };
398 let call = ToolCall {
399 tool_id: "shell".into(),
400 params: serde_json::Map::new(),
401 };
402 let result: Result<Option<ToolOutput>, ToolError> = Ok(None);
403
404 layer_a.before_tool(&ctx, &call).await;
405 layer_b.before_tool(&ctx, &call).await;
406 layer_a.after_tool(&ctx, &call, &result).await;
407 layer_b.after_tool(&ctx, &call, &result).await;
408
409 let events = log.lock().unwrap().clone();
410 assert_eq!(
411 events,
412 vec![
413 "before_tool_1",
414 "before_tool_2",
415 "after_tool_1",
416 "after_tool_2"
417 ],
418 "tool hooks must fire in registration order"
419 );
420 }
421
422 #[tokio::test]
424 async fn noop_layer_after_tool_returns_unit() {
425 use zeph_tools::executor::ToolOutput;
426
427 let layer = NoopLayer;
428 let ctx = LayerContext {
429 conversation_id: None,
430 turn_number: 0,
431 };
432 let call = ToolCall {
433 tool_id: "shell".into(),
434 params: serde_json::Map::new(),
435 };
436 let result: Result<Option<ToolOutput>, zeph_tools::ToolError> = Ok(None);
437 layer.after_tool(&ctx, &call, &result).await;
438 }
440}