Skip to main content

synaptic_middleware/
model_call_limit.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5
6use crate::{AgentMiddleware, ModelCaller, ModelRequest, ModelResponse};
7
8/// Limits the number of model invocations during a single agent run.
9///
10/// When the limit is exceeded, `wrap_model_call` returns a
11/// `SynapticError::MaxStepsExceeded` error.
12pub struct ModelCallLimitMiddleware {
13    max_calls: usize,
14    count: AtomicUsize,
15}
16
17impl ModelCallLimitMiddleware {
18    pub fn new(max_calls: usize) -> Self {
19        Self {
20            max_calls,
21            count: AtomicUsize::new(0),
22        }
23    }
24
25    pub fn call_count(&self) -> usize {
26        self.count.load(Ordering::SeqCst)
27    }
28
29    pub fn reset(&self) {
30        self.count.store(0, Ordering::SeqCst);
31    }
32}
33
34#[async_trait]
35impl AgentMiddleware for ModelCallLimitMiddleware {
36    async fn wrap_model_call(
37        &self,
38        request: ModelRequest,
39        next: &dyn ModelCaller,
40    ) -> Result<ModelResponse, SynapticError> {
41        let current = self.count.fetch_add(1, Ordering::SeqCst);
42        if current >= self.max_calls {
43            return Err(SynapticError::MaxStepsExceeded {
44                max_steps: self.max_calls,
45            });
46        }
47        next.call(request).await
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54
55    #[test]
56    fn tracks_count() {
57        let mw = ModelCallLimitMiddleware::new(5);
58        assert_eq!(mw.call_count(), 0);
59        mw.count.fetch_add(1, Ordering::SeqCst);
60        assert_eq!(mw.call_count(), 1);
61        mw.reset();
62        assert_eq!(mw.call_count(), 0);
63    }
64}