Skip to main content

rust_integration_services/sftp/
sftp_client.rs

1use std::{marker::PhantomData, path::{Path, PathBuf}, sync::Arc};
2
3use anyhow::Ok;
4use bytes::Bytes;
5use russh::{client::Handle, keys::{HashAlg, PrivateKeyWithHashAlg}};
6use russh_sftp::client::SftpSession;
7use tokio::{io::{AsyncReadExt, AsyncWriteExt}, sync::Mutex};
8use tokio_util::io::ReaderStream;
9
10use crate::{common::stream::ByteStream, sftp::{sftp_client_config::SftpClientConfig, ssh_client::SshClient}};
11
12pub struct Empty;
13pub struct GetFile;
14pub struct PutFile;
15
16pub struct SftpClient<State> {
17    config: Arc<SftpClientConfig>,
18    path: Option<PathBuf>,
19    session: Arc<Mutex<Option<Handle<SshClient>>>>,
20    _state: PhantomData<State>,
21}
22
23impl SftpClient<Empty> {
24    pub fn new(config: SftpClientConfig) -> Self {
25        SftpClient {
26            config: Arc::new(config),
27            path: None,
28            session: Arc::new(Mutex::new(None)),
29            _state: PhantomData
30        }
31    }
32
33    pub fn get_file(&self, path: impl Into<PathBuf>) -> SftpClient<GetFile> {
34        SftpClient {
35            config: self.config.clone(),
36            path: Some(path.into()),
37            session: self.session.clone(),
38            _state: PhantomData
39        }
40    }
41
42    pub fn put_file(&self, path: impl Into<PathBuf>) -> SftpClient<PutFile> {
43        SftpClient {
44            config: self.config.clone(),
45            path: Some(path.into()),
46            session: self.session.clone(),
47            _state: PhantomData
48        }
49    }
50
51    pub async fn delete_file(&mut self, path: impl AsRef<Path>) -> anyhow::Result<()> {
52        let session = self.get_session().await?;
53        let path = path.as_ref().to_string_lossy();
54
55        tracing::trace!("SFTP removing file {:?}", path);
56        session.remove_file(path).await?;
57
58        Ok(())
59    }
60}
61
62impl SftpClient<GetFile> {
63    pub async fn as_bytes(&mut self) -> anyhow::Result<Bytes> {
64        let session = self.get_session().await?;
65        let path = self.path.as_ref().unwrap().to_string_lossy();
66
67        let mut remote_file = session.open(path).await?;
68        let mut buffer = Vec::new();
69        remote_file.read_to_end(&mut buffer).await?;
70        remote_file.shutdown().await?;
71
72        Ok(Bytes::from(buffer))
73    }
74
75    pub async fn as_stream(&mut self) -> anyhow::Result<ByteStream> {
76        let session = self.get_session().await?;
77        let path = self.path.as_ref().unwrap().to_string_lossy();
78
79        let remote_file = session.open(path).await?;
80        let reader = ReaderStream::new(remote_file);
81
82        Ok(ByteStream::new(reader))
83    }
84}
85
86impl SftpClient<PutFile> {
87    pub async fn from_bytes(&mut self, bytes: impl Into<Bytes>) -> anyhow::Result<()> {
88        let session = self.get_session().await?;
89        let path = self.path.as_ref().unwrap().to_string_lossy();
90        tracing::trace!("SFTP uploading bytes to {:?}", path);
91
92        let mut remote_file = session.create(path).await?;
93        remote_file.write_all(&bytes.into()).await?;
94        remote_file.shutdown().await?;
95
96        tracing::trace!("SFTP upload complete");
97        Ok(())
98    }
99
100    pub async fn from_stream(&mut self, mut stream: ByteStream) -> anyhow::Result<()> {
101        let session = self.get_session().await?;
102        let path = self.path.as_ref().unwrap().to_string_lossy();
103        tracing::trace!("SFTP uploading bytes to {:?}", path);
104
105        let mut remote_file = session.create(path).await?;
106        
107        while let Some(chunk) = stream.next().await {
108            let chunk = chunk?; 
109            remote_file.write_all(&chunk).await?;
110        }
111        remote_file.shutdown().await?;
112
113        tracing::trace!("SFTP upload complete");
114        Ok(())
115    }
116}
117
118impl<State> SftpClient<State> {
119    async fn get_session(&mut self) -> anyhow::Result<SftpSession> {
120        let mut guard = self.session.lock().await;
121
122        let session = match guard.take() {
123            Some(session) if !session.is_closed() => {
124                tracing::trace!("SSH session reused");
125                session
126            },
127            _ => self.connect_session().await?
128        };
129
130        let sftp = self.connect_sftp(&session).await?;
131        *guard = Some(session);
132        Ok(sftp)
133    }
134
135    async fn connect_session(&self) -> anyhow::Result<Handle<SshClient>> {
136        let config = self.config.clone();
137        tracing::trace!("SSH connecting to {}", config.endpoint);
138        let mut session = russh::client::connect(Arc::new(russh::client::Config::default()), &config.endpoint, SshClient {}).await?;
139        
140        let mut authenticated = false;
141
142        // Try public key authentication first.
143        if let Some(auth) = &config.auth_private_key {
144            let key = russh::keys::load_secret_key(&auth.path, auth.passphrase.as_deref())?;
145            let hash_alg = match &key.algorithm() {
146                russh::keys::Algorithm::Rsa { .. } => Some(HashAlg::Sha256),
147                _ => None,
148            };
149
150            let key_with_alg = PrivateKeyWithHashAlg::new(Arc::new(key), hash_alg);
151            authenticated = session.authenticate_publickey(&auth.user, key_with_alg).await?.success();
152            if authenticated {
153                tracing::trace!("SSH authenticated using public key authentication");
154            } else {
155                tracing::trace!("SSH public key authentication failed");
156            }
157        }
158
159        // Try basic authentication if public key authentication failed or was not used.
160        if !authenticated {
161            if let Some(auth) = &config.auth_basic {
162                authenticated = session.authenticate_password(&auth.user, &auth.password).await?.success();
163                if authenticated {
164                    tracing::trace!("SSH authenticated using basic authentication");
165                } else {
166                    tracing::trace!("SSH basic authentication failed");
167                }
168                
169            }
170        }
171
172        if !authenticated {
173            return Err(anyhow::anyhow!("All authentication methods failed"))
174        }
175
176        Ok(session)
177    }
178
179    async fn connect_sftp(&self, session: &Handle<SshClient>) -> anyhow::Result<SftpSession> {
180        tracing::trace!("SSH requesting SFTP subsystem");
181        let channel = session.channel_open_session().await?;
182        channel.request_subsystem(true, "sftp").await?;
183        let sftp = SftpSession::new(channel.into_stream()).await?;
184        Ok(sftp)
185    }
186}