rust_integration_services/sftp/
sftp_receiver.rs1use std::path::{Path, PathBuf};
2
3use async_ssh2_lite::{AsyncFile, AsyncSession, SessionConfiguration, TokioTcpStream};
4use futures_util::{AsyncReadExt};
5use regex::Regex;
6use tokio::{fs::{File, OpenOptions}, io::AsyncWriteExt, signal::unix::{signal, SignalKind}, sync::mpsc, task::JoinSet};
7use tokio::net::TcpStream;
8use uuid::Uuid;
9
10use super::sftp_auth::SftpAuth;
11use crate::utils::error::Error;
12
13#[derive(Clone)]
14pub enum SftpReceiverEventSignal {
15 OnDownloadStart(String, PathBuf),
16 OnDownloadSuccess(String, PathBuf),
17 OnDownloadFailed(String, PathBuf),
18}
19
20pub struct SftpReceiver {
21 host: String,
22 remote_path: PathBuf,
23 delete_after: bool,
24 regex: String,
25 auth: SftpAuth,
26 event_broadcast: mpsc::Sender<SftpReceiverEventSignal>,
27 event_receiver: Option<mpsc::Receiver<SftpReceiverEventSignal>>,
28 event_join_set: JoinSet<()>,
29}
30
31impl SftpReceiver {
32 pub fn new<T: AsRef<str>>(host: T, user: T) -> Self {
33 let (event_broadcast, event_receiver) = mpsc::channel(128);
34 SftpReceiver {
35 host: host.as_ref().to_string(),
36 remote_path: PathBuf::new(),
37 delete_after: false,
38 regex: String::from(r"^.+\.[^./\\]+$"),
39 auth: SftpAuth { user: user.as_ref().to_string(), password: None, private_key: None, private_key_passphrase: None },
40 event_broadcast,
41 event_receiver: Some(event_receiver),
42 event_join_set: JoinSet::new(),
43 }
44 }
45
46 pub fn on_event<T, Fut>(mut self, handler: T) -> Self
47 where
48 T: Fn(SftpReceiverEventSignal) -> Fut + Send + Sync + 'static,
49 Fut: Future<Output = ()> + Send + 'static,
50 {
51 let mut receiver = self.event_receiver.unwrap();
52 let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver.");
53 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver.");
54
55 self.event_join_set.spawn(async move {
56 loop {
57 tokio::select! {
58 _ = sigterm.recv() => break,
59 _ = sigint.recv() => break,
60 event = receiver.recv() => {
61 match event {
62 Some(event) => handler(event).await,
63 None => break,
64 }
65 }
66 }
67 }
68 });
69
70 self.event_receiver = None;
71 self
72 }
73
74 pub fn password<T: AsRef<str>>(mut self, password: T) -> Self {
76 self.auth.password = Some(password.as_ref().to_string());
77 self
78 }
79
80 pub fn private_key<T: AsRef<Path>, S: AsRef<str>>(mut self, key_path: T, passphrase: Option<S>) -> Self {
82 self.auth.private_key = Some(key_path.as_ref().to_path_buf());
83 self.auth.private_key_passphrase = match passphrase {
84 Some(passphrase) => Some(passphrase.as_ref().to_string()),
85 None => None,
86 };
87 self
88 }
89
90 pub fn remote_path<T: AsRef<Path>>(mut self, remote_path: T) -> Self {
92 self.remote_path = remote_path.as_ref().to_path_buf();
93 self
94 }
95
96 pub fn delete_after(mut self, delete_after: bool) -> Self {
98 self.delete_after = delete_after;
99 self
100 }
101
102 pub fn regex<T: AsRef<str>>(mut self, regex: T) -> Self {
106 self.regex = regex.as_ref().to_string();
107 self
108 }
109
110 pub async fn receive_files_to_path<T: AsRef<Path>>(mut self, target_local_path: T) -> tokio::io::Result<()> {
114 let local_path = target_local_path.as_ref();
115 if !local_path.try_exists()? {
116 return Err(Error::tokio_io(format!("The path '{:?}' does not exist!", &local_path)));
117 }
118
119 let tcp = TokioTcpStream::connect(&self.host).await?;
120 let mut session = AsyncSession::new(tcp, SessionConfiguration::default())?;
121 session.handshake().await?;
122
123 if let Some(password) = self.auth.password {
124 session.userauth_password(&self.auth.user, &password).await?;
125 }
126 if let Some(private_key) = self.auth.private_key {
127 session.userauth_pubkey_file(&self.auth.user, None, &private_key, self.auth.private_key_passphrase.as_deref()).await?;
128 }
129
130 let remote_path = Path::new(&self.remote_path);
131 let sftp = session.sftp().await?;
132 let entries = sftp.readdir(remote_path).await?;
133 let regex = Regex::new(&self.regex).unwrap();
134
135 for (entry, metadata) in entries {
136 if metadata.is_dir() {
137 continue;
138 }
139
140 let file_name = entry.file_name().unwrap().to_str().unwrap();
141 if regex.is_match(file_name) {
142
143 let remote_file_path = Path::new(&self.remote_path).join(file_name);
144 let remote_file = sftp.open(&remote_file_path).await?;
145 let local_file_path = local_path.join(file_name);
146 let local_file = OpenOptions::new().create(true).write(true).open(&local_file_path).await?;
147
148 let uuid = Uuid::new_v4().to_string();
149 self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadStart(uuid.clone(), local_file_path.clone())).await.unwrap();
150
151 match Self::download_file(remote_file, local_file).await {
152 Ok(_) => {
153 self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadSuccess(uuid.clone(), local_file_path.clone())).await.unwrap();
154
155 if self.delete_after {
156 sftp.unlink(&remote_file_path).await?;
157 }
158 },
159 Err(_) => self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadFailed(uuid.clone(), local_file_path.clone())).await.unwrap(),
160 }
161 }
162 }
163
164 drop(self.event_broadcast);
165 while let Some(_) = self.event_join_set.join_next().await {}
166
167 Ok(())
168 }
169
170 async fn download_file(mut remote_file: AsyncFile<TcpStream>, mut local_file: File) -> tokio::io::Result<()> {
171 let mut buffer = vec![0u8; 1024 * 1024];
172 loop {
173 let bytes = remote_file.read(&mut buffer).await?;
174 if bytes == 0 {
175 break;
176 }
177 local_file.write_all(&buffer[..bytes]).await?;
178 }
179
180 local_file.flush().await?;
181 remote_file.close().await?;
182
183 Ok(())
184 }
185}