siumai/
benchmarks.rs

1//! Benchmarking and Performance Testing
2//!
3//! This module provides benchmarking utilities and performance tests
4//! for the siumai library components.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9use tokio::time::sleep;
10
11use crate::error::LlmError;
12use crate::performance::{MonitorConfig, PerformanceMonitor};
13use crate::traits::ChatCapability;
14use crate::types::{ChatMessage, MessageContent, MessageMetadata, MessageRole};
15
16/// Benchmark configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BenchmarkConfig {
19    /// Number of concurrent requests
20    pub concurrency: usize,
21    /// Total number of requests to send
22    pub total_requests: usize,
23    /// Duration to run the benchmark
24    pub duration: Option<Duration>,
25    /// Warmup period before starting measurements
26    pub warmup_duration: Duration,
27    /// Request rate limit (requests per second)
28    pub rate_limit: Option<f64>,
29    /// Test scenarios to run
30    pub scenarios: Vec<BenchmarkScenario>,
31}
32
33impl Default for BenchmarkConfig {
34    fn default() -> Self {
35        Self {
36            concurrency: 10,
37            total_requests: 100,
38            duration: None,
39            warmup_duration: Duration::from_secs(5),
40            rate_limit: None,
41            scenarios: vec![BenchmarkScenario::default()],
42        }
43    }
44}
45
46/// Benchmark scenario
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct BenchmarkScenario {
49    /// Scenario name
50    pub name: String,
51    /// Test messages to send
52    pub messages: Vec<ChatMessage>,
53    /// Expected response characteristics
54    pub expected: ExpectedResponse,
55    /// Weight of this scenario (for mixed workloads)
56    pub weight: f64,
57}
58
59impl Default for BenchmarkScenario {
60    fn default() -> Self {
61        Self {
62            name: "basic_chat".to_string(),
63            messages: vec![ChatMessage {
64                role: MessageRole::User,
65                content: MessageContent::Text("Hello, how are you?".to_string()),
66                metadata: MessageMetadata::default(),
67                tool_calls: None,
68                tool_call_id: None,
69            }],
70            expected: ExpectedResponse::default(),
71            weight: 1.0,
72        }
73    }
74}
75
76/// Expected response characteristics for validation
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ExpectedResponse {
79    /// Minimum response length
80    pub min_length: Option<usize>,
81    /// Maximum response length
82    pub max_length: Option<usize>,
83    /// Expected response time range
84    pub response_time_range: Option<(Duration, Duration)>,
85    /// Required keywords in response
86    pub required_keywords: Vec<String>,
87    /// Forbidden keywords in response
88    pub forbidden_keywords: Vec<String>,
89}
90
91impl Default for ExpectedResponse {
92    fn default() -> Self {
93        Self {
94            min_length: Some(1),
95            max_length: None,
96            response_time_range: None,
97            required_keywords: Vec::new(),
98            forbidden_keywords: Vec::new(),
99        }
100    }
101}
102
103/// Benchmark results
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct BenchmarkResults {
106    /// Total requests sent
107    pub total_requests: usize,
108    /// Successful requests
109    pub successful_requests: usize,
110    /// Failed requests
111    pub failed_requests: usize,
112    /// Total duration of the benchmark
113    pub total_duration: Duration,
114    /// Requests per second
115    pub requests_per_second: f64,
116    /// Latency statistics
117    pub latency_stats: LatencyStats,
118    /// Error breakdown
119    pub error_breakdown: HashMap<String, usize>,
120    /// Scenario results
121    pub scenario_results: HashMap<String, ScenarioResults>,
122    /// Resource usage
123    pub resource_usage: ResourceUsage,
124}
125
126/// Latency statistics
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct LatencyStats {
129    /// Mean latency
130    pub mean: Duration,
131    /// Median latency (P50)
132    pub median: Duration,
133    /// 95th percentile
134    pub p95: Duration,
135    /// 99th percentile
136    pub p99: Duration,
137    /// 99.9th percentile
138    pub p999: Duration,
139    /// Minimum latency
140    pub min: Duration,
141    /// Maximum latency
142    pub max: Duration,
143    /// Standard deviation
144    pub std_dev: Duration,
145}
146
147/// Scenario-specific results
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ScenarioResults {
150    /// Scenario name
151    pub name: String,
152    /// Number of requests for this scenario
153    pub requests: usize,
154    /// Success rate
155    pub success_rate: f64,
156    /// Average response time
157    pub avg_response_time: Duration,
158    /// Validation results
159    pub validation_results: ValidationResults,
160}
161
162/// Validation results
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ValidationResults {
165    /// Number of responses that passed validation
166    pub passed: usize,
167    /// Number of responses that failed validation
168    pub failed: usize,
169    /// Validation failure reasons
170    pub failure_reasons: HashMap<String, usize>,
171}
172
173/// Resource usage during benchmark
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ResourceUsage {
176    /// Peak memory usage in bytes
177    pub peak_memory: u64,
178    /// Average memory usage in bytes
179    pub avg_memory: u64,
180    /// CPU usage percentage
181    pub cpu_usage: f64,
182    /// Network bytes sent
183    pub bytes_sent: u64,
184    /// Network bytes received
185    pub bytes_received: u64,
186}
187
188/// Benchmark runner
189pub struct BenchmarkRunner {
190    /// Configuration
191    config: BenchmarkConfig,
192    /// Performance monitor
193    monitor: PerformanceMonitor,
194}
195
196impl BenchmarkRunner {
197    /// Create a new benchmark runner
198    pub fn new(config: BenchmarkConfig) -> Self {
199        let monitor_config = MonitorConfig {
200            detailed_metrics: true,
201            memory_tracking: true,
202            ..MonitorConfig::default()
203        };
204
205        Self {
206            config,
207            monitor: PerformanceMonitor::new(monitor_config),
208        }
209    }
210
211    /// Run benchmark against a client
212    pub async fn run<T: ChatCapability + Send + Sync + 'static>(
213        &self,
214        client: std::sync::Arc<T>,
215    ) -> Result<BenchmarkResults, LlmError> {
216        println!(
217            "🚀 Starting benchmark with {} concurrent requests",
218            self.config.concurrency
219        );
220
221        // Warmup phase
222        if !self.config.warmup_duration.is_zero() {
223            println!("🔥 Warming up for {:?}", self.config.warmup_duration);
224            self.warmup(&*client).await?;
225        }
226
227        let start_time = Instant::now();
228        let mut handles = Vec::new();
229        let mut results = Vec::new();
230
231        // Create semaphore for concurrency control
232        let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(self.config.concurrency));
233
234        // Calculate requests per worker
235        let requests_per_worker = self.config.total_requests / self.config.concurrency;
236        let remaining_requests = self.config.total_requests % self.config.concurrency;
237
238        for worker_id in 0..self.config.concurrency {
239            let worker_requests = if worker_id < remaining_requests {
240                requests_per_worker + 1
241            } else {
242                requests_per_worker
243            };
244
245            if worker_requests == 0 {
246                continue;
247            }
248
249            let semaphore = semaphore.clone();
250            let scenarios = self.config.scenarios.clone();
251            let monitor = self.monitor.clone();
252            let client = client.clone();
253
254            let handle = tokio::spawn(async move {
255                let _permit = semaphore.acquire().await.unwrap();
256                Self::run_worker(worker_id, worker_requests, scenarios, &*client, monitor).await
257            });
258
259            handles.push(handle);
260        }
261
262        // Wait for all workers to complete
263        for handle in handles {
264            match handle.await {
265                Ok(worker_results) => results.extend(worker_results),
266                Err(e) => eprintln!("Worker failed: {e}"),
267            }
268        }
269
270        let total_duration = start_time.elapsed();
271
272        // Compile results
273        self.compile_results(results, total_duration).await
274    }
275
276    /// Run warmup requests
277    async fn warmup<T: ChatCapability + Send + Sync>(&self, client: &T) -> Result<(), LlmError> {
278        let warmup_requests = (self.config.concurrency * 2).min(10);
279        let scenario = &self.config.scenarios[0];
280
281        for _ in 0..warmup_requests {
282            let _ = client
283                .chat_with_tools(scenario.messages.clone(), None)
284                .await;
285            sleep(Duration::from_millis(100)).await;
286        }
287
288        Ok(())
289    }
290
291    /// Run a single worker
292    async fn run_worker<T: ChatCapability + Send + Sync>(
293        worker_id: usize,
294        requests: usize,
295        scenarios: Vec<BenchmarkScenario>,
296        client: &T,
297        monitor: PerformanceMonitor,
298    ) -> Vec<RequestResult> {
299        let mut results = Vec::new();
300
301        for request_id in 0..requests {
302            // Select scenario based on weight
303            let scenario = Self::select_scenario(&scenarios);
304
305            let timer = monitor.start_request().await;
306
307            match client
308                .chat_with_tools(scenario.messages.clone(), None)
309                .await
310            {
311                Ok(response) => {
312                    let duration = timer.finish().await;
313                    monitor.record_success(None, duration).await;
314
315                    let validation = Self::validate_response(&response, &scenario.expected);
316
317                    results.push(RequestResult {
318                        worker_id,
319                        request_id,
320                        scenario_name: scenario.name.clone(),
321                        success: true,
322                        duration,
323                        error: None,
324                        response_length: response.content.text().map(str::len),
325                        validation,
326                    });
327                }
328                Err(error) => {
329                    let duration = timer.finish().await;
330                    monitor.record_error("request_failed", None).await;
331
332                    results.push(RequestResult {
333                        worker_id,
334                        request_id,
335                        scenario_name: scenario.name.clone(),
336                        success: false,
337                        duration,
338                        error: Some(error.to_string()),
339                        response_length: None,
340                        validation: ValidationResults {
341                            passed: 0,
342                            failed: 1,
343                            failure_reasons: [("error".to_string(), 1)].into_iter().collect(),
344                        },
345                    });
346                }
347            }
348        }
349
350        results
351    }
352
353    /// Select a scenario based on weights
354    fn select_scenario(scenarios: &[BenchmarkScenario]) -> &BenchmarkScenario {
355        // Simple implementation - just use the first scenario
356        // In a real implementation, you'd use weighted random selection
357        &scenarios[0]
358    }
359
360    /// Validate response against expected characteristics
361    fn validate_response(
362        response: &crate::types::ChatResponse,
363        expected: &ExpectedResponse,
364    ) -> ValidationResults {
365        let mut passed = 0;
366        let mut failed = 0;
367        let mut failure_reasons = HashMap::new();
368
369        let response_text = response.content.text().unwrap_or("");
370        let response_length = response_text.len();
371
372        // Check length constraints
373        if let Some(min_length) = expected.min_length {
374            if response_length >= min_length {
375                passed += 1;
376            } else {
377                failed += 1;
378                *failure_reasons.entry("min_length".to_string()).or_insert(0) += 1;
379            }
380        }
381
382        if let Some(max_length) = expected.max_length {
383            if response_length <= max_length {
384                passed += 1;
385            } else {
386                failed += 1;
387                *failure_reasons.entry("max_length".to_string()).or_insert(0) += 1;
388            }
389        }
390
391        // Check required keywords
392        for keyword in &expected.required_keywords {
393            if response_text.contains(keyword) {
394                passed += 1;
395            } else {
396                failed += 1;
397                *failure_reasons
398                    .entry("missing_keyword".to_string())
399                    .or_insert(0) += 1;
400            }
401        }
402
403        // Check forbidden keywords
404        for keyword in &expected.forbidden_keywords {
405            if !response_text.contains(keyword) {
406                passed += 1;
407            } else {
408                failed += 1;
409                *failure_reasons
410                    .entry("forbidden_keyword".to_string())
411                    .or_insert(0) += 1;
412            }
413        }
414
415        ValidationResults {
416            passed,
417            failed,
418            failure_reasons,
419        }
420    }
421
422    /// Compile final results
423    async fn compile_results(
424        &self,
425        results: Vec<RequestResult>,
426        total_duration: Duration,
427    ) -> Result<BenchmarkResults, LlmError> {
428        let total_requests = results.len();
429        let successful_requests = results.iter().filter(|r| r.success).count();
430        let failed_requests = total_requests - successful_requests;
431
432        let requests_per_second = total_requests as f64 / total_duration.as_secs_f64();
433
434        // Calculate latency statistics
435        let mut durations: Vec<Duration> = results.iter().map(|r| r.duration).collect();
436        durations.sort();
437
438        let latency_stats = if !durations.is_empty() {
439            let len = durations.len();
440            LatencyStats {
441                mean: durations.iter().sum::<Duration>() / len as u32,
442                median: durations[len / 2],
443                p95: durations[(len * 95) / 100],
444                p99: durations[(len * 99) / 100],
445                p999: durations[(len * 999) / 1000],
446                min: durations[0],
447                max: durations[len - 1],
448                std_dev: Duration::ZERO, // Simplified - would calculate actual std dev
449            }
450        } else {
451            LatencyStats {
452                mean: Duration::ZERO,
453                median: Duration::ZERO,
454                p95: Duration::ZERO,
455                p99: Duration::ZERO,
456                p999: Duration::ZERO,
457                min: Duration::ZERO,
458                max: Duration::ZERO,
459                std_dev: Duration::ZERO,
460            }
461        };
462
463        // Error breakdown
464        let mut error_breakdown = HashMap::new();
465        for result in &results {
466            if let Some(ref error) = result.error {
467                *error_breakdown.entry(error.clone()).or_insert(0) += 1;
468            }
469        }
470
471        // Scenario results
472        let mut scenario_results = HashMap::new();
473        for scenario in &self.config.scenarios {
474            let scenario_requests: Vec<_> = results
475                .iter()
476                .filter(|r| r.scenario_name == scenario.name)
477                .collect();
478
479            if !scenario_requests.is_empty() {
480                let success_count = scenario_requests.iter().filter(|r| r.success).count();
481                let success_rate = success_count as f64 / scenario_requests.len() as f64;
482
483                let avg_response_time = scenario_requests
484                    .iter()
485                    .map(|r| r.duration)
486                    .sum::<Duration>()
487                    / scenario_requests.len() as u32;
488
489                let validation_results = ValidationResults {
490                    passed: scenario_requests.iter().map(|r| r.validation.passed).sum(),
491                    failed: scenario_requests.iter().map(|r| r.validation.failed).sum(),
492                    failure_reasons: HashMap::new(), // Simplified
493                };
494
495                scenario_results.insert(
496                    scenario.name.clone(),
497                    ScenarioResults {
498                        name: scenario.name.clone(),
499                        requests: scenario_requests.len(),
500                        success_rate,
501                        avg_response_time,
502                        validation_results,
503                    },
504                );
505            }
506        }
507
508        Ok(BenchmarkResults {
509            total_requests,
510            successful_requests,
511            failed_requests,
512            total_duration,
513            requests_per_second,
514            latency_stats,
515            error_breakdown,
516            scenario_results,
517            resource_usage: ResourceUsage {
518                peak_memory: 0,
519                avg_memory: 0,
520                cpu_usage: 0.0,
521                bytes_sent: 0,
522                bytes_received: 0,
523            },
524        })
525    }
526}
527
528/// Individual request result
529#[derive(Debug, Clone)]
530#[allow(dead_code)]
531struct RequestResult {
532    worker_id: usize,
533    request_id: usize,
534    scenario_name: String,
535    success: bool,
536    duration: Duration,
537    error: Option<String>,
538    response_length: Option<usize>,
539    validation: ValidationResults,
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_benchmark_config() {
548        let config = BenchmarkConfig::default();
549        assert_eq!(config.concurrency, 10);
550        assert_eq!(config.total_requests, 100);
551        assert_eq!(config.scenarios.len(), 1);
552    }
553
554    #[test]
555    fn test_expected_response_validation() {
556        let expected = ExpectedResponse {
557            min_length: Some(5),
558            max_length: Some(100),
559            response_time_range: None,
560            required_keywords: vec!["hello".to_string()],
561            forbidden_keywords: vec!["error".to_string()],
562        };
563
564        // This would require a mock response to test properly
565        assert!(expected.min_length.is_some());
566        assert!(expected.max_length.is_some());
567    }
568}