benchmark_server/
benchmark_server.rs1use std::env;
25use std::sync::{
26 Arc,
27 atomic::{AtomicUsize, Ordering},
28};
29use std::time::{Duration, Instant};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio::net::{TcpListener, TcpStream};
32use tokio::sync::{Barrier, Semaphore, oneshot};
33use tokio::time::sleep;
34use zerust::datapack::DataPack;
35use zerust::{DefaultRouter, Response, Server};
36
37#[tokio::main]
38async fn main() -> Result<(), Box<dyn std::error::Error>> {
39 let args: Vec<String> = env::args().collect();
40
41 match args.get(1).map(|s| s.as_str()) {
42 Some("server") => run_server().await?,
43 Some("client") => {
44 let connections = args
45 .get(2)
46 .and_then(|s| s.parse::<usize>().ok())
47 .unwrap_or(100);
48 let requests_per_conn = args
49 .get(3)
50 .and_then(|s| s.parse::<usize>().ok())
51 .unwrap_or(1000);
52 run_client(connections, requests_per_conn).await?
53 }
54 _ => {
55 println!(
56 "用法: cargo run --release --example benchmark_server -- [server|client] [连接数] [每连接请求数]"
57 );
58 println!(" server - 启动基准测试服务器");
59 println!(" client [连接数] [每连接请求数] - 启动客户端测试");
60 }
61 }
62
63 Ok(())
64}
65
66async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
68 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
70
71 let router = Arc::new(DefaultRouter::new());
73 let router_clone = router.clone();
74
75 let request_counter = Arc::new(AtomicUsize::new(0));
77 let counter_clone = request_counter.clone();
78
79 router_clone.add_route(1, move |req| {
81 counter_clone.fetch_add(1, Ordering::Relaxed);
82 Response::new(req.msg_id(), req.data().to_vec())
83 });
84
85 let server_addr = "127.0.0.1:8888";
87 let server = Server::new(server_addr, router);
88 println!("[Server] 基准测试服务器启动在 {}", server_addr);
89
90 let stats_handle = tokio::spawn(async move {
92 let mut last_count = 0;
93 let mut last_time = Instant::now();
94
95 loop {
96 sleep(Duration::from_secs(1)).await;
97 let current_count = request_counter.load(Ordering::Relaxed);
98 let current_time = Instant::now();
99 let elapsed = current_time.duration_since(last_time).as_secs_f64();
100
101 let rps = (current_count - last_count) as f64 / elapsed;
102 println!(
103 "[Stats] 当前RPS: {:.2} req/s, 总请求数: {}",
104 rps, current_count
105 );
106
107 last_count = current_count;
108 last_time = current_time;
109 }
110 });
111
112 let server_handle = tokio::spawn(async move {
114 if let Err(e) = server.run(shutdown_rx).await {
115 eprintln!("[Server] 运行时错误: {}", e);
116 }
117 });
118
119 println!("[Server] 按 Ctrl+C 停止服务器...");
120 tokio::signal::ctrl_c().await?;
121 println!("[Server] 接收到停止信号,正在关闭...");
122
123 let _ = shutdown_tx.send(());
125
126 let _ = server_handle.await;
128 stats_handle.abort();
129
130 println!("[Server] 服务器已关闭");
131 Ok(())
132}
133
134async fn run_client(
136 connections: usize,
137 requests_per_conn: usize,
138) -> Result<(), Box<dyn std::error::Error>> {
139 println!(
140 "[Client] 开始基准测试: {} 并发连接, 每连接 {} 请求",
141 connections, requests_per_conn
142 );
143
144 let semaphore = Arc::new(Semaphore::new(connections));
146
147 let barrier = Arc::new(Barrier::new(connections + 1)); let total_requests = connections * requests_per_conn;
152 let completed_requests = Arc::new(AtomicUsize::new(0));
153 let total_latency = Arc::new(AtomicUsize::new(0)); let mut handles = Vec::with_capacity(connections);
157
158 let start_time = Instant::now();
159
160 for i in 0..connections {
161 let semaphore_clone = semaphore.clone();
162 let barrier_clone = barrier.clone();
163 let completed_clone = completed_requests.clone();
164 let latency_clone = total_latency.clone();
165
166 let handle = tokio::spawn(async move {
167 let _permit = semaphore_clone.acquire().await.unwrap();
169
170 let mut stream = match TcpStream::connect("127.0.0.1:8888").await {
172 Ok(stream) => stream,
173 Err(e) => {
174 eprintln!("[Client {}] 连接失败: {}", i, e);
175 return;
176 }
177 };
178
179 barrier_clone.wait().await;
181
182 for _ in 0..requests_per_conn {
184 let payload = vec![b'A'; 64]; let request = DataPack::pack(1, &payload);
187
188 let request_start = Instant::now();
189
190 if let Err(e) = stream.write_all(&request).await {
192 eprintln!("[Client {}] 发送请求失败: {}", i, e);
193 break;
194 }
195
196 let mut header = [0u8; 8];
198 if let Err(e) = stream.read_exact(&mut header).await {
199 eprintln!("[Client {}] 读取响应头失败: {}", i, e);
200 break;
201 }
202
203 let (msg_id, data_len) = match DataPack::unpack_header(&header) {
204 Ok(result) => result,
205 Err(e) => {
206 eprintln!("[Client {}] 解析响应头失败: {}", i, e);
207 break;
208 }
209 };
210
211 let mut data = vec![0u8; data_len as usize];
213 if let Err(e) = stream.read_exact(&mut data).await {
214 eprintln!("[Client {}] 读取响应数据失败: {}", i, e);
215 break;
216 }
217
218 let latency = request_start.elapsed().as_micros() as usize;
220 latency_clone.fetch_add(latency, Ordering::Relaxed);
221
222 completed_clone.fetch_add(1, Ordering::Relaxed);
224 }
225 });
226
227 handles.push(handle);
228 }
229
230 let progress_completed = completed_requests.clone();
232 let progress_handle = tokio::spawn(async move {
233 loop {
234 sleep(Duration::from_secs(1)).await;
235 let completed = progress_completed.load(Ordering::Relaxed);
236 let progress = (completed as f64 / total_requests as f64) * 100.0;
237 println!(
238 "[Progress] {:.2}% ({}/{})",
239 progress, completed, total_requests
240 );
241
242 if completed >= total_requests {
243 break;
244 }
245 }
246 });
247
248 println!("[Client] 所有连接已就绪,开始测试...");
250 barrier.wait().await;
251
252 for handle in handles {
254 let _ = handle.await;
255 }
256
257 progress_handle.abort();
259
260 let elapsed = start_time.elapsed();
262 let completed = completed_requests.load(Ordering::Relaxed);
263 let avg_latency = if completed > 0 {
264 total_latency.load(Ordering::Relaxed) as f64 / completed as f64
265 } else {
266 0.0
267 };
268
269 println!("\n===== 基准测试结果 =====");
271 println!("总连接数: {}", connections);
272 println!("每连接请求数: {}", requests_per_conn);
273 println!("总请求数: {}", total_requests);
274 println!("完成请求数: {}", completed);
275 println!("总耗时: {:.2} 秒", elapsed.as_secs_f64());
276 println!("平均延迟: {:.2} 微秒", avg_latency);
277 println!(
278 "吞吐量: {:.2} 请求/秒",
279 completed as f64 / elapsed.as_secs_f64()
280 );
281
282 Ok(())
283}