Skip to main content

trustformers_core/cache/
cache_key.rs

1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::fmt;
4use std::hash::{Hash, Hasher};
5
6/// A cache key for inference requests
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct CacheKey {
9    /// Hash of the input text or tokens
10    pub input_hash: u64,
11    /// Model identifier
12    pub model_id: String,
13    /// Task type (e.g., "text-classification", "text-generation")
14    pub task: String,
15    /// Additional parameters that affect the output
16    pub params_hash: u64,
17}
18
19impl fmt::Display for CacheKey {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        write!(
22            f,
23            "{}-{}-{}-{}",
24            self.model_id, self.task, self.input_hash, self.params_hash
25        )
26    }
27}
28
29impl CacheKey {
30    pub fn new(input_hash: u64, model_id: String, task: String, params_hash: u64) -> Self {
31        Self {
32            input_hash,
33            model_id,
34            task,
35            params_hash,
36        }
37    }
38}
39
40/// Builder for creating cache keys with proper hashing
41pub struct CacheKeyBuilder {
42    model_id: String,
43    task: String,
44    input_hasher: DefaultHasher,
45    params_hasher: DefaultHasher,
46}
47
48impl CacheKeyBuilder {
49    pub fn new(model_id: impl Into<String>, task: impl Into<String>) -> Self {
50        Self {
51            model_id: model_id.into(),
52            task: task.into(),
53            input_hasher: DefaultHasher::new(),
54            params_hasher: DefaultHasher::new(),
55        }
56    }
57
58    /// Add text input to the key
59    pub fn with_text(mut self, text: &str) -> Self {
60        text.hash(&mut self.input_hasher);
61        self
62    }
63
64    /// Add tokenized input to the key
65    pub fn with_tokens(mut self, tokens: &[u32]) -> Self {
66        tokens.hash(&mut self.input_hasher);
67        self
68    }
69
70    /// Add a parameter that affects output
71    pub fn with_param<T: Hash>(mut self, name: &str, value: &T) -> Self {
72        name.hash(&mut self.params_hasher);
73        value.hash(&mut self.params_hasher);
74        self
75    }
76
77    /// Add generation-specific parameters
78    pub fn with_generation_params(
79        mut self,
80        max_length: Option<usize>,
81        temperature: Option<f32>,
82        top_p: Option<f32>,
83        top_k: Option<usize>,
84        do_sample: bool,
85        num_beams: Option<usize>,
86    ) -> Self {
87        if let Some(v) = max_length {
88            "max_length".hash(&mut self.params_hasher);
89            v.hash(&mut self.params_hasher);
90        }
91        if let Some(v) = temperature {
92            "temperature".hash(&mut self.params_hasher);
93            v.to_bits().hash(&mut self.params_hasher);
94        }
95        if let Some(v) = top_p {
96            "top_p".hash(&mut self.params_hasher);
97            v.to_bits().hash(&mut self.params_hasher);
98        }
99        if let Some(v) = top_k {
100            "top_k".hash(&mut self.params_hasher);
101            v.hash(&mut self.params_hasher);
102        }
103        "do_sample".hash(&mut self.params_hasher);
104        do_sample.hash(&mut self.params_hasher);
105        if let Some(v) = num_beams {
106            "num_beams".hash(&mut self.params_hasher);
107            v.hash(&mut self.params_hasher);
108        }
109        self
110    }
111
112    /// Build the final cache key
113    pub fn build(self) -> CacheKey {
114        CacheKey::new(
115            self.input_hasher.finish(),
116            self.model_id,
117            self.task,
118            self.params_hasher.finish(),
119        )
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_cache_key_builder() {
129        let key1 = CacheKeyBuilder::new("bert-base", "text-classification")
130            .with_text("Hello world")
131            .with_param("max_length", &512)
132            .build();
133
134        let key2 = CacheKeyBuilder::new("bert-base", "text-classification")
135            .with_text("Hello world")
136            .with_param("max_length", &512)
137            .build();
138
139        let key3 = CacheKeyBuilder::new("bert-base", "text-classification")
140            .with_text("Hello world!")
141            .with_param("max_length", &512)
142            .build();
143
144        assert_eq!(key1, key2);
145        assert_ne!(key1, key3);
146    }
147
148    #[test]
149    fn test_generation_params_hashing() {
150        let key1 = CacheKeyBuilder::new("gpt2", "text-generation")
151            .with_text("Once upon a time")
152            .with_generation_params(Some(100), Some(0.8), Some(0.9), Some(50), true, Some(4))
153            .build();
154
155        let key2 = CacheKeyBuilder::new("gpt2", "text-generation")
156            .with_text("Once upon a time")
157            .with_generation_params(Some(100), Some(0.8), Some(0.9), Some(50), true, Some(4))
158            .build();
159
160        let key3 = CacheKeyBuilder::new("gpt2", "text-generation")
161            .with_text("Once upon a time")
162            .with_generation_params(Some(100), Some(0.9), Some(0.9), Some(50), true, Some(4))
163            .build();
164
165        assert_eq!(key1, key2);
166        assert_ne!(key1, key3);
167    }
168}