rust_rabbit/patterns/
request_response.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use tokio::sync::oneshot;
7use tracing::{debug, error, info, warn};
8use uuid::Uuid;
9
10use crate::error::RustRabbitError;
11
12/// Correlation ID for request-response tracking
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct CorrelationId(String);
15
16impl CorrelationId {
17    /// Create a new unique correlation ID
18    pub fn new() -> Self {
19        Self(Uuid::new_v4().to_string())
20    }
21
22    /// Create from existing string
23    pub fn from_string(id: String) -> Self {
24        Self(id)
25    }
26
27    /// Get the correlation ID as string
28    pub fn as_str(&self) -> &str {
29        &self.0
30    }
31}
32
33impl Default for CorrelationId {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl std::fmt::Display for CorrelationId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45/// Request message with correlation ID and timeout
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RequestMessage {
48    pub correlation_id: CorrelationId,
49    pub reply_to: String,
50    pub payload: Vec<u8>,
51    pub timeout: Duration,
52    pub timestamp: chrono::DateTime<chrono::Utc>,
53}
54
55impl RequestMessage {
56    pub fn new(payload: Vec<u8>, reply_to: String, timeout: Duration) -> Self {
57        Self {
58            correlation_id: CorrelationId::new(),
59            reply_to,
60            payload,
61            timeout,
62            timestamp: chrono::Utc::now(),
63        }
64    }
65
66    pub fn with_correlation_id(mut self, correlation_id: CorrelationId) -> Self {
67        self.correlation_id = correlation_id;
68        self
69    }
70}
71
72/// Response message with correlation ID
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ResponseMessage {
75    pub correlation_id: CorrelationId,
76    pub payload: Vec<u8>,
77    pub success: bool,
78    pub error_message: Option<String>,
79    pub timestamp: chrono::DateTime<chrono::Utc>,
80}
81
82impl ResponseMessage {
83    pub fn success(correlation_id: CorrelationId, payload: Vec<u8>) -> Self {
84        Self {
85            correlation_id,
86            payload,
87            success: true,
88            error_message: None,
89            timestamp: chrono::Utc::now(),
90        }
91    }
92
93    pub fn error(correlation_id: CorrelationId, error: String) -> Self {
94        Self {
95            correlation_id,
96            payload: Vec::new(),
97            success: false,
98            error_message: Some(error),
99            timestamp: chrono::Utc::now(),
100        }
101    }
102}
103
104/// Pending request tracker
105#[derive(Debug)]
106struct PendingRequest {
107    sender: oneshot::Sender<ResponseMessage>,
108    created_at: Instant,
109    timeout: Duration,
110}
111
112/// Request-Response client for handling RPC-style messaging
113#[derive(Debug)]
114pub struct RequestResponseClient {
115    pending_requests: Arc<Mutex<HashMap<CorrelationId, PendingRequest>>>,
116    default_timeout: Duration,
117}
118
119impl RequestResponseClient {
120    pub fn new(default_timeout: Duration) -> Self {
121        let client = Self {
122            pending_requests: Arc::new(Mutex::new(HashMap::new())),
123            default_timeout,
124        };
125
126        // Start cleanup task for expired requests
127        let pending_requests = client.pending_requests.clone();
128        tokio::spawn(async move {
129            let mut interval = tokio::time::interval(Duration::from_secs(30));
130            loop {
131                interval.tick().await;
132                Self::cleanup_expired_requests(&pending_requests).await;
133            }
134        });
135
136        client
137    }
138
139    /// Send a request and wait for response
140    pub async fn send_request(
141        &self,
142        payload: Vec<u8>,
143        reply_to: String,
144        timeout: Option<Duration>,
145    ) -> Result<ResponseMessage> {
146        let timeout = timeout.unwrap_or(self.default_timeout);
147        let request = RequestMessage::new(payload, reply_to, timeout);
148        let correlation_id = request.correlation_id.clone();
149
150        let (sender, receiver) = oneshot::channel();
151        let pending_request = PendingRequest {
152            sender,
153            created_at: Instant::now(),
154            timeout,
155        };
156
157        // Store pending request
158        {
159            let mut pending = self.pending_requests.lock().unwrap();
160            pending.insert(correlation_id.clone(), pending_request);
161        }
162
163        debug!(
164            correlation_id = %correlation_id,
165            timeout_ms = timeout.as_millis(),
166            "Registered pending request"
167        );
168
169        // TODO: Send actual request message via RabbitMQ
170        // This will be integrated with the main RustRabbit client
171
172        // Wait for response with timeout
173        tokio::select! {
174            result = receiver => {
175                match result {
176                    Ok(response) => {
177                        info!(
178                            correlation_id = %correlation_id,
179                            success = response.success,
180                            "Received response"
181                        );
182                        Ok(response)
183                    }
184                    Err(_) => {
185                        warn!(correlation_id = %correlation_id, "Response channel closed");
186                        Err(RustRabbitError::RequestTimeout.into())
187                    }
188                }
189            }
190            _ = tokio::time::sleep(timeout) => {
191                // Remove from pending requests on timeout
192                {
193                    let mut pending = self.pending_requests.lock().unwrap();
194                    pending.remove(&correlation_id);
195                }
196                error!(correlation_id = %correlation_id, "Request timeout");
197                Err(RustRabbitError::RequestTimeout.into())
198            }
199        }
200    }
201
202    /// Handle incoming response message
203    pub async fn handle_response(&self, response: ResponseMessage) -> Result<()> {
204        let correlation_id = response.correlation_id.clone();
205
206        let sender = {
207            let mut pending = self.pending_requests.lock().unwrap();
208            pending.remove(&correlation_id)
209        };
210
211        if let Some(pending_request) = sender {
212            debug!(
213                correlation_id = %correlation_id,
214                "Forwarding response to pending request"
215            );
216
217            if pending_request.sender.send(response).is_err() {
218                warn!(
219                    correlation_id = %correlation_id,
220                    "Failed to send response - receiver dropped"
221                );
222            }
223        } else {
224            warn!(
225                correlation_id = %correlation_id,
226                "Received response for unknown correlation ID"
227            );
228        }
229
230        Ok(())
231    }
232
233    /// Get pending request count (for monitoring)
234    pub fn pending_count(&self) -> usize {
235        self.pending_requests.lock().unwrap().len()
236    }
237
238    /// Cleanup expired requests
239    async fn cleanup_expired_requests(
240        pending_requests: &Arc<Mutex<HashMap<CorrelationId, PendingRequest>>>,
241    ) {
242        let now = Instant::now();
243        let mut expired_ids = Vec::new();
244
245        {
246            let pending = pending_requests.lock().unwrap();
247            for (correlation_id, request) in pending.iter() {
248                if now.duration_since(request.created_at) > request.timeout {
249                    expired_ids.push(correlation_id.clone());
250                }
251            }
252        }
253
254        if !expired_ids.is_empty() {
255            let mut pending = pending_requests.lock().unwrap();
256            for correlation_id in expired_ids {
257                if let Some(expired_request) = pending.remove(&correlation_id) {
258                    let _ = expired_request.sender.send(ResponseMessage::error(
259                        correlation_id.clone(),
260                        "Request timeout".to_string(),
261                    ));
262
263                    warn!(
264                        correlation_id = %correlation_id,
265                        "Cleaned up expired request"
266                    );
267                }
268            }
269        }
270    }
271}
272
273/// Request-Response server for handling incoming requests
274pub struct RequestResponseServer {
275    handler: Arc<dyn RequestHandler + Send + Sync>,
276}
277
278/// Trait for handling incoming requests
279#[async_trait::async_trait]
280pub trait RequestHandler {
281    async fn handle_request(&self, request: RequestMessage) -> Result<ResponseMessage>;
282}
283
284impl RequestResponseServer {
285    pub fn new(handler: Arc<dyn RequestHandler + Send + Sync>) -> Self {
286        Self { handler }
287    }
288
289    /// Process incoming request and generate response
290    pub async fn process_request(&self, request: RequestMessage) -> Result<ResponseMessage> {
291        let correlation_id = request.correlation_id.clone();
292
293        debug!(
294            correlation_id = %correlation_id,
295            "Processing incoming request"
296        );
297
298        let start_time = Instant::now();
299        let response = self.handler.handle_request(request).await;
300        let processing_time = start_time.elapsed();
301
302        match &response {
303            Ok(resp) => {
304                info!(
305                    correlation_id = %correlation_id,
306                    processing_time_ms = processing_time.as_millis(),
307                    success = resp.success,
308                    "Request processed"
309                );
310            }
311            Err(err) => {
312                error!(
313                    correlation_id = %correlation_id,
314                    processing_time_ms = processing_time.as_millis(),
315                    error = %err,
316                    "Request processing failed"
317                );
318                return Ok(ResponseMessage::error(correlation_id, err.to_string()));
319            }
320        }
321
322        response
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use tokio::time::sleep;
330
331    struct TestHandler;
332
333    #[async_trait::async_trait]
334    impl RequestHandler for TestHandler {
335        async fn handle_request(&self, request: RequestMessage) -> Result<ResponseMessage> {
336            let payload = format!("Echo: {}", String::from_utf8_lossy(&request.payload));
337            Ok(ResponseMessage::success(
338                request.correlation_id,
339                payload.into_bytes(),
340            ))
341        }
342    }
343
344    #[tokio::test]
345    async fn test_correlation_id_generation() {
346        let id1 = CorrelationId::new();
347        let id2 = CorrelationId::new();
348        assert_ne!(id1, id2);
349    }
350
351    #[tokio::test]
352    async fn test_request_message_creation() {
353        let payload = b"test message".to_vec();
354        let request = RequestMessage::new(
355            payload.clone(),
356            "reply.queue".to_string(),
357            Duration::from_secs(30),
358        );
359
360        assert_eq!(request.payload, payload);
361        assert_eq!(request.reply_to, "reply.queue");
362        assert_eq!(request.timeout, Duration::from_secs(30));
363    }
364
365    #[tokio::test]
366    async fn test_response_creation() {
367        let correlation_id = CorrelationId::new();
368        let payload = b"response".to_vec();
369
370        let success_response = ResponseMessage::success(correlation_id.clone(), payload.clone());
371        assert!(success_response.success);
372        assert_eq!(success_response.correlation_id, correlation_id);
373        assert_eq!(success_response.payload, payload);
374
375        let error_response = ResponseMessage::error(correlation_id.clone(), "Error".to_string());
376        assert!(!error_response.success);
377        assert_eq!(error_response.error_message, Some("Error".to_string()));
378    }
379
380    #[tokio::test]
381    async fn test_request_response_server() {
382        let handler = Arc::new(TestHandler);
383        let server = RequestResponseServer::new(handler);
384
385        let request = RequestMessage::new(
386            b"hello".to_vec(),
387            "reply.queue".to_string(),
388            Duration::from_secs(30),
389        );
390        let correlation_id = request.correlation_id.clone();
391
392        let response = server.process_request(request).await.unwrap();
393        assert_eq!(response.correlation_id, correlation_id);
394        assert!(response.success);
395        assert_eq!(String::from_utf8_lossy(&response.payload), "Echo: hello");
396    }
397
398    #[tokio::test]
399    async fn test_pending_requests_cleanup() {
400        let client = RequestResponseClient::new(Duration::from_millis(100));
401
402        // Send request with short timeout
403        let result = client
404            .send_request(
405                b"test".to_vec(),
406                "reply.queue".to_string(),
407                Some(Duration::from_millis(50)),
408            )
409            .await;
410
411        // Should timeout
412        assert!(result.is_err());
413
414        // Wait for cleanup
415        sleep(Duration::from_millis(200)).await;
416
417        // Should have 0 pending requests
418        assert_eq!(client.pending_count(), 0);
419    }
420}