trustformers_core/cache/
cache_key.rs1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::fmt;
4use std::hash::{Hash, Hasher};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct CacheKey {
9 pub input_hash: u64,
11 pub model_id: String,
13 pub task: String,
15 pub params_hash: u64,
17}
18
19impl fmt::Display for CacheKey {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 write!(
22 f,
23 "{}-{}-{}-{}",
24 self.model_id, self.task, self.input_hash, self.params_hash
25 )
26 }
27}
28
29impl CacheKey {
30 pub fn new(input_hash: u64, model_id: String, task: String, params_hash: u64) -> Self {
31 Self {
32 input_hash,
33 model_id,
34 task,
35 params_hash,
36 }
37 }
38}
39
40pub struct CacheKeyBuilder {
42 model_id: String,
43 task: String,
44 input_hasher: DefaultHasher,
45 params_hasher: DefaultHasher,
46}
47
48impl CacheKeyBuilder {
49 pub fn new(model_id: impl Into<String>, task: impl Into<String>) -> Self {
50 Self {
51 model_id: model_id.into(),
52 task: task.into(),
53 input_hasher: DefaultHasher::new(),
54 params_hasher: DefaultHasher::new(),
55 }
56 }
57
58 pub fn with_text(mut self, text: &str) -> Self {
60 text.hash(&mut self.input_hasher);
61 self
62 }
63
64 pub fn with_tokens(mut self, tokens: &[u32]) -> Self {
66 tokens.hash(&mut self.input_hasher);
67 self
68 }
69
70 pub fn with_param<T: Hash>(mut self, name: &str, value: &T) -> Self {
72 name.hash(&mut self.params_hasher);
73 value.hash(&mut self.params_hasher);
74 self
75 }
76
77 pub fn with_generation_params(
79 mut self,
80 max_length: Option<usize>,
81 temperature: Option<f32>,
82 top_p: Option<f32>,
83 top_k: Option<usize>,
84 do_sample: bool,
85 num_beams: Option<usize>,
86 ) -> Self {
87 if let Some(v) = max_length {
88 "max_length".hash(&mut self.params_hasher);
89 v.hash(&mut self.params_hasher);
90 }
91 if let Some(v) = temperature {
92 "temperature".hash(&mut self.params_hasher);
93 v.to_bits().hash(&mut self.params_hasher);
94 }
95 if let Some(v) = top_p {
96 "top_p".hash(&mut self.params_hasher);
97 v.to_bits().hash(&mut self.params_hasher);
98 }
99 if let Some(v) = top_k {
100 "top_k".hash(&mut self.params_hasher);
101 v.hash(&mut self.params_hasher);
102 }
103 "do_sample".hash(&mut self.params_hasher);
104 do_sample.hash(&mut self.params_hasher);
105 if let Some(v) = num_beams {
106 "num_beams".hash(&mut self.params_hasher);
107 v.hash(&mut self.params_hasher);
108 }
109 self
110 }
111
112 pub fn build(self) -> CacheKey {
114 CacheKey::new(
115 self.input_hasher.finish(),
116 self.model_id,
117 self.task,
118 self.params_hasher.finish(),
119 )
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_cache_key_builder() {
129 let key1 = CacheKeyBuilder::new("bert-base", "text-classification")
130 .with_text("Hello world")
131 .with_param("max_length", &512)
132 .build();
133
134 let key2 = CacheKeyBuilder::new("bert-base", "text-classification")
135 .with_text("Hello world")
136 .with_param("max_length", &512)
137 .build();
138
139 let key3 = CacheKeyBuilder::new("bert-base", "text-classification")
140 .with_text("Hello world!")
141 .with_param("max_length", &512)
142 .build();
143
144 assert_eq!(key1, key2);
145 assert_ne!(key1, key3);
146 }
147
148 #[test]
149 fn test_generation_params_hashing() {
150 let key1 = CacheKeyBuilder::new("gpt2", "text-generation")
151 .with_text("Once upon a time")
152 .with_generation_params(Some(100), Some(0.8), Some(0.9), Some(50), true, Some(4))
153 .build();
154
155 let key2 = CacheKeyBuilder::new("gpt2", "text-generation")
156 .with_text("Once upon a time")
157 .with_generation_params(Some(100), Some(0.8), Some(0.9), Some(50), true, Some(4))
158 .build();
159
160 let key3 = CacheKeyBuilder::new("gpt2", "text-generation")
161 .with_text("Once upon a time")
162 .with_generation_params(Some(100), Some(0.9), Some(0.9), Some(50), true, Some(4))
163 .build();
164
165 assert_eq!(key1, key2);
166 assert_ne!(key1, key3);
167 }
168}