rust_integration_services/sftp/
sftp_receiver.rs

1use std::path::{Path, PathBuf};
2
3use async_ssh2_lite::{AsyncFile, AsyncSession, SessionConfiguration, TokioTcpStream};
4use futures_util::{AsyncReadExt};
5use regex::Regex;
6use tokio::{fs::{File, OpenOptions}, io::AsyncWriteExt, task::JoinSet};
7use tokio::net::TcpStream;
8use uuid::Uuid;
9
10use super::sftp_auth::SftpAuth;
11use crate::{common::event_handler::EventHandler, sftp::sftp_receiver_event::SftpReceiverEvent, utils::error::Error};
12
13pub struct SftpReceiver {
14    host: String,
15    remote_path: PathBuf,
16    delete_after: bool,
17    regex: String,
18    auth: SftpAuth,
19    event_handler: EventHandler<SftpReceiverEvent>,
20    event_join_set: JoinSet<()>,
21}
22
23impl SftpReceiver {
24    pub fn new<T: AsRef<str>>(host: T, user: T) -> Self {
25        SftpReceiver { 
26            host: host.as_ref().to_string(),
27            remote_path: PathBuf::new(),
28            delete_after: false,
29            regex: String::from(r"^.+\.[^./\\]+$"),
30            auth: SftpAuth { user: user.as_ref().to_string(), password: None, private_key: None, private_key_passphrase: None },
31            event_handler: EventHandler::new(),
32            event_join_set: JoinSet::new(),
33        }
34    }
35
36    pub fn on_event<T, Fut>(mut self, handler: T) -> Self
37    where
38        T: Fn(SftpReceiverEvent) -> Fut + Send + Sync + 'static,
39        Fut: Future<Output = ()> + Send + 'static,
40    {
41        self.event_join_set = self.event_handler.init(handler);
42        self
43    }
44
45    /// Sets the password for authentication.
46    pub fn password<T: AsRef<str>>(mut self, password: T) -> Self {
47        self.auth.password = Some(password.as_ref().to_string());
48        self
49    }
50
51    /// Sets the private key path and passphrase for authentication.
52    pub fn private_key<T: AsRef<Path>, S: AsRef<str>>(mut self, key_path: T, passphrase: Option<S>) -> Self {
53        self.auth.private_key = Some(key_path.as_ref().to_path_buf());
54        self.auth.private_key_passphrase = match passphrase {
55            Some(passphrase) => Some(passphrase.as_ref().to_string()),
56            None => None,
57        };
58        self
59    }
60
61    /// Sets the remote directory for the user on the sftp server.
62    pub fn remote_path<T: AsRef<Path>>(mut self, remote_path: T) -> Self {
63        self.remote_path = remote_path.as_ref().to_path_buf();
64        self
65    }
66
67    /// Delete the remote file in sftp after successfully downloading it.
68    pub fn delete_after(mut self, delete_after: bool) -> Self {
69        self.delete_after = delete_after;
70        self
71    }
72
73    /// Sets the regex filter for what files will be downloaded from the sftp server.
74    /// 
75    /// The default regex is: ^.+\.[^./\\]+$
76    pub fn regex<T: AsRef<str>>(mut self, regex: T) -> Self {
77        self.regex = regex.as_ref().to_string();
78        self
79    }
80
81    /// Download files from the sftp server to the target local path.
82    /// 
83    /// Filters for files can be set with regex(), the default regex is: ^.+\.[^./\\]+$
84    pub async fn receive_once<T: AsRef<Path>>(mut self, target_local_path: T) -> tokio::io::Result<()> {
85        let local_path = target_local_path.as_ref();
86        if !local_path.try_exists()? {
87            return Err(Error::tokio_io(format!("The path '{:?}' does not exist!", &local_path)));
88        }
89
90        let tcp = TokioTcpStream::connect(&self.host).await?;
91        let mut session = AsyncSession::new(tcp, SessionConfiguration::default())?;
92        session.handshake().await?;
93
94        if let Some(password) = self.auth.password {
95            session.userauth_password(&self.auth.user, &password).await?;
96        }
97        if let Some(private_key) = self.auth.private_key {
98            session.userauth_pubkey_file(&self.auth.user, None, &private_key, self.auth.private_key_passphrase.as_deref()).await?;
99        }
100
101        let remote_path = Path::new(&self.remote_path);
102        let sftp = session.sftp().await?;
103        let entries = sftp.readdir(remote_path).await?;
104        let regex = Regex::new(&self.regex).unwrap();
105        let event_broadcast = self.event_handler.broadcast();
106
107        for (entry, metadata) in entries {
108            if metadata.is_dir() {
109                continue;
110            }
111
112            let file_name = entry.file_name().unwrap().to_str().unwrap();
113            if regex.is_match(file_name) {
114
115                let remote_file_path = Path::new(&self.remote_path).join(file_name);
116                let remote_file = sftp.open(&remote_file_path).await?;
117                let local_file_path = local_path.join(file_name);
118                let local_file = OpenOptions::new().create(true).write(true).open(&local_file_path).await?;
119
120                let uuid = Uuid::new_v4().to_string();
121                event_broadcast.send(SftpReceiverEvent::OnDownloadStart(uuid.clone(), local_file_path.clone())).await.unwrap();
122
123                match Self::download_file(remote_file, local_file).await {
124                    Ok(_) => {
125                        event_broadcast.send(SftpReceiverEvent::OnDownloadSuccess(uuid.clone(), local_file_path.clone())).await.unwrap();
126
127                        if self.delete_after {
128                            sftp.unlink(&remote_file_path).await?;
129                        }
130                    },
131                    Err(err) => event_broadcast.send(SftpReceiverEvent::OnError(uuid.clone(), err.to_string())).await.unwrap(),
132                }
133            }
134        }
135
136        drop(event_broadcast);
137        while let Some(_) = self.event_join_set.join_next().await {}
138
139        Ok(())
140    }
141
142    async fn download_file(mut remote_file: AsyncFile<TcpStream>, mut local_file: File) -> tokio::io::Result<()> {
143        let mut buffer = vec![0u8; 1024 * 1024];
144        loop {
145            let bytes = remote_file.read(&mut buffer).await?;
146            if bytes == 0 {
147                break;
148            }
149            local_file.write_all(&buffer[..bytes]).await?;
150        }
151
152        local_file.flush().await?;
153        remote_file.close().await?;
154
155        Ok(())
156    }
157}