1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GenerateStats {
5 pub model: String,
6 pub device: String,
7 pub dtype: String,
8 pub prompt_tokens: usize,
9 pub generated_tokens: usize,
10 pub prefill_ms: f64,
11 pub ttft_ms: Option<f64>,
12 pub decode_tok_s: Option<f64>,
13 pub total_ms: f64,
14 pub stop_reason: StopReason,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
18#[serde(rename_all = "snake_case")]
19pub enum StopReason {
20 Eos,
21 MaxTokens,
22}
23
24#[derive(Debug, Clone, Serialize)]
25pub struct BenchStats {
26 pub model: String,
27 pub device: String,
28 pub dtype: String,
29 pub prompt_tokens: usize,
30 pub generated_tokens: usize,
31 pub model_load_ms: f64,
32 pub prefill_ms: f64,
33 pub prefill_tok_s: Option<f64>,
34 pub ttft_ms: Option<f64>,
35 pub decode_tok_s: Option<f64>,
36 pub total_generation_ms: f64,
37 pub peak_memory_mb: Option<u64>,
38 pub candle_version: &'static str,
39 pub rust_version: String,
40 pub git_commit: Option<String>,
41}
42
43pub const CANDLE_VERSION: &str = "0.10.2";
44
45#[cfg(test)]
46mod tests {
47 use super::{BenchStats, GenerateStats, StopReason};
48
49 #[test]
50 fn generate_stats_serialize_stop_reason_as_snake_case() {
51 let json = serde_json::to_value(GenerateStats {
52 model: "local".to_string(),
53 device: "cpu".to_string(),
54 dtype: "f32".to_string(),
55 prompt_tokens: 2,
56 generated_tokens: 3,
57 prefill_ms: 1.0,
58 ttft_ms: Some(2.0),
59 decode_tok_s: Some(3.0),
60 total_ms: 4.0,
61 stop_reason: StopReason::MaxTokens,
62 })
63 .unwrap();
64
65 assert_eq!(json["stop_reason"], "max_tokens");
66 }
67
68 #[test]
69 fn bench_stats_allow_unknown_peak_memory_and_git_commit() {
70 let json = serde_json::to_value(BenchStats {
71 model: "local".to_string(),
72 device: "metal:0".to_string(),
73 dtype: "f16".to_string(),
74 prompt_tokens: 9,
75 generated_tokens: 4,
76 model_load_ms: 10.0,
77 prefill_ms: 2.0,
78 prefill_tok_s: Some(4.5),
79 ttft_ms: None,
80 decode_tok_s: None,
81 total_generation_ms: 12.0,
82 peak_memory_mb: None,
83 candle_version: "0.10.2",
84 rust_version: "rustc test".to_string(),
85 git_commit: None,
86 })
87 .unwrap();
88
89 assert!(json["peak_memory_mb"].is_null());
90 assert!(json["git_commit"].is_null());
91 }
92}