1use std::path::Path;
6
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as B64;
9
10use crate::client::WinrmClient;
11use crate::error::WinrmError;
12
13const CHUNK_SIZE: usize = 2000;
20
21const MAX_REMOTE_PATH_LEN: usize = 260;
23
24fn validate_remote_path(path: &str) -> Result<(), WinrmError> {
29 if path.len() > MAX_REMOTE_PATH_LEN {
30 return Err(WinrmError::Transfer(format!(
31 "remote path exceeds {MAX_REMOTE_PATH_LEN} characters"
32 )));
33 }
34 if path.chars().any(|c| c.is_control() && c != '\t') {
35 return Err(WinrmError::Transfer(
36 "remote path contains control characters".into(),
37 ));
38 }
39 Ok(())
40}
41
42impl WinrmClient {
43 pub async fn upload_file(
52 &self,
53 host: &str,
54 local_path: &Path,
55 remote_path: &str,
56 ) -> Result<u64, WinrmError> {
57 validate_remote_path(remote_path)?;
58
59 let data = std::fs::read(local_path).map_err(|e| {
60 WinrmError::Transfer(format!(
61 "failed to read local file {}: {e}",
62 local_path.display()
63 ))
64 })?;
65
66 let shell = self.open_shell(host).await?;
67 let total = data.len() as u64;
68 let escaped_path = remote_path.replace('\'', "''");
69
70 for (i, chunk) in data.chunks(CHUNK_SIZE).enumerate() {
71 let b64 = B64.encode(chunk);
72
73 let script = if i == 0 {
74 format!(
75 "$bytes = [Convert]::FromBase64String('{b64}'); \
76 [IO.File]::WriteAllBytes('{escaped_path}', $bytes)"
77 )
78 } else {
79 format!(
80 "$bytes = [Convert]::FromBase64String('{b64}'); \
81 $f = [IO.File]::Open('{escaped_path}', 'Append'); \
82 $f.Write($bytes, 0, $bytes.Length); $f.Close()"
83 )
84 };
85
86 let output = shell.run_powershell(&script).await?;
87 if output.exit_code != 0 {
88 shell.close().await.ok();
89 return Err(WinrmError::Transfer(format!(
90 "upload chunk {i} failed: {}",
91 String::from_utf8_lossy(&output.stderr)
92 )));
93 }
94 }
95
96 shell.close().await.ok();
97 Ok(total)
98 }
99
100 pub async fn download_file(
106 &self,
107 host: &str,
108 remote_path: &str,
109 local_path: &Path,
110 ) -> Result<u64, WinrmError> {
111 validate_remote_path(remote_path)?;
112
113 let escaped = remote_path.replace('\'', "''");
114 let script = format!("[Convert]::ToBase64String([IO.File]::ReadAllBytes('{escaped}'))");
115
116 let output = self.run_powershell(host, &script).await?;
117 if output.exit_code != 0 {
118 return Err(WinrmError::Transfer(format!(
119 "download failed: {}",
120 String::from_utf8_lossy(&output.stderr)
121 )));
122 }
123
124 let b64 = String::from_utf8_lossy(&output.stdout);
125 let data = B64
126 .decode(b64.trim_ascii())
127 .map_err(|e| WinrmError::Transfer(format!("base64 decode of downloaded file: {e}")))?;
128
129 let total = data.len() as u64;
130 std::fs::write(local_path, &data).map_err(|e| {
131 WinrmError::Transfer(format!(
132 "failed to write local file {}: {e}",
133 local_path.display()
134 ))
135 })?;
136
137 Ok(total)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn validate_remote_path_ok() {
147 assert!(validate_remote_path("C:\\Users\\admin\\file.txt").is_ok());
148 }
149
150 #[test]
151 fn validate_remote_path_too_long() {
152 let long_path = "C:\\".to_string() + &"a".repeat(260);
153 assert!(validate_remote_path(&long_path).is_err());
154 }
155
156 #[test]
157 fn validate_remote_path_control_chars() {
158 assert!(validate_remote_path("C:\\bad\x00path").is_err());
159 assert!(validate_remote_path("C:\\bad\x01path").is_err());
160 }
161
162 #[test]
163 fn validate_remote_path_tab_allowed() {
164 assert!(validate_remote_path("C:\\path\twith\ttabs").is_ok());
166 }
167
168 #[test]
169 fn validate_remote_path_max_length_boundary() {
170 let exact = "a".repeat(260);
171 assert!(validate_remote_path(&exact).is_ok());
172 let over = "a".repeat(261);
173 assert!(validate_remote_path(&over).is_err());
174 }
175}