1use crate::{
8 ansible::provisioning::PrivateNodeProvisionInventory,
9 error::{Error, Result},
10 inventory::VirtualMachine,
11 run_external_command,
12};
13use log::debug;
14use std::{
15 collections::HashMap,
16 net::IpAddr,
17 path::PathBuf,
18 sync::{Arc, RwLock},
19};
20
21#[derive(Clone, Debug)]
22pub struct RoutedVms {
23 full_cone_private_node_nat_gateway_ip_map: HashMap<VirtualMachine, IpAddr>,
24 symmetric_private_node_nat_gateway_ip_map: HashMap<VirtualMachine, IpAddr>,
25}
26
27impl RoutedVms {
28 fn find_symmetric_nat_routed_node(
29 &self,
30 ip_address: &IpAddr,
31 ) -> Option<(&VirtualMachine, &IpAddr)> {
32 debug!("Check if {ip_address} is a symmetric NAT routed node...");
33 self.symmetric_private_node_nat_gateway_ip_map
34 .iter()
35 .find_map(|(private_vm, gateway_ip)| {
36 if &private_vm.public_ip_addr == ip_address {
37 Some((private_vm, gateway_ip))
38 } else {
39 None
40 }
41 })
42 .inspect(|op| {
43 debug!("Found symmetric NAT routed node: {op:?}");
44 })
45 }
46
47 fn find_full_cone_nat_routed_node(
48 &self,
49 ip_address: &IpAddr,
50 ) -> Option<(&VirtualMachine, &IpAddr)> {
51 debug!("Check if {ip_address} is a full cone NAT routed node...");
52 self.full_cone_private_node_nat_gateway_ip_map
53 .iter()
54 .find_map(|(private_vm, gateway_ip)| {
55 if &private_vm.public_ip_addr == ip_address {
56 Some((private_vm, gateway_ip))
57 } else {
58 None
59 }
60 })
61 .inspect(|op| {
62 debug!("Found full cone NAT routed node: {op:?}");
63 })
64 }
65}
66
67#[derive(Clone)]
68pub struct SshClient {
69 pub private_key_path: PathBuf,
70 pub routed_vms: Arc<RwLock<Option<RoutedVms>>>,
72}
73impl SshClient {
74 pub fn new(private_key_path: PathBuf) -> SshClient {
75 SshClient {
76 private_key_path,
77 routed_vms: Arc::new(RwLock::new(None)),
78 }
79 }
80
81 pub fn set_full_cone_nat_routed_vms(
84 &self,
85 private_node_vms: &[VirtualMachine],
86 nat_gateway_vms: &[VirtualMachine],
87 ) -> Result<()> {
88 let private_node_nat_gateway_map =
89 PrivateNodeProvisionInventory::match_private_node_vm_and_gateway_vm(
90 private_node_vms,
91 nat_gateway_vms,
92 )?;
93 let full_cone_private_node_nat_gateway_ip_map = private_node_nat_gateway_map
94 .into_iter()
95 .map(|(private_node_vm, nat_gateway_vm)| {
96 (private_node_vm, nat_gateway_vm.public_ip_addr)
97 })
98 .collect::<HashMap<_, _>>();
99 let mut write_access = self.routed_vms.write().map_err(|err| {
100 log::error!("Failed to set routed VMs: {err}");
101 Error::SshSettingsRwLockError
102 })?;
103
104 debug!("Full Cone Private Routed VMs have been set to: {full_cone_private_node_nat_gateway_ip_map:?}");
105 match write_access.as_mut() {
106 Some(routed_vms) => {
107 routed_vms.full_cone_private_node_nat_gateway_ip_map =
108 full_cone_private_node_nat_gateway_ip_map;
109 }
110 None => {
111 *write_access = Some(RoutedVms {
112 full_cone_private_node_nat_gateway_ip_map,
113 symmetric_private_node_nat_gateway_ip_map: HashMap::new(),
114 });
115 }
116 }
117
118 Ok(())
119 }
120
121 pub fn set_symmetric_nat_routed_vms(
124 &self,
125 private_node_vms: &[VirtualMachine],
126 nat_gateway_vms: &[VirtualMachine],
127 ) -> Result<()> {
128 let private_node_nat_gateway_map =
129 PrivateNodeProvisionInventory::match_private_node_vm_and_gateway_vm(
130 private_node_vms,
131 nat_gateway_vms,
132 )?;
133 let symmetric_private_node_nat_gateway_ip_map = private_node_nat_gateway_map
134 .into_iter()
135 .map(|(private_node_vm, nat_gateway_vm)| {
136 (private_node_vm, nat_gateway_vm.public_ip_addr)
137 })
138 .collect::<HashMap<_, _>>();
139 let mut write_access = self.routed_vms.write().map_err(|err| {
140 log::error!("Failed to set routed VMs: {err}");
141 Error::SshSettingsRwLockError
142 })?;
143 debug!("Symmetric Private node Routed VMs have been set to: {symmetric_private_node_nat_gateway_ip_map:?}");
144
145 match write_access.as_mut() {
146 Some(routed_vms) => {
147 routed_vms.symmetric_private_node_nat_gateway_ip_map =
148 symmetric_private_node_nat_gateway_ip_map;
149 }
150 None => {
151 *write_access = Some(RoutedVms {
152 full_cone_private_node_nat_gateway_ip_map: HashMap::new(),
153 symmetric_private_node_nat_gateway_ip_map,
154 });
155 }
156 }
157
158 Ok(())
159 }
160
161 pub fn get_private_key_path(&self) -> PathBuf {
162 self.private_key_path.clone()
163 }
164
165 pub fn wait_for_ssh_availability(&self, ip_address: &IpAddr, user: &str) -> Result<()> {
166 let mut args = vec![
167 "-i".to_string(),
168 self.private_key_path.to_string_lossy().to_string(),
169 "-q".to_string(),
170 "-o".to_string(),
171 "BatchMode=yes".to_string(),
172 "-o".to_string(),
173 "ConnectTimeout=5".to_string(),
174 "-o".to_string(),
175 "StrictHostKeyChecking=no".to_string(),
176 ];
177 let routed_vm_read = self.routed_vms.read().map_err(|err| {
178 log::error!("Failed to read routed VMs: {err}");
179 Error::SshSettingsRwLockError
180 })?;
181 if let Some((vm, gateway_ip)) = routed_vm_read
182 .as_ref()
183 .and_then(|routed_vms| routed_vms.find_symmetric_nat_routed_node(ip_address))
184 {
185 println!(
186 "Checking for SSH availability at {} ({ip_address}) via symmetric NAT gateway {gateway_ip}...",
187 vm.private_ip_addr
188 );
189 debug!(
190 "Checking for SSH availability at {} ({ip_address}) via symmetric NAT gateway {gateway_ip}...",
191 vm.private_ip_addr
192 );
193 args.push("-o".to_string());
194 args.push(format!(
195 "ProxyCommand=ssh -i {} -W %h:%p {}@{}",
196 self.private_key_path.to_string_lossy(),
197 user,
198 gateway_ip
199 ));
200 args.push(format!("{user}@{}", vm.private_ip_addr));
201 } else if let Some((vm, gateway_ip)) = routed_vm_read
202 .as_ref()
203 .and_then(|routed_vms| routed_vms.find_full_cone_nat_routed_node(ip_address))
204 {
205 println!(
206 "Checking for SSH availability at {} ({ip_address}) via Full Cone NAT gateway {gateway_ip}...",
207 vm.private_ip_addr,
208 );
209 debug!(
210 "Checking for SSH availability at {} ({ip_address}) via Full Cone NAT gateway {gateway_ip}...",
211 vm.private_ip_addr,
212 );
213 args.push(format!("{user}@{gateway_ip}"));
214 } else {
215 println!("Checking for SSH availability at {ip_address}...");
216 args.push(format!("{user}@{ip_address}"));
217 }
218 args.push("bash".to_string());
219 args.push("--version".to_string());
220
221 let mut retries = 0;
222 let max_retries = 10;
223 while retries < max_retries {
224 let result = run_external_command(
225 PathBuf::from("ssh"),
226 std::env::current_dir()?,
227 args.clone(),
228 false,
229 false,
230 );
231 if result.is_ok() {
232 println!("SSH is available.");
233 return Ok(());
234 } else {
235 retries += 1;
236 println!("SSH is still unavailable after {retries} attempts.");
237 println!("Will sleep for 5 seconds then retry.");
238 std::thread::sleep(std::time::Duration::from_secs(5));
239 }
240 }
241
242 println!("The maximum number of connection retry attempts has been exceeded.");
243 Err(Error::SshUnavailable)
244 }
245
246 pub fn run_command(
247 &self,
248 ip_address: &IpAddr,
249 user: &str,
250 command: &str,
251 suppress_output: bool,
252 ) -> Result<Vec<String>> {
253 let command_args: Vec<String> = command.split_whitespace().map(String::from).collect();
254 let mut args = vec![
255 "-i".to_string(),
256 self.private_key_path.to_string_lossy().to_string(),
257 "-q".to_string(),
258 "-o".to_string(),
259 "BatchMode=yes".to_string(),
260 "-o".to_string(),
261 "ConnectTimeout=30".to_string(),
262 "-o".to_string(),
263 "StrictHostKeyChecking=no".to_string(),
264 ];
265 let routed_vm_read = self.routed_vms.read().map_err(|err| {
266 log::error!("Failed to read routed VMs: {err}");
267 Error::SshSettingsRwLockError
268 })?;
269
270 if let Some((vm, gateway)) = routed_vm_read
271 .as_ref()
272 .and_then(|routed_vms| routed_vms.find_symmetric_nat_routed_node(ip_address))
273 {
274 debug!(
275 "Running command '{}' on {} ({ip_address}) via symmetric NAT gateway {gateway}...",
276 command, vm.private_ip_addr
277 );
278 args.push("-o".to_string());
279 args.push(format!(
280 "ProxyCommand=ssh -i {} -W %h:%p {user}@{gateway}",
281 self.private_key_path.to_string_lossy(),
282 ));
283 args.push(format!("{user}@{}", vm.private_ip_addr));
284 } else if let Some((vm, gateway)) = routed_vm_read
285 .as_ref()
286 .and_then(|routed_vms| routed_vms.find_full_cone_nat_routed_node(ip_address))
287 {
288 debug!(
289 "Running command '{}' on {} ({ip_address}) via full cone NAT gateway {gateway}...",
290 command, vm.private_ip_addr
291 );
292 args.push(format!("{user}@{gateway}"));
293 } else {
294 debug!(
295 "Running command '{}' on {}@{}...",
296 command, user, ip_address
297 );
298 args.push(format!("{user}@{ip_address}"));
299 }
300 args.extend(command_args);
301
302 let output = run_external_command(
303 PathBuf::from("ssh"),
304 std::env::current_dir()?,
305 args,
306 suppress_output,
307 false,
308 )?;
309 Ok(output)
310 }
311
312 pub fn run_script(
313 &self,
314 ip_address: IpAddr,
315 user: &str,
316 script: PathBuf,
317 suppress_output: bool,
318 ) -> Result<Vec<String>> {
319 let file_name = script
320 .file_name()
321 .ok_or_else(|| {
322 Error::SshCommandFailed("Could not obtain file name from script path".to_string())
323 })?
324 .to_string_lossy()
325 .to_string();
326 let args = vec![
327 "-i".to_string(),
328 self.private_key_path.to_string_lossy().to_string(),
329 "-q".to_string(),
330 "-o".to_string(),
331 "BatchMode=yes".to_string(),
332 "-o".to_string(),
333 "ConnectTimeout=30".to_string(),
334 "-o".to_string(),
335 "StrictHostKeyChecking=no".to_string(),
336 script.to_string_lossy().to_string(),
337 format!("{}@{}:/tmp/{}", user, ip_address, file_name),
338 ];
339 run_external_command(
340 PathBuf::from("scp"),
341 std::env::current_dir()?,
342 args,
343 suppress_output,
344 false,
345 )
346 .map_err(|e| {
347 Error::SshCommandFailed(format!(
348 "Failed to copy script file to remote host {ip_address:?}: {e}"
349 ))
350 })?;
351
352 let args = vec![
353 "-i".to_string(),
354 self.private_key_path.to_string_lossy().to_string(),
355 "-q".to_string(),
356 "-o".to_string(),
357 "BatchMode=yes".to_string(),
358 "-o".to_string(),
359 "ConnectTimeout=30".to_string(),
360 "-o".to_string(),
361 "StrictHostKeyChecking=no".to_string(),
362 format!("{user}@{ip_address}"),
363 "bash".to_string(),
364 format!("/tmp/{file_name}"),
365 ];
366 let output = run_external_command(
367 PathBuf::from("ssh"),
368 std::env::current_dir()?,
369 args,
370 suppress_output,
371 false,
372 )
373 .map_err(|e| {
374 Error::SshCommandFailed(format!("Failed to execute command on remote host: {e}"))
375 })?;
376 Ok(output)
377 }
378}