1use crate::{
4 BatchError,
5 core::step::{RepeatStatus, StepExecution, Tasklet},
6 tasklet::s3::{S3ClientConfig, build_s3_client},
7};
8use aws_sdk_s3::primitives::ByteStream;
9use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
10use log::{debug, info};
11use std::path::{Path, PathBuf};
12use tokio::runtime::Handle;
13
14const DEFAULT_CHUNK_SIZE: usize = 8 * 1024 * 1024; #[derive(Debug)]
43pub struct S3PutTasklet {
44 bucket: String,
45 key: String,
46 local_file: PathBuf,
47 chunk_size: usize,
48 config: S3ClientConfig,
49}
50
51impl S3PutTasklet {
52 async fn execute_async(&self) -> Result<RepeatStatus, BatchError> {
53 info!(
54 "Uploading {} to s3://{}/{}",
55 self.local_file.display(),
56 self.bucket,
57 self.key
58 );
59
60 let client = build_s3_client(&self.config).await?;
61 let file_size = std::fs::metadata(&self.local_file)
62 .map_err(BatchError::Io)?
63 .len() as usize;
64
65 if file_size < self.chunk_size {
66 let body = ByteStream::from_path(&self.local_file).await.map_err(|e| {
68 BatchError::ItemWriter(format!("Failed to read file for upload: {}", e))
69 })?;
70
71 client
72 .put_object()
73 .bucket(&self.bucket)
74 .key(&self.key)
75 .body(body)
76 .send()
77 .await
78 .map_err(|e| BatchError::ItemWriter(format!("S3 put_object failed: {}", e)))?;
79 } else {
80 upload_multipart(
82 &client,
83 &self.bucket,
84 &self.key,
85 &self.local_file,
86 self.chunk_size,
87 )
88 .await?;
89 }
90
91 info!("Upload complete: s3://{}/{}", self.bucket, self.key);
92 Ok(RepeatStatus::Finished)
93 }
94}
95
96impl Tasklet for S3PutTasklet {
97 fn execute(&self, _step_execution: &StepExecution) -> Result<RepeatStatus, BatchError> {
98 tokio::task::block_in_place(|| Handle::current().block_on(self.execute_async()))
99 }
100}
101
102#[derive(Debug, Default)]
125pub struct S3PutTaskletBuilder {
126 bucket: Option<String>,
127 key: Option<String>,
128 local_file: Option<PathBuf>,
129 chunk_size: usize,
130 config: S3ClientConfig,
131}
132
133impl S3PutTaskletBuilder {
134 pub fn new() -> Self {
146 Self {
147 chunk_size: DEFAULT_CHUNK_SIZE,
148 ..Default::default()
149 }
150 }
151
152 pub fn bucket<S: Into<String>>(mut self, bucket: S) -> Self {
162 self.bucket = Some(bucket.into());
163 self
164 }
165
166 pub fn key<S: Into<String>>(mut self, key: S) -> Self {
176 self.key = Some(key.into());
177 self
178 }
179
180 pub fn local_file<P: AsRef<Path>>(mut self, path: P) -> Self {
190 self.local_file = Some(path.as_ref().to_path_buf());
191 self
192 }
193
194 pub fn region<S: Into<String>>(mut self, region: S) -> Self {
207 self.config.region = Some(region.into());
208 self
209 }
210
211 pub fn endpoint_url<S: Into<String>>(mut self, url: S) -> Self {
223 self.config.endpoint_url = Some(url.into());
224 self
225 }
226
227 pub fn access_key_id<S: Into<String>>(mut self, key_id: S) -> Self {
240 self.config.access_key_id = Some(key_id.into());
241 self
242 }
243
244 pub fn secret_access_key<S: Into<String>>(mut self, secret: S) -> Self {
256 self.config.secret_access_key = Some(secret.into());
257 self
258 }
259
260 pub fn chunk_size(mut self, size: usize) -> Self {
274 self.chunk_size = size;
275 self
276 }
277
278 pub fn build(self) -> Result<S3PutTasklet, BatchError> {
300 let bucket = self.bucket.ok_or_else(|| {
301 BatchError::Configuration("S3PutTasklet: 'bucket' is required".to_string())
302 })?;
303 let key = self.key.ok_or_else(|| {
304 BatchError::Configuration("S3PutTasklet: 'key' is required".to_string())
305 })?;
306 let local_file = self.local_file.ok_or_else(|| {
307 BatchError::Configuration("S3PutTasklet: 'local_file' is required".to_string())
308 })?;
309
310 if self.chunk_size < 5 * 1024 * 1024 {
311 return Err(BatchError::Configuration(
312 "S3PutTasklet: 'chunk_size' must be at least 5 MiB".to_string(),
313 ));
314 }
315
316 Ok(S3PutTasklet {
317 bucket,
318 key,
319 local_file,
320 chunk_size: self.chunk_size,
321 config: self.config,
322 })
323 }
324}
325
326#[derive(Debug)]
357pub struct S3PutFolderTasklet {
358 bucket: String,
359 prefix: String,
360 local_folder: PathBuf,
361 chunk_size: usize,
362 config: S3ClientConfig,
363}
364
365impl S3PutFolderTasklet {
366 async fn execute_async(&self) -> Result<RepeatStatus, BatchError> {
367 info!(
368 "Uploading folder {} to s3://{}/{}",
369 self.local_folder.display(),
370 self.bucket,
371 self.prefix
372 );
373
374 let client = build_s3_client(&self.config).await?;
375 let entries = collect_files(&self.local_folder)?;
376
377 for local_path in &entries {
378 let relative = local_path
379 .strip_prefix(&self.local_folder)
380 .map_err(|e| BatchError::Io(std::io::Error::other(e.to_string())))?;
381 let key = format!(
382 "{}{}",
383 self.prefix,
384 relative.to_string_lossy().replace('\\', "/")
385 );
386
387 let file_size = std::fs::metadata(local_path).map_err(BatchError::Io)?.len() as usize;
388
389 debug!(
390 "Uploading {} -> s3://{}/{}",
391 local_path.display(),
392 self.bucket,
393 key
394 );
395
396 if file_size < self.chunk_size {
397 let body = ByteStream::from_path(local_path).await.map_err(|e| {
398 BatchError::ItemWriter(format!(
399 "Failed to read {}: {}",
400 local_path.display(),
401 e
402 ))
403 })?;
404
405 client
406 .put_object()
407 .bucket(&self.bucket)
408 .key(&key)
409 .body(body)
410 .send()
411 .await
412 .map_err(|e| {
413 BatchError::ItemWriter(format!("S3 put_object failed for {}: {}", key, e))
414 })?;
415 } else {
416 upload_multipart(&client, &self.bucket, &key, local_path, self.chunk_size).await?;
417 }
418 }
419
420 info!(
421 "Folder upload complete: {} files uploaded to s3://{}/{}",
422 entries.len(),
423 self.bucket,
424 self.prefix
425 );
426 Ok(RepeatStatus::Finished)
427 }
428}
429
430impl Tasklet for S3PutFolderTasklet {
431 fn execute(&self, _step_execution: &StepExecution) -> Result<RepeatStatus, BatchError> {
432 tokio::task::block_in_place(|| Handle::current().block_on(self.execute_async()))
433 }
434}
435
436#[derive(Debug, Default)]
457pub struct S3PutFolderTaskletBuilder {
458 bucket: Option<String>,
459 prefix: Option<String>,
460 local_folder: Option<PathBuf>,
461 chunk_size: usize,
462 config: S3ClientConfig,
463}
464
465impl S3PutFolderTaskletBuilder {
466 pub fn new() -> Self {
478 Self {
479 chunk_size: DEFAULT_CHUNK_SIZE,
480 ..Default::default()
481 }
482 }
483
484 pub fn bucket<S: Into<String>>(mut self, bucket: S) -> Self {
494 self.bucket = Some(bucket.into());
495 self
496 }
497
498 pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self {
510 self.prefix = Some(prefix.into());
511 self
512 }
513
514 pub fn local_folder<P: AsRef<Path>>(mut self, path: P) -> Self {
524 self.local_folder = Some(path.as_ref().to_path_buf());
525 self
526 }
527
528 pub fn region<S: Into<String>>(mut self, region: S) -> Self {
540 self.config.region = Some(region.into());
541 self
542 }
543
544 pub fn endpoint_url<S: Into<String>>(mut self, url: S) -> Self {
554 self.config.endpoint_url = Some(url.into());
555 self
556 }
557
558 pub fn access_key_id<S: Into<String>>(mut self, key_id: S) -> Self {
568 self.config.access_key_id = Some(key_id.into());
569 self
570 }
571
572 pub fn secret_access_key<S: Into<String>>(mut self, secret: S) -> Self {
582 self.config.secret_access_key = Some(secret.into());
583 self
584 }
585
586 pub fn chunk_size(mut self, size: usize) -> Self {
596 self.chunk_size = size;
597 self
598 }
599
600 pub fn build(self) -> Result<S3PutFolderTasklet, BatchError> {
622 let bucket = self.bucket.ok_or_else(|| {
623 BatchError::Configuration("S3PutFolderTasklet: 'bucket' is required".to_string())
624 })?;
625 let prefix = self.prefix.ok_or_else(|| {
626 BatchError::Configuration("S3PutFolderTasklet: 'prefix' is required".to_string())
627 })?;
628 let local_folder = self.local_folder.ok_or_else(|| {
629 BatchError::Configuration("S3PutFolderTasklet: 'local_folder' is required".to_string())
630 })?;
631
632 if self.chunk_size < 5 * 1024 * 1024 {
633 return Err(BatchError::Configuration(
634 "S3PutFolderTasklet: 'chunk_size' must be at least 5 MiB".to_string(),
635 ));
636 }
637
638 Ok(S3PutFolderTasklet {
639 bucket,
640 prefix,
641 local_folder,
642 chunk_size: self.chunk_size,
643 config: self.config,
644 })
645 }
646}
647
648async fn upload_multipart(
656 client: &aws_sdk_s3::Client,
657 bucket: &str,
658 key: &str,
659 local_file: &Path,
660 chunk_size: usize,
661) -> Result<(), BatchError> {
662 let create_resp = client
663 .create_multipart_upload()
664 .bucket(bucket)
665 .key(key)
666 .send()
667 .await
668 .map_err(|e| {
669 BatchError::ItemWriter(format!("create_multipart_upload failed for {}: {}", key, e))
670 })?;
671
672 let upload_id = create_resp
673 .upload_id()
674 .ok_or_else(|| {
675 BatchError::ItemWriter("create_multipart_upload returned no upload_id".to_string())
676 })?
677 .to_string();
678
679 let result = upload_parts(client, bucket, key, &upload_id, local_file, chunk_size).await;
680
681 if let Err(e) = result {
682 let _ = client
684 .abort_multipart_upload()
685 .bucket(bucket)
686 .key(key)
687 .upload_id(&upload_id)
688 .send()
689 .await;
690 return Err(e);
691 }
692
693 Ok(())
694}
695
696async fn upload_parts(
698 client: &aws_sdk_s3::Client,
699 bucket: &str,
700 key: &str,
701 upload_id: &str,
702 local_file: &Path,
703 chunk_size: usize,
704) -> Result<(), BatchError> {
705 use std::io::Read;
706
707 let file = std::fs::File::open(local_file).map_err(BatchError::Io)?;
708 let mut reader = std::io::BufReader::new(file);
709 let mut part_number = 1i32;
710 let mut completed_parts = Vec::new();
711
712 loop {
713 let mut buffer = Vec::with_capacity(chunk_size);
714 let bytes_read = reader
715 .by_ref()
716 .take(chunk_size as u64)
717 .read_to_end(&mut buffer)
718 .map_err(BatchError::Io)?;
719 if bytes_read == 0 {
720 break;
721 }
722
723 debug!(
724 "Multipart upload: part {} ({} bytes) -> s3://{}/{}",
725 part_number, bytes_read, bucket, key
726 );
727
728 let body = ByteStream::from(buffer);
729 let part_resp = client
730 .upload_part()
731 .bucket(bucket)
732 .key(key)
733 .upload_id(upload_id)
734 .part_number(part_number)
735 .body(body)
736 .send()
737 .await
738 .map_err(|e| {
739 BatchError::ItemWriter(format!("upload_part {} failed: {}", part_number, e))
740 })?;
741
742 let etag = part_resp
743 .e_tag()
744 .ok_or_else(|| {
745 BatchError::ItemWriter(format!("upload_part {} returned no ETag", part_number))
746 })?
747 .to_string();
748
749 completed_parts.push(
750 CompletedPart::builder()
751 .part_number(part_number)
752 .e_tag(etag)
753 .build(),
754 );
755
756 part_number += 1;
757 }
758
759 let completed = CompletedMultipartUpload::builder()
760 .set_parts(Some(completed_parts))
761 .build();
762
763 client
764 .complete_multipart_upload()
765 .bucket(bucket)
766 .key(key)
767 .upload_id(upload_id)
768 .multipart_upload(completed)
769 .send()
770 .await
771 .map_err(|e| {
772 BatchError::ItemWriter(format!(
773 "complete_multipart_upload failed for {}: {}",
774 key, e
775 ))
776 })?;
777
778 Ok(())
779}
780
781pub(crate) fn collect_files(dir: &Path) -> Result<Vec<PathBuf>, BatchError> {
783 let mut files = Vec::new();
784 for entry in std::fs::read_dir(dir).map_err(BatchError::Io)? {
785 let entry = entry.map_err(BatchError::Io)?;
786 let path = entry.path();
787 if path.is_dir() {
788 files.extend(collect_files(&path)?);
789 } else {
790 files.push(path);
791 }
792 }
793 Ok(files)
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799 use std::env::temp_dir;
800 use std::fs;
801
802 #[test]
805 fn should_fail_build_when_bucket_missing() {
806 let result = S3PutTaskletBuilder::new()
807 .key("file.csv")
808 .local_file("/tmp/file.csv")
809 .build();
810 assert!(result.is_err(), "build should fail without bucket");
811 assert!(
812 result.unwrap_err().to_string().contains("bucket"),
813 "error message should mention 'bucket'"
814 );
815 }
816
817 #[test]
818 fn should_fail_build_when_key_missing() {
819 let result = S3PutTaskletBuilder::new()
820 .bucket("my-bucket")
821 .local_file("/tmp/file.csv")
822 .build();
823 assert!(result.is_err(), "build should fail without key");
824 assert!(
825 result.unwrap_err().to_string().contains("key"),
826 "error message should mention 'key'"
827 );
828 }
829
830 #[test]
831 fn should_fail_build_when_local_file_missing() {
832 let result = S3PutTaskletBuilder::new()
833 .bucket("my-bucket")
834 .key("file.csv")
835 .build();
836 assert!(result.is_err(), "build should fail without local_file");
837 assert!(
838 result.unwrap_err().to_string().contains("local_file"),
839 "error message should mention 'local_file'"
840 );
841 }
842
843 #[test]
844 fn should_build_with_required_fields() {
845 let result = S3PutTaskletBuilder::new()
846 .bucket("my-bucket")
847 .key("file.csv")
848 .local_file("/tmp/file.csv")
849 .build();
850 assert!(
851 result.is_ok(),
852 "build should succeed with required fields: {:?}",
853 result.err()
854 );
855 }
856
857 #[test]
858 fn should_apply_default_chunk_size() {
859 let tasklet = S3PutTaskletBuilder::new()
860 .bucket("b")
861 .key("k")
862 .local_file("/tmp/f")
863 .build()
864 .unwrap(); assert_eq!(
866 tasklet.chunk_size, DEFAULT_CHUNK_SIZE,
867 "default chunk_size should be 8 MiB"
868 );
869 }
870
871 #[test]
872 fn should_override_chunk_size() {
873 let tasklet = S3PutTaskletBuilder::new()
874 .bucket("b")
875 .key("k")
876 .local_file("/tmp/f")
877 .chunk_size(16 * 1024 * 1024)
878 .build()
879 .unwrap(); assert_eq!(tasklet.chunk_size, 16 * 1024 * 1024);
881 }
882
883 #[test]
884 fn should_fail_build_when_chunk_size_below_minimum() {
885 let result = S3PutTaskletBuilder::new()
886 .bucket("b")
887 .key("k")
888 .local_file("/tmp/f")
889 .chunk_size(1024) .build();
891 assert!(result.is_err(), "build should fail with chunk_size < 5 MiB");
892 assert!(
893 result.unwrap_err().to_string().contains("chunk_size"),
894 "error message should mention 'chunk_size'"
895 );
896 }
897
898 #[test]
899 fn should_store_optional_config_fields() {
900 let tasklet = S3PutTaskletBuilder::new()
901 .bucket("b")
902 .key("k")
903 .local_file("/tmp/f")
904 .region("us-east-1")
905 .endpoint_url("http://localhost:9000")
906 .access_key_id("AKID")
907 .secret_access_key("SECRET")
908 .build()
909 .unwrap(); assert_eq!(tasklet.config.region.as_deref(), Some("us-east-1"));
911 assert_eq!(
912 tasklet.config.endpoint_url.as_deref(),
913 Some("http://localhost:9000")
914 );
915 assert_eq!(tasklet.config.access_key_id.as_deref(), Some("AKID"));
916 assert_eq!(tasklet.config.secret_access_key.as_deref(), Some("SECRET"));
917 }
918
919 #[test]
922 fn should_fail_folder_build_when_bucket_missing() {
923 let result = S3PutFolderTaskletBuilder::new()
924 .prefix("backups/")
925 .local_folder("/tmp/exports")
926 .build();
927 assert!(result.is_err(), "build should fail without bucket");
928 assert!(result.unwrap_err().to_string().contains("bucket"));
929 }
930
931 #[test]
932 fn should_fail_folder_build_when_prefix_missing() {
933 let result = S3PutFolderTaskletBuilder::new()
934 .bucket("my-bucket")
935 .local_folder("/tmp/exports")
936 .build();
937 assert!(result.is_err(), "build should fail without prefix");
938 assert!(result.unwrap_err().to_string().contains("prefix"));
939 }
940
941 #[test]
942 fn should_fail_folder_build_when_local_folder_missing() {
943 let result = S3PutFolderTaskletBuilder::new()
944 .bucket("my-bucket")
945 .prefix("backups/")
946 .build();
947 assert!(result.is_err(), "build should fail without local_folder");
948 assert!(result.unwrap_err().to_string().contains("local_folder"));
949 }
950
951 #[test]
952 fn should_build_folder_with_required_fields() {
953 let result = S3PutFolderTaskletBuilder::new()
954 .bucket("my-bucket")
955 .prefix("backups/")
956 .local_folder("/tmp/exports")
957 .build();
958 assert!(
959 result.is_ok(),
960 "build should succeed with required fields: {:?}",
961 result.err()
962 );
963 }
964
965 #[test]
966 fn should_fail_folder_build_when_chunk_size_below_minimum() {
967 let result = S3PutFolderTaskletBuilder::new()
968 .bucket("b")
969 .prefix("p/")
970 .local_folder("/tmp/exports")
971 .chunk_size(1024) .build();
973 assert!(result.is_err(), "build should fail with chunk_size < 5 MiB");
974 assert!(
975 result.unwrap_err().to_string().contains("chunk_size"),
976 "error message should mention 'chunk_size'"
977 );
978 }
979
980 #[test]
983 fn should_collect_files_from_directory() {
984 let dir = temp_dir().join("spring_batch_rs_test_collect");
985 fs::remove_dir_all(&dir).ok(); fs::create_dir_all(&dir).unwrap(); fs::write(dir.join("a.txt"), "a").unwrap(); fs::write(dir.join("b.txt"), "b").unwrap(); let files = collect_files(&dir).unwrap(); assert_eq!(files.len(), 2, "should collect 2 files, got: {:?}", files);
992
993 fs::remove_dir_all(&dir).ok();
994 }
995
996 #[test]
997 fn should_collect_files_from_nested_directories() {
998 let dir = temp_dir().join("spring_batch_rs_test_collect_nested");
999 let sub = dir.join("sub");
1000 fs::remove_dir_all(&dir).ok(); fs::create_dir_all(&sub).unwrap(); fs::write(dir.join("root.txt"), "r").unwrap(); fs::write(sub.join("child.txt"), "c").unwrap(); let files = collect_files(&dir).unwrap(); assert_eq!(
1007 files.len(),
1008 2,
1009 "should collect files from nested dirs: {:?}",
1010 files
1011 );
1012
1013 fs::remove_dir_all(&dir).ok();
1014 }
1015
1016 #[test]
1017 fn should_return_error_for_missing_directory() {
1018 let result = collect_files(Path::new("/nonexistent/path/xyz"));
1019 assert!(result.is_err(), "should return error for missing directory");
1020 }
1021}