Skip to main content

vtcode_core/tools/
async_middleware.rs

1//! Async middleware for LLM-compatible tool execution
2//!
3//! Proper composition pattern with async/await support.
4//! Suitable for tokio-based systems handling LLM operations.
5
6use crate::tools::improvements_errors::ObservabilityContext;
7use crate::types::CompactStr;
8use serde_json::{Map, Value};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Instant;
13
14/// Type alias for the async continuation function
15type AsyncContinuation<'a> =
16    Box<dyn Fn(ToolRequest) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync + 'a>;
17
18/// Type alias for the owned async continuation function
19type AsyncContinuationOwned =
20    Box<dyn Fn(ToolRequest) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync>;
21
22/// Async middleware trait
23#[async_trait::async_trait]
24pub trait AsyncMiddleware: Send + Sync {
25    /// Middleware name
26    fn name(&self) -> &str;
27
28    /// Execute middleware
29    async fn execute<'a>(&'a self, request: ToolRequest, next: AsyncContinuation<'a>)
30    -> ToolResult;
31}
32
33/// Tool request
34#[derive(Clone, Debug)]
35pub struct ToolRequest {
36    pub tool_name: CompactStr,
37    pub arguments: String,
38    pub context: String,
39}
40
41/// Tool result
42#[derive(Clone, Debug)]
43pub struct ToolResult {
44    pub success: bool,
45    pub output: Option<String>,
46    pub error: Option<String>,
47    pub duration_ms: u64,
48    pub from_cache: bool,
49}
50
51/// Async middleware chain executor
52pub struct AsyncMiddlewareChain {
53    middlewares: Vec<Arc<dyn AsyncMiddleware>>,
54}
55
56impl AsyncMiddlewareChain {
57    pub fn new() -> Self {
58        Self {
59            middlewares: Vec::new(),
60        }
61    }
62
63    pub fn with_middleware(mut self, middleware: Arc<dyn AsyncMiddleware>) -> Self {
64        self.middlewares.push(middleware);
65        self
66    }
67
68    /// Execute request through chain (simplified)
69    pub async fn execute_simple<F>(&self, request: ToolRequest, executor: F) -> ToolResult
70    where
71        F: Fn(ToolRequest) -> ToolResult + Send + Sync + 'static,
72    {
73        if self.middlewares.is_empty() {
74            return executor(request);
75        }
76
77        let executor = Arc::new(executor);
78        let middlewares = self.middlewares.clone();
79
80        fn build_chain(
81            middlewares: &[Arc<dyn AsyncMiddleware>],
82            executor: Arc<dyn Fn(ToolRequest) -> ToolResult + Send + Sync>,
83        ) -> AsyncContinuationOwned {
84            if middlewares.is_empty() {
85                Box::new(move |req: ToolRequest| {
86                    let result = executor(req);
87                    Box::pin(async move { result })
88                })
89            } else {
90                let current = middlewares[0].clone();
91                let rest = build_chain(&middlewares[1..], executor);
92                let rest = Arc::new(rest);
93                Box::new(move |req: ToolRequest| {
94                    let current = current.clone();
95                    let rest = rest.clone();
96                    Box::pin(async move {
97                        let next: AsyncContinuationOwned = Box::new(move |r: ToolRequest| {
98                            let rest = rest.clone();
99                            Box::pin(async move { rest(r).await })
100                        });
101                        current.execute(req, next).await
102                    })
103                })
104            }
105        }
106
107        let chain = build_chain(&middlewares, executor);
108        chain(request).await
109    }
110}
111
112impl Default for AsyncMiddlewareChain {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118fn normalize_context(context: &str) -> String {
119    let mut normalized = Map::new();
120    let parsed: Value = serde_json::from_str(context).unwrap_or_else(|_| Value::Object(Map::new()));
121
122    if let Some(session) = parsed.get("session_id").and_then(Value::as_str)
123        && !session.is_empty()
124    {
125        normalized.insert("session_id".into(), Value::String(session.to_string()));
126    }
127
128    if let Some(task) = parsed.get("task_id").and_then(Value::as_str)
129        && !task.is_empty()
130    {
131        normalized.insert("task_id".into(), Value::String(task.to_string()));
132    }
133
134    if let Some(version) = parsed.get("plan_version").and_then(Value::as_u64) {
135        normalized.insert("plan_version".into(), Value::Number(version.into()));
136    }
137
138    if let Some(plan) = parsed.get("plan_summary").and_then(Value::as_object) {
139        let mut summary = Map::new();
140        if let Some(status) = plan.get("status").and_then(Value::as_str) {
141            summary.insert("status".into(), Value::String(status.to_string()));
142        }
143        if let Some(total) = plan.get("total_steps").and_then(Value::as_u64) {
144            summary.insert("total_steps".into(), Value::Number(total.into()));
145        }
146        if let Some(completed) = plan.get("completed_steps").and_then(Value::as_u64) {
147            summary.insert("completed_steps".into(), Value::Number(completed.into()));
148        }
149        if !summary.is_empty() {
150            normalized.insert("plan_summary".into(), Value::Object(summary));
151        }
152    }
153
154    if let Some(phase) = parsed
155        .get("plan_phase")
156        .and_then(|v| v.as_str())
157        .filter(|p| !p.is_empty())
158    {
159        normalized.insert("plan_phase".into(), Value::String(phase.to_string()));
160    }
161
162    serde_json::to_string(&Value::Object(normalized)).unwrap_or_else(|_| "{}".to_string())
163}
164
165/// Async logging middleware
166pub struct AsyncLoggingMiddleware {
167    obs_context: Arc<ObservabilityContext>,
168}
169
170impl AsyncLoggingMiddleware {
171    pub fn new(obs_context: Arc<ObservabilityContext>) -> Self {
172        Self { obs_context }
173    }
174}
175
176#[async_trait::async_trait]
177impl AsyncMiddleware for AsyncLoggingMiddleware {
178    fn name(&self) -> &str {
179        "async_logging"
180    }
181
182    async fn execute<'a>(
183        &'a self,
184        request: ToolRequest,
185        next: Box<
186            dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
187                + Send
188                + Sync
189                + 'a,
190        >,
191    ) -> ToolResult {
192        let tool_name = request.tool_name.clone();
193        let normalized_context = normalize_context(&request.context);
194        let context_json: Option<Value> = serde_json::from_str(&normalized_context).ok();
195        let session_id = context_json
196            .as_ref()
197            .and_then(|v| v.get("session_id").and_then(|s| s.as_str()))
198            .unwrap_or("");
199        let task_id = context_json
200            .as_ref()
201            .and_then(|v| v.get("task_id").and_then(|s| s.as_str()))
202            .unwrap_or("");
203        let plan_summary = context_json.as_ref().and_then(|v| v.get("plan_summary"));
204        let plan_status = plan_summary
205            .and_then(|v| v.get("status").and_then(|s| s.as_str()))
206            .unwrap_or("");
207        let plan_phase = context_json
208            .as_ref()
209            .and_then(|v| v.get("plan_phase").and_then(|p| p.as_str()))
210            .unwrap_or("");
211        let plan_total_steps = plan_summary
212            .and_then(|v| v.get("total_steps").and_then(|n| n.as_u64()))
213            .unwrap_or(0);
214        let plan_completed_steps = plan_summary
215            .and_then(|v| v.get("completed_steps").and_then(|n| n.as_u64()))
216            .unwrap_or(0);
217        let plan_version = context_json
218            .as_ref()
219            .and_then(|v| v.get("plan_version").and_then(|n| n.as_u64()))
220            .unwrap_or(0);
221
222        tracing::debug!(
223            tool = %tool_name,
224            session_id = %session_id,
225            task_id = %task_id,
226            plan_version,
227            plan_status = %plan_status,
228            plan_phase = %plan_phase,
229            plan_total_steps,
230            plan_completed_steps,
231            "tool execution started"
232        );
233        tracing::trace!(
234            tool = %tool_name,
235            context = %normalized_context,
236            "tool execution context payload"
237        );
238
239        let start = Instant::now();
240        let mut result = next(request).await;
241        let duration = start.elapsed().as_millis().min(u64::MAX as u128) as u64;
242
243        result.duration_ms = duration;
244
245        if result.success {
246            tracing::debug!(
247                tool = %tool_name,
248                duration_ms = duration,
249                session_id = %session_id,
250                task_id = %task_id,
251                plan_version,
252                plan_status = %plan_status,
253                plan_phase = %plan_phase,
254                plan_total_steps,
255                plan_completed_steps,
256                from_cache = result.from_cache,
257                "tool execution completed"
258            );
259            self.obs_context.event(
260                crate::tools::EventType::ToolSelected,
261                "executor",
262                format!("executed {} in {}ms", tool_name, duration),
263                Some(1.0),
264            );
265        } else {
266            tracing::error!(
267                tool = %tool_name,
268                error = ?result.error,
269                session_id = %context_json
270                    .as_ref()
271                    .and_then(|v| v.get("session_id").and_then(|s| s.as_str()))
272                    .unwrap_or(""),
273                task_id = %context_json
274                    .as_ref()
275                    .and_then(|v| v.get("task_id").and_then(|s| s.as_str()))
276                    .unwrap_or(""),
277                "tool execution failed"
278            );
279        }
280
281        result
282    }
283}
284
285/// Async caching middleware with UnifiedCache (migrated from LruCache)
286pub struct AsyncCachingMiddleware {
287    cache: Arc<crate::cache::UnifiedCache<AsyncCacheKey, String>>,
288    obs_context: Arc<ObservabilityContext>,
289}
290
291#[derive(Debug, Clone, Hash, PartialEq, Eq)]
292struct AsyncCacheKey(String);
293
294impl crate::cache::CacheKey for AsyncCacheKey {
295    fn to_cache_key(&self) -> String {
296        self.0.clone()
297    }
298}
299
300impl AsyncCachingMiddleware {
301    pub fn new(
302        max_entries: usize,
303        ttl_seconds: u64,
304        obs_context: Arc<ObservabilityContext>,
305    ) -> Self {
306        let cache = crate::cache::UnifiedCache::new(
307            max_entries,
308            std::time::Duration::from_secs(ttl_seconds),
309            crate::cache::EvictionPolicy::Lru,
310        );
311
312        Self {
313            cache: Arc::new(cache),
314            obs_context,
315        }
316    }
317
318    fn cache_key(tool: &str, args: &str, context: &str) -> String {
319        // Use a hashed key to avoid creating large string cache keys while still uniquely identifying args
320        use std::collections::hash_map::DefaultHasher;
321        use std::hash::Hasher;
322        let mut hasher = DefaultHasher::new();
323        hasher.write(args.as_bytes());
324        let normalized = normalize_context(context);
325        if !normalized.is_empty() {
326            hasher.write(normalized.as_bytes());
327        }
328        format!("{}::{}", tool, hasher.finish())
329    }
330}
331
332#[async_trait::async_trait]
333impl AsyncMiddleware for AsyncCachingMiddleware {
334    fn name(&self) -> &str {
335        "async_caching"
336    }
337
338    async fn execute<'a>(
339        &'a self,
340        request: ToolRequest,
341        next: Box<
342            dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
343                + Send
344                + Sync
345                + 'a,
346        >,
347    ) -> ToolResult {
348        let key = AsyncCacheKey(Self::cache_key(
349            &request.tool_name,
350            &request.arguments,
351            &request.context,
352        ));
353
354        // Check cache (migrated to UnifiedCache)
355        if let Some(cached) = self.cache.get_owned(&key) {
356            self.obs_context.event(
357                crate::tools::EventType::CacheHit,
358                "cache",
359                "returning cached result",
360                Some(1.0),
361            );
362
363            return ToolResult {
364                success: true,
365                output: Some(cached),
366                error: None,
367                duration_ms: 0,
368                from_cache: true,
369            };
370        }
371
372        // Execute
373        let result = next(request).await;
374
375        // Cache successful result (migrated to UnifiedCache)
376        if result.success
377            && let Some(ref output) = result.output
378        {
379            let size = output.len() as u64;
380            self.cache.insert(key, output.clone(), size);
381        }
382
383        result
384    }
385}
386
387/// Async retry middleware with exponential backoff
388pub struct AsyncRetryMiddleware {
389    max_attempts: u32,
390    initial_backoff_ms: u64,
391    max_backoff_ms: u64,
392    obs_context: Arc<ObservabilityContext>,
393}
394
395impl AsyncRetryMiddleware {
396    pub fn new(
397        max_attempts: u32,
398        initial_backoff_ms: u64,
399        max_backoff_ms: u64,
400        obs_context: Arc<ObservabilityContext>,
401    ) -> Self {
402        Self {
403            max_attempts,
404            initial_backoff_ms,
405            max_backoff_ms,
406            obs_context,
407        }
408    }
409
410    fn backoff_duration(&self, attempt: u32) -> std::time::Duration {
411        let backoff = self.initial_backoff_ms * 2_u64.pow(attempt);
412        std::time::Duration::from_millis(backoff.min(self.max_backoff_ms))
413    }
414}
415
416#[async_trait::async_trait]
417impl AsyncMiddleware for AsyncRetryMiddleware {
418    fn name(&self) -> &str {
419        "async_retry"
420    }
421
422    async fn execute<'a>(
423        &'a self,
424        request: ToolRequest,
425        next: Box<
426            dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
427                + Send
428                + Sync
429                + 'a,
430        >,
431    ) -> ToolResult {
432        for attempt in 0..self.max_attempts {
433            if attempt > 0 {
434                let backoff = self.backoff_duration(attempt - 1);
435                tracing::debug!(
436                    attempt = attempt,
437                    backoff_ms = backoff.as_millis(),
438                    "retrying after backoff"
439                );
440                tokio::time::sleep(backoff).await;
441            }
442
443            let result = next(request.clone()).await;
444
445            if result.success {
446                if attempt > 0 {
447                    self.obs_context.event(
448                        crate::tools::EventType::FallbackSuccess,
449                        "retry",
450                        format!("succeeded on attempt {}", attempt + 1),
451                        Some(1.0),
452                    );
453                }
454                return result;
455            }
456
457            // Skip retry for non-retryable errors (auth failures, policy
458            // violations, invalid parameters) to fail fast.
459            if let Some(ref error_msg) = result.error {
460                let category = vtcode_commons::classify_error_message(error_msg);
461                if !category.is_retryable() {
462                    tracing::debug!(
463                        attempt = attempt,
464                        category = ?category,
465                        "non-retryable error, skipping remaining attempts"
466                    );
467                    return result;
468                }
469            }
470
471            self.obs_context.event(
472                crate::tools::EventType::FallbackAttempt,
473                "retry",
474                format!("attempt {} failed", attempt + 1),
475                None,
476            );
477        }
478
479        ToolResult {
480            success: false,
481            output: None,
482            error: Some(format!("all {} attempts failed", self.max_attempts)),
483            duration_ms: 0,
484            from_cache: false,
485        }
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use std::future::Future;
493    use std::pin::Pin;
494    use std::sync::atomic::{AtomicUsize, Ordering};
495
496    type BoxedToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send>>;
497    type BoxedExecutor = Box<dyn Fn(ToolRequest) -> BoxedToolFuture + Send + Sync>;
498
499    fn make_executor(output: &'static str) -> BoxedExecutor {
500        Box::new(move |_req: ToolRequest| {
501            Box::pin(async move {
502                ToolResult {
503                    success: true,
504                    output: Some(output.to_string()),
505                    error: None,
506                    duration_ms: 0,
507                    from_cache: false,
508                }
509            })
510        })
511    }
512
513    #[tokio::test]
514    async fn test_async_logging_middleware() {
515        let obs = Arc::new(ObservabilityContext::noop());
516        let middleware = AsyncLoggingMiddleware::new(obs);
517
518        let request = ToolRequest {
519            tool_name: "test_tool".into(),
520            arguments: "arg1".to_string(),
521            context: "ctx".to_string(),
522        };
523
524        let executor = make_executor("result");
525
526        let result = middleware.execute(request, executor).await;
527
528        assert!(result.success);
529    }
530
531    #[tokio::test]
532    async fn test_async_caching_middleware() {
533        let obs = Arc::new(ObservabilityContext::noop());
534        let cache = AsyncCachingMiddleware::new(10, 60, obs);
535
536        let request = ToolRequest {
537            tool_name: "cached_tool".into(),
538            arguments: "arg1".to_string(),
539            context: "ctx".to_string(),
540        };
541
542        // First call
543        let executor1 = make_executor("result1");
544
545        let result1 = cache.execute(request.clone(), executor1).await;
546        assert!(!result1.from_cache);
547
548        // Second call (should be cached)
549        let executor2 = make_executor("result2");
550
551        let result2 = cache.execute(request, executor2).await;
552        assert!(result2.from_cache);
553        assert_eq!(result2.output, Some("result1".to_string())); // Returns cached value
554    }
555
556    #[tokio::test]
557    async fn async_retry_skips_non_retryable_errors() {
558        let obs = Arc::new(ObservabilityContext::noop());
559        let middleware = AsyncRetryMiddleware::new(3, 1, 2, obs);
560        let attempts = Arc::new(AtomicUsize::new(0));
561
562        let executor_attempts = attempts.clone();
563        let executor: BoxedExecutor = Box::new(move |_req: ToolRequest| {
564            let executor_attempts = executor_attempts.clone();
565            Box::pin(async move {
566                executor_attempts.fetch_add(1, Ordering::SeqCst);
567                ToolResult {
568                    success: false,
569                    output: None,
570                    error: Some("invalid api key".to_string()),
571                    duration_ms: 0,
572                    from_cache: false,
573                }
574            })
575        });
576
577        let result = middleware
578            .execute(
579                ToolRequest {
580                    tool_name: "auth_tool".into(),
581                    arguments: "{}".to_string(),
582                    context: "{}".to_string(),
583                },
584                executor,
585            )
586            .await;
587
588        assert!(!result.success);
589        assert_eq!(attempts.load(Ordering::SeqCst), 1);
590    }
591
592    #[tokio::test]
593    async fn async_retry_retries_retryable_errors_until_success() {
594        let obs = Arc::new(ObservabilityContext::noop());
595        let middleware = AsyncRetryMiddleware::new(3, 1, 2, obs);
596        let attempts = Arc::new(AtomicUsize::new(0));
597
598        let executor_attempts = attempts.clone();
599        let executor: BoxedExecutor = Box::new(move |_req: ToolRequest| {
600            let executor_attempts = executor_attempts.clone();
601            Box::pin(async move {
602                let attempt = executor_attempts.fetch_add(1, Ordering::SeqCst);
603                if attempt < 2 {
604                    ToolResult {
605                        success: false,
606                        output: None,
607                        error: Some("429 Too Many Requests".to_string()),
608                        duration_ms: 0,
609                        from_cache: false,
610                    }
611                } else {
612                    ToolResult {
613                        success: true,
614                        output: Some("ok".to_string()),
615                        error: None,
616                        duration_ms: 0,
617                        from_cache: false,
618                    }
619                }
620            })
621        });
622
623        let result = middleware
624            .execute(
625                ToolRequest {
626                    tool_name: "rate_limited_tool".into(),
627                    arguments: "{}".to_string(),
628                    context: "{}".to_string(),
629                },
630                executor,
631            )
632            .await;
633
634        assert!(result.success);
635        assert_eq!(result.output.as_deref(), Some("ok"));
636        assert_eq!(attempts.load(Ordering::SeqCst), 3);
637    }
638}