synaptic_middleware/
model_call_limit.rs1use std::sync::atomic::{AtomicUsize, Ordering};
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5
6use crate::{AgentMiddleware, ModelCaller, ModelRequest, ModelResponse};
7
8pub struct ModelCallLimitMiddleware {
13 max_calls: usize,
14 count: AtomicUsize,
15}
16
17impl ModelCallLimitMiddleware {
18 pub fn new(max_calls: usize) -> Self {
19 Self {
20 max_calls,
21 count: AtomicUsize::new(0),
22 }
23 }
24
25 pub fn call_count(&self) -> usize {
26 self.count.load(Ordering::SeqCst)
27 }
28
29 pub fn reset(&self) {
30 self.count.store(0, Ordering::SeqCst);
31 }
32}
33
34#[async_trait]
35impl AgentMiddleware for ModelCallLimitMiddleware {
36 async fn wrap_model_call(
37 &self,
38 request: ModelRequest,
39 next: &dyn ModelCaller,
40 ) -> Result<ModelResponse, SynapticError> {
41 let current = self.count.fetch_add(1, Ordering::SeqCst);
42 if current >= self.max_calls {
43 return Err(SynapticError::MaxStepsExceeded {
44 max_steps: self.max_calls,
45 });
46 }
47 next.call(request).await
48 }
49}
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54
55 #[test]
56 fn tracks_count() {
57 let mw = ModelCallLimitMiddleware::new(5);
58 assert_eq!(mw.call_count(), 0);
59 mw.count.fetch_add(1, Ordering::SeqCst);
60 assert_eq!(mw.call_count(), 1);
61 mw.reset();
62 assert_eq!(mw.call_count(), 0);
63 }
64}