1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use tokio::sync::oneshot;
7use tracing::{debug, error, info, warn};
8use uuid::Uuid;
9
10use crate::error::RustRabbitError;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct CorrelationId(String);
15
16impl CorrelationId {
17 pub fn new() -> Self {
19 Self(Uuid::new_v4().to_string())
20 }
21
22 pub fn from_string(id: String) -> Self {
24 Self(id)
25 }
26
27 pub fn as_str(&self) -> &str {
29 &self.0
30 }
31}
32
33impl Default for CorrelationId {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl std::fmt::Display for CorrelationId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RequestMessage {
48 pub correlation_id: CorrelationId,
49 pub reply_to: String,
50 pub payload: Vec<u8>,
51 pub timeout: Duration,
52 pub timestamp: chrono::DateTime<chrono::Utc>,
53}
54
55impl RequestMessage {
56 pub fn new(payload: Vec<u8>, reply_to: String, timeout: Duration) -> Self {
57 Self {
58 correlation_id: CorrelationId::new(),
59 reply_to,
60 payload,
61 timeout,
62 timestamp: chrono::Utc::now(),
63 }
64 }
65
66 pub fn with_correlation_id(mut self, correlation_id: CorrelationId) -> Self {
67 self.correlation_id = correlation_id;
68 self
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ResponseMessage {
75 pub correlation_id: CorrelationId,
76 pub payload: Vec<u8>,
77 pub success: bool,
78 pub error_message: Option<String>,
79 pub timestamp: chrono::DateTime<chrono::Utc>,
80}
81
82impl ResponseMessage {
83 pub fn success(correlation_id: CorrelationId, payload: Vec<u8>) -> Self {
84 Self {
85 correlation_id,
86 payload,
87 success: true,
88 error_message: None,
89 timestamp: chrono::Utc::now(),
90 }
91 }
92
93 pub fn error(correlation_id: CorrelationId, error: String) -> Self {
94 Self {
95 correlation_id,
96 payload: Vec::new(),
97 success: false,
98 error_message: Some(error),
99 timestamp: chrono::Utc::now(),
100 }
101 }
102}
103
104#[derive(Debug)]
106struct PendingRequest {
107 sender: oneshot::Sender<ResponseMessage>,
108 created_at: Instant,
109 timeout: Duration,
110}
111
112#[derive(Debug)]
114pub struct RequestResponseClient {
115 pending_requests: Arc<Mutex<HashMap<CorrelationId, PendingRequest>>>,
116 default_timeout: Duration,
117}
118
119impl RequestResponseClient {
120 pub fn new(default_timeout: Duration) -> Self {
121 let client = Self {
122 pending_requests: Arc::new(Mutex::new(HashMap::new())),
123 default_timeout,
124 };
125
126 let pending_requests = client.pending_requests.clone();
128 tokio::spawn(async move {
129 let mut interval = tokio::time::interval(Duration::from_secs(30));
130 loop {
131 interval.tick().await;
132 Self::cleanup_expired_requests(&pending_requests).await;
133 }
134 });
135
136 client
137 }
138
139 pub async fn send_request(
141 &self,
142 payload: Vec<u8>,
143 reply_to: String,
144 timeout: Option<Duration>,
145 ) -> Result<ResponseMessage> {
146 let timeout = timeout.unwrap_or(self.default_timeout);
147 let request = RequestMessage::new(payload, reply_to, timeout);
148 let correlation_id = request.correlation_id.clone();
149
150 let (sender, receiver) = oneshot::channel();
151 let pending_request = PendingRequest {
152 sender,
153 created_at: Instant::now(),
154 timeout,
155 };
156
157 {
159 let mut pending = self.pending_requests.lock().unwrap();
160 pending.insert(correlation_id.clone(), pending_request);
161 }
162
163 debug!(
164 correlation_id = %correlation_id,
165 timeout_ms = timeout.as_millis(),
166 "Registered pending request"
167 );
168
169 tokio::select! {
174 result = receiver => {
175 match result {
176 Ok(response) => {
177 info!(
178 correlation_id = %correlation_id,
179 success = response.success,
180 "Received response"
181 );
182 Ok(response)
183 }
184 Err(_) => {
185 warn!(correlation_id = %correlation_id, "Response channel closed");
186 Err(RustRabbitError::RequestTimeout.into())
187 }
188 }
189 }
190 _ = tokio::time::sleep(timeout) => {
191 {
193 let mut pending = self.pending_requests.lock().unwrap();
194 pending.remove(&correlation_id);
195 }
196 error!(correlation_id = %correlation_id, "Request timeout");
197 Err(RustRabbitError::RequestTimeout.into())
198 }
199 }
200 }
201
202 pub async fn handle_response(&self, response: ResponseMessage) -> Result<()> {
204 let correlation_id = response.correlation_id.clone();
205
206 let sender = {
207 let mut pending = self.pending_requests.lock().unwrap();
208 pending.remove(&correlation_id)
209 };
210
211 if let Some(pending_request) = sender {
212 debug!(
213 correlation_id = %correlation_id,
214 "Forwarding response to pending request"
215 );
216
217 if pending_request.sender.send(response).is_err() {
218 warn!(
219 correlation_id = %correlation_id,
220 "Failed to send response - receiver dropped"
221 );
222 }
223 } else {
224 warn!(
225 correlation_id = %correlation_id,
226 "Received response for unknown correlation ID"
227 );
228 }
229
230 Ok(())
231 }
232
233 pub fn pending_count(&self) -> usize {
235 self.pending_requests.lock().unwrap().len()
236 }
237
238 async fn cleanup_expired_requests(
240 pending_requests: &Arc<Mutex<HashMap<CorrelationId, PendingRequest>>>,
241 ) {
242 let now = Instant::now();
243 let mut expired_ids = Vec::new();
244
245 {
246 let pending = pending_requests.lock().unwrap();
247 for (correlation_id, request) in pending.iter() {
248 if now.duration_since(request.created_at) > request.timeout {
249 expired_ids.push(correlation_id.clone());
250 }
251 }
252 }
253
254 if !expired_ids.is_empty() {
255 let mut pending = pending_requests.lock().unwrap();
256 for correlation_id in expired_ids {
257 if let Some(expired_request) = pending.remove(&correlation_id) {
258 let _ = expired_request.sender.send(ResponseMessage::error(
259 correlation_id.clone(),
260 "Request timeout".to_string(),
261 ));
262
263 warn!(
264 correlation_id = %correlation_id,
265 "Cleaned up expired request"
266 );
267 }
268 }
269 }
270 }
271}
272
273pub struct RequestResponseServer {
275 handler: Arc<dyn RequestHandler + Send + Sync>,
276}
277
278#[async_trait::async_trait]
280pub trait RequestHandler {
281 async fn handle_request(&self, request: RequestMessage) -> Result<ResponseMessage>;
282}
283
284impl RequestResponseServer {
285 pub fn new(handler: Arc<dyn RequestHandler + Send + Sync>) -> Self {
286 Self { handler }
287 }
288
289 pub async fn process_request(&self, request: RequestMessage) -> Result<ResponseMessage> {
291 let correlation_id = request.correlation_id.clone();
292
293 debug!(
294 correlation_id = %correlation_id,
295 "Processing incoming request"
296 );
297
298 let start_time = Instant::now();
299 let response = self.handler.handle_request(request).await;
300 let processing_time = start_time.elapsed();
301
302 match &response {
303 Ok(resp) => {
304 info!(
305 correlation_id = %correlation_id,
306 processing_time_ms = processing_time.as_millis(),
307 success = resp.success,
308 "Request processed"
309 );
310 }
311 Err(err) => {
312 error!(
313 correlation_id = %correlation_id,
314 processing_time_ms = processing_time.as_millis(),
315 error = %err,
316 "Request processing failed"
317 );
318 return Ok(ResponseMessage::error(correlation_id, err.to_string()));
319 }
320 }
321
322 response
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use tokio::time::sleep;
330
331 struct TestHandler;
332
333 #[async_trait::async_trait]
334 impl RequestHandler for TestHandler {
335 async fn handle_request(&self, request: RequestMessage) -> Result<ResponseMessage> {
336 let payload = format!("Echo: {}", String::from_utf8_lossy(&request.payload));
337 Ok(ResponseMessage::success(
338 request.correlation_id,
339 payload.into_bytes(),
340 ))
341 }
342 }
343
344 #[tokio::test]
345 async fn test_correlation_id_generation() {
346 let id1 = CorrelationId::new();
347 let id2 = CorrelationId::new();
348 assert_ne!(id1, id2);
349 }
350
351 #[tokio::test]
352 async fn test_request_message_creation() {
353 let payload = b"test message".to_vec();
354 let request = RequestMessage::new(
355 payload.clone(),
356 "reply.queue".to_string(),
357 Duration::from_secs(30),
358 );
359
360 assert_eq!(request.payload, payload);
361 assert_eq!(request.reply_to, "reply.queue");
362 assert_eq!(request.timeout, Duration::from_secs(30));
363 }
364
365 #[tokio::test]
366 async fn test_response_creation() {
367 let correlation_id = CorrelationId::new();
368 let payload = b"response".to_vec();
369
370 let success_response = ResponseMessage::success(correlation_id.clone(), payload.clone());
371 assert!(success_response.success);
372 assert_eq!(success_response.correlation_id, correlation_id);
373 assert_eq!(success_response.payload, payload);
374
375 let error_response = ResponseMessage::error(correlation_id.clone(), "Error".to_string());
376 assert!(!error_response.success);
377 assert_eq!(error_response.error_message, Some("Error".to_string()));
378 }
379
380 #[tokio::test]
381 async fn test_request_response_server() {
382 let handler = Arc::new(TestHandler);
383 let server = RequestResponseServer::new(handler);
384
385 let request = RequestMessage::new(
386 b"hello".to_vec(),
387 "reply.queue".to_string(),
388 Duration::from_secs(30),
389 );
390 let correlation_id = request.correlation_id.clone();
391
392 let response = server.process_request(request).await.unwrap();
393 assert_eq!(response.correlation_id, correlation_id);
394 assert!(response.success);
395 assert_eq!(String::from_utf8_lossy(&response.payload), "Echo: hello");
396 }
397
398 #[tokio::test]
399 async fn test_pending_requests_cleanup() {
400 let client = RequestResponseClient::new(Duration::from_millis(100));
401
402 let result = client
404 .send_request(
405 b"test".to_vec(),
406 "reply.queue".to_string(),
407 Some(Duration::from_millis(50)),
408 )
409 .await;
410
411 assert!(result.is_err());
413
414 sleep(Duration::from_millis(200)).await;
416
417 assert_eq!(client.pending_count(), 0);
419 }
420}