1use super::XApiClient;
8use crate::error::XApiError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ApiTier {
13 Free,
15 Basic,
17 Pro,
19}
20
21#[derive(Debug, Clone)]
23pub struct TierCapabilities {
24 pub search_available: bool,
26 pub mentions_available: bool,
28 pub posting_available: bool,
30 pub discovery_loop_enabled: bool,
32}
33
34impl ApiTier {
35 pub fn capabilities(&self) -> TierCapabilities {
37 match self {
38 ApiTier::Free => TierCapabilities {
39 search_available: false,
40 mentions_available: false,
41 posting_available: true,
42 discovery_loop_enabled: false,
43 },
44 ApiTier::Basic | ApiTier::Pro => TierCapabilities {
45 search_available: true,
46 mentions_available: true,
47 posting_available: true,
48 discovery_loop_enabled: true,
49 },
50 }
51 }
52}
53
54impl std::fmt::Display for ApiTier {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 ApiTier::Free => write!(f, "Free"),
58 ApiTier::Basic => write!(f, "Basic"),
59 ApiTier::Pro => write!(f, "Pro"),
60 }
61 }
62}
63
64pub async fn detect_tier(client: &dyn XApiClient) -> Result<ApiTier, XApiError> {
70 match client.search_tweets("test", 10, None, None).await {
71 Ok(_) => {
72 let tier = ApiTier::Basic;
73 log_tier_detection(&tier);
74 Ok(tier)
75 }
76 Err(XApiError::Forbidden { .. }) => {
77 let tier = ApiTier::Free;
78 log_tier_detection(&tier);
79 Ok(tier)
80 }
81 Err(XApiError::RateLimited { .. }) => {
82 let tier = ApiTier::Basic;
84 log_tier_detection(&tier);
85 Ok(tier)
86 }
87 Err(XApiError::AuthExpired) => {
88 Err(XApiError::AuthExpired)
90 }
91 Err(XApiError::Network { .. }) => {
92 tracing::warn!("Network error during tier detection, defaulting to Free tier");
93 let tier = ApiTier::Free;
94 log_tier_detection(&tier);
95 Ok(tier)
96 }
97 Err(e) => {
98 tracing::warn!(error = %e, "Unexpected error during tier detection, defaulting to Free tier");
99 let tier = ApiTier::Free;
100 log_tier_detection(&tier);
101 Ok(tier)
102 }
103 }
104}
105
106fn log_tier_detection(tier: &ApiTier) {
108 let caps = tier.capabilities();
109 tracing::info!(
110 tier = %tier,
111 search = caps.search_available,
112 mentions = caps.mentions_available,
113 posting = caps.posting_available,
114 discovery_loop = caps.discovery_loop_enabled,
115 "Detected X API tier"
116 );
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::x_api::types::*;
123
124 struct MockClient {
126 search_result: Result<SearchResponse, XApiError>,
127 }
128
129 impl MockClient {
130 fn ok() -> Self {
131 Self {
132 search_result: Ok(SearchResponse {
133 data: vec![],
134 includes: None,
135 meta: SearchMeta {
136 newest_id: None,
137 oldest_id: None,
138 result_count: 0,
139 next_token: None,
140 },
141 }),
142 }
143 }
144
145 fn forbidden() -> Self {
146 Self {
147 search_result: Err(XApiError::Forbidden {
148 message: "Not permitted".to_string(),
149 }),
150 }
151 }
152
153 fn rate_limited() -> Self {
154 Self {
155 search_result: Err(XApiError::RateLimited {
156 retry_after: Some(60),
157 }),
158 }
159 }
160
161 fn auth_expired() -> Self {
162 Self {
163 search_result: Err(XApiError::AuthExpired),
164 }
165 }
166
167 fn api_error() -> Self {
168 Self {
169 search_result: Err(XApiError::ApiError {
170 status: 500,
171 message: "Internal error".to_string(),
172 }),
173 }
174 }
175 }
176
177 #[async_trait::async_trait]
178 impl XApiClient for MockClient {
179 async fn search_tweets(
180 &self,
181 _query: &str,
182 _max_results: u32,
183 _since_id: Option<&str>,
184 _pagination_token: Option<&str>,
185 ) -> Result<SearchResponse, XApiError> {
186 match &self.search_result {
187 Ok(r) => Ok(r.clone()),
188 Err(e) => match e {
189 XApiError::Forbidden { message } => Err(XApiError::Forbidden {
190 message: message.clone(),
191 }),
192 XApiError::RateLimited { retry_after } => Err(XApiError::RateLimited {
193 retry_after: *retry_after,
194 }),
195 XApiError::AuthExpired => Err(XApiError::AuthExpired),
196 XApiError::ApiError { status, message } => Err(XApiError::ApiError {
197 status: *status,
198 message: message.clone(),
199 }),
200 _ => Err(XApiError::ApiError {
201 status: 0,
202 message: "test error".to_string(),
203 }),
204 },
205 }
206 }
207
208 async fn get_mentions(
209 &self,
210 _user_id: &str,
211 _since_id: Option<&str>,
212 _pagination_token: Option<&str>,
213 ) -> Result<MentionResponse, XApiError> {
214 unimplemented!()
215 }
216
217 async fn post_tweet(&self, _text: &str) -> Result<PostedTweet, XApiError> {
218 unimplemented!()
219 }
220
221 async fn reply_to_tweet(
222 &self,
223 _text: &str,
224 _in_reply_to_id: &str,
225 ) -> Result<PostedTweet, XApiError> {
226 unimplemented!()
227 }
228
229 async fn get_tweet(&self, _tweet_id: &str) -> Result<Tweet, XApiError> {
230 unimplemented!()
231 }
232
233 async fn get_me(&self) -> Result<User, XApiError> {
234 unimplemented!()
235 }
236
237 async fn get_user_tweets(
238 &self,
239 _user_id: &str,
240 _max_results: u32,
241 _pagination_token: Option<&str>,
242 ) -> Result<SearchResponse, XApiError> {
243 unimplemented!()
244 }
245
246 async fn get_user_by_username(&self, _username: &str) -> Result<User, XApiError> {
247 unimplemented!()
248 }
249 }
250
251 #[tokio::test]
252 async fn detect_basic_on_search_success() {
253 let client = MockClient::ok();
254 let tier = detect_tier(&client).await.expect("detect");
255 assert_eq!(tier, ApiTier::Basic);
256 }
257
258 #[tokio::test]
259 async fn detect_free_on_forbidden() {
260 let client = MockClient::forbidden();
261 let tier = detect_tier(&client).await.expect("detect");
262 assert_eq!(tier, ApiTier::Free);
263 }
264
265 #[tokio::test]
266 async fn detect_basic_on_rate_limited() {
267 let client = MockClient::rate_limited();
268 let tier = detect_tier(&client).await.expect("detect");
269 assert_eq!(tier, ApiTier::Basic);
270 }
271
272 #[tokio::test]
273 async fn detect_propagates_auth_expired() {
274 let client = MockClient::auth_expired();
275 let result = detect_tier(&client).await;
276 assert!(matches!(result, Err(XApiError::AuthExpired)));
277 }
278
279 #[tokio::test]
280 async fn detect_defaults_to_free_on_other_errors() {
281 let client = MockClient::api_error();
282 let tier = detect_tier(&client).await.expect("detect");
283 assert_eq!(tier, ApiTier::Free);
284 }
285
286 #[test]
287 fn free_tier_capabilities() {
288 let caps = ApiTier::Free.capabilities();
289 assert!(!caps.search_available);
290 assert!(!caps.mentions_available);
291 assert!(caps.posting_available);
292 assert!(!caps.discovery_loop_enabled);
293 }
294
295 #[test]
296 fn basic_tier_capabilities() {
297 let caps = ApiTier::Basic.capabilities();
298 assert!(caps.search_available);
299 assert!(caps.mentions_available);
300 assert!(caps.posting_available);
301 assert!(caps.discovery_loop_enabled);
302 }
303
304 #[test]
305 fn pro_tier_same_as_basic() {
306 let basic = ApiTier::Basic.capabilities();
307 let pro = ApiTier::Pro.capabilities();
308 assert_eq!(basic.search_available, pro.search_available);
309 assert_eq!(basic.mentions_available, pro.mentions_available);
310 }
311
312 #[test]
313 fn tier_display() {
314 assert_eq!(ApiTier::Free.to_string(), "Free");
315 assert_eq!(ApiTier::Basic.to_string(), "Basic");
316 assert_eq!(ApiTier::Pro.to_string(), "Pro");
317 }
318}