1use crate::api::types::*;
3use crate::config::CONFIG;
4use crate::error::{BotError, Result};
5use bytes::Bytes;
6use rand::Rng;
7use reqwest::{
8 Body, Client, ClientBuilder, StatusCode, Url,
9 multipart::{Form, Part},
10};
11use std::time::Duration;
12use tokio::fs::File;
13use tokio::signal;
14use tokio::time::sleep;
15use tokio_util::codec::{BytesCodec, FramedRead};
16use tracing::{debug, error, trace, warn};
17#[derive(Debug, Clone)]
19pub struct ConnectionPool {
20 client: Client,
21 retries: usize,
22 max_backoff: Duration,
23}
24
25impl Default for ConnectionPool {
26 fn default() -> Self {
27 let cfg = &CONFIG.network;
28 Self::new(
29 Client::new(),
30 cfg.retries,
31 Duration::from_millis(cfg.max_backoff_ms),
32 )
33 }
34}
35
36impl ConnectionPool {
37 pub fn new(client: Client, retries: usize, max_backoff: Duration) -> Self {
39 Self {
40 client,
41 retries,
42 max_backoff,
43 }
44 }
45
46 pub fn optimized() -> Self {
48 let cfg = &CONFIG.network;
49 let client = build_optimized_client().unwrap_or_else(|e| {
50 warn!(
51 "Failed to build optimized client. Use default instead: {}",
52 e
53 );
54 Client::new()
55 });
56 let retries = cfg.retries;
57 let max_backoff = Duration::from_millis(cfg.max_backoff_ms);
58
59 Self {
60 client,
61 retries,
62 max_backoff,
63 }
64 }
65
66 pub async fn execute_with_retry<F, Fut, T>(&self, operation: F) -> Result<T>
68 where
69 F: Fn() -> Fut + Send + Sync,
70 Fut: std::future::Future<Output = Result<T>> + Send,
71 T: Send,
72 {
73 let mut retries = 0;
74 let mut backoff_ms = 100;
75
76 loop {
77 match operation().await {
78 Ok(result) => return Ok(result),
79 Err(e) => {
80 if let BotError::Network(ref req_err) = e {
81 if !should_retry(req_err) || retries >= self.retries {
82 return Err(e);
83 }
84
85 retries += 1;
86 let jitter = rand::random::<u64>() % 100;
87 let delay = Duration::from_millis(backoff_ms + jitter);
88
89 warn!(
90 "Request failed, retrying ({}/{}): {} after {:?}",
91 retries, self.retries, req_err, delay
92 );
93
94 sleep(delay).await;
95 backoff_ms =
96 std::cmp::min(backoff_ms * 2, self.max_backoff.as_millis() as u64);
97 } else {
98 return Err(e);
99 }
100 }
101 }
102 }
103 }
104
105 #[tracing::instrument(skip(self))]
107 pub async fn get_text(&self, url: Url) -> Result<String> {
108 debug!("Getting response from API at path {}...", url);
109
110 let url_str = url.as_str().to_string();
111 self.execute_with_retry(move || {
112 let client = self.client.clone();
113 let url_str = url_str.clone();
114
115 async move {
116 let response = client.get(&url_str).send().await?;
117 trace!("Response status: {}", response.status());
118
119 validate_response(&response.status())?;
120
121 let text = response.text().await?;
122 trace!("Response body length: {} bytes", text.len());
123 Ok(text)
124 }
125 })
126 .await
127 }
128
129 #[tracing::instrument(skip(self))]
131 pub async fn get_bytes(&self, url: Url) -> Result<Vec<u8>> {
132 debug!("Getting binary response from API at path {}...", url);
133
134 let url_str = url.as_str().to_string();
135 self.execute_with_retry(move || {
136 let client = self.client.clone();
137 let url_str = url_str.clone();
138
139 async move {
140 let response = client.get(&url_str).send().await?;
141 trace!("Response status: {}", response.status());
142
143 validate_response(&response.status())?;
144
145 let bytes = response.bytes().await?;
146 trace!("Response body size: {} bytes", bytes.len());
147 Ok(bytes.to_vec())
148 }
149 })
150 .await
151 }
152
153 #[tracing::instrument(skip(self, retryable_form))]
155 pub async fn post_file_retryable(
156 &self,
157 url: Url,
158 retryable_form: &RetryableMultipartForm,
159 ) -> Result<String> {
160 debug!(
161 "Sending file to API at path {} (size: {} bytes)...",
162 url,
163 retryable_form.size()
164 );
165
166 let mut attempts = 0;
167 let max_attempts = self.retries + 1;
168
169 loop {
170 attempts += 1;
171
172 let form = retryable_form.to_form();
174
175 trace!("Attempt {} of {}", attempts, max_attempts);
176
177 let response = self.client.post(url.as_str()).multipart(form).send().await;
178
179 match response {
180 Ok(response) => {
181 trace!("Response status: {}", response.status());
182
183 if let Err(e) = validate_response(&response.status()) {
185 if attempts >= max_attempts || !should_retry_status(&response.status()) {
186 return Err(e);
187 }
188
189 let backoff = calculate_backoff_duration(attempts, self.max_backoff);
190 warn!(
191 "HTTP error {}, retrying in {:?} (attempt {} of {})",
192 response.status(),
193 backoff,
194 attempts,
195 max_attempts
196 );
197 sleep(backoff).await;
198 continue;
199 }
200
201 let text = response.text().await?;
203 trace!("Response body length: {} bytes", text.len());
204 debug!("File uploaded successfully after {} attempt(s)", attempts);
205 return Ok(text);
206 }
207 Err(err) => {
208 if attempts >= max_attempts || !should_retry(&err) {
209 error!("File upload failed after {} attempt(s): {}", attempts, err);
210 return Err(BotError::Network(err));
211 }
212
213 let backoff = calculate_backoff_duration(attempts, self.max_backoff);
214 warn!(
215 "File upload failed, retrying in {:?} (attempt {} of {}): {}",
216 backoff, attempts, max_attempts, err
217 );
218 sleep(backoff).await;
219 }
220 }
221 }
222 }
223
224 #[tracing::instrument(skip(self, form))]
226 pub async fn post_file(&self, url: Url, form: Form) -> Result<String> {
227 debug!(
228 "Sending file to API at path {} (legacy method, no retry)...",
229 url
230 );
231
232 let response = self.client.post(url.as_str()).multipart(form).send().await;
233
234 match response {
235 Ok(response) => {
236 trace!("Response status: {}", response.status());
237 validate_response(&response.status())?;
238 let text = response.text().await?;
239 trace!("Response body length: {} bytes", text.len());
240 Ok(text)
241 }
242 Err(err) => {
243 warn!("File upload failed (no retry available): {}", err);
244 Err(BotError::Network(err))
245 }
246 }
247 }
248}
249
250fn validate_response(status: &StatusCode) -> Result<()> {
252 if status.is_success() {
253 Ok(())
254 } else if status.is_server_error() {
255 warn!("Server error: {}", status);
256 Err(BotError::System(format!("Server error: HTTP {status}")))
257 } else if status.is_client_error() {
258 error!("Client error: {}", status);
259 Err(BotError::Validation(format!("HTTP error: {status}")))
260 } else {
261 warn!("Unexpected status code: {}", status);
262 Err(BotError::System(format!(
263 "Unexpected HTTP status code: {status}"
264 )))
265 }
266}
267
268fn should_retry(err: &reqwest::Error) -> bool {
270 err.is_timeout()
271 || err.is_connect()
272 || err.is_request()
273 || (err.status().is_some_and(|s| s.is_server_error()))
274}
275
276pub fn should_retry_status(status: &StatusCode) -> bool {
278 match status.as_u16() {
279 500..=599 => true,
281 429 => true,
283 408 | 409 | 423 | 424 => true,
285 _ => false,
287 }
288}
289
290pub fn calculate_backoff_duration(attempt: usize, max_backoff: Duration) -> Duration {
292 let base_duration = Duration::from_millis(100); let exponential_backoff = base_duration * 2_u32.pow((attempt - 1) as u32);
294
295 let capped_backoff = std::cmp::min(exponential_backoff, max_backoff);
297
298 let jitter_range = capped_backoff.as_millis() / 4; let mut rng = rand::rng();
301 let jitter = rng.random_range(0..=(jitter_range as u64 * 2));
302 let jitter_offset = jitter as i64 - jitter_range as i64;
303
304 let final_duration = (capped_backoff.as_millis() as i64 + jitter_offset).max(0) as u64;
305 Duration::from_millis(final_duration)
306}
307
308fn build_optimized_client() -> Result<Client> {
310 let cfg = &CONFIG.network;
311 let builder = ClientBuilder::new()
312 .timeout(Duration::from_secs(cfg.request_timeout_secs))
313 .connect_timeout(Duration::from_secs(cfg.connect_timeout_secs))
314 .pool_idle_timeout(Duration::from_secs(cfg.pool_idle_timeout_secs))
315 .tcp_nodelay(true)
316 .pool_max_idle_per_host(cfg.max_idle_connections)
317 .use_rustls_tls();
318
319 builder.build().map_err(BotError::Network)
320}
321#[tracing::instrument(skip(client))]
330pub async fn get_bytes_response(client: Client, url: Url) -> Result<Vec<u8>> {
331 debug!("Getting binary response from API at path {}...", url);
332 let response = client.get(url.as_str()).send().await?;
333 trace!("Response status: {}", response.status());
334 let bytes = response.bytes().await?;
335 Ok(bytes.to_vec())
336}
337#[tracing::instrument(skip(file))]
344pub async fn file_to_retryable_multipart(file: &MultipartName) -> Result<RetryableMultipartForm> {
347 match file {
348 MultipartName::FilePath(path) | MultipartName::ImagePath(path) => {
349 RetryableMultipartForm::from_file_path(path.clone()).await
350 }
351 MultipartName::FileContent { filename, content }
352 | MultipartName::ImageContent { filename, content } => {
353 validate_filename(filename)?;
355
356 if content.is_empty() {
358 return Err(BotError::Validation(
359 "File content cannot be empty".to_string(),
360 ));
361 }
362
363 Ok(RetryableMultipartForm::from_content(
364 filename.clone(),
365 filename.clone(),
366 content.clone(),
367 ))
368 }
369 _ => Err(BotError::Validation("File not specified".to_string())),
370 }
371}
372
373pub async fn file_to_multipart(file: &MultipartName) -> Result<Form> {
376 match file {
378 MultipartName::FilePath(name) | MultipartName::ImagePath(name) => {
379 validate_file_path(name)?;
381
382 let file_stream = make_stream(name).await?;
383 let part = Part::stream(file_stream).file_name(name.to_string());
384 Ok(Form::new().part(name.to_string(), part))
385 }
386 MultipartName::FileContent { filename, content }
387 | MultipartName::ImageContent { filename, content } => {
388 validate_filename(filename)?;
390
391 if content.is_empty() {
393 return Err(BotError::Validation(
394 "File content cannot be empty".to_string(),
395 ));
396 }
397
398 let part = Part::bytes(content.clone()).file_name(filename.clone());
399 Ok(Form::new().part(filename.to_string(), part))
400 }
401 _ => Err(BotError::Validation("File not specified".to_string())),
402 }
403}
404#[tracing::instrument(skip(path))]
410async fn make_stream(path: &String) -> Result<Body> {
411 let file = File::open(path).await?;
413 let file_stream = Body::wrap_stream(FramedRead::new(file, BytesCodec::new()));
415 Ok(file_stream)
416}
417pub async fn shutdown_signal() {
422 let ctrl_c = async {
423 signal::ctrl_c()
424 .await
425 .map_err(|e| BotError::System(format!("Failed to set up Ctrl+C handler: {e}")))
426 .unwrap_or_else(|e| panic!("{}", e));
427 };
428
429 #[cfg(unix)]
430 let terminate = async {
431 signal::unix::signal(signal::unix::SignalKind::terminate())
432 .map_err(|e| BotError::System(format!("Failed to set up signal handler: {e}")))
433 .unwrap_or_else(|e| panic!("{}", e))
434 .recv()
435 .await;
436 };
437
438 #[cfg(not(unix))]
439 let terminate = std::future::pending::<()>();
440
441 tokio::select! {
442 _ = ctrl_c => {},
443 _ = terminate => {},
444 }
445}
446
447#[cfg(test)]
449mod tests {
450 use super::*;
451 use reqwest::StatusCode;
452 use std::time::Duration;
453
454 #[tokio::test]
455 async fn test_connection_pool_new_and_default() {
456 let client = reqwest::Client::new();
457 let pool = ConnectionPool::new(client.clone(), 2, Duration::from_millis(100));
458 assert_eq!(pool.retries, 2);
459 assert_eq!(pool.max_backoff, Duration::from_millis(100));
460 let _default = ConnectionPool::default();
461 }
462
463 #[tokio::test]
464 async fn test_validate_response_success() {
465 assert!(validate_response(&StatusCode::OK).is_ok());
466 }
467
468 #[tokio::test]
469 async fn test_validate_response_client_error() {
470 let err = validate_response(&StatusCode::BAD_REQUEST).unwrap_err();
471 match err {
472 BotError::Validation(msg) => assert!(msg.contains("HTTP error")),
473 _ => panic!("Expected Validation error"),
474 }
475 }
476
477 #[tokio::test]
478 async fn test_validate_response_server_error() {
479 let err = validate_response(&StatusCode::INTERNAL_SERVER_ERROR).unwrap_err();
480 match err {
481 BotError::System(msg) => assert!(msg.contains("Server error")),
482 _ => panic!("Expected System error"),
483 }
484 }
485
486 #[tokio::test]
487 async fn test_validate_response_unexpected_status() {
488 let status = StatusCode::SWITCHING_PROTOCOLS;
489 let err = validate_response(&status).unwrap_err();
490 match err {
491 BotError::System(msg) => assert!(msg.contains("Unexpected HTTP status code")),
492 _ => panic!("Expected System error"),
493 }
494 }
495
496 #[tokio::test]
497 async fn test_should_retry_timeout() {
498 let err = reqwest::ClientBuilder::new()
499 .timeout(Duration::from_millis(1))
500 .build()
501 .unwrap()
502 .get("http://httpbin.org/delay/10")
503 .send()
504 .await
505 .unwrap_err();
506
507 assert!(should_retry(&err));
509 }
510
511 #[tokio::test]
512 async fn test_should_retry_server_error() {
513 let client = reqwest::Client::new();
515 let response = client.get("http://httpbin.org/status/500").send().await;
516
517 if let Err(err) = response {
518 assert!(should_retry(&err));
519 }
520 }
521
522 #[tokio::test]
523 async fn test_build_optimized_client() {
524 let result = build_optimized_client();
525 assert!(
526 result.is_ok(),
527 "Failed to build optimized client: {:?}",
528 result.err()
529 );
530
531 let client = result.unwrap();
532 assert!(client.get("https://example.com").build().is_ok());
534 }
535
536 #[tokio::test]
537 async fn test_connection_pool_optimized() {
538 let pool = ConnectionPool::optimized();
539 assert!(pool.retries > 0);
540 assert!(pool.max_backoff > Duration::from_millis(0));
541 }
542
543 #[tokio::test]
544 async fn test_connection_pool_execute_with_retry_success() {
545 let pool = ConnectionPool::new(reqwest::Client::new(), 2, Duration::from_millis(100));
546
547 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
548 let counter_clone = counter.clone();
549
550 let result = pool
551 .execute_with_retry(|| {
552 let counter = counter_clone.clone();
553 async move {
554 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
555 Ok::<String, BotError>("success".to_string())
556 }
557 })
558 .await;
559
560 assert!(result.is_ok());
561 assert_eq!(result.unwrap(), "success");
562 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
563 }
564
565 #[tokio::test]
566 async fn test_connection_pool_execute_with_retry_failure() {
567 let pool = ConnectionPool::new(reqwest::Client::new(), 0, Duration::from_millis(10));
568
569 let result = pool
570 .execute_with_retry(|| async {
571 Err::<String, BotError>(BotError::Network(
572 reqwest::ClientBuilder::new()
573 .build()
574 .unwrap()
575 .get("http://invalid-url-that-does-not-exist.invalid")
576 .send()
577 .await
578 .unwrap_err(),
579 ))
580 })
581 .await;
582
583 assert!(result.is_err());
584 }
585
586 #[tokio::test]
587 async fn test_connection_pool_execute_with_retry_non_retryable_error() {
588 let pool = ConnectionPool::new(reqwest::Client::new(), 2, Duration::from_millis(10));
589
590 let result = pool
591 .execute_with_retry(|| async {
592 Err::<String, BotError>(BotError::Validation("Non-retryable error".to_string()))
593 })
594 .await;
595
596 assert!(result.is_err());
597 match result.unwrap_err() {
598 BotError::Validation(msg) => assert_eq!(msg, "Non-retryable error"),
599 _ => panic!("Expected Validation error"),
600 }
601 }
602
603 #[tokio::test]
604 async fn test_file_to_multipart_filepath() {
605 use std::io::Write;
606 use tempfile::NamedTempFile;
607
608 let mut temp_file = NamedTempFile::new().unwrap();
610 write!(temp_file, "test content").unwrap();
611 let temp_path = temp_file.path().to_string_lossy().to_string();
612
613 let multipart = MultipartName::FilePath(temp_path);
614 let result = file_to_multipart(&multipart).await;
615
616 assert!(
617 result.is_ok(),
618 "Failed to create multipart: {:?}",
619 result.err()
620 );
621 }
622
623 #[tokio::test]
624 async fn test_file_to_multipart_file_content() {
625 let multipart = MultipartName::FileContent {
626 filename: "test.txt".to_string(),
627 content: b"test content".to_vec(),
628 };
629
630 let result = file_to_multipart(&multipart).await;
631 assert!(
632 result.is_ok(),
633 "Failed to create multipart from content: {:?}",
634 result.err()
635 );
636 }
637
638 #[tokio::test]
639 async fn test_file_to_multipart_image_content() {
640 let multipart = MultipartName::ImageContent {
641 filename: "test.jpg".to_string(),
642 content: b"fake image content".to_vec(),
643 };
644
645 let result = file_to_multipart(&multipart).await;
646 assert!(
647 result.is_ok(),
648 "Failed to create multipart from image content: {:?}",
649 result.err()
650 );
651 }
652
653 #[tokio::test]
654 async fn test_file_to_multipart_invalid() {
655 let multipart = MultipartName::FilePath("/non/existent/file.txt".to_string());
656 let result = file_to_multipart(&multipart).await;
657
658 assert!(result.is_err());
659 match result.unwrap_err() {
660 BotError::Validation(msg) => assert!(msg.contains("File does not exist")),
661 _ => panic!("Expected Validation error"),
662 }
663 }
664
665 #[tokio::test]
666 async fn test_file_to_multipart_path_traversal() {
667 let multipart = MultipartName::FilePath("../../../etc/passwd".to_string());
668 let result = file_to_multipart(&multipart).await;
669
670 assert!(result.is_err());
671 match result.unwrap_err() {
672 BotError::Validation(msg) => assert!(msg.contains("parent directory references")),
673 _ => panic!("Expected Validation error"),
674 }
675 }
676
677 #[tokio::test]
678 async fn test_file_to_multipart_empty_path() {
679 let multipart = MultipartName::FilePath("".to_string());
680 let result = file_to_multipart(&multipart).await;
681
682 assert!(result.is_err());
683 match result.unwrap_err() {
684 BotError::Validation(msg) => assert_eq!(msg, "File path cannot be empty"),
685 _ => panic!("Expected Validation error"),
686 }
687 }
688
689 #[tokio::test]
690 async fn test_file_to_multipart_invalid_filename() {
691 let multipart = MultipartName::FileContent {
692 filename: "file<name>.txt".to_string(), content: b"test content".to_vec(),
694 };
695
696 let result = file_to_multipart(&multipart).await;
697 assert!(result.is_err());
698 match result.unwrap_err() {
699 BotError::Validation(msg) => assert!(msg.contains("forbidden character")),
700 _ => panic!("Expected Validation error"),
701 }
702 }
703
704 #[tokio::test]
705 async fn test_file_to_multipart_empty_content() {
706 let multipart = MultipartName::FileContent {
707 filename: "empty.txt".to_string(),
708 content: Vec::new(), };
710
711 let result = file_to_multipart(&multipart).await;
712 assert!(result.is_err());
713 match result.unwrap_err() {
714 BotError::Validation(msg) => assert_eq!(msg, "File content cannot be empty"),
715 _ => panic!("Expected Validation error"),
716 }
717 }
718
719 #[tokio::test]
720 async fn test_validate_filename_reserved_names() {
721 let reserved_names = ["CON", "PRN", "AUX", "NUL", "COM1", "LPT1"];
722
723 for name in reserved_names.iter() {
724 let result = validate_filename(name);
725 assert!(result.is_err());
726 match result.unwrap_err() {
727 BotError::Validation(msg) => assert!(msg.contains("reserved name")),
728 _ => panic!("Expected Validation error for {name}"),
729 }
730 }
731 }
732
733 #[tokio::test]
734 async fn test_validate_filename_valid() {
735 let valid_names = ["document.txt", "image.jpg", "data.json", "archive.zip"];
736
737 for name in valid_names.iter() {
738 let result = validate_filename(name);
739 assert!(result.is_ok(), "Filename {name} should be valid");
740 }
741 }
742
743 #[tokio::test]
744 async fn test_make_stream_valid_file() {
745 use std::io::Write;
746 use tempfile::NamedTempFile;
747
748 let mut temp_file = NamedTempFile::new().unwrap();
750 write!(temp_file, "test stream content").unwrap();
751 let temp_path = temp_file.path().to_string_lossy().to_string();
752
753 let result = make_stream(&temp_path).await;
754 assert!(
755 result.is_ok(),
756 "Failed to create stream: {:?}",
757 result.err()
758 );
759 }
760
761 #[tokio::test]
762 async fn test_make_stream_invalid_file() {
763 let invalid_path = "/path/that/does/not/exist/file.txt".to_string();
764 let result = make_stream(&invalid_path).await;
765
766 assert!(result.is_err());
767 match result.unwrap_err() {
768 BotError::Io(_) => {} _ => panic!("Expected IO error"),
770 }
771 }
772
773 #[tokio::test]
774 async fn test_validate_response_all_success_codes() {
775 let success_codes = [
776 StatusCode::OK,
777 StatusCode::CREATED,
778 StatusCode::ACCEPTED,
779 StatusCode::NO_CONTENT,
780 ];
781
782 for code in success_codes.iter() {
783 assert!(
784 validate_response(code).is_ok(),
785 "Status code {code:?} should be valid"
786 );
787 }
788 }
789
790 #[tokio::test]
791 async fn test_validate_response_all_client_error_codes() {
792 let client_error_codes = [
793 StatusCode::BAD_REQUEST,
794 StatusCode::UNAUTHORIZED,
795 StatusCode::FORBIDDEN,
796 StatusCode::NOT_FOUND,
797 StatusCode::METHOD_NOT_ALLOWED,
798 ];
799
800 for code in client_error_codes.iter() {
801 let result = validate_response(code);
802 assert!(result.is_err(), "Status code {code:?} should be error");
803 match result.unwrap_err() {
804 BotError::Validation(_) => {} _ => panic!("Expected Validation error for code {code:?}"),
806 }
807 }
808 }
809
810 #[tokio::test]
811 async fn test_validate_response_all_server_error_codes() {
812 let server_error_codes = [
813 StatusCode::INTERNAL_SERVER_ERROR,
814 StatusCode::NOT_IMPLEMENTED,
815 StatusCode::BAD_GATEWAY,
816 StatusCode::SERVICE_UNAVAILABLE,
817 StatusCode::GATEWAY_TIMEOUT,
818 ];
819
820 for code in server_error_codes.iter() {
821 let result = validate_response(code);
822 assert!(result.is_err(), "Status code {code:?} should be error");
823 match result.unwrap_err() {
824 BotError::System(_) => {} _ => panic!("Expected System error for code {code:?}"),
826 }
827 }
828 }
829
830 #[tokio::test]
831 async fn test_connection_pool_clone() {
832 let pool1 = ConnectionPool::new(reqwest::Client::new(), 3, Duration::from_millis(200));
833 let pool2 = pool1.clone();
834
835 assert_eq!(pool1.retries, pool2.retries);
836 assert_eq!(pool1.max_backoff, pool2.max_backoff);
837 }
838
839 #[test]
840 fn test_connection_pool_debug() {
841 let pool = ConnectionPool::new(reqwest::Client::new(), 2, Duration::from_millis(100));
842 let debug_str = format!("{pool:?}");
843 assert!(debug_str.contains("ConnectionPool"));
844 }
845
846 #[tokio::test]
847 async fn test_deprecated_get_bytes_response() {
848 let client = reqwest::Client::new();
850 let url = reqwest::Url::parse("https://httpbin.org/bytes/10").unwrap();
851
852 let result = get_bytes_response(client, url).await;
853 if let Ok(bytes) = result {
855 assert!(!bytes.is_empty());
856 } }
858
859 #[tokio::test]
860 async fn test_shutdown_signal_setup() {
861 let signal_task = tokio::spawn(async {
865 tokio::time::timeout(Duration::from_millis(100), shutdown_signal()).await
866 });
867
868 let result = signal_task.await.unwrap();
870 assert!(result.is_err()); }
872}
873fn validate_file_path(path: &str) -> Result<()> {
878 use std::path::Path;
879
880 if path.is_empty() {
882 return Err(BotError::Validation(
883 "File path cannot be empty".to_string(),
884 ));
885 }
886
887 if path.contains('\0') {
889 return Err(BotError::Validation(
890 "File path contains null bytes".to_string(),
891 ));
892 }
893
894 let path_obj = Path::new(path);
896
897 for component in path_obj.components() {
899 match component {
900 std::path::Component::ParentDir => {
901 return Err(BotError::Validation(
902 "File path contains parent directory references (..)".to_string(),
903 ));
904 }
905 std::path::Component::CurDir => {
906 return Err(BotError::Validation(
907 "File path contains current directory references (.)".to_string(),
908 ));
909 }
910 _ => {}
911 }
912 }
913
914 if path_obj.is_absolute() {
916 if !path_obj.exists() {
918 return Err(BotError::Validation(format!("File does not exist: {path}")));
919 }
920
921 if !path_obj.is_file() {
922 return Err(BotError::Validation(format!("Path is not a file: {path}")));
923 }
924 }
925
926 #[cfg(target_os = "windows")]
928 const MAX_PATH_LEN: usize = 260;
929 #[cfg(not(target_os = "windows"))]
930 const MAX_PATH_LEN: usize = 4096;
931
932 if path.len() > MAX_PATH_LEN {
933 return Err(BotError::Validation(format!(
934 "File path too long: {} characters (max: {})",
935 path.len(),
936 MAX_PATH_LEN
937 )));
938 }
939
940 Ok(())
941}
942
943fn validate_filename(filename: &str) -> Result<()> {
948 if filename.is_empty() {
950 return Err(BotError::Validation("Filename cannot be empty".to_string()));
951 }
952
953 if filename.contains('\0') {
955 return Err(BotError::Validation(
956 "Filename contains null bytes".to_string(),
957 ));
958 }
959
960 const FORBIDDEN_CHARS: &[char] = &['/', '\\', ':', '*', '?', '"', '<', '>', '|'];
962 for &forbidden_char in FORBIDDEN_CHARS {
963 if filename.contains(forbidden_char) {
964 return Err(BotError::Validation(format!(
965 "Filename contains forbidden character: '{forbidden_char}'"
966 )));
967 }
968 }
969
970 const RESERVED_NAMES: &[&str] = &[
972 "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8",
973 "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
974 ];
975
976 let filename_upper = filename.to_uppercase();
977 let name_without_ext = filename_upper.split('.').next().unwrap_or("");
978
979 if RESERVED_NAMES.contains(&name_without_ext) {
980 return Err(BotError::Validation(format!(
981 "Filename uses reserved name: {filename}"
982 )));
983 }
984
985 const MAX_FILENAME_LEN: usize = 255;
987 if filename.len() > MAX_FILENAME_LEN {
988 return Err(BotError::Validation(format!(
989 "Filename too long: {} characters (max: {})",
990 filename.len(),
991 MAX_FILENAME_LEN
992 )));
993 }
994
995 if filename.starts_with('.') && filename != "." && filename != ".." {
997 }
999
1000 if filename.ends_with(' ') || filename.ends_with('.') {
1001 return Err(BotError::Validation(
1002 "Filename cannot end with space or dot".to_string(),
1003 ));
1004 }
1005
1006 Ok(())
1007}
1008#[derive(Debug, Clone)]
1010pub struct RetryableMultipartForm {
1011 file_data: Bytes,
1012 pub filename: String,
1013 field_name: String,
1014}
1015
1016impl RetryableMultipartForm {
1017 pub fn from_content(filename: String, field_name: String, content: Vec<u8>) -> Self {
1019 Self {
1020 file_data: Bytes::from(content),
1021 filename,
1022 field_name,
1023 }
1024 }
1025
1026 pub async fn from_file_path(path: String) -> Result<Self> {
1028 validate_file_path_async(&path).await?;
1030
1031 let content = tokio::fs::read(&path).await.map_err(BotError::Io)?;
1033
1034 let filename = std::path::Path::new(&path)
1035 .file_name()
1036 .and_then(|name| name.to_str())
1037 .unwrap_or(&path)
1038 .to_string();
1039
1040 Ok(Self::from_content(filename.clone(), filename, content))
1041 }
1042
1043 pub fn to_form(&self) -> Form {
1045 let part = Part::bytes(self.file_data.clone().to_vec()).file_name(self.filename.clone());
1046 Form::new().part(self.field_name.clone(), part)
1047 }
1048
1049 pub fn size(&self) -> usize {
1051 self.file_data.len()
1052 }
1053}
1054
1055pub async fn validate_file_path_async(path: &str) -> Result<()> {
1060 if path.is_empty() {
1062 return Err(BotError::Validation(
1063 "File path cannot be empty".to_string(),
1064 ));
1065 }
1066
1067 if path.contains('\0') {
1069 return Err(BotError::Validation(
1070 "File path contains null bytes".to_string(),
1071 ));
1072 }
1073
1074 let path_obj = std::path::Path::new(path);
1076
1077 for component in path_obj.components() {
1079 match component {
1080 std::path::Component::ParentDir => {
1081 return Err(BotError::Validation(
1082 "File path contains parent directory references (..)".to_string(),
1083 ));
1084 }
1085 std::path::Component::CurDir => {
1086 return Err(BotError::Validation(
1087 "File path contains current directory references (.)".to_string(),
1088 ));
1089 }
1090 _ => {}
1091 }
1092 }
1093
1094 if path_obj.is_absolute() {
1096 let metadata = tokio::fs::metadata(path)
1098 .await
1099 .map_err(|e| BotError::Validation(format!("File does not exist: {path} ({e})")))?;
1100
1101 if !metadata.is_file() {
1102 return Err(BotError::Validation(format!("Path is not a file: {path}")));
1103 }
1104
1105 let _canonical = tokio::fs::canonicalize(path)
1107 .await
1108 .map_err(|e| BotError::Validation(format!("Cannot access file: {path} ({e})")))?;
1109 }
1110
1111 #[cfg(target_os = "windows")]
1113 const MAX_PATH_LEN: usize = 260;
1114 #[cfg(not(target_os = "windows"))]
1115 const MAX_PATH_LEN: usize = 4096;
1116
1117 if path.len() > MAX_PATH_LEN {
1118 return Err(BotError::Validation(format!(
1119 "File path too long: {} characters (max: {})",
1120 path.len(),
1121 MAX_PATH_LEN
1122 )));
1123 }
1124
1125 Ok(())
1126}