1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9use tokio::time::sleep;
10
11use crate::error::LlmError;
12use crate::performance::{MonitorConfig, PerformanceMonitor};
13use crate::traits::ChatCapability;
14use crate::types::{ChatMessage, MessageContent, MessageMetadata, MessageRole};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BenchmarkConfig {
19 pub concurrency: usize,
21 pub total_requests: usize,
23 pub duration: Option<Duration>,
25 pub warmup_duration: Duration,
27 pub rate_limit: Option<f64>,
29 pub scenarios: Vec<BenchmarkScenario>,
31}
32
33impl Default for BenchmarkConfig {
34 fn default() -> Self {
35 Self {
36 concurrency: 10,
37 total_requests: 100,
38 duration: None,
39 warmup_duration: Duration::from_secs(5),
40 rate_limit: None,
41 scenarios: vec![BenchmarkScenario::default()],
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct BenchmarkScenario {
49 pub name: String,
51 pub messages: Vec<ChatMessage>,
53 pub expected: ExpectedResponse,
55 pub weight: f64,
57}
58
59impl Default for BenchmarkScenario {
60 fn default() -> Self {
61 Self {
62 name: "basic_chat".to_string(),
63 messages: vec![ChatMessage {
64 role: MessageRole::User,
65 content: MessageContent::Text("Hello, how are you?".to_string()),
66 metadata: MessageMetadata::default(),
67 tool_calls: None,
68 tool_call_id: None,
69 }],
70 expected: ExpectedResponse::default(),
71 weight: 1.0,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ExpectedResponse {
79 pub min_length: Option<usize>,
81 pub max_length: Option<usize>,
83 pub response_time_range: Option<(Duration, Duration)>,
85 pub required_keywords: Vec<String>,
87 pub forbidden_keywords: Vec<String>,
89}
90
91impl Default for ExpectedResponse {
92 fn default() -> Self {
93 Self {
94 min_length: Some(1),
95 max_length: None,
96 response_time_range: None,
97 required_keywords: Vec::new(),
98 forbidden_keywords: Vec::new(),
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct BenchmarkResults {
106 pub total_requests: usize,
108 pub successful_requests: usize,
110 pub failed_requests: usize,
112 pub total_duration: Duration,
114 pub requests_per_second: f64,
116 pub latency_stats: LatencyStats,
118 pub error_breakdown: HashMap<String, usize>,
120 pub scenario_results: HashMap<String, ScenarioResults>,
122 pub resource_usage: ResourceUsage,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct LatencyStats {
129 pub mean: Duration,
131 pub median: Duration,
133 pub p95: Duration,
135 pub p99: Duration,
137 pub p999: Duration,
139 pub min: Duration,
141 pub max: Duration,
143 pub std_dev: Duration,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ScenarioResults {
150 pub name: String,
152 pub requests: usize,
154 pub success_rate: f64,
156 pub avg_response_time: Duration,
158 pub validation_results: ValidationResults,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct ValidationResults {
165 pub passed: usize,
167 pub failed: usize,
169 pub failure_reasons: HashMap<String, usize>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ResourceUsage {
176 pub peak_memory: u64,
178 pub avg_memory: u64,
180 pub cpu_usage: f64,
182 pub bytes_sent: u64,
184 pub bytes_received: u64,
186}
187
188pub struct BenchmarkRunner {
190 config: BenchmarkConfig,
192 monitor: PerformanceMonitor,
194}
195
196impl BenchmarkRunner {
197 pub fn new(config: BenchmarkConfig) -> Self {
199 let monitor_config = MonitorConfig {
200 detailed_metrics: true,
201 memory_tracking: true,
202 ..MonitorConfig::default()
203 };
204
205 Self {
206 config,
207 monitor: PerformanceMonitor::new(monitor_config),
208 }
209 }
210
211 pub async fn run<T: ChatCapability + Send + Sync + 'static>(
213 &self,
214 client: std::sync::Arc<T>,
215 ) -> Result<BenchmarkResults, LlmError> {
216 println!(
217 "🚀 Starting benchmark with {} concurrent requests",
218 self.config.concurrency
219 );
220
221 if !self.config.warmup_duration.is_zero() {
223 println!("🔥 Warming up for {:?}", self.config.warmup_duration);
224 self.warmup(&*client).await?;
225 }
226
227 let start_time = Instant::now();
228 let mut handles = Vec::new();
229 let mut results = Vec::new();
230
231 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(self.config.concurrency));
233
234 let requests_per_worker = self.config.total_requests / self.config.concurrency;
236 let remaining_requests = self.config.total_requests % self.config.concurrency;
237
238 for worker_id in 0..self.config.concurrency {
239 let worker_requests = if worker_id < remaining_requests {
240 requests_per_worker + 1
241 } else {
242 requests_per_worker
243 };
244
245 if worker_requests == 0 {
246 continue;
247 }
248
249 let semaphore = semaphore.clone();
250 let scenarios = self.config.scenarios.clone();
251 let monitor = self.monitor.clone();
252 let client = client.clone();
253
254 let handle = tokio::spawn(async move {
255 let _permit = semaphore.acquire().await.unwrap();
256 Self::run_worker(worker_id, worker_requests, scenarios, &*client, monitor).await
257 });
258
259 handles.push(handle);
260 }
261
262 for handle in handles {
264 match handle.await {
265 Ok(worker_results) => results.extend(worker_results),
266 Err(e) => eprintln!("Worker failed: {e}"),
267 }
268 }
269
270 let total_duration = start_time.elapsed();
271
272 self.compile_results(results, total_duration).await
274 }
275
276 async fn warmup<T: ChatCapability + Send + Sync>(&self, client: &T) -> Result<(), LlmError> {
278 let warmup_requests = (self.config.concurrency * 2).min(10);
279 let scenario = &self.config.scenarios[0];
280
281 for _ in 0..warmup_requests {
282 let _ = client
283 .chat_with_tools(scenario.messages.clone(), None)
284 .await;
285 sleep(Duration::from_millis(100)).await;
286 }
287
288 Ok(())
289 }
290
291 async fn run_worker<T: ChatCapability + Send + Sync>(
293 worker_id: usize,
294 requests: usize,
295 scenarios: Vec<BenchmarkScenario>,
296 client: &T,
297 monitor: PerformanceMonitor,
298 ) -> Vec<RequestResult> {
299 let mut results = Vec::new();
300
301 for request_id in 0..requests {
302 let scenario = Self::select_scenario(&scenarios);
304
305 let timer = monitor.start_request().await;
306
307 match client
308 .chat_with_tools(scenario.messages.clone(), None)
309 .await
310 {
311 Ok(response) => {
312 let duration = timer.finish().await;
313 monitor.record_success(None, duration).await;
314
315 let validation = Self::validate_response(&response, &scenario.expected);
316
317 results.push(RequestResult {
318 worker_id,
319 request_id,
320 scenario_name: scenario.name.clone(),
321 success: true,
322 duration,
323 error: None,
324 response_length: response.content.text().map(str::len),
325 validation,
326 });
327 }
328 Err(error) => {
329 let duration = timer.finish().await;
330 monitor.record_error("request_failed", None).await;
331
332 results.push(RequestResult {
333 worker_id,
334 request_id,
335 scenario_name: scenario.name.clone(),
336 success: false,
337 duration,
338 error: Some(error.to_string()),
339 response_length: None,
340 validation: ValidationResults {
341 passed: 0,
342 failed: 1,
343 failure_reasons: [("error".to_string(), 1)].into_iter().collect(),
344 },
345 });
346 }
347 }
348 }
349
350 results
351 }
352
353 fn select_scenario(scenarios: &[BenchmarkScenario]) -> &BenchmarkScenario {
355 &scenarios[0]
358 }
359
360 fn validate_response(
362 response: &crate::types::ChatResponse,
363 expected: &ExpectedResponse,
364 ) -> ValidationResults {
365 let mut passed = 0;
366 let mut failed = 0;
367 let mut failure_reasons = HashMap::new();
368
369 let response_text = response.content.text().unwrap_or("");
370 let response_length = response_text.len();
371
372 if let Some(min_length) = expected.min_length {
374 if response_length >= min_length {
375 passed += 1;
376 } else {
377 failed += 1;
378 *failure_reasons.entry("min_length".to_string()).or_insert(0) += 1;
379 }
380 }
381
382 if let Some(max_length) = expected.max_length {
383 if response_length <= max_length {
384 passed += 1;
385 } else {
386 failed += 1;
387 *failure_reasons.entry("max_length".to_string()).or_insert(0) += 1;
388 }
389 }
390
391 for keyword in &expected.required_keywords {
393 if response_text.contains(keyword) {
394 passed += 1;
395 } else {
396 failed += 1;
397 *failure_reasons
398 .entry("missing_keyword".to_string())
399 .or_insert(0) += 1;
400 }
401 }
402
403 for keyword in &expected.forbidden_keywords {
405 if !response_text.contains(keyword) {
406 passed += 1;
407 } else {
408 failed += 1;
409 *failure_reasons
410 .entry("forbidden_keyword".to_string())
411 .or_insert(0) += 1;
412 }
413 }
414
415 ValidationResults {
416 passed,
417 failed,
418 failure_reasons,
419 }
420 }
421
422 async fn compile_results(
424 &self,
425 results: Vec<RequestResult>,
426 total_duration: Duration,
427 ) -> Result<BenchmarkResults, LlmError> {
428 let total_requests = results.len();
429 let successful_requests = results.iter().filter(|r| r.success).count();
430 let failed_requests = total_requests - successful_requests;
431
432 let requests_per_second = total_requests as f64 / total_duration.as_secs_f64();
433
434 let mut durations: Vec<Duration> = results.iter().map(|r| r.duration).collect();
436 durations.sort();
437
438 let latency_stats = if !durations.is_empty() {
439 let len = durations.len();
440 LatencyStats {
441 mean: durations.iter().sum::<Duration>() / len as u32,
442 median: durations[len / 2],
443 p95: durations[(len * 95) / 100],
444 p99: durations[(len * 99) / 100],
445 p999: durations[(len * 999) / 1000],
446 min: durations[0],
447 max: durations[len - 1],
448 std_dev: Duration::ZERO, }
450 } else {
451 LatencyStats {
452 mean: Duration::ZERO,
453 median: Duration::ZERO,
454 p95: Duration::ZERO,
455 p99: Duration::ZERO,
456 p999: Duration::ZERO,
457 min: Duration::ZERO,
458 max: Duration::ZERO,
459 std_dev: Duration::ZERO,
460 }
461 };
462
463 let mut error_breakdown = HashMap::new();
465 for result in &results {
466 if let Some(ref error) = result.error {
467 *error_breakdown.entry(error.clone()).or_insert(0) += 1;
468 }
469 }
470
471 let mut scenario_results = HashMap::new();
473 for scenario in &self.config.scenarios {
474 let scenario_requests: Vec<_> = results
475 .iter()
476 .filter(|r| r.scenario_name == scenario.name)
477 .collect();
478
479 if !scenario_requests.is_empty() {
480 let success_count = scenario_requests.iter().filter(|r| r.success).count();
481 let success_rate = success_count as f64 / scenario_requests.len() as f64;
482
483 let avg_response_time = scenario_requests
484 .iter()
485 .map(|r| r.duration)
486 .sum::<Duration>()
487 / scenario_requests.len() as u32;
488
489 let validation_results = ValidationResults {
490 passed: scenario_requests.iter().map(|r| r.validation.passed).sum(),
491 failed: scenario_requests.iter().map(|r| r.validation.failed).sum(),
492 failure_reasons: HashMap::new(), };
494
495 scenario_results.insert(
496 scenario.name.clone(),
497 ScenarioResults {
498 name: scenario.name.clone(),
499 requests: scenario_requests.len(),
500 success_rate,
501 avg_response_time,
502 validation_results,
503 },
504 );
505 }
506 }
507
508 Ok(BenchmarkResults {
509 total_requests,
510 successful_requests,
511 failed_requests,
512 total_duration,
513 requests_per_second,
514 latency_stats,
515 error_breakdown,
516 scenario_results,
517 resource_usage: ResourceUsage {
518 peak_memory: 0,
519 avg_memory: 0,
520 cpu_usage: 0.0,
521 bytes_sent: 0,
522 bytes_received: 0,
523 },
524 })
525 }
526}
527
528#[derive(Debug, Clone)]
530#[allow(dead_code)]
531struct RequestResult {
532 worker_id: usize,
533 request_id: usize,
534 scenario_name: String,
535 success: bool,
536 duration: Duration,
537 error: Option<String>,
538 response_length: Option<usize>,
539 validation: ValidationResults,
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_benchmark_config() {
548 let config = BenchmarkConfig::default();
549 assert_eq!(config.concurrency, 10);
550 assert_eq!(config.total_requests, 100);
551 assert_eq!(config.scenarios.len(), 1);
552 }
553
554 #[test]
555 fn test_expected_response_validation() {
556 let expected = ExpectedResponse {
557 min_length: Some(5),
558 max_length: Some(100),
559 response_time_range: None,
560 required_keywords: vec!["hello".to_string()],
561 forbidden_keywords: vec!["error".to_string()],
562 };
563
564 assert!(expected.min_length.is_some());
566 assert!(expected.max_length.is_some());
567 }
568}