rust_integration_services/sftp/
sftp_client.rs1use std::{marker::PhantomData, path::{Path, PathBuf}, sync::Arc};
2
3use anyhow::Ok;
4use bytes::Bytes;
5use russh::{client::Handle, keys::{HashAlg, PrivateKeyWithHashAlg}};
6use russh_sftp::client::SftpSession;
7use tokio::{io::{AsyncReadExt, AsyncWriteExt}, sync::Mutex};
8use tokio_util::io::ReaderStream;
9
10use crate::{common::stream::ByteStream, sftp::{sftp_client_config::SftpClientConfig, ssh_client::SshClient}};
11
12pub struct Empty;
13pub struct GetFile;
14pub struct PutFile;
15
16pub struct SftpClient<State> {
17 config: Arc<SftpClientConfig>,
18 path: Option<PathBuf>,
19 session: Arc<Mutex<Option<Handle<SshClient>>>>,
20 _state: PhantomData<State>,
21}
22
23impl SftpClient<Empty> {
24 pub fn new(config: SftpClientConfig) -> Self {
25 SftpClient {
26 config: Arc::new(config),
27 path: None,
28 session: Arc::new(Mutex::new(None)),
29 _state: PhantomData
30 }
31 }
32
33 pub fn get_file(&self, path: impl Into<PathBuf>) -> SftpClient<GetFile> {
34 SftpClient {
35 config: self.config.clone(),
36 path: Some(path.into()),
37 session: self.session.clone(),
38 _state: PhantomData
39 }
40 }
41
42 pub fn put_file(&self, path: impl Into<PathBuf>) -> SftpClient<PutFile> {
43 SftpClient {
44 config: self.config.clone(),
45 path: Some(path.into()),
46 session: self.session.clone(),
47 _state: PhantomData
48 }
49 }
50
51 pub async fn delete_file(&mut self, path: impl AsRef<Path>) -> anyhow::Result<()> {
52 let session = self.get_session().await?;
53 let path = path.as_ref().to_string_lossy();
54
55 tracing::trace!("SFTP removing file {:?}", path);
56 session.remove_file(path).await?;
57
58 Ok(())
59 }
60}
61
62impl SftpClient<GetFile> {
63 pub async fn as_bytes(&mut self) -> anyhow::Result<Bytes> {
64 let session = self.get_session().await?;
65 let path = self.path.as_ref().unwrap().to_string_lossy();
66
67 let mut remote_file = session.open(path).await?;
68 let mut buffer = Vec::new();
69 remote_file.read_to_end(&mut buffer).await?;
70 remote_file.shutdown().await?;
71
72 Ok(Bytes::from(buffer))
73 }
74
75 pub async fn as_stream(&mut self) -> anyhow::Result<ByteStream> {
76 let session = self.get_session().await?;
77 let path = self.path.as_ref().unwrap().to_string_lossy();
78
79 let remote_file = session.open(path).await?;
80 let reader = ReaderStream::new(remote_file);
81
82 Ok(ByteStream::new(reader))
83 }
84}
85
86impl SftpClient<PutFile> {
87 pub async fn from_bytes(&mut self, bytes: impl Into<Bytes>) -> anyhow::Result<()> {
88 let session = self.get_session().await?;
89 let path = self.path.as_ref().unwrap().to_string_lossy();
90 tracing::trace!("SFTP uploading bytes to {:?}", path);
91
92 let mut remote_file = session.create(path).await?;
93 remote_file.write_all(&bytes.into()).await?;
94 remote_file.shutdown().await?;
95
96 tracing::trace!("SFTP upload complete");
97 Ok(())
98 }
99
100 pub async fn from_stream(&mut self, mut stream: ByteStream) -> anyhow::Result<()> {
101 let session = self.get_session().await?;
102 let path = self.path.as_ref().unwrap().to_string_lossy();
103 tracing::trace!("SFTP uploading bytes to {:?}", path);
104
105 let mut remote_file = session.create(path).await?;
106
107 while let Some(chunk) = stream.next().await {
108 let chunk = chunk?;
109 remote_file.write_all(&chunk).await?;
110 }
111 remote_file.shutdown().await?;
112
113 tracing::trace!("SFTP upload complete");
114 Ok(())
115 }
116}
117
118impl<State> SftpClient<State> {
119 async fn get_session(&mut self) -> anyhow::Result<SftpSession> {
120 let mut guard = self.session.lock().await;
121
122 let session = match guard.take() {
123 Some(session) if !session.is_closed() => {
124 tracing::trace!("SSH session reused");
125 session
126 },
127 _ => self.connect_session().await?
128 };
129
130 let sftp = self.connect_sftp(&session).await?;
131 *guard = Some(session);
132 Ok(sftp)
133 }
134
135 async fn connect_session(&self) -> anyhow::Result<Handle<SshClient>> {
136 let config = self.config.clone();
137 tracing::trace!("SSH connecting to {}", config.endpoint);
138 let mut session = russh::client::connect(Arc::new(russh::client::Config::default()), &config.endpoint, SshClient {}).await?;
139
140 let mut authenticated = false;
141
142 if let Some(auth) = &config.auth_private_key {
144 let key = russh::keys::load_secret_key(&auth.path, auth.passphrase.as_deref())?;
145 let hash_alg = match &key.algorithm() {
146 russh::keys::Algorithm::Rsa { .. } => Some(HashAlg::Sha256),
147 _ => None,
148 };
149
150 let key_with_alg = PrivateKeyWithHashAlg::new(Arc::new(key), hash_alg);
151 authenticated = session.authenticate_publickey(&auth.user, key_with_alg).await?.success();
152 if authenticated {
153 tracing::trace!("SSH authenticated using public key authentication");
154 } else {
155 tracing::trace!("SSH public key authentication failed");
156 }
157 }
158
159 if !authenticated {
161 if let Some(auth) = &config.auth_basic {
162 authenticated = session.authenticate_password(&auth.user, &auth.password).await?.success();
163 if authenticated {
164 tracing::trace!("SSH authenticated using basic authentication");
165 } else {
166 tracing::trace!("SSH basic authentication failed");
167 }
168
169 }
170 }
171
172 if !authenticated {
173 return Err(anyhow::anyhow!("All authentication methods failed"))
174 }
175
176 Ok(session)
177 }
178
179 async fn connect_sftp(&self, session: &Handle<SshClient>) -> anyhow::Result<SftpSession> {
180 tracing::trace!("SSH requesting SFTP subsystem");
181 let channel = session.channel_open_session().await?;
182 channel.request_subsystem(true, "sftp").await?;
183 let sftp = SftpSession::new(channel.into_stream()).await?;
184 Ok(sftp)
185 }
186}