rust_integration_services/sftp/
sftp_receiver.rs

1use std::path::{Path, PathBuf};
2
3use async_ssh2_lite::{AsyncSession, SessionConfiguration, TokioTcpStream};
4use futures_util::{AsyncReadExt};
5use regex::Regex;
6use tokio::{fs::OpenOptions, io::AsyncWriteExt, signal::unix::{signal, SignalKind}, sync::mpsc, task::JoinSet};
7use uuid::Uuid;
8
9use super::sftp_auth::SftpAuth;
10use crate::utils::error::Error;
11
12#[derive(Clone)]
13pub enum SftpReceiverEventSignal {
14    OnDownloadStart(String, PathBuf),
15    OnDownloadSuccess(String, PathBuf),
16}
17
18pub struct SftpReceiver {
19    host: String,
20    remote_path: PathBuf,
21    delete_after: bool,
22    regex: String,
23    auth: SftpAuth,
24    event_broadcast: mpsc::Sender<SftpReceiverEventSignal>,
25    event_receiver: Option<mpsc::Receiver<SftpReceiverEventSignal>>,
26    event_join_set: JoinSet<()>,
27}
28
29impl SftpReceiver {
30    pub fn new<T: AsRef<str>>(host: T, user: T) -> Self {
31        let (event_broadcast, event_receiver) = mpsc::channel(128);
32        SftpReceiver { 
33            host: host.as_ref().to_string(),
34            remote_path: PathBuf::new(),
35            delete_after: false,
36            regex: String::from(r"^.+\.[^./\\]+$"),
37            auth: SftpAuth { user: user.as_ref().to_string(), password: None, private_key: None, private_key_passphrase: None },
38            event_broadcast,
39            event_receiver: Some(event_receiver),
40            event_join_set: JoinSet::new(),
41        }
42    }
43
44    pub fn on_event<T, Fut>(mut self, handler: T) -> Self
45    where
46        T: Fn(SftpReceiverEventSignal) -> Fut + Send + Sync + 'static,
47        Fut: Future<Output = ()> + Send + 'static,
48    {
49        let mut receiver = self.event_receiver.unwrap();
50        let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver.");
51        let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver.");
52        
53        self.event_join_set.spawn(async move {
54            loop {
55                tokio::select! {
56                    _ = sigterm.recv() => break,
57                    _ = sigint.recv() => break,
58                    event = receiver.recv() => {
59                        match event {
60                            Some(event) => handler(event).await,
61                            None => break,
62                        }
63                    }
64                }
65            }
66        });
67
68        self.event_receiver = None;
69        self
70    }
71
72    /// Sets the password for authentication.
73    pub fn password<T: AsRef<str>>(mut self, password: T) -> Self {
74        self.auth.password = Some(password.as_ref().to_string());
75        self
76    }
77
78    /// Sets the private key path and passphrase for authentication.
79    pub fn private_key<T: AsRef<Path>, S: AsRef<str>>(mut self, key_path: T, passphrase: Option<S>) -> Self {
80        self.auth.private_key = Some(key_path.as_ref().to_path_buf());
81        self.auth.private_key_passphrase = match passphrase {
82            Some(passphrase) => Some(passphrase.as_ref().to_string()),
83            None => None,
84        };
85        self
86    }
87
88    /// Sets the remote directory for the user on the sftp server.
89    pub fn remote_path<T: AsRef<Path>>(mut self, remote_path: T) -> Self {
90        self.remote_path = remote_path.as_ref().to_path_buf();
91        self
92    }
93
94    /// Delete the remote file in sftp after successfully downloading it.
95    pub fn delete_after(mut self, delete_after: bool) -> Self {
96        self.delete_after = delete_after;
97        self
98    }
99
100    /// Sets the regex filter for what files will be downloaded from the sftp server.
101    /// 
102    /// The default regex is: ^.+\.[^./\\]+$
103    pub fn regex<T: AsRef<str>>(mut self, regex: T) -> Self {
104        self.regex = regex.as_ref().to_string();
105        self
106    }
107
108    /// Download files from the sftp server to the target local path.
109    /// 
110    /// Filters for files can be set with regex(), the default regex is: ^.+\.[^./\\]+$
111    pub async fn receive_files<T: AsRef<Path>>(mut self, target_local_path: T) -> tokio::io::Result<()> {
112        let local_path = target_local_path.as_ref();
113        if !local_path.try_exists()? {
114            return Err(Error::tokio_io(format!("The path '{:?}' does not exist!", &local_path)));
115        }
116
117        let tcp = TokioTcpStream::connect(&self.host).await?;
118        let mut session = AsyncSession::new(tcp, SessionConfiguration::default())?;
119        session.handshake().await?;
120
121        if let Some(password) = self.auth.password {
122            session.userauth_password(&self.auth.user, &password).await?;
123        }
124        if let Some(private_key) = self.auth.private_key {
125            session.userauth_pubkey_file(&self.auth.user, None, &private_key, self.auth.private_key_passphrase.as_deref()).await?;
126        }
127
128        let remote_path = Path::new(&self.remote_path);
129        let sftp = session.sftp().await?;
130        let entries = sftp.readdir(remote_path).await?;
131        let regex = Regex::new(&self.regex).unwrap();
132
133        for (entry, metadata) in entries {
134            if metadata.is_dir() {
135                continue;
136            }
137
138            let file_name = entry.file_name().unwrap().to_str().unwrap();
139            if regex.is_match(file_name) {
140
141                let remote_file_path = Path::new(&self.remote_path).join(file_name);
142                let mut remote_file = sftp.open(&remote_file_path).await?;
143                let local_file_path = local_path.join(file_name);
144                let mut local_file = OpenOptions::new().create(true).write(true).open(&local_file_path).await?;
145
146                let uuid = Uuid::new_v4().to_string();
147                self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadStart(uuid.clone(), local_file_path.clone())).await.unwrap();
148
149                let mut buffer = vec![0u8; 1024 * 1024];
150                loop {
151                    let bytes = remote_file.read(&mut buffer).await?;
152                    if bytes == 0 {
153                        break;
154                    }
155                    local_file.write_all(&buffer[..bytes]).await?;
156                }
157
158                self.event_broadcast.send(SftpReceiverEventSignal::OnDownloadSuccess(uuid.clone(), local_file_path.clone())).await.unwrap();
159                local_file.flush().await?;
160                remote_file.close().await?;
161
162                if self.delete_after {
163                    sftp.unlink(&remote_file_path).await?;
164                }
165            }
166        }
167
168        while let Some(_) = self.event_join_set.join_next().await {}
169        
170        Ok(())
171    }
172}