1use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::error::XApiError;
11
12use super::types::{
13 MentionResponse, PostTweetRequest, PostTweetResponse, PostedTweet, RateLimitInfo, ReplyTo,
14 SearchResponse, SingleTweetResponse, Tweet, User, UserResponse, XApiErrorResponse,
15};
16use super::XApiClient;
17
18const DEFAULT_BASE_URL: &str = "https://api.x.com/2";
20
21const TWEET_FIELDS: &str = "public_metrics,created_at,author_id,conversation_id";
23
24const EXPANSIONS: &str = "author_id";
26
27const USER_FIELDS: &str = "username,public_metrics";
29
30pub struct XApiHttpClient {
36 client: reqwest::Client,
37 base_url: String,
38 access_token: Arc<RwLock<String>>,
39}
40
41impl XApiHttpClient {
42 pub fn new(access_token: String) -> Self {
44 Self {
45 client: reqwest::Client::new(),
46 base_url: DEFAULT_BASE_URL.to_string(),
47 access_token: Arc::new(RwLock::new(access_token)),
48 }
49 }
50
51 pub fn with_base_url(access_token: String, base_url: String) -> Self {
53 Self {
54 client: reqwest::Client::new(),
55 base_url,
56 access_token: Arc::new(RwLock::new(access_token)),
57 }
58 }
59
60 pub fn access_token_lock(&self) -> Arc<RwLock<String>> {
62 self.access_token.clone()
63 }
64
65 pub async fn set_access_token(&self, token: String) {
67 let mut lock = self.access_token.write().await;
68 *lock = token;
69 }
70
71 fn parse_rate_limit_headers(headers: &reqwest::header::HeaderMap) -> RateLimitInfo {
73 let remaining = headers
74 .get("x-rate-limit-remaining")
75 .and_then(|v| v.to_str().ok())
76 .and_then(|v| v.parse::<u64>().ok());
77
78 let reset_at = headers
79 .get("x-rate-limit-reset")
80 .and_then(|v| v.to_str().ok())
81 .and_then(|v| v.parse::<u64>().ok());
82
83 RateLimitInfo {
84 remaining,
85 reset_at,
86 }
87 }
88
89 async fn map_error_response(response: reqwest::Response) -> XApiError {
91 let status = response.status().as_u16();
92 let rate_info = Self::parse_rate_limit_headers(response.headers());
93
94 let body = response.text().await.unwrap_or_default();
95 let error_detail = serde_json::from_str::<XApiErrorResponse>(&body).ok();
96
97 let message = error_detail
98 .as_ref()
99 .and_then(|e| e.detail.clone())
100 .unwrap_or_else(|| body.clone());
101
102 match status {
103 429 => {
104 let retry_after = rate_info.reset_at.and_then(|reset| {
105 let now = chrono::Utc::now().timestamp() as u64;
106 reset.checked_sub(now)
107 });
108 XApiError::RateLimited { retry_after }
109 }
110 401 => XApiError::AuthExpired,
111 403 => XApiError::Forbidden { message },
112 _ => XApiError::ApiError { status, message },
113 }
114 }
115
116 async fn get(
118 &self,
119 path: &str,
120 query: &[(&str, &str)],
121 ) -> Result<reqwest::Response, XApiError> {
122 let token = self.access_token.read().await;
123 let url = format!("{}{}", self.base_url, path);
124
125 let response = self
126 .client
127 .get(&url)
128 .bearer_auth(&*token)
129 .query(query)
130 .send()
131 .await
132 .map_err(|e| XApiError::Network { source: e })?;
133
134 let rate_info = Self::parse_rate_limit_headers(response.headers());
135 tracing::debug!(
136 path,
137 remaining = ?rate_info.remaining,
138 reset_at = ?rate_info.reset_at,
139 "X API response"
140 );
141
142 if response.status().is_success() {
143 Ok(response)
144 } else {
145 Err(Self::map_error_response(response).await)
146 }
147 }
148
149 async fn post_json<T: serde::Serialize>(
151 &self,
152 path: &str,
153 body: &T,
154 ) -> Result<reqwest::Response, XApiError> {
155 let token = self.access_token.read().await;
156 let url = format!("{}{}", self.base_url, path);
157
158 let response = self
159 .client
160 .post(&url)
161 .bearer_auth(&*token)
162 .json(body)
163 .send()
164 .await
165 .map_err(|e| XApiError::Network { source: e })?;
166
167 let rate_info = Self::parse_rate_limit_headers(response.headers());
168 tracing::debug!(
169 path,
170 remaining = ?rate_info.remaining,
171 reset_at = ?rate_info.reset_at,
172 "X API response"
173 );
174
175 if response.status().is_success() {
176 Ok(response)
177 } else {
178 Err(Self::map_error_response(response).await)
179 }
180 }
181}
182
183#[async_trait::async_trait]
184impl XApiClient for XApiHttpClient {
185 async fn search_tweets(
186 &self,
187 query: &str,
188 max_results: u32,
189 since_id: Option<&str>,
190 ) -> Result<SearchResponse, XApiError> {
191 tracing::debug!(query = %query, max_results = max_results, "Search tweets");
192 let max_str = max_results.to_string();
193 let mut params = vec![
194 ("query", query),
195 ("max_results", &max_str),
196 ("tweet.fields", TWEET_FIELDS),
197 ("expansions", EXPANSIONS),
198 ("user.fields", USER_FIELDS),
199 ];
200
201 let since_id_owned;
202 if let Some(sid) = since_id {
203 since_id_owned = sid.to_string();
204 params.push(("since_id", &since_id_owned));
205 }
206
207 let response = self.get("/tweets/search/recent", ¶ms).await?;
208 let resp: SearchResponse = response
209 .json()
210 .await
211 .map_err(|e| XApiError::Network { source: e })?;
212 tracing::debug!(
213 query = %query,
214 results = resp.data.len(),
215 "Search tweets completed",
216 );
217 Ok(resp)
218 }
219
220 async fn get_mentions(
221 &self,
222 user_id: &str,
223 since_id: Option<&str>,
224 ) -> Result<MentionResponse, XApiError> {
225 let path = format!("/users/{user_id}/mentions");
226 let mut params = vec![
227 ("tweet.fields", TWEET_FIELDS),
228 ("expansions", EXPANSIONS),
229 ("user.fields", USER_FIELDS),
230 ];
231
232 let since_id_owned;
233 if let Some(sid) = since_id {
234 since_id_owned = sid.to_string();
235 params.push(("since_id", &since_id_owned));
236 }
237
238 let response = self.get(&path, ¶ms).await?;
239 response
240 .json::<MentionResponse>()
241 .await
242 .map_err(|e| XApiError::Network { source: e })
243 }
244
245 async fn post_tweet(&self, text: &str) -> Result<PostedTweet, XApiError> {
246 tracing::debug!(chars = text.len(), "Posting tweet");
247 let body = PostTweetRequest {
248 text: text.to_string(),
249 reply: None,
250 };
251
252 let response = self.post_json("/tweets", &body).await?;
253 let resp: PostTweetResponse = response
254 .json()
255 .await
256 .map_err(|e| XApiError::Network { source: e })?;
257 Ok(resp.data)
258 }
259
260 async fn reply_to_tweet(
261 &self,
262 text: &str,
263 in_reply_to_id: &str,
264 ) -> Result<PostedTweet, XApiError> {
265 tracing::debug!(in_reply_to = %in_reply_to_id, chars = text.len(), "Posting reply");
266 let body = PostTweetRequest {
267 text: text.to_string(),
268 reply: Some(ReplyTo {
269 in_reply_to_tweet_id: in_reply_to_id.to_string(),
270 }),
271 };
272
273 let response = self.post_json("/tweets", &body).await?;
274 let resp: PostTweetResponse = response
275 .json()
276 .await
277 .map_err(|e| XApiError::Network { source: e })?;
278 Ok(resp.data)
279 }
280
281 async fn get_tweet(&self, tweet_id: &str) -> Result<Tweet, XApiError> {
282 let path = format!("/tweets/{tweet_id}");
283 let params = [
284 ("tweet.fields", TWEET_FIELDS),
285 ("expansions", EXPANSIONS),
286 ("user.fields", USER_FIELDS),
287 ];
288
289 let response = self.get(&path, ¶ms).await?;
290 let resp: SingleTweetResponse = response
291 .json()
292 .await
293 .map_err(|e| XApiError::Network { source: e })?;
294 Ok(resp.data)
295 }
296
297 async fn get_me(&self) -> Result<User, XApiError> {
298 let params = [("user.fields", USER_FIELDS)];
299
300 let response = self.get("/users/me", ¶ms).await?;
301 let resp: UserResponse = response
302 .json()
303 .await
304 .map_err(|e| XApiError::Network { source: e })?;
305 Ok(resp.data)
306 }
307
308 async fn get_user_tweets(
309 &self,
310 user_id: &str,
311 max_results: u32,
312 ) -> Result<SearchResponse, XApiError> {
313 let path = format!("/users/{user_id}/tweets");
314 let max_str = max_results.to_string();
315 let params = [
316 ("max_results", max_str.as_str()),
317 ("tweet.fields", TWEET_FIELDS),
318 ("expansions", EXPANSIONS),
319 ("user.fields", USER_FIELDS),
320 ];
321
322 let response = self.get(&path, ¶ms).await?;
323 response
324 .json::<SearchResponse>()
325 .await
326 .map_err(|e| XApiError::Network { source: e })
327 }
328
329 async fn follow_user(
330 &self,
331 source_user_id: &str,
332 target_user_id: &str,
333 ) -> Result<(), XApiError> {
334 let path = format!("/users/{source_user_id}/following");
335 let body = serde_json::json!({ "target_user_id": target_user_id });
336
337 self.post_json(&path, &body).await?;
338 Ok(())
339 }
340
341 async fn get_user_by_username(&self, username: &str) -> Result<User, XApiError> {
342 let path = format!("/users/by/username/{username}");
343 let params = [("user.fields", USER_FIELDS)];
344
345 let response = self.get(&path, ¶ms).await?;
346 let resp: UserResponse = response
347 .json()
348 .await
349 .map_err(|e| XApiError::Network { source: e })?;
350 Ok(resp.data)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use wiremock::matchers::{header, method, path, query_param};
358 use wiremock::{Mock, MockServer, ResponseTemplate};
359
360 async fn setup_client(server: &MockServer) -> XApiHttpClient {
361 XApiHttpClient::with_base_url("test-token".to_string(), server.uri())
362 }
363
364 #[tokio::test]
365 async fn search_tweets_success() {
366 let server = MockServer::start().await;
367 let client = setup_client(&server).await;
368
369 Mock::given(method("GET"))
370 .and(path("/tweets/search/recent"))
371 .and(query_param("query", "rust"))
372 .and(query_param("max_results", "10"))
373 .and(header("Authorization", "Bearer test-token"))
374 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
375 "data": [{"id": "1", "text": "Rust is great", "author_id": "a1"}],
376 "meta": {"result_count": 1}
377 })))
378 .mount(&server)
379 .await;
380
381 let result = client.search_tweets("rust", 10, None).await;
382 let resp = result.expect("search");
383 assert_eq!(resp.data.len(), 1);
384 assert_eq!(resp.data[0].text, "Rust is great");
385 }
386
387 #[tokio::test]
388 async fn search_tweets_with_since_id() {
389 let server = MockServer::start().await;
390 let client = setup_client(&server).await;
391
392 Mock::given(method("GET"))
393 .and(path("/tweets/search/recent"))
394 .and(query_param("since_id", "999"))
395 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
396 "data": [],
397 "meta": {"result_count": 0}
398 })))
399 .mount(&server)
400 .await;
401
402 let result = client.search_tweets("test", 10, Some("999")).await;
403 assert!(result.is_ok());
404 }
405
406 #[tokio::test]
407 async fn post_tweet_success() {
408 let server = MockServer::start().await;
409 let client = setup_client(&server).await;
410
411 Mock::given(method("POST"))
412 .and(path("/tweets"))
413 .and(header("Authorization", "Bearer test-token"))
414 .respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({
415 "data": {"id": "new_123", "text": "Hello world"}
416 })))
417 .mount(&server)
418 .await;
419
420 let result = client.post_tweet("Hello world").await;
421 let tweet = result.expect("post");
422 assert_eq!(tweet.id, "new_123");
423 }
424
425 #[tokio::test]
426 async fn reply_to_tweet_success() {
427 let server = MockServer::start().await;
428 let client = setup_client(&server).await;
429
430 Mock::given(method("POST"))
431 .and(path("/tweets"))
432 .respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({
433 "data": {"id": "reply_1", "text": "Nice point!"}
434 })))
435 .mount(&server)
436 .await;
437
438 let result = client.reply_to_tweet("Nice point!", "original_1").await;
439 assert!(result.is_ok());
440 }
441
442 #[tokio::test]
443 async fn get_me_success() {
444 let server = MockServer::start().await;
445 let client = setup_client(&server).await;
446
447 Mock::given(method("GET"))
448 .and(path("/users/me"))
449 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
450 "data": {
451 "id": "u1",
452 "username": "testuser",
453 "name": "Test User",
454 "public_metrics": {
455 "followers_count": 100,
456 "following_count": 50,
457 "tweet_count": 500
458 }
459 }
460 })))
461 .mount(&server)
462 .await;
463
464 let user = client.get_me().await.expect("get me");
465 assert_eq!(user.username, "testuser");
466 assert_eq!(user.public_metrics.followers_count, 100);
467 }
468
469 #[tokio::test]
470 async fn error_429_maps_to_rate_limited() {
471 let server = MockServer::start().await;
472 let client = setup_client(&server).await;
473
474 Mock::given(method("GET"))
475 .and(path("/tweets/search/recent"))
476 .respond_with(
477 ResponseTemplate::new(429)
478 .set_body_json(serde_json::json!({"detail": "Too Many Requests"})),
479 )
480 .mount(&server)
481 .await;
482
483 let result = client.search_tweets("test", 10, None).await;
484 assert!(matches!(result, Err(XApiError::RateLimited { .. })));
485 }
486
487 #[tokio::test]
488 async fn error_401_maps_to_auth_expired() {
489 let server = MockServer::start().await;
490 let client = setup_client(&server).await;
491
492 Mock::given(method("GET"))
493 .and(path("/users/me"))
494 .respond_with(
495 ResponseTemplate::new(401)
496 .set_body_json(serde_json::json!({"detail": "Unauthorized"})),
497 )
498 .mount(&server)
499 .await;
500
501 let result = client.get_me().await;
502 assert!(matches!(result, Err(XApiError::AuthExpired)));
503 }
504
505 #[tokio::test]
506 async fn error_403_maps_to_forbidden() {
507 let server = MockServer::start().await;
508 let client = setup_client(&server).await;
509
510 Mock::given(method("GET"))
511 .and(path("/tweets/search/recent"))
512 .respond_with(ResponseTemplate::new(403).set_body_json(
513 serde_json::json!({"detail": "You are not permitted to use this endpoint"}),
514 ))
515 .mount(&server)
516 .await;
517
518 let result = client.search_tweets("test", 10, None).await;
519 match result {
520 Err(XApiError::Forbidden { message }) => {
521 assert!(message.contains("not permitted"));
522 }
523 other => panic!("expected Forbidden, got: {other:?}"),
524 }
525 }
526
527 #[tokio::test]
528 async fn error_500_maps_to_api_error() {
529 let server = MockServer::start().await;
530 let client = setup_client(&server).await;
531
532 Mock::given(method("GET"))
533 .and(path("/users/me"))
534 .respond_with(
535 ResponseTemplate::new(500)
536 .set_body_json(serde_json::json!({"detail": "Internal Server Error"})),
537 )
538 .mount(&server)
539 .await;
540
541 let result = client.get_me().await;
542 match result {
543 Err(XApiError::ApiError { status, .. }) => assert_eq!(status, 500),
544 other => panic!("expected ApiError, got: {other:?}"),
545 }
546 }
547
548 #[tokio::test]
549 async fn parse_rate_limit_headers_works() {
550 let mut headers = reqwest::header::HeaderMap::new();
551 headers.insert("x-rate-limit-remaining", "42".parse().unwrap());
552 headers.insert("x-rate-limit-reset", "1700000000".parse().unwrap());
553
554 let info = XApiHttpClient::parse_rate_limit_headers(&headers);
555 assert_eq!(info.remaining, Some(42));
556 assert_eq!(info.reset_at, Some(1700000000));
557 }
558
559 #[tokio::test]
560 async fn set_access_token_updates() {
561 let client = XApiHttpClient::new("old-token".to_string());
562 client.set_access_token("new-token".to_string()).await;
563
564 let token = client.access_token.read().await;
565 assert_eq!(*token, "new-token");
566 }
567
568 #[tokio::test]
569 async fn get_tweet_success() {
570 let server = MockServer::start().await;
571 let client = setup_client(&server).await;
572
573 Mock::given(method("GET"))
574 .and(path("/tweets/12345"))
575 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
576 "data": {
577 "id": "12345",
578 "text": "Hello",
579 "author_id": "a1",
580 "public_metrics": {"like_count": 5, "retweet_count": 1, "reply_count": 0, "quote_count": 0}
581 }
582 })))
583 .mount(&server)
584 .await;
585
586 let tweet = client.get_tweet("12345").await.expect("get tweet");
587 assert_eq!(tweet.id, "12345");
588 assert_eq!(tweet.public_metrics.like_count, 5);
589 }
590
591 #[tokio::test]
592 async fn get_mentions_success() {
593 let server = MockServer::start().await;
594 let client = setup_client(&server).await;
595
596 Mock::given(method("GET"))
597 .and(path("/users/u1/mentions"))
598 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
599 "data": [{"id": "m1", "text": "@testuser hello", "author_id": "a2"}],
600 "meta": {"result_count": 1}
601 })))
602 .mount(&server)
603 .await;
604
605 let resp = client.get_mentions("u1", None).await.expect("mentions");
606 assert_eq!(resp.data.len(), 1);
607 }
608}