sentinel_proxy/inference/
tokens.rs1use http::HeaderMap;
7use sentinel_config::TokenEstimation;
8use tracing::{debug, trace};
9
10use super::providers::InferenceProviderAdapter;
11
12#[derive(Debug, Clone)]
14pub struct TokenEstimate {
15 pub tokens: u64,
17 pub source: TokenSource,
19 pub model: Option<String>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TokenSource {
26 Header,
28 Body,
30 Estimated,
32}
33
34pub struct TokenCounter {
36 provider: Box<dyn InferenceProviderAdapter>,
37 estimation_method: TokenEstimation,
38}
39
40impl TokenCounter {
41 pub fn new(provider: Box<dyn InferenceProviderAdapter>, estimation_method: TokenEstimation) -> Self {
43 Self {
44 provider,
45 estimation_method,
46 }
47 }
48
49 pub fn estimate_request(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
51 let model = self.provider.extract_model(headers, body);
53
54 let tokens = self.provider.estimate_request_tokens(body, self.estimation_method);
56
57 trace!(
58 provider = self.provider.name(),
59 tokens = tokens,
60 model = ?model,
61 method = ?self.estimation_method,
62 "Estimated request tokens"
63 );
64
65 TokenEstimate {
66 tokens,
67 source: TokenSource::Estimated,
68 model,
69 }
70 }
71
72 pub fn tokens_from_response(&self, headers: &HeaderMap, body: &[u8]) -> TokenEstimate {
76 if let Some(tokens) = self.provider.tokens_from_headers(headers) {
78 debug!(
79 provider = self.provider.name(),
80 tokens = tokens,
81 source = "header",
82 "Got actual token count from response headers"
83 );
84 return TokenEstimate {
85 tokens,
86 source: TokenSource::Header,
87 model: None,
88 };
89 }
90
91 if let Some(tokens) = self.provider.tokens_from_body(body) {
93 debug!(
94 provider = self.provider.name(),
95 tokens = tokens,
96 source = "body",
97 "Got actual token count from response body"
98 );
99 return TokenEstimate {
100 tokens,
101 source: TokenSource::Body,
102 model: None,
103 };
104 }
105
106 trace!(
108 provider = self.provider.name(),
109 "Could not extract actual token count from response"
110 );
111 TokenEstimate {
112 tokens: 0,
113 source: TokenSource::Estimated,
114 model: None,
115 }
116 }
117
118 pub fn provider_name(&self) -> &'static str {
120 self.provider.name()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::inference::providers::create_provider;
128 use sentinel_config::InferenceProvider;
129
130 #[test]
131 fn test_request_estimation() {
132 let provider = create_provider(&InferenceProvider::OpenAi);
133 let counter = TokenCounter::new(provider, TokenEstimation::Chars);
134
135 let body = br#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello world"}]}"#;
136 let headers = HeaderMap::new();
137
138 let estimate = counter.estimate_request(&headers, body);
139 assert!(estimate.tokens > 0);
140 assert_eq!(estimate.source, TokenSource::Estimated);
141 assert_eq!(estimate.model, Some("gpt-4".to_string()));
142 }
143
144 #[test]
145 fn test_response_parsing() {
146 let provider = create_provider(&InferenceProvider::OpenAi);
147 let counter = TokenCounter::new(provider, TokenEstimation::Chars);
148
149 let body = br#"{"usage": {"total_tokens": 150}}"#;
150 let headers = HeaderMap::new();
151
152 let estimate = counter.tokens_from_response(&headers, body);
153 assert_eq!(estimate.tokens, 150);
154 assert_eq!(estimate.source, TokenSource::Body);
155 }
156}