benchmark_server/
benchmark_server.rs

1//! # Zerust Benchmark Server 性能测试程序
2//!
3//! 本示例用于测试 Zerust 框架的性能,并与 Go 的 Zinx 项目进行对比:
4//! - 高并发连接处理能力
5//! - 请求吞吐量(RPS - Requests Per Second)
6//! - 请求延迟(Latency)
7//! - 资源占用(CPU、内存)
8//!
9//! ✅ 运行方式:
10//! ```bash
11//! # 启动服务器
12//! cargo run --release --example benchmark_server -- server
13//!
14//! # 在另一个终端运行客户端测试
15//! cargo run --release --example benchmark_server -- client [连接数] [每连接请求数]
16//! ```
17//!
18//! 例如:
19//! ```bash
20//! cargo run --release --example benchmark_server -- client 100 1000
21//! ```
22//! 将创建100个并发连接,每个连接发送1000个请求
23
24use std::env;
25use std::sync::{
26    Arc,
27    atomic::{AtomicUsize, Ordering},
28};
29use std::time::{Duration, Instant};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio::net::{TcpListener, TcpStream};
32use tokio::sync::{Barrier, Semaphore, oneshot};
33use tokio::time::sleep;
34use zerust::datapack::DataPack;
35use zerust::{DefaultRouter, Response, Server};
36
37#[tokio::main]
38async fn main() -> Result<(), Box<dyn std::error::Error>> {
39    let args: Vec<String> = env::args().collect();
40
41    match args.get(1).map(|s| s.as_str()) {
42        Some("server") => run_server().await?,
43        Some("client") => {
44            let connections = args
45                .get(2)
46                .and_then(|s| s.parse::<usize>().ok())
47                .unwrap_or(100);
48            let requests_per_conn = args
49                .get(3)
50                .and_then(|s| s.parse::<usize>().ok())
51                .unwrap_or(1000);
52            run_client(connections, requests_per_conn).await?
53        }
54        _ => {
55            println!(
56                "用法: cargo run --release --example benchmark_server -- [server|client] [连接数] [每连接请求数]"
57            );
58            println!("  server          - 启动基准测试服务器");
59            println!("  client [连接数] [每连接请求数] - 启动客户端测试");
60        }
61    }
62
63    Ok(())
64}
65
66/// 运行基准测试服务器
67async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
68    // 创建关闭通道
69    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
70
71    // 创建路由器并注册回显处理函数
72    let router = Arc::new(DefaultRouter::new());
73    let router_clone = router.clone();
74
75    // 计数器,用于统计处理的请求数
76    let request_counter = Arc::new(AtomicUsize::new(0));
77    let counter_clone = request_counter.clone();
78
79    // 注册高性能回显处理函数 - 不打印日志,直接返回
80    router_clone.add_route(1, move |req| {
81        counter_clone.fetch_add(1, Ordering::Relaxed);
82        Response::new(req.msg_id(), req.data().to_vec())
83    });
84
85    // 启动服务器
86    let server_addr = "127.0.0.1:8888";
87    let server = Server::new(server_addr, router);
88    println!("[Server] 基准测试服务器启动在 {}", server_addr);
89
90    // 启动统计任务
91    let stats_handle = tokio::spawn(async move {
92        let mut last_count = 0;
93        let mut last_time = Instant::now();
94
95        loop {
96            sleep(Duration::from_secs(1)).await;
97            let current_count = request_counter.load(Ordering::Relaxed);
98            let current_time = Instant::now();
99            let elapsed = current_time.duration_since(last_time).as_secs_f64();
100
101            let rps = (current_count - last_count) as f64 / elapsed;
102            println!(
103                "[Stats] 当前RPS: {:.2} req/s, 总请求数: {}",
104                rps, current_count
105            );
106
107            last_count = current_count;
108            last_time = current_time;
109        }
110    });
111
112    // 启动服务器并等待Ctrl+C信号
113    let server_handle = tokio::spawn(async move {
114        if let Err(e) = server.run(shutdown_rx).await {
115            eprintln!("[Server] 运行时错误: {}", e);
116        }
117    });
118
119    println!("[Server] 按 Ctrl+C 停止服务器...");
120    tokio::signal::ctrl_c().await?;
121    println!("[Server] 接收到停止信号,正在关闭...");
122
123    // 发送关闭信号
124    let _ = shutdown_tx.send(());
125
126    // 等待服务器和统计任务完成
127    let _ = server_handle.await;
128    stats_handle.abort();
129
130    println!("[Server] 服务器已关闭");
131    Ok(())
132}
133
134/// 运行客户端基准测试
135async fn run_client(
136    connections: usize,
137    requests_per_conn: usize,
138) -> Result<(), Box<dyn std::error::Error>> {
139    println!(
140        "[Client] 开始基准测试: {} 并发连接, 每连接 {} 请求",
141        connections, requests_per_conn
142    );
143
144    // 创建信号量限制并发连接数
145    let semaphore = Arc::new(Semaphore::new(connections));
146
147    // 创建同步屏障,确保所有连接同时开始发送请求
148    let barrier = Arc::new(Barrier::new(connections + 1)); // +1 for main thread
149
150    // 统计数据
151    let total_requests = connections * requests_per_conn;
152    let completed_requests = Arc::new(AtomicUsize::new(0));
153    let total_latency = Arc::new(AtomicUsize::new(0)); // 以微秒为单位
154
155    // 启动客户端连接
156    let mut handles = Vec::with_capacity(connections);
157
158    let start_time = Instant::now();
159
160    for i in 0..connections {
161        let semaphore_clone = semaphore.clone();
162        let barrier_clone = barrier.clone();
163        let completed_clone = completed_requests.clone();
164        let latency_clone = total_latency.clone();
165
166        let handle = tokio::spawn(async move {
167            // 获取信号量许可
168            let _permit = semaphore_clone.acquire().await.unwrap();
169
170            // 连接到服务器
171            let mut stream = match TcpStream::connect("127.0.0.1:8888").await {
172                Ok(stream) => stream,
173                Err(e) => {
174                    eprintln!("[Client {}] 连接失败: {}", i, e);
175                    return;
176                }
177            };
178
179            // 等待所有连接就绪
180            barrier_clone.wait().await;
181
182            // 发送请求并测量延迟
183            for _ in 0..requests_per_conn {
184                // 准备请求数据 - 使用随机大小的负载
185                let payload = vec![b'A'; 64]; // 固定64字节负载
186                let request = DataPack::pack(1, &payload);
187
188                let request_start = Instant::now();
189
190                // 发送请求
191                if let Err(e) = stream.write_all(&request).await {
192                    eprintln!("[Client {}] 发送请求失败: {}", i, e);
193                    break;
194                }
195
196                // 读取响应头
197                let mut header = [0u8; 8];
198                if let Err(e) = stream.read_exact(&mut header).await {
199                    eprintln!("[Client {}] 读取响应头失败: {}", i, e);
200                    break;
201                }
202
203                let (msg_id, data_len) = match DataPack::unpack_header(&header) {
204                    Ok(result) => result,
205                    Err(e) => {
206                        eprintln!("[Client {}] 解析响应头失败: {}", i, e);
207                        break;
208                    }
209                };
210
211                // 读取响应数据
212                let mut data = vec![0u8; data_len as usize];
213                if let Err(e) = stream.read_exact(&mut data).await {
214                    eprintln!("[Client {}] 读取响应数据失败: {}", i, e);
215                    break;
216                }
217
218                // 计算延迟(微秒)
219                let latency = request_start.elapsed().as_micros() as usize;
220                latency_clone.fetch_add(latency, Ordering::Relaxed);
221
222                // 增加完成请求计数
223                completed_clone.fetch_add(1, Ordering::Relaxed);
224            }
225        });
226
227        handles.push(handle);
228    }
229
230    // 启动进度报告任务
231    let progress_completed = completed_requests.clone();
232    let progress_handle = tokio::spawn(async move {
233        loop {
234            sleep(Duration::from_secs(1)).await;
235            let completed = progress_completed.load(Ordering::Relaxed);
236            let progress = (completed as f64 / total_requests as f64) * 100.0;
237            println!(
238                "[Progress] {:.2}% ({}/{})",
239                progress, completed, total_requests
240            );
241
242            if completed >= total_requests {
243                break;
244            }
245        }
246    });
247
248    // 所有连接已建立,开始测试
249    println!("[Client] 所有连接已就绪,开始测试...");
250    barrier.wait().await;
251
252    // 等待所有客户端完成
253    for handle in handles {
254        let _ = handle.await;
255    }
256
257    // 停止进度报告
258    progress_handle.abort();
259
260    // 计算结果
261    let elapsed = start_time.elapsed();
262    let completed = completed_requests.load(Ordering::Relaxed);
263    let avg_latency = if completed > 0 {
264        total_latency.load(Ordering::Relaxed) as f64 / completed as f64
265    } else {
266        0.0
267    };
268
269    // 打印结果
270    println!("\n===== 基准测试结果 =====");
271    println!("总连接数: {}", connections);
272    println!("每连接请求数: {}", requests_per_conn);
273    println!("总请求数: {}", total_requests);
274    println!("完成请求数: {}", completed);
275    println!("总耗时: {:.2} 秒", elapsed.as_secs_f64());
276    println!("平均延迟: {:.2} 微秒", avg_latency);
277    println!(
278        "吞吐量: {:.2} 请求/秒",
279        completed as f64 / elapsed.as_secs_f64()
280    );
281
282    Ok(())
283}