1use crate::{
8 ansible::provisioning::PrivateNodeProvisionInventory,
9 error::{Error, Result},
10 inventory::VirtualMachine,
11 run_external_command,
12};
13use log::debug;
14use std::{
15 collections::HashMap,
16 net::IpAddr,
17 path::PathBuf,
18 sync::{Arc, RwLock},
19};
20
21#[derive(Clone, Debug)]
22pub struct RoutedVms {
23 full_cone_private_node_nat_gateway_ip_map: HashMap<VirtualMachine, IpAddr>,
24 symmetric_private_node_nat_gateway_ip_map: HashMap<VirtualMachine, IpAddr>,
25}
26
27impl RoutedVms {
28 pub fn find_symmetric_nat_routed_node(
29 &self,
30 ip_address: &IpAddr,
31 ) -> Option<(&VirtualMachine, &IpAddr)> {
32 debug!("Check if {ip_address} is a symmetric NAT routed node...");
33 self.symmetric_private_node_nat_gateway_ip_map
34 .iter()
35 .find_map(|(private_vm, gateway_ip)| {
36 if &private_vm.public_ip_addr == ip_address {
37 Some((private_vm, gateway_ip))
38 } else {
39 None
40 }
41 })
42 .inspect(|op| {
43 debug!("Found symmetric NAT routed node: {op:?}");
44 })
45 }
46
47 pub fn find_full_cone_nat_routed_node(
48 &self,
49 ip_address: &IpAddr,
50 ) -> Option<(&VirtualMachine, &IpAddr)> {
51 debug!("Check if {ip_address} is a full cone NAT routed node...");
52 self.full_cone_private_node_nat_gateway_ip_map
53 .iter()
54 .find_map(|(private_vm, gateway_ip)| {
55 if &private_vm.public_ip_addr == ip_address {
56 Some((private_vm, gateway_ip))
57 } else {
58 None
59 }
60 })
61 .inspect(|op| {
62 debug!("Found full cone NAT routed node: {op:?}");
63 })
64 }
65}
66
67#[derive(Clone)]
68pub struct SshClient {
69 pub private_key_path: PathBuf,
70 pub routed_vms: Arc<RwLock<Option<RoutedVms>>>,
72}
73impl SshClient {
74 pub fn new(private_key_path: PathBuf) -> SshClient {
75 SshClient {
76 private_key_path,
77 routed_vms: Arc::new(RwLock::new(None)),
78 }
79 }
80
81 pub fn set_full_cone_nat_routed_vms(
84 &self,
85 private_node_vms: &[VirtualMachine],
86 nat_gateway_vms: &[VirtualMachine],
87 ) -> Result<()> {
88 let private_node_nat_gateway_map =
89 PrivateNodeProvisionInventory::match_private_node_vm_and_gateway_vm(
90 private_node_vms,
91 nat_gateway_vms,
92 )?;
93 let full_cone_private_node_nat_gateway_ip_map = private_node_nat_gateway_map
94 .into_iter()
95 .map(|(private_node_vm, nat_gateway_vm)| {
96 (private_node_vm, nat_gateway_vm.public_ip_addr)
97 })
98 .collect::<HashMap<_, _>>();
99 let mut write_access = self.routed_vms.write().map_err(|err| {
100 log::error!("Failed to set routed VMs: {err}");
101 Error::SshSettingsRwLockError
102 })?;
103
104 debug!("Full Cone Private Routed VMs have been set to: {full_cone_private_node_nat_gateway_ip_map:?}");
105 match write_access.as_mut() {
106 Some(routed_vms) => {
107 routed_vms.full_cone_private_node_nat_gateway_ip_map =
108 full_cone_private_node_nat_gateway_ip_map;
109 }
110 None => {
111 *write_access = Some(RoutedVms {
112 full_cone_private_node_nat_gateway_ip_map,
113 symmetric_private_node_nat_gateway_ip_map: HashMap::new(),
114 });
115 }
116 }
117
118 Ok(())
119 }
120
121 pub fn set_symmetric_nat_routed_vms(
124 &self,
125 private_node_vms: &[VirtualMachine],
126 nat_gateway_vms: &[VirtualMachine],
127 ) -> Result<()> {
128 let private_node_nat_gateway_map =
129 PrivateNodeProvisionInventory::match_private_node_vm_and_gateway_vm(
130 private_node_vms,
131 nat_gateway_vms,
132 )?;
133 let symmetric_private_node_nat_gateway_ip_map = private_node_nat_gateway_map
134 .into_iter()
135 .map(|(private_node_vm, nat_gateway_vm)| {
136 (private_node_vm, nat_gateway_vm.public_ip_addr)
137 })
138 .collect::<HashMap<_, _>>();
139 let mut write_access = self.routed_vms.write().map_err(|err| {
140 log::error!("Failed to set routed VMs: {err}");
141 Error::SshSettingsRwLockError
142 })?;
143 debug!("Symmetric Private node Routed VMs have been set to: {symmetric_private_node_nat_gateway_ip_map:?}");
144
145 match write_access.as_mut() {
146 Some(routed_vms) => {
147 routed_vms.symmetric_private_node_nat_gateway_ip_map =
148 symmetric_private_node_nat_gateway_ip_map;
149 }
150 None => {
151 *write_access = Some(RoutedVms {
152 full_cone_private_node_nat_gateway_ip_map: HashMap::new(),
153 symmetric_private_node_nat_gateway_ip_map,
154 });
155 }
156 }
157
158 Ok(())
159 }
160
161 pub fn get_private_key_path(&self) -> PathBuf {
162 self.private_key_path.clone()
163 }
164
165 pub fn wait_for_ssh_availability(&self, ip_address: &IpAddr, user: &str) -> Result<()> {
166 let mut args = vec![
167 "-i".to_string(),
168 self.private_key_path.to_string_lossy().to_string(),
169 "-q".to_string(),
170 "-o".to_string(),
171 "BatchMode=yes".to_string(),
172 "-o".to_string(),
173 "ConnectTimeout=5".to_string(),
174 "-o".to_string(),
175 "StrictHostKeyChecking=no".to_string(),
176 ];
177 let routed_vm_read = self.routed_vms.read().map_err(|err| {
178 log::error!("Failed to read routed VMs: {err}");
179 Error::SshSettingsRwLockError
180 })?;
181 if let Some((vm, gateway_ip)) = routed_vm_read
182 .as_ref()
183 .and_then(|routed_vms| routed_vms.find_symmetric_nat_routed_node(ip_address))
184 {
185 println!(
186 "Checking for SSH availability at {} ({ip_address}) via symmetric NAT gateway {gateway_ip}...",
187 vm.private_ip_addr
188 );
189 debug!(
190 "Checking for SSH availability at {} ({ip_address}) via symmetric NAT gateway {gateway_ip}...",
191 vm.private_ip_addr
192 );
193 args.push("-o".to_string());
194 args.push(format!(
195 "ProxyCommand=ssh -i {} -W %h:%p {}@{}",
196 self.private_key_path.to_string_lossy(),
197 user,
198 gateway_ip
199 ));
200 args.push(format!("{user}@{}", vm.private_ip_addr));
201 } else if let Some((vm, gateway_ip)) = routed_vm_read
202 .as_ref()
203 .and_then(|routed_vms| routed_vms.find_full_cone_nat_routed_node(ip_address))
204 {
205 println!(
206 "Checking for SSH availability at {} ({ip_address}) via Full Cone NAT gateway {gateway_ip}...",
207 vm.private_ip_addr,
208 );
209 debug!(
210 "Checking for SSH availability at {} ({ip_address}) via Full Cone NAT gateway {gateway_ip}...",
211 vm.private_ip_addr,
212 );
213 args.push(format!("{user}@{gateway_ip}"));
214 } else {
215 println!("Checking for SSH availability at {ip_address}...");
216 args.push(format!("{user}@{ip_address}"));
217 }
218 args.push("bash".to_string());
219 args.push("--version".to_string());
220
221 let mut retries = 0;
222 let max_retries = 10;
223 while retries < max_retries {
224 let result = run_external_command(
225 PathBuf::from("ssh"),
226 std::env::current_dir()?,
227 args.clone(),
228 false,
229 false,
230 );
231 if result.is_ok() {
232 println!("SSH is available.");
233 return Ok(());
234 } else {
235 retries += 1;
236 println!("SSH is still unavailable after {retries} attempts.");
237 println!("Will sleep for 5 seconds then retry.");
238 std::thread::sleep(std::time::Duration::from_secs(5));
239 }
240 }
241
242 println!("The maximum number of connection retry attempts has been exceeded.");
243 Err(Error::SshUnavailable)
244 }
245
246 pub fn run_command(
247 &self,
248 ip_address: &IpAddr,
249 user: &str,
250 command: &str,
251 suppress_output: bool,
252 ) -> Result<Vec<String>> {
253 let command_args: Vec<String> = command.split_whitespace().map(String::from).collect();
254 let mut args = vec![
255 "-i".to_string(),
256 self.private_key_path.to_string_lossy().to_string(),
257 "-q".to_string(),
258 "-o".to_string(),
259 "BatchMode=yes".to_string(),
260 "-o".to_string(),
261 "ConnectTimeout=30".to_string(),
262 "-o".to_string(),
263 "StrictHostKeyChecking=no".to_string(),
264 ];
265 let routed_vm_read = self.routed_vms.read().map_err(|err| {
266 log::error!("Failed to read routed VMs: {err}");
267 Error::SshSettingsRwLockError
268 })?;
269
270 if let Some((vm, gateway)) = routed_vm_read
271 .as_ref()
272 .and_then(|routed_vms| routed_vms.find_symmetric_nat_routed_node(ip_address))
273 {
274 debug!(
275 "Running command '{}' on {} ({ip_address}) via symmetric NAT gateway {gateway}...",
276 command, vm.private_ip_addr
277 );
278 args.push("-o".to_string());
279 args.push(format!(
280 "ProxyCommand=ssh -i {} -W %h:%p {user}@{gateway}",
281 self.private_key_path.to_string_lossy(),
282 ));
283 args.push(format!("{user}@{}", vm.private_ip_addr));
284 } else if let Some((vm, gateway)) = routed_vm_read
285 .as_ref()
286 .and_then(|routed_vms| routed_vms.find_full_cone_nat_routed_node(ip_address))
287 {
288 debug!(
289 "Running command '{}' on {} ({ip_address}) via full cone NAT gateway {gateway}...",
290 command, vm.private_ip_addr
291 );
292 args.push(format!("{user}@{gateway}"));
293 } else {
294 debug!("Running command '{command}' on {user}@{ip_address}...");
295 args.push(format!("{user}@{ip_address}"));
296 }
297 args.extend(command_args);
298
299 let output = run_external_command(
300 PathBuf::from("ssh"),
301 std::env::current_dir()?,
302 args,
303 suppress_output,
304 false,
305 )?;
306 Ok(output)
307 }
308
309 pub fn run_script(
310 &self,
311 ip_address: IpAddr,
312 user: &str,
313 script: PathBuf,
314 suppress_output: bool,
315 ) -> Result<Vec<String>> {
316 let file_name = script
317 .file_name()
318 .ok_or_else(|| {
319 Error::SshCommandFailed("Could not obtain file name from script path".to_string())
320 })?
321 .to_string_lossy()
322 .to_string();
323 let args = vec![
324 "-i".to_string(),
325 self.private_key_path.to_string_lossy().to_string(),
326 "-q".to_string(),
327 "-o".to_string(),
328 "BatchMode=yes".to_string(),
329 "-o".to_string(),
330 "ConnectTimeout=30".to_string(),
331 "-o".to_string(),
332 "StrictHostKeyChecking=no".to_string(),
333 script.to_string_lossy().to_string(),
334 format!("{}@{}:/tmp/{}", user, ip_address, file_name),
335 ];
336 run_external_command(
337 PathBuf::from("scp"),
338 std::env::current_dir()?,
339 args,
340 suppress_output,
341 false,
342 )
343 .map_err(|e| {
344 Error::SshCommandFailed(format!(
345 "Failed to copy script file to remote host {ip_address:?}: {e}"
346 ))
347 })?;
348
349 let args = vec![
350 "-i".to_string(),
351 self.private_key_path.to_string_lossy().to_string(),
352 "-q".to_string(),
353 "-o".to_string(),
354 "BatchMode=yes".to_string(),
355 "-o".to_string(),
356 "ConnectTimeout=30".to_string(),
357 "-o".to_string(),
358 "StrictHostKeyChecking=no".to_string(),
359 format!("{user}@{ip_address}"),
360 "bash".to_string(),
361 format!("/tmp/{file_name}"),
362 ];
363 let output = run_external_command(
364 PathBuf::from("ssh"),
365 std::env::current_dir()?,
366 args,
367 suppress_output,
368 false,
369 )
370 .map_err(|e| {
371 Error::SshCommandFailed(format!("Failed to execute command on remote host: {e}"))
372 })?;
373 Ok(output)
374 }
375}