1use super::{
2 ledger::IdempotencyLedger,
3 models::{Tweet, TweetFields, TweetMeta},
4};
5use anyhow::{anyhow, Context, Result};
6use serde::{Deserialize, Serialize};
7use std::str::FromStr;
8use uuid::Uuid;
9
10#[derive(Debug)]
12pub struct IdempotencyConflictError {
13 pub client_request_id: String,
14}
15
16impl std::fmt::Display for IdempotencyConflictError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(
19 f,
20 "Operation with client_request_id '{}' already exists",
21 self.client_request_id
22 )
23 }
24}
25
26impl std::error::Error for IdempotencyConflictError {}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum IfExistsPolicy {
31 Return,
33 Error,
35}
36
37impl FromStr for IfExistsPolicy {
38 type Err = anyhow::Error;
39
40 fn from_str(s: &str) -> Result<Self> {
41 match s {
42 "return" => Ok(Self::Return),
43 "error" => Ok(Self::Error),
44 _ => Err(anyhow!(
45 "Invalid if-exists policy: {}. Valid values: return, error",
46 s
47 )),
48 }
49 }
50}
51
52impl IfExistsPolicy {
53 pub fn as_str(&self) -> &'static str {
54 match self {
55 Self::Return => "return",
56 Self::Error => "error",
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct CreateArgs {
64 pub text: String,
65 pub client_request_id: Option<String>,
66 pub if_exists: IfExistsPolicy,
67}
68
69#[derive(Debug, Clone)]
71pub struct ListArgs {
72 pub fields: Vec<TweetFields>,
73 pub limit: Option<usize>,
74 pub cursor: Option<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct CreateResult {
80 pub tweet: Tweet,
81 pub meta: TweetMeta,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct PaginationMeta {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub next_cursor: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub prev_cursor: Option<String>,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ListResult {
97 pub tweets: Vec<Tweet>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub meta: Option<ListResultMeta>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ListResultMeta {
105 pub pagination: PaginationMeta,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum ErrorKind {
111 Retryable,
113 NonRetryable,
115 Timeout,
117}
118
119#[derive(Debug)]
121pub struct ClassifiedError {
122 pub kind: ErrorKind,
123 pub status_code: Option<u16>,
124 pub message: String,
125 pub is_retryable: bool,
126 pub retry_after_ms: Option<u64>,
127}
128
129impl ClassifiedError {
130 pub fn from_status_code(status_code: u16, message: String) -> Self {
131 let (kind, is_retryable) = match status_code {
132 429 => (ErrorKind::Retryable, true),
133 500..=599 => (ErrorKind::Retryable, true),
134 400..=499 => (ErrorKind::NonRetryable, false),
135 _ => (ErrorKind::NonRetryable, false),
136 };
137
138 Self {
139 kind,
140 status_code: Some(status_code),
141 message,
142 is_retryable,
143 retry_after_ms: None,
144 }
145 }
146
147 pub fn timeout(message: String) -> Self {
148 Self {
149 kind: ErrorKind::Timeout,
150 status_code: None,
151 message,
152 is_retryable: true,
153 retry_after_ms: None,
154 }
155 }
156
157 pub fn with_retry_after(mut self, retry_after_ms: u64) -> Self {
158 self.retry_after_ms = Some(retry_after_ms);
159 self
160 }
161
162 pub fn to_error_code(&self) -> crate::protocol::ErrorCode {
164 use crate::protocol::ErrorCode;
165 match self.kind {
166 ErrorKind::Retryable => {
167 if let Some(429) = self.status_code {
168 ErrorCode::RateLimitExceeded
169 } else {
170 ErrorCode::ServiceUnavailable
171 }
172 }
173 ErrorKind::Timeout => ErrorCode::NetworkError,
174 ErrorKind::NonRetryable => ErrorCode::InternalError,
175 }
176 }
177}
178
179impl std::fmt::Display for ClassifiedError {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 write!(f, "{}", self.message)
182 }
183}
184
185impl std::error::Error for ClassifiedError {}
186
187pub struct TweetCommand {
189 ledger: IdempotencyLedger,
190}
191
192impl TweetCommand {
193 pub fn new(ledger: IdempotencyLedger) -> Self {
195 Self { ledger }
196 }
197
198 pub fn create(&self, args: CreateArgs) -> Result<CreateResult> {
200 if let Ok(error_type) = std::env::var("XCOM_SIMULATE_ERROR") {
202 match error_type.as_str() {
203 "rate_limit" => {
204 let retry_after = std::env::var("XCOM_RETRY_AFTER_MS")
205 .ok()
206 .and_then(|s| s.parse::<u64>().ok())
207 .unwrap_or(60000);
208 return Err(ClassifiedError::from_status_code(
209 429,
210 "Rate limit exceeded".to_string(),
211 )
212 .with_retry_after(retry_after)
213 .into());
214 }
215 "server_error" => {
216 return Err(ClassifiedError::from_status_code(
217 500,
218 "Internal server error".to_string(),
219 )
220 .into());
221 }
222 "timeout" => {
223 return Err(ClassifiedError::timeout("Request timeout".to_string()).into());
224 }
225 _ => {
226 }
228 }
229 }
230
231 let client_request_id = args
233 .client_request_id
234 .unwrap_or_else(|| Uuid::new_v4().to_string());
235
236 let request_hash = IdempotencyLedger::compute_request_hash(&args.text);
238
239 if let Some(entry) = self
241 .ledger
242 .lookup(&client_request_id)
243 .context("Failed to lookup operation in ledger")?
244 {
245 match args.if_exists {
247 IfExistsPolicy::Return => {
248 let mut tweet = Tweet::new(entry.tweet_id.clone());
250 tweet.text = Some(args.text.clone());
251
252 let meta = TweetMeta {
253 client_request_id: client_request_id.clone(),
254 from_cache: Some(true),
255 };
256
257 return Ok(CreateResult { tweet, meta });
258 }
259 IfExistsPolicy::Error => {
260 return Err(IdempotencyConflictError {
262 client_request_id: client_request_id.clone(),
263 }
264 .into());
265 }
266 }
267 }
268
269 let tweet_id = format!("tweet_{}", Uuid::new_v4());
271 let mut tweet = Tweet::new(tweet_id.clone());
272 tweet.text = Some(args.text);
273
274 self.ledger
276 .record(&client_request_id, &request_hash, &tweet_id, "success")
277 .context("Failed to record operation in ledger")?;
278
279 let meta = TweetMeta {
280 client_request_id,
281 from_cache: None,
282 };
283
284 Ok(CreateResult { tweet, meta })
285 }
286
287 pub fn list(&self, args: ListArgs) -> Result<ListResult> {
289 if let Ok(error_type) = std::env::var("XCOM_SIMULATE_ERROR") {
291 match error_type.as_str() {
292 "rate_limit" => {
293 let retry_after = std::env::var("XCOM_RETRY_AFTER_MS")
294 .ok()
295 .and_then(|s| s.parse::<u64>().ok())
296 .unwrap_or(60000);
297 return Err(ClassifiedError::from_status_code(
298 429,
299 "Rate limit exceeded".to_string(),
300 )
301 .with_retry_after(retry_after)
302 .into());
303 }
304 "server_error" => {
305 return Err(ClassifiedError::from_status_code(
306 500,
307 "Internal server error".to_string(),
308 )
309 .into());
310 }
311 "timeout" => {
312 return Err(ClassifiedError::timeout("Request timeout".to_string()).into());
313 }
314 _ => {
315 }
317 }
318 }
319
320 let limit = args.limit.unwrap_or(10);
322
323 let offset = if let Some(cursor) = &args.cursor {
325 cursor
327 .strip_prefix("cursor_")
328 .and_then(|s| s.parse::<usize>().ok())
329 .unwrap_or(0)
330 } else {
331 0
332 };
333
334 let mut tweets = Vec::new();
335 for i in offset..(offset + limit) {
336 let mut tweet = Tweet::new(format!("tweet_{}", i));
337 tweet.text = Some(format!("Tweet text {}", i));
338 tweet.author_id = Some(format!("user_{}", i));
339 tweet.created_at = Some("2024-01-01T00:00:00Z".to_string());
340
341 let projected = tweet.project(&args.fields);
343 tweets.push(projected);
344 }
345
346 let next_cursor = if tweets.len() == limit {
348 Some(format!("cursor_{}", offset + limit))
349 } else {
350 None
351 };
352
353 let prev_cursor = if offset > 0 {
354 Some(format!("cursor_{}", offset.saturating_sub(limit)))
355 } else {
356 None
357 };
358
359 let meta = Some(ListResultMeta {
360 pagination: PaginationMeta {
361 next_cursor,
362 prev_cursor,
363 },
364 });
365
366 Ok(ListResult { tweets, meta })
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use tempfile::TempDir;
374
375 fn create_test_command() -> (TweetCommand, TempDir) {
376 let temp_dir = TempDir::new().unwrap();
377 let db_path = temp_dir.path().join("test.db");
378 let ledger = IdempotencyLedger::new(Some(&db_path)).unwrap();
379 let cmd = TweetCommand::new(ledger);
380 (cmd, temp_dir)
381 }
382
383 #[test]
384 fn test_create_generates_client_request_id() {
385 let (cmd, _temp) = create_test_command();
386
387 let args = CreateArgs {
388 text: "Hello world".to_string(),
389 client_request_id: None,
390 if_exists: IfExistsPolicy::Return,
391 };
392
393 let result = cmd.create(args).unwrap();
394 assert!(!result.meta.client_request_id.is_empty());
395 assert_eq!(result.tweet.text, Some("Hello world".to_string()));
396 }
397
398 #[test]
399 fn test_create_with_explicit_client_request_id() {
400 let (cmd, _temp) = create_test_command();
401
402 let args = CreateArgs {
403 text: "Hello world".to_string(),
404 client_request_id: Some("my-request-id".to_string()),
405 if_exists: IfExistsPolicy::Return,
406 };
407
408 let result = cmd.create(args).unwrap();
409 assert_eq!(result.meta.client_request_id, "my-request-id");
410 }
411
412 #[test]
413 fn test_create_idempotency_return_policy() {
414 let (cmd, _temp) = create_test_command();
415
416 let args = CreateArgs {
417 text: "Hello world".to_string(),
418 client_request_id: Some("test-123".to_string()),
419 if_exists: IfExistsPolicy::Return,
420 };
421
422 let result1 = cmd.create(args.clone()).unwrap();
424 let tweet_id1 = result1.tweet.id.clone();
425
426 let result2 = cmd.create(args).unwrap();
428 assert_eq!(result2.tweet.id, tweet_id1);
429 assert_eq!(result2.meta.from_cache, Some(true));
430 }
431
432 #[test]
433 fn test_create_idempotency_error_policy() {
434 let (cmd, _temp) = create_test_command();
435
436 let args = CreateArgs {
437 text: "Hello world".to_string(),
438 client_request_id: Some("test-456".to_string()),
439 if_exists: IfExistsPolicy::Error,
440 };
441
442 cmd.create(args.clone()).unwrap();
444
445 let result = cmd.create(args);
447 assert!(result.is_err());
448 assert!(result.unwrap_err().to_string().contains("already exists"));
449 }
450
451 #[test]
452 fn test_list_with_field_projection() {
453 let (cmd, _temp) = create_test_command();
454
455 let args = ListArgs {
456 fields: vec![TweetFields::Id, TweetFields::Text],
457 limit: Some(5),
458 cursor: None,
459 };
460
461 let result = cmd.list(args).unwrap();
462 assert_eq!(result.tweets.len(), 5);
463
464 for tweet in &result.tweets {
466 assert!(!tweet.id.is_empty());
467 assert!(tweet.text.is_some());
468 assert!(tweet.author_id.is_none()); }
470 }
471
472 #[test]
473 fn test_list_pagination() {
474 let (cmd, _temp) = create_test_command();
475
476 let args = ListArgs {
477 fields: TweetFields::default_fields(),
478 limit: Some(10),
479 cursor: None,
480 };
481
482 let result = cmd.list(args).unwrap();
483 assert_eq!(result.tweets.len(), 10);
484 assert!(result.meta.is_some());
485 let meta = result.meta.unwrap();
486 assert!(meta.pagination.next_cursor.is_some());
487 assert_eq!(meta.pagination.next_cursor, Some("cursor_10".to_string()));
488 assert!(meta.pagination.prev_cursor.is_none());
489 }
490
491 #[test]
492 fn test_error_classification() {
493 let err_429 = ClassifiedError::from_status_code(429, "Rate limit".to_string());
494 assert_eq!(err_429.kind, ErrorKind::Retryable);
495 assert!(err_429.is_retryable);
496
497 let err_500 = ClassifiedError::from_status_code(500, "Server error".to_string());
498 assert_eq!(err_500.kind, ErrorKind::Retryable);
499 assert!(err_500.is_retryable);
500
501 let err_400 = ClassifiedError::from_status_code(400, "Bad request".to_string());
502 assert_eq!(err_400.kind, ErrorKind::NonRetryable);
503 assert!(!err_400.is_retryable);
504
505 let err_timeout = ClassifiedError::timeout("Timeout".to_string());
506 assert_eq!(err_timeout.kind, ErrorKind::Timeout);
507 assert!(err_timeout.is_retryable);
508 }
509
510 #[test]
511 fn test_if_exists_policy_from_str() {
512 assert_eq!(
513 IfExistsPolicy::from_str("return").unwrap(),
514 IfExistsPolicy::Return
515 );
516 assert_eq!(
517 IfExistsPolicy::from_str("error").unwrap(),
518 IfExistsPolicy::Error
519 );
520 assert!(IfExistsPolicy::from_str("invalid").is_err());
521 }
522}