1use crate::tools::improvements_errors::ObservabilityContext;
7use crate::types::CompactStr;
8use serde_json::{Map, Value};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Instant;
13
14type AsyncContinuation<'a> =
16 Box<dyn Fn(ToolRequest) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync + 'a>;
17
18type AsyncContinuationOwned =
20 Box<dyn Fn(ToolRequest) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync>;
21
22#[async_trait::async_trait]
24pub trait AsyncMiddleware: Send + Sync {
25 fn name(&self) -> &str;
27
28 async fn execute<'a>(&'a self, request: ToolRequest, next: AsyncContinuation<'a>)
30 -> ToolResult;
31}
32
33#[derive(Clone, Debug)]
35pub struct ToolRequest {
36 pub tool_name: CompactStr,
37 pub arguments: String,
38 pub context: String,
39}
40
41#[derive(Clone, Debug)]
43pub struct ToolResult {
44 pub success: bool,
45 pub output: Option<String>,
46 pub error: Option<String>,
47 pub duration_ms: u64,
48 pub from_cache: bool,
49}
50
51pub struct AsyncMiddlewareChain {
53 middlewares: Vec<Arc<dyn AsyncMiddleware>>,
54}
55
56impl AsyncMiddlewareChain {
57 pub fn new() -> Self {
58 Self {
59 middlewares: Vec::new(),
60 }
61 }
62
63 pub fn with_middleware(mut self, middleware: Arc<dyn AsyncMiddleware>) -> Self {
64 self.middlewares.push(middleware);
65 self
66 }
67
68 pub async fn execute_simple<F>(&self, request: ToolRequest, executor: F) -> ToolResult
70 where
71 F: Fn(ToolRequest) -> ToolResult + Send + Sync + 'static,
72 {
73 if self.middlewares.is_empty() {
74 return executor(request);
75 }
76
77 let executor = Arc::new(executor);
78 let middlewares = self.middlewares.clone();
79
80 fn build_chain(
81 middlewares: &[Arc<dyn AsyncMiddleware>],
82 executor: Arc<dyn Fn(ToolRequest) -> ToolResult + Send + Sync>,
83 ) -> AsyncContinuationOwned {
84 if middlewares.is_empty() {
85 Box::new(move |req: ToolRequest| {
86 let result = executor(req);
87 Box::pin(async move { result })
88 })
89 } else {
90 let current = middlewares[0].clone();
91 let rest = build_chain(&middlewares[1..], executor);
92 let rest = Arc::new(rest);
93 Box::new(move |req: ToolRequest| {
94 let current = current.clone();
95 let rest = rest.clone();
96 Box::pin(async move {
97 let next: AsyncContinuationOwned = Box::new(move |r: ToolRequest| {
98 let rest = rest.clone();
99 Box::pin(async move { rest(r).await })
100 });
101 current.execute(req, next).await
102 })
103 })
104 }
105 }
106
107 let chain = build_chain(&middlewares, executor);
108 chain(request).await
109 }
110}
111
112impl Default for AsyncMiddlewareChain {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118fn normalize_context(context: &str) -> String {
119 let mut normalized = Map::new();
120 let parsed: Value = serde_json::from_str(context).unwrap_or_else(|_| Value::Object(Map::new()));
121
122 if let Some(session) = parsed.get("session_id").and_then(Value::as_str)
123 && !session.is_empty()
124 {
125 normalized.insert("session_id".into(), Value::String(session.to_string()));
126 }
127
128 if let Some(task) = parsed.get("task_id").and_then(Value::as_str)
129 && !task.is_empty()
130 {
131 normalized.insert("task_id".into(), Value::String(task.to_string()));
132 }
133
134 if let Some(version) = parsed.get("plan_version").and_then(Value::as_u64) {
135 normalized.insert("plan_version".into(), Value::Number(version.into()));
136 }
137
138 if let Some(plan) = parsed.get("plan_summary").and_then(Value::as_object) {
139 let mut summary = Map::new();
140 if let Some(status) = plan.get("status").and_then(Value::as_str) {
141 summary.insert("status".into(), Value::String(status.to_string()));
142 }
143 if let Some(total) = plan.get("total_steps").and_then(Value::as_u64) {
144 summary.insert("total_steps".into(), Value::Number(total.into()));
145 }
146 if let Some(completed) = plan.get("completed_steps").and_then(Value::as_u64) {
147 summary.insert("completed_steps".into(), Value::Number(completed.into()));
148 }
149 if !summary.is_empty() {
150 normalized.insert("plan_summary".into(), Value::Object(summary));
151 }
152 }
153
154 if let Some(phase) = parsed
155 .get("plan_phase")
156 .and_then(|v| v.as_str())
157 .filter(|p| !p.is_empty())
158 {
159 normalized.insert("plan_phase".into(), Value::String(phase.to_string()));
160 }
161
162 serde_json::to_string(&Value::Object(normalized)).unwrap_or_else(|_| "{}".to_string())
163}
164
165pub struct AsyncLoggingMiddleware {
167 obs_context: Arc<ObservabilityContext>,
168}
169
170impl AsyncLoggingMiddleware {
171 pub fn new(obs_context: Arc<ObservabilityContext>) -> Self {
172 Self { obs_context }
173 }
174}
175
176#[async_trait::async_trait]
177impl AsyncMiddleware for AsyncLoggingMiddleware {
178 fn name(&self) -> &str {
179 "async_logging"
180 }
181
182 async fn execute<'a>(
183 &'a self,
184 request: ToolRequest,
185 next: Box<
186 dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
187 + Send
188 + Sync
189 + 'a,
190 >,
191 ) -> ToolResult {
192 let tool_name = request.tool_name.clone();
193 let normalized_context = normalize_context(&request.context);
194 let context_json: Option<Value> = serde_json::from_str(&normalized_context).ok();
195 let session_id = context_json
196 .as_ref()
197 .and_then(|v| v.get("session_id").and_then(|s| s.as_str()))
198 .unwrap_or("");
199 let task_id = context_json
200 .as_ref()
201 .and_then(|v| v.get("task_id").and_then(|s| s.as_str()))
202 .unwrap_or("");
203 let plan_summary = context_json.as_ref().and_then(|v| v.get("plan_summary"));
204 let plan_status = plan_summary
205 .and_then(|v| v.get("status").and_then(|s| s.as_str()))
206 .unwrap_or("");
207 let plan_phase = context_json
208 .as_ref()
209 .and_then(|v| v.get("plan_phase").and_then(|p| p.as_str()))
210 .unwrap_or("");
211 let plan_total_steps = plan_summary
212 .and_then(|v| v.get("total_steps").and_then(|n| n.as_u64()))
213 .unwrap_or(0);
214 let plan_completed_steps = plan_summary
215 .and_then(|v| v.get("completed_steps").and_then(|n| n.as_u64()))
216 .unwrap_or(0);
217 let plan_version = context_json
218 .as_ref()
219 .and_then(|v| v.get("plan_version").and_then(|n| n.as_u64()))
220 .unwrap_or(0);
221
222 tracing::debug!(
223 tool = %tool_name,
224 session_id = %session_id,
225 task_id = %task_id,
226 plan_version,
227 plan_status = %plan_status,
228 plan_phase = %plan_phase,
229 plan_total_steps,
230 plan_completed_steps,
231 "tool execution started"
232 );
233 tracing::trace!(
234 tool = %tool_name,
235 context = %normalized_context,
236 "tool execution context payload"
237 );
238
239 let start = Instant::now();
240 let mut result = next(request).await;
241 let duration = start.elapsed().as_millis().min(u64::MAX as u128) as u64;
242
243 result.duration_ms = duration;
244
245 if result.success {
246 tracing::debug!(
247 tool = %tool_name,
248 duration_ms = duration,
249 session_id = %session_id,
250 task_id = %task_id,
251 plan_version,
252 plan_status = %plan_status,
253 plan_phase = %plan_phase,
254 plan_total_steps,
255 plan_completed_steps,
256 from_cache = result.from_cache,
257 "tool execution completed"
258 );
259 self.obs_context.event(
260 crate::tools::EventType::ToolSelected,
261 "executor",
262 format!("executed {} in {}ms", tool_name, duration),
263 Some(1.0),
264 );
265 } else {
266 tracing::error!(
267 tool = %tool_name,
268 error = ?result.error,
269 session_id = %context_json
270 .as_ref()
271 .and_then(|v| v.get("session_id").and_then(|s| s.as_str()))
272 .unwrap_or(""),
273 task_id = %context_json
274 .as_ref()
275 .and_then(|v| v.get("task_id").and_then(|s| s.as_str()))
276 .unwrap_or(""),
277 "tool execution failed"
278 );
279 }
280
281 result
282 }
283}
284
285pub struct AsyncCachingMiddleware {
287 cache: Arc<crate::cache::UnifiedCache<AsyncCacheKey, String>>,
288 obs_context: Arc<ObservabilityContext>,
289}
290
291#[derive(Debug, Clone, Hash, PartialEq, Eq)]
292struct AsyncCacheKey(String);
293
294impl crate::cache::CacheKey for AsyncCacheKey {
295 fn to_cache_key(&self) -> String {
296 self.0.clone()
297 }
298}
299
300impl AsyncCachingMiddleware {
301 pub fn new(
302 max_entries: usize,
303 ttl_seconds: u64,
304 obs_context: Arc<ObservabilityContext>,
305 ) -> Self {
306 let cache = crate::cache::UnifiedCache::new(
307 max_entries,
308 std::time::Duration::from_secs(ttl_seconds),
309 crate::cache::EvictionPolicy::Lru,
310 );
311
312 Self {
313 cache: Arc::new(cache),
314 obs_context,
315 }
316 }
317
318 fn cache_key(tool: &str, args: &str, context: &str) -> String {
319 use std::collections::hash_map::DefaultHasher;
321 use std::hash::Hasher;
322 let mut hasher = DefaultHasher::new();
323 hasher.write(args.as_bytes());
324 let normalized = normalize_context(context);
325 if !normalized.is_empty() {
326 hasher.write(normalized.as_bytes());
327 }
328 format!("{}::{}", tool, hasher.finish())
329 }
330}
331
332#[async_trait::async_trait]
333impl AsyncMiddleware for AsyncCachingMiddleware {
334 fn name(&self) -> &str {
335 "async_caching"
336 }
337
338 async fn execute<'a>(
339 &'a self,
340 request: ToolRequest,
341 next: Box<
342 dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
343 + Send
344 + Sync
345 + 'a,
346 >,
347 ) -> ToolResult {
348 let key = AsyncCacheKey(Self::cache_key(
349 &request.tool_name,
350 &request.arguments,
351 &request.context,
352 ));
353
354 if let Some(cached) = self.cache.get_owned(&key) {
356 self.obs_context.event(
357 crate::tools::EventType::CacheHit,
358 "cache",
359 "returning cached result",
360 Some(1.0),
361 );
362
363 return ToolResult {
364 success: true,
365 output: Some(cached),
366 error: None,
367 duration_ms: 0,
368 from_cache: true,
369 };
370 }
371
372 let result = next(request).await;
374
375 if result.success
377 && let Some(ref output) = result.output
378 {
379 let size = output.len() as u64;
380 self.cache.insert(key, output.clone(), size);
381 }
382
383 result
384 }
385}
386
387pub struct AsyncRetryMiddleware {
389 max_attempts: u32,
390 initial_backoff_ms: u64,
391 max_backoff_ms: u64,
392 obs_context: Arc<ObservabilityContext>,
393}
394
395impl AsyncRetryMiddleware {
396 pub fn new(
397 max_attempts: u32,
398 initial_backoff_ms: u64,
399 max_backoff_ms: u64,
400 obs_context: Arc<ObservabilityContext>,
401 ) -> Self {
402 Self {
403 max_attempts,
404 initial_backoff_ms,
405 max_backoff_ms,
406 obs_context,
407 }
408 }
409
410 fn backoff_duration(&self, attempt: u32) -> std::time::Duration {
411 let backoff = self.initial_backoff_ms * 2_u64.pow(attempt);
412 std::time::Duration::from_millis(backoff.min(self.max_backoff_ms))
413 }
414}
415
416#[async_trait::async_trait]
417impl AsyncMiddleware for AsyncRetryMiddleware {
418 fn name(&self) -> &str {
419 "async_retry"
420 }
421
422 async fn execute<'a>(
423 &'a self,
424 request: ToolRequest,
425 next: Box<
426 dyn Fn(ToolRequest) -> Pin<Box<dyn std::future::Future<Output = ToolResult> + Send>>
427 + Send
428 + Sync
429 + 'a,
430 >,
431 ) -> ToolResult {
432 for attempt in 0..self.max_attempts {
433 if attempt > 0 {
434 let backoff = self.backoff_duration(attempt - 1);
435 tracing::debug!(
436 attempt = attempt,
437 backoff_ms = backoff.as_millis(),
438 "retrying after backoff"
439 );
440 tokio::time::sleep(backoff).await;
441 }
442
443 let result = next(request.clone()).await;
444
445 if result.success {
446 if attempt > 0 {
447 self.obs_context.event(
448 crate::tools::EventType::FallbackSuccess,
449 "retry",
450 format!("succeeded on attempt {}", attempt + 1),
451 Some(1.0),
452 );
453 }
454 return result;
455 }
456
457 if let Some(ref error_msg) = result.error {
460 let category = vtcode_commons::classify_error_message(error_msg);
461 if !category.is_retryable() {
462 tracing::debug!(
463 attempt = attempt,
464 category = ?category,
465 "non-retryable error, skipping remaining attempts"
466 );
467 return result;
468 }
469 }
470
471 self.obs_context.event(
472 crate::tools::EventType::FallbackAttempt,
473 "retry",
474 format!("attempt {} failed", attempt + 1),
475 None,
476 );
477 }
478
479 ToolResult {
480 success: false,
481 output: None,
482 error: Some(format!("all {} attempts failed", self.max_attempts)),
483 duration_ms: 0,
484 from_cache: false,
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use std::future::Future;
493 use std::pin::Pin;
494 use std::sync::atomic::{AtomicUsize, Ordering};
495
496 type BoxedToolFuture = Pin<Box<dyn Future<Output = ToolResult> + Send>>;
497 type BoxedExecutor = Box<dyn Fn(ToolRequest) -> BoxedToolFuture + Send + Sync>;
498
499 fn make_executor(output: &'static str) -> BoxedExecutor {
500 Box::new(move |_req: ToolRequest| {
501 Box::pin(async move {
502 ToolResult {
503 success: true,
504 output: Some(output.to_string()),
505 error: None,
506 duration_ms: 0,
507 from_cache: false,
508 }
509 })
510 })
511 }
512
513 #[tokio::test]
514 async fn test_async_logging_middleware() {
515 let obs = Arc::new(ObservabilityContext::noop());
516 let middleware = AsyncLoggingMiddleware::new(obs);
517
518 let request = ToolRequest {
519 tool_name: "test_tool".into(),
520 arguments: "arg1".to_string(),
521 context: "ctx".to_string(),
522 };
523
524 let executor = make_executor("result");
525
526 let result = middleware.execute(request, executor).await;
527
528 assert!(result.success);
529 }
530
531 #[tokio::test]
532 async fn test_async_caching_middleware() {
533 let obs = Arc::new(ObservabilityContext::noop());
534 let cache = AsyncCachingMiddleware::new(10, 60, obs);
535
536 let request = ToolRequest {
537 tool_name: "cached_tool".into(),
538 arguments: "arg1".to_string(),
539 context: "ctx".to_string(),
540 };
541
542 let executor1 = make_executor("result1");
544
545 let result1 = cache.execute(request.clone(), executor1).await;
546 assert!(!result1.from_cache);
547
548 let executor2 = make_executor("result2");
550
551 let result2 = cache.execute(request, executor2).await;
552 assert!(result2.from_cache);
553 assert_eq!(result2.output, Some("result1".to_string())); }
555
556 #[tokio::test]
557 async fn async_retry_skips_non_retryable_errors() {
558 let obs = Arc::new(ObservabilityContext::noop());
559 let middleware = AsyncRetryMiddleware::new(3, 1, 2, obs);
560 let attempts = Arc::new(AtomicUsize::new(0));
561
562 let executor_attempts = attempts.clone();
563 let executor: BoxedExecutor = Box::new(move |_req: ToolRequest| {
564 let executor_attempts = executor_attempts.clone();
565 Box::pin(async move {
566 executor_attempts.fetch_add(1, Ordering::SeqCst);
567 ToolResult {
568 success: false,
569 output: None,
570 error: Some("invalid api key".to_string()),
571 duration_ms: 0,
572 from_cache: false,
573 }
574 })
575 });
576
577 let result = middleware
578 .execute(
579 ToolRequest {
580 tool_name: "auth_tool".into(),
581 arguments: "{}".to_string(),
582 context: "{}".to_string(),
583 },
584 executor,
585 )
586 .await;
587
588 assert!(!result.success);
589 assert_eq!(attempts.load(Ordering::SeqCst), 1);
590 }
591
592 #[tokio::test]
593 async fn async_retry_retries_retryable_errors_until_success() {
594 let obs = Arc::new(ObservabilityContext::noop());
595 let middleware = AsyncRetryMiddleware::new(3, 1, 2, obs);
596 let attempts = Arc::new(AtomicUsize::new(0));
597
598 let executor_attempts = attempts.clone();
599 let executor: BoxedExecutor = Box::new(move |_req: ToolRequest| {
600 let executor_attempts = executor_attempts.clone();
601 Box::pin(async move {
602 let attempt = executor_attempts.fetch_add(1, Ordering::SeqCst);
603 if attempt < 2 {
604 ToolResult {
605 success: false,
606 output: None,
607 error: Some("429 Too Many Requests".to_string()),
608 duration_ms: 0,
609 from_cache: false,
610 }
611 } else {
612 ToolResult {
613 success: true,
614 output: Some("ok".to_string()),
615 error: None,
616 duration_ms: 0,
617 from_cache: false,
618 }
619 }
620 })
621 });
622
623 let result = middleware
624 .execute(
625 ToolRequest {
626 tool_name: "rate_limited_tool".into(),
627 arguments: "{}".to_string(),
628 context: "{}".to_string(),
629 },
630 executor,
631 )
632 .await;
633
634 assert!(result.success);
635 assert_eq!(result.output.as_deref(), Some("ok"));
636 assert_eq!(attempts.load(Ordering::SeqCst), 3);
637 }
638}