sentinel_proxy/inference/
providers.rs1use http::HeaderMap;
9use serde_json::Value;
10use sentinel_config::{InferenceProvider, TokenEstimation};
11use tracing::trace;
12
13use super::tiktoken::tiktoken_manager;
14
15pub trait InferenceProviderAdapter: Send + Sync {
17 fn name(&self) -> &'static str;
19
20 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64>;
22
23 fn tokens_from_body(&self, body: &[u8]) -> Option<u64>;
25
26 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64;
28
29 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String>;
31}
32
33pub fn create_provider(provider: &InferenceProvider) -> Box<dyn InferenceProviderAdapter> {
35 match provider {
36 InferenceProvider::OpenAi => Box::new(OpenAiProvider),
37 InferenceProvider::Anthropic => Box::new(AnthropicProvider),
38 InferenceProvider::Generic => Box::new(GenericProvider),
39 }
40}
41
42struct OpenAiProvider;
47
48impl InferenceProviderAdapter for OpenAiProvider {
49 fn name(&self) -> &'static str {
50 "openai"
51 }
52
53 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
54 if let Some(value) = headers.get("x-ratelimit-used-tokens") {
59 if let Ok(s) = value.to_str() {
60 if let Ok(n) = s.parse::<u64>() {
61 trace!(tokens = n, "Got token count from OpenAI x-ratelimit-used-tokens");
62 return Some(n);
63 }
64 }
65 }
66
67 let limit = headers
69 .get("x-ratelimit-limit-tokens")
70 .and_then(|v| v.to_str().ok())
71 .and_then(|s| s.parse::<u64>().ok());
72
73 let remaining = headers
74 .get("x-ratelimit-remaining-tokens")
75 .and_then(|v| v.to_str().ok())
76 .and_then(|s| s.parse::<u64>().ok());
77
78 if let (Some(l), Some(r)) = (limit, remaining) {
79 let used = l.saturating_sub(r);
80 trace!(limit = l, remaining = r, used = used, "Calculated token usage from OpenAI headers");
81 return Some(used);
82 }
83
84 None
85 }
86
87 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
88 let json: Value = serde_json::from_slice(body).ok()?;
91 let total = json.get("usage")?.get("total_tokens")?.as_u64();
92 if let Some(t) = total {
93 trace!(tokens = t, "Got token count from OpenAI response body");
94 }
95 total
96 }
97
98 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
99 estimate_tokens(body, method)
100 }
101
102 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
103 if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
105 return Some(model.to_string());
106 }
107
108 let json: Value = serde_json::from_slice(body).ok()?;
110 json.get("model")?.as_str().map(|s| s.to_string())
111 }
112}
113
114struct AnthropicProvider;
119
120impl InferenceProviderAdapter for AnthropicProvider {
121 fn name(&self) -> &'static str {
122 "anthropic"
123 }
124
125 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
126 let limit = headers
131 .get("anthropic-ratelimit-tokens-limit")
132 .and_then(|v| v.to_str().ok())
133 .and_then(|s| s.parse::<u64>().ok());
134
135 let remaining = headers
136 .get("anthropic-ratelimit-tokens-remaining")
137 .and_then(|v| v.to_str().ok())
138 .and_then(|s| s.parse::<u64>().ok());
139
140 if let (Some(l), Some(r)) = (limit, remaining) {
141 let used = l.saturating_sub(r);
142 trace!(limit = l, remaining = r, used = used, "Calculated token usage from Anthropic headers");
143 return Some(used);
144 }
145
146 None
147 }
148
149 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
150 let json: Value = serde_json::from_slice(body).ok()?;
153 let usage = json.get("usage")?;
154
155 let input = usage.get("input_tokens")?.as_u64().unwrap_or(0);
156 let output = usage.get("output_tokens")?.as_u64().unwrap_or(0);
157 let total = input + output;
158
159 trace!(input = input, output = output, total = total, "Got token count from Anthropic response body");
160 Some(total)
161 }
162
163 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
164 estimate_tokens(body, method)
165 }
166
167 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
168 if let Some(model) = headers.get("x-model").and_then(|v| v.to_str().ok()) {
170 return Some(model.to_string());
171 }
172
173 let json: Value = serde_json::from_slice(body).ok()?;
175 json.get("model")?.as_str().map(|s| s.to_string())
176 }
177}
178
179struct GenericProvider;
184
185impl InferenceProviderAdapter for GenericProvider {
186 fn name(&self) -> &'static str {
187 "generic"
188 }
189
190 fn tokens_from_headers(&self, headers: &HeaderMap) -> Option<u64> {
191 let candidates = [
193 "x-tokens-used",
194 "x-token-count",
195 "x-total-tokens",
196 ];
197
198 for header in candidates {
199 if let Some(value) = headers.get(header) {
200 if let Ok(s) = value.to_str() {
201 if let Ok(n) = s.parse::<u64>() {
202 trace!(header = header, tokens = n, "Got token count from generic header");
203 return Some(n);
204 }
205 }
206 }
207 }
208
209 None
210 }
211
212 fn tokens_from_body(&self, body: &[u8]) -> Option<u64> {
213 let json: Value = serde_json::from_slice(body).ok()?;
215
216 if let Some(total) = json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|t| t.as_u64()) {
218 return Some(total);
219 }
220
221 if let Some(usage) = json.get("usage") {
223 let input = usage.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
224 let output = usage.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
225 if input > 0 || output > 0 {
226 return Some(input + output);
227 }
228 }
229
230 None
231 }
232
233 fn estimate_request_tokens(&self, body: &[u8], method: TokenEstimation) -> u64 {
234 estimate_tokens(body, method)
235 }
236
237 fn extract_model(&self, headers: &HeaderMap, body: &[u8]) -> Option<String> {
238 let candidates = ["x-model", "x-model-id", "model"];
240 for header in candidates {
241 if let Some(model) = headers.get(header).and_then(|v| v.to_str().ok()) {
242 return Some(model.to_string());
243 }
244 }
245
246 let json: Value = serde_json::from_slice(body).ok()?;
248 json.get("model")?.as_str().map(|s| s.to_string())
249 }
250}
251
252fn estimate_tokens(body: &[u8], method: TokenEstimation) -> u64 {
258 estimate_tokens_with_model(body, method, None)
259}
260
261fn estimate_tokens_with_model(body: &[u8], method: TokenEstimation, model: Option<&str>) -> u64 {
263 match method {
264 TokenEstimation::Chars => {
265 let char_count = String::from_utf8_lossy(body).chars().count();
267 (char_count / 4).max(1) as u64
268 }
269 TokenEstimation::Words => {
270 let text = String::from_utf8_lossy(body);
272 let word_count = text.split_whitespace().count();
273 ((word_count as f64 * 1.3).ceil() as u64).max(1)
274 }
275 TokenEstimation::Tiktoken => {
276 estimate_tokens_tiktoken(body, model)
277 }
278 }
279}
280
281fn estimate_tokens_tiktoken(body: &[u8], model: Option<&str>) -> u64 {
288 let manager = tiktoken_manager();
289
290 let tokens = manager.count_chat_request(body, model);
293
294 trace!(
295 token_count = tokens,
296 model = ?model,
297 tiktoken_available = manager.is_available(),
298 "Tiktoken token count"
299 );
300
301 tokens
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_openai_body_parsing() {
310 let body = br#"{"usage": {"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}}"#;
311 let provider = OpenAiProvider;
312 assert_eq!(provider.tokens_from_body(body), Some(150));
313 }
314
315 #[test]
316 fn test_anthropic_body_parsing() {
317 let body = br#"{"usage": {"input_tokens": 100, "output_tokens": 50}}"#;
318 let provider = AnthropicProvider;
319 assert_eq!(provider.tokens_from_body(body), Some(150));
320 }
321
322 #[test]
323 fn test_token_estimation_chars() {
324 let body = b"Hello world, this is a test message for token counting!";
325 let estimate = estimate_tokens(body, TokenEstimation::Chars);
326 assert!(estimate > 0 && estimate < 100);
328 }
329
330 #[test]
331 fn test_model_extraction() {
332 let body = br#"{"model": "gpt-4", "messages": []}"#;
333 let provider = OpenAiProvider;
334 let headers = HeaderMap::new();
335 assert_eq!(provider.extract_model(&headers, body), Some("gpt-4".to_string()));
336 }
337
338 #[test]
339 fn test_token_estimation_tiktoken() {
340 let body = b"Hello world, this is a test message for token counting!";
341 let estimate = estimate_tokens(body, TokenEstimation::Tiktoken);
342 assert!(estimate > 0 && estimate < 100);
344 }
345
346 #[test]
347 #[cfg(feature = "tiktoken")]
348 fn test_tiktoken_accurate_count() {
349 let body = b"Hello world";
351 let estimate = estimate_tokens_tiktoken(body, Some("gpt-4"));
352 assert_eq!(estimate, 2);
353 }
354
355 #[test]
356 fn test_tiktoken_chat_request() {
357 let body = br#"{
358 "model": "gpt-4",
359 "messages": [
360 {"role": "user", "content": "Hello!"}
361 ]
362 }"#;
363 let estimate = estimate_tokens_tiktoken(body, None);
364 assert!(estimate > 0);
366 }
367}