rust_integration_services/sftp/
sftp_receiver.rs

1use std::path::{Path, PathBuf};
2
3use async_ssh2_lite::{AsyncFile, AsyncSession, SessionConfiguration, TokioTcpStream};
4use futures_util::{AsyncReadExt};
5use regex::Regex;
6use tokio::{fs::{File, OpenOptions}, io::AsyncWriteExt, signal::unix::{signal, SignalKind}, sync::mpsc, task::JoinSet};
7use tokio::net::TcpStream;
8use uuid::Uuid;
9
10use super::sftp_auth::SftpAuth;
11use crate::utils::error::Error;
12
13#[derive(Clone)]
14pub enum SftpReceiverEventSignal {
15    OnDownloadStart(String, PathBuf),
16    OnDownloadSuccess(String, PathBuf),
17    OnDownloadFailed(String, PathBuf),
18}
19
20pub struct SftpReceiver {
21    host: String,
22    remote_path: PathBuf,
23    delete_after: bool,
24    regex: String,
25    auth: SftpAuth,
26    event_broadcast: mpsc::Sender<SftpReceiverEventSignal>,
27    event_receiver: Option<mpsc::Receiver<SftpReceiverEventSignal>>,
28    event_join_set: JoinSet<()>,
29}
30
31impl SftpReceiver {
32    pub fn new<T: AsRef<str>>(host: T, user: T) -> Self {
33        let (event_broadcast, event_receiver) = mpsc::channel(128);
34        SftpReceiver { 
35            host: host.as_ref().to_string(),
36            remote_path: PathBuf::new(),
37            delete_after: false,
38            regex: String::from(r"^.+\.[^./\\]+$"),
39            auth: SftpAuth { user: user.as_ref().to_string(), password: None, private_key: None, private_key_passphrase: None },
40            event_broadcast,
41            event_receiver: Some(event_receiver),
42            event_join_set: JoinSet::new(),
43        }
44    }
45
46    pub fn on_event<T, Fut>(mut self, handler: T) -> Self
47    where
48        T: Fn(SftpReceiverEventSignal) -> Fut + Send + Sync + 'static,
49        Fut: Future<Output = ()> + Send + 'static,
50    {
51        let mut receiver = self.event_receiver.unwrap();
52        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver.");
53        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver.");
54        
55        self.event_join_set.spawn(async move {
56            loop {
57                tokio::select! {
58                    _ = sigterm.recv() => break,
59                    _ = sigint.recv() => break,
60                    event = receiver.recv() => {
61                        match event {
62                            Some(event) => handler(event).await,
63                            None => break,
64                        }
65                    }
66                }
67            }
68        });
69
70        self.event_receiver = None;
71        self
72    }
73
74    /// Sets the password for authentication.
75    pub fn password<T: AsRef<str>>(mut self, password: T) -> Self {
76        self.auth.password = Some(password.as_ref().to_string());
77        self
78    }
79
80    /// Sets the private key path and passphrase for authentication.
81    pub fn private_key<T: AsRef<Path>, S: AsRef<str>>(mut self, key_path: T, passphrase: Option<S>) -> Self {
82        self.auth.private_key = Some(key_path.as_ref().to_path_buf());
83        self.auth.private_key_passphrase = match passphrase {
84            Some(passphrase) => Some(passphrase.as_ref().to_string()),
85            None => None,
86        };
87        self
88    }
89
90    /// Sets the remote directory for the user on the sftp server.
91    pub fn remote_path<T: AsRef<Path>>(mut self, remote_path: T) -> Self {
92        self.remote_path = remote_path.as_ref().to_path_buf();
93        self
94    }
95
96    /// Delete the remote file in sftp after successfully downloading it.
97    pub fn delete_after(mut self, delete_after: bool) -> Self {
98        self.delete_after = delete_after;
99        self
100    }
101
102    /// Sets the regex filter for what files will be downloaded from the sftp server.
103    /// 
104    /// The default regex is: ^.+\.[^./\\]+$
105    pub fn regex<T: AsRef<str>>(mut self, regex: T) -> Self {
106        self.regex = regex.as_ref().to_string();
107        self
108    }
109
110    /// Download files from the sftp server to the target local path.
111    /// 
112    /// Filters for files can be set with regex(), the default regex is: ^.+\.[^./\\]+$
113    pub async fn receive_files_to_path<T: AsRef<Path>>(mut self, target_local_path: T) -> tokio::io::Result<()> {
114        let local_path = target_local_path.as_ref();
115        if !local_path.try_exists()? {
116            return Err(Error::tokio_io(format!("The path '{:?}' does not exist!", &local_path)));
117        }
118
119        let tcp = TokioTcpStream::connect(&self.host).await?;
120        let mut session = AsyncSession::new(tcp, SessionConfiguration::default())?;
121        session.handshake().await?;
122
123        if let Some(password) = self.auth.password {
124            session.userauth_password(&self.auth.user, &password).await?;
125        }
126        if let Some(private_key) = self.auth.private_key {
127            session.userauth_pubkey_file(&self.auth.user, None, &private_key, self.auth.private_key_passphrase.as_deref()).await?;
128        }
129
130        let remote_path = Path::new(&self.remote_path);
131        let sftp = session.sftp().await?;
132        let entries = sftp.readdir(remote_path).await?;
133        let regex = Regex::new(&self.regex).unwrap();
134
135        for (entry, metadata) in entries {
136            if metadata.is_dir() {
137                continue;
138            }
139
140            let file_name = entry.file_name().unwrap().to_str().unwrap();
141            if regex.is_match(file_name) {
142
143                let remote_file_path = Path::new(&self.remote_path).join(file_name);
144                let remote_file = sftp.open(&remote_file_path).await?;
145                let local_file_path = local_path.join(file_name);
146                let local_file = OpenOptions::new().create(true).write(true).open(&local_file_path).await?;
147
148                let uuid = Uuid::new_v4().to_string();
149                self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadStart(uuid.clone(), local_file_path.clone())).await.unwrap();
150
151                match Self::download_file(remote_file, local_file).await {
152                    Ok(_) => {
153                        self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadSuccess(uuid.clone(), local_file_path.clone())).await.unwrap();
154
155                        if self.delete_after {
156                            sftp.unlink(&remote_file_path).await?;
157                        }
158                    },
159                    Err(_) => self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadFailed(uuid.clone(), local_file_path.clone())).await.unwrap(),
160                }
161            }
162        }
163
164        drop(self.event_broadcast);
165        while let Some(_) = self.event_join_set.join_next().await {}
166
167        Ok(())
168    }
169
170    async fn download_file(mut remote_file: AsyncFile<TcpStream>, mut local_file: File) -> tokio::io::Result<()> {
171        let mut buffer = vec![0u8; 1024 * 1024];
172        loop {
173            let bytes = remote_file.read(&mut buffer).await?;
174            if bytes == 0 {
175                break;
176            }
177            local_file.write_all(&buffer[..bytes]).await?;
178        }
179
180        local_file.flush().await?;
181        remote_file.close().await?;
182
183        Ok(())
184    }
185}