xai_rust/api/
tokenizer.rs1use crate::client::XaiClient;
4use crate::models::tokenizer::{TokenizeRequest, TokenizeResponse};
5use crate::{Error, Result};
6
7#[derive(Debug, Clone)]
9pub struct TokenizerApi {
10 client: XaiClient,
11}
12
13impl TokenizerApi {
14 pub(crate) fn new(client: XaiClient) -> Self {
15 Self { client }
16 }
17
18 pub async fn tokenize(&self, request: TokenizeRequest) -> Result<TokenizeResponse> {
37 let url = format!("{}/tokenize-text", self.client.base_url());
38
39 let response = self
40 .client
41 .send(self.client.http().post(&url).json(&request))
42 .await?;
43
44 if !response.status().is_success() {
45 return Err(Error::from_response(response).await);
46 }
47
48 Ok(response.json().await?)
49 }
50
51 pub async fn tokenize_text(
70 &self,
71 model: impl Into<String>,
72 text: impl Into<String>,
73 ) -> Result<TokenizeResponse> {
74 self.tokenize(TokenizeRequest::new(model, text)).await
75 }
76
77 pub async fn count_tokens(
96 &self,
97 model: impl Into<String>,
98 text: impl Into<String>,
99 ) -> Result<usize> {
100 let response = self.tokenize_text(model, text).await?;
101 Ok(response.count())
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use serde_json::json;
109 use wiremock::matchers::{method, path};
110 use wiremock::{Mock, MockServer, ResponseTemplate};
111
112 #[tokio::test]
113 async fn tokenize_forwards_model_and_text_payload() {
114 let server = MockServer::start().await;
115
116 Mock::given(method("POST"))
117 .and(path("/tokenize-text"))
118 .respond_with(move |req: &wiremock::Request| {
119 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
120 assert_eq!(body["model"], "grok-4");
121 assert_eq!(body["text"], "Hello tokenizer");
122 ResponseTemplate::new(200).set_body_json(json!({
123 "tokens": [10, 20, 30],
124 "token_count": 3
125 }))
126 })
127 .mount(&server)
128 .await;
129
130 let client = XaiClient::builder()
131 .api_key("test-key")
132 .base_url(server.uri())
133 .build()
134 .unwrap();
135
136 let response = client
137 .tokenizer()
138 .tokenize(TokenizeRequest::new("grok-4", "Hello tokenizer"))
139 .await
140 .unwrap();
141 assert_eq!(response.tokens, vec![10, 20, 30]);
142 assert_eq!(response.count(), 3);
143 }
144
145 #[tokio::test]
146 async fn count_tokens_prefers_explicit_token_count() {
147 let server = MockServer::start().await;
148
149 Mock::given(method("POST"))
150 .and(path("/tokenize-text"))
151 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
152 "tokens": [1, 2],
153 "token_count": 9
154 })))
155 .mount(&server)
156 .await;
157
158 let client = XaiClient::builder()
159 .api_key("test-key")
160 .base_url(server.uri())
161 .build()
162 .unwrap();
163
164 let count = client
165 .tokenizer()
166 .count_tokens("grok-4", "count this")
167 .await
168 .unwrap();
169 assert_eq!(count, 9);
170 }
171
172 #[tokio::test]
173 async fn count_tokens_falls_back_to_token_vector_length() {
174 let server = MockServer::start().await;
175
176 Mock::given(method("POST"))
177 .and(path("/tokenize-text"))
178 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
179 "tokens": [7, 8, 9, 10]
180 })))
181 .mount(&server)
182 .await;
183
184 let client = XaiClient::builder()
185 .api_key("test-key")
186 .base_url(server.uri())
187 .build()
188 .unwrap();
189
190 let count = client
191 .tokenizer()
192 .count_tokens("grok-4", "fallback count")
193 .await
194 .unwrap();
195 assert_eq!(count, 4);
196 }
197
198 #[tokio::test]
199 async fn tokenize_returns_api_error_for_non_success_response() {
200 let server = MockServer::start().await;
201
202 Mock::given(method("POST"))
203 .and(path("/tokenize-text"))
204 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
205 "error": {
206 "message": "bad tokenize request",
207 "type": "invalid_request_error"
208 }
209 })))
210 .mount(&server)
211 .await;
212
213 let client = XaiClient::builder()
214 .api_key("test-key")
215 .base_url(server.uri())
216 .build()
217 .unwrap();
218
219 let err = client
220 .tokenizer()
221 .tokenize_text("grok-4", "bad input")
222 .await
223 .unwrap_err();
224 match err {
225 Error::Api {
226 status,
227 message,
228 error_type,
229 } => {
230 assert_eq!(status, 400);
231 assert_eq!(message, "bad tokenize request");
232 assert_eq!(error_type.as_deref(), Some("invalid_request_error"));
233 }
234 other => panic!("expected api error, got {other:?}"),
235 }
236 }
237}