Skip to main content

wax_core/
stats.rs

1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GenerateStats {
5    pub model: String,
6    pub device: String,
7    pub dtype: String,
8    pub prompt_tokens: usize,
9    pub generated_tokens: usize,
10    pub prefill_ms: f64,
11    pub ttft_ms: Option<f64>,
12    pub decode_tok_s: Option<f64>,
13    pub total_ms: f64,
14    pub stop_reason: StopReason,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
18#[serde(rename_all = "snake_case")]
19pub enum StopReason {
20    Eos,
21    MaxTokens,
22}
23
24#[derive(Debug, Clone, Serialize)]
25pub struct BenchStats {
26    pub model: String,
27    pub device: String,
28    pub dtype: String,
29    pub prompt_tokens: usize,
30    pub generated_tokens: usize,
31    pub model_load_ms: f64,
32    pub prefill_ms: f64,
33    pub prefill_tok_s: Option<f64>,
34    pub ttft_ms: Option<f64>,
35    pub decode_tok_s: Option<f64>,
36    pub total_generation_ms: f64,
37    pub peak_memory_mb: Option<u64>,
38    pub candle_version: &'static str,
39    pub rust_version: String,
40    pub git_commit: Option<String>,
41}
42
43pub const CANDLE_VERSION: &str = "0.10.2";
44
45#[cfg(test)]
46mod tests {
47    use super::{BenchStats, GenerateStats, StopReason};
48
49    #[test]
50    fn generate_stats_serialize_stop_reason_as_snake_case() {
51        let json = serde_json::to_value(GenerateStats {
52            model: "local".to_string(),
53            device: "cpu".to_string(),
54            dtype: "f32".to_string(),
55            prompt_tokens: 2,
56            generated_tokens: 3,
57            prefill_ms: 1.0,
58            ttft_ms: Some(2.0),
59            decode_tok_s: Some(3.0),
60            total_ms: 4.0,
61            stop_reason: StopReason::MaxTokens,
62        })
63        .unwrap();
64
65        assert_eq!(json["stop_reason"], "max_tokens");
66    }
67
68    #[test]
69    fn bench_stats_allow_unknown_peak_memory_and_git_commit() {
70        let json = serde_json::to_value(BenchStats {
71            model: "local".to_string(),
72            device: "metal:0".to_string(),
73            dtype: "f16".to_string(),
74            prompt_tokens: 9,
75            generated_tokens: 4,
76            model_load_ms: 10.0,
77            prefill_ms: 2.0,
78            prefill_tok_s: Some(4.5),
79            ttft_ms: None,
80            decode_tok_s: None,
81            total_generation_ms: 12.0,
82            peak_memory_mb: None,
83            candle_version: "0.10.2",
84            rust_version: "rustc test".to_string(),
85            git_commit: None,
86        })
87        .unwrap();
88
89        assert!(json["peak_memory_mb"].is_null());
90        assert!(json["git_commit"].is_null());
91    }
92}