1#![doc(html_root_url = "https://docs.rs/spurs-util/0.3.1")]
16
17pub mod centos;
18pub mod ubuntu;
19
20use std::{
21 collections::{BTreeSet, HashMap, HashSet},
22 net::{IpAddr, ToSocketAddrs},
23};
24
25use spurs::{cmd, Execute, SshCommand, SshError};
26
27pub fn escape_for_bash(s: &str) -> String {
36 let mut new = String::with_capacity(s.len());
37
38 for c in s.chars() {
40 if c.is_ascii_alphanumeric() {
41 new.push(c);
42 } else {
43 new.push('\\');
44 new.push(c);
45 }
46 }
47
48 new
49}
50
51pub fn get_host_ip<A: ToSocketAddrs>(addr: A) -> (IpAddr, u16) {
53 let addr = addr.to_socket_addrs().unwrap().next().unwrap();
54 let ip = addr.ip();
55 let port = addr.port();
56 (ip, port)
57}
58
59pub fn set_cpu_scaling_governor(gov: &str) -> SshCommand {
68 cmd!("sudo cpupower frequency-set -g {}", gov)
69}
70
71pub fn swapoff(device: &str) -> SshCommand {
73 cmd!("sudo swapoff {}", device)
74}
75
76pub fn swapon(device: &str) -> SshCommand {
79 cmd!("sudo swapon {}", device)
80}
81
82pub fn add_to_group(group: &str) -> SshCommand {
84 cmd!("sudo usermod -aG {} `whoami`", group).use_bash()
85}
86
87pub fn write_gpt(device: &str) -> SshCommand {
91 cmd!("sudo parted -a optimal {} -s -- mklabel gpt", device)
92}
93
94pub fn create_partition(device: &str) -> SshCommand {
96 cmd!(
97 "sudo parted -a optimal {} -s -- mkpart primary 0% 100%",
98 device
99 )
100}
101
102pub fn format_partition_as_ext4<P: AsRef<std::path::Path>>(
130 shell: &impl Execute,
131 dry_run: bool,
132 partition: &str,
133 mount: P,
134 owner: &str,
135) -> Result<(), SshError> {
136 shell.run(cmd!("lsblk").dry_run(dry_run))?;
137
138 shell.run(cmd!("sudo mkfs.ext4 {}", partition).dry_run(dry_run))?;
140
141 shell.run(cmd!("mkdir -p /tmp/tmp_mnt").dry_run(dry_run))?;
143 shell.run(cmd!("sudo mount -t ext4 {} /tmp/tmp_mnt", partition).dry_run(dry_run))?;
144 shell.run(cmd!("sudo chown {} /tmp/tmp_mnt", owner).dry_run(dry_run))?;
145
146 shell.run(cmd!("rsync -a {}/ /tmp/tmp_mnt/", mount.as_ref().display()).dry_run(dry_run))?;
148
149 shell.run(cmd!("sync").dry_run(dry_run))?;
151 shell.run(cmd!("sudo umount /tmp/tmp_mnt").dry_run(dry_run))?;
152
153 shell.run(
155 cmd!(
156 "sudo mount -t ext4 {} {}",
157 partition,
158 mount.as_ref().display()
159 )
160 .dry_run(dry_run),
161 )?;
162 shell.run(cmd!("sudo chown {} {}", owner, mount.as_ref().display()).dry_run(dry_run))?;
163
164 let uuid = shell
166 .run(
167 cmd!("sudo blkid -o export {} | grep '^UUID='", partition)
168 .use_bash()
169 .dry_run(dry_run),
170 )?
171 .stdout;
172 let uuid = uuid.trim();
173 shell.run(
174 cmd!(
175 r#"echo "{} {} ext4 defaults 0 1" | sudo tee -a /etc/fstab"#,
176 uuid,
177 mount.as_ref().display()
178 )
179 .dry_run(dry_run),
180 )?;
181
182 shell.run(cmd!("lsblk").dry_run(dry_run))?;
184
185 Ok(())
186}
187
188pub fn get_partitions(
190 shell: &impl Execute,
191 device: &str,
192 dry_run: bool,
193) -> Result<HashSet<String>, SshError> {
194 Ok(shell
195 .run(cmd!("lsblk -o KNAME {}", device).dry_run(dry_run))?
196 .stdout
197 .lines()
198 .map(|line| line.trim().to_owned())
199 .skip(2)
200 .collect())
201}
202
203pub fn get_unpartitioned_devs(
205 shell: &impl Execute,
206 dry_run: bool,
207) -> Result<HashSet<String>, SshError> {
208 let lsblk = shell.run(cmd!("lsblk -o KNAME").dry_run(dry_run))?.stdout;
210 let mut devices: BTreeSet<&str> = lsblk.lines().map(|line| line.trim()).skip(1).collect();
211
212 let partitions: HashMap<_, _> = devices
214 .iter()
215 .map(|&dev| {
216 (
217 dev,
218 get_partitions(shell, &format!("/dev/{}", dev), dry_run),
219 )
220 })
221 .collect();
222
223 for (dev, parts) in partitions.into_iter() {
225 let parts = parts?;
226 if !parts.is_empty() {
227 devices.remove(dev);
228 for part in parts {
229 devices.remove(part.as_str());
230 }
231 }
232 }
233
234 Ok(devices.iter().map(|&dev| dev.to_owned()).collect())
235}
236
237pub fn get_mounted_devs(
239 shell: &impl Execute,
240 dry_run: bool,
241) -> Result<Vec<(String, String)>, SshError> {
242 let devices = shell
243 .run(cmd!("lsblk -o KNAME,MOUNTPOINT").dry_run(dry_run))?
244 .stdout;
245 let devices = devices.lines().skip(1);
246 let mut mounted = vec![];
247 for line in devices {
248 let split: Vec<_> = line
249 .split(char::is_whitespace)
250 .filter(|s| !s.is_empty())
251 .collect();
252
253 if split.len() > 1 && mounted.iter().all(|(d, _)| d != split[0]) {
255 mounted.push((split[0].to_owned(), split[1].to_owned()));
256 }
257 }
258 Ok(mounted)
259}
260
261pub fn get_dev_sizes(
263 shell: &impl Execute,
264 devs: Vec<&str>,
265 dry_run: bool,
266) -> Result<Vec<String>, SshError> {
267 let per_dev = devs
268 .iter()
269 .map(|dev| shell.run(cmd!("lsblk -o SIZE /dev/{}", dev).dry_run(dry_run)));
270
271 let mut sizes = vec![];
272 for size in per_dev {
273 sizes.push(size?.stdout.lines().nth(1).unwrap().trim().to_owned());
274 }
275
276 Ok(sizes)
277}
278
279pub fn reboot(shell: &mut impl Execute, dry_run: bool) -> Result<(), SshError> {
281 let _ = shell.run(cmd!("sudo reboot").dry_run(dry_run));
282
283 if !dry_run {
284 std::thread::sleep(std::time::Duration::from_secs(10));
286
287 shell.reconnect()?;
289 }
290
291 shell.run(cmd!("whoami").dry_run(dry_run))?;
293
294 Ok(())
295}
296
297#[cfg(test)]
302mod test {
303 use log::info;
304
305 use spurs::{Execute, SshCommand, SshError, SshOutput};
306
307 #[derive(Clone, Debug)]
309 pub struct TestSshShell {
310 pub commands: std::sync::Arc<std::sync::Mutex<Vec<SshCommand>>>,
311 }
312
313 impl TestSshShell {
314 pub fn new() -> Self {
315 use std::sync::Once;
317 static START: Once = Once::new();
318 START.call_once(|| {
319 env_logger::init();
320 });
321
322 Self {
323 commands: std::sync::Arc::new(std::sync::Mutex::new(vec![])),
324 }
325 }
326 }
327
328 impl Execute for TestSshShell {
329 fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError> {
330 info!("Test run({:#?})", cmd);
331
332 enum FakeCommand {
333 Blkid,
334 Kname1,
335 Kname2,
336 Kname3,
337 Kname4,
338 KnameMountpoint,
339 Size1,
340 Size2,
341 Size3,
342 Unknown,
343 }
344
345 let short_cmd = {
346 if cmd.cmd().contains("blkid") {
347 FakeCommand::Blkid
348 } else if cmd.cmd().contains("KNAME /dev/foobar") {
349 FakeCommand::Kname1
350 } else if cmd.cmd().contains("KNAME /dev/sd") {
351 FakeCommand::Kname3
352 } else if cmd.cmd().contains("KNAME /dev/") {
353 FakeCommand::Kname4
354 } else if cmd.cmd().contains("KNAME,MOUNTPOINT") {
355 FakeCommand::KnameMountpoint
356 } else if cmd.cmd().contains("KNAME") {
357 FakeCommand::Kname2
358 } else if cmd.cmd().contains("SIZE /dev/sda") {
359 FakeCommand::Size1
360 } else if cmd.cmd().contains("SIZE /dev/sdb") {
361 FakeCommand::Size2
362 } else if cmd.cmd().contains("SIZE /dev/sdc") {
363 FakeCommand::Size3
364 } else {
365 FakeCommand::Unknown
366 }
367 };
368
369 self.commands.lock().unwrap().push(cmd);
370
371 let stdout = match short_cmd {
372 FakeCommand::Blkid => "UUID=1fb958bf-de7e-428a-a0b7-a598f22e96fa\n".into(),
373 FakeCommand::Kname1 => "KNAME\nfoobar\nfoo\nbar\nbaz\n".into(),
374 FakeCommand::Kname2 => "KNAME\nfoobar\nfoo\nbar\nbaz\nsdb\nsdc".into(),
375 FakeCommand::Kname3 => "KNAME\nsdb".into(),
376 FakeCommand::Kname4 => "KNAME\nfoo".into(),
377 FakeCommand::KnameMountpoint => {
378 "KNAME MOUNTPOINT\nfoobar\nfoo /mnt/foo\nbar /mnt/bar\nbaz\nsdb\nsdc".into()
379 }
380 FakeCommand::Size1 => "SIZE\n477G".into(),
381 FakeCommand::Size2 => "SIZE\n400G".into(),
382 FakeCommand::Size3 => "SIZE\n500G".into(),
383 FakeCommand::Unknown => String::new(),
384 };
385
386 info!("Output: {}", stdout);
387
388 Ok(SshOutput {
389 stdout,
390 stderr: String::new(),
391 })
392 }
393
394 fn duplicate(&self) -> Result<Self, SshError> {
395 Ok(self.clone())
396 }
397
398 fn reconnect(&mut self) -> Result<(), SshError> {
399 info!("Test reconnect");
400
401 Ok(())
402 }
403 }
404
405 macro_rules! expect_cmd_sequence {
406 ($shell:expr) => {
407 assert!($shell.commands.is_empty());
408 };
409 ($shell:expr, $($cmd:expr),+ $(,)?) => {
410 let expected: &[SshCommand] = &[$($cmd),+];
411 let locked = $shell.commands.lock().unwrap();
412
413 if locked.len() != expected.len() {
414 panic!("Number of commands run does not match expected number: \n Expected: {:#?}\nActual: {:#?}====\n", expected, locked);
415 }
416
417 let mut fail = false;
418 let mut message = "Actual commands did not match expected commands: \n".to_owned();
419
420 for (expected, actual) in expected.iter().zip(locked.iter()) {
421 if expected != actual {
422 fail = true;
423 message.push_str(&format!("\nExpected: {:#?}\nActual: {:#?}\n=====\n", expected, actual));
424 }
425 }
426
427 if fail {
428 panic!("{}", message);
429 }
430 };
431 }
432
433 #[test]
434 fn test_set_cpu_scaling_governor() {
435 assert_eq!(
436 super::set_cpu_scaling_governor("foobar"),
437 SshCommand::make_cmd(
438 "sudo cpupower frequency-set -g foobar".into(),
439 None,
440 false,
441 false,
442 false,
443 false,
444 )
445 );
446 }
447
448 #[test]
449 fn test_swapoff() {
450 assert_eq!(
451 super::swapoff("foobar"),
452 SshCommand::make_cmd(
453 "sudo swapoff foobar".into(),
454 None,
455 false,
456 false,
457 false,
458 false,
459 )
460 );
461 }
462
463 #[test]
464 fn test_swapon() {
465 assert_eq!(
466 super::swapon("foobar"),
467 SshCommand::make_cmd(
468 "sudo swapon foobar".into(),
469 None,
470 false,
471 false,
472 false,
473 false,
474 )
475 );
476 }
477
478 #[test]
479 fn test_add_to_group() {
480 assert_eq!(
481 super::add_to_group("foobar"),
482 SshCommand::make_cmd(
483 "sudo usermod -aG foobar `whoami`".into(),
484 None,
485 true, false,
487 false,
488 false,
489 )
490 );
491 }
492
493 #[test]
494 fn test_write_gpt() {
495 assert_eq!(
496 super::write_gpt("foobar"),
497 SshCommand::make_cmd(
498 "sudo parted -a optimal foobar -s -- mklabel gpt".into(),
499 None,
500 false,
501 false,
502 false,
503 false,
504 )
505 );
506 }
507
508 #[test]
509 fn test_create_partition() {
510 assert_eq!(
511 super::create_partition("foobar"),
512 SshCommand::make_cmd(
513 "sudo parted -a optimal foobar -s -- mkpart primary 0% 100%".into(),
514 None,
515 false,
516 false,
517 false,
518 false,
519 )
520 );
521 }
522
523 #[test]
524 fn test_format_partition_as_ext4() {
525 let mut shell = TestSshShell::new();
526 super::format_partition_as_ext4(&mut shell, false, "/dev/foobar", "/mnt/point/", "me")
527 .unwrap();
528 expect_cmd_sequence! {
529 shell,
530 SshCommand::make_cmd("lsblk", None, false, false, false, false),
531 SshCommand::make_cmd("sudo mkfs.ext4 /dev/foobar", None, false, false, false, false),
532 SshCommand::make_cmd("mkdir -p /tmp/tmp_mnt", None, false, false, false, false),
533 SshCommand::make_cmd("sudo mount -t ext4 /dev/foobar /tmp/tmp_mnt", None, false, false, false, false),
534 SshCommand::make_cmd("sudo chown me /tmp/tmp_mnt", None, false, false, false, false),
535 SshCommand::make_cmd("rsync -a /mnt/point// /tmp/tmp_mnt/", None, false, false, false, false),
536 SshCommand::make_cmd("sync", None, false, false, false, false),
537 SshCommand::make_cmd("sudo umount /tmp/tmp_mnt", None, false, false, false, false),
538 SshCommand::make_cmd("sudo mount -t ext4 /dev/foobar /mnt/point/", None, false, false, false, false),
539 SshCommand::make_cmd("sudo chown me /mnt/point/", None, false, false, false, false),
540 SshCommand::make_cmd("sudo blkid -o export /dev/foobar | grep '^UUID='", None, true, false, false, false),
541 SshCommand::make_cmd(r#"echo "UUID=1fb958bf-de7e-428a-a0b7-a598f22e96fa /mnt/point/ ext4 defaults 0 1" | sudo tee -a /etc/fstab"#, None, false, false, false, false),
542 SshCommand::make_cmd("lsblk", None, false, false, false, false),
543 };
544 }
545
546 #[test]
547 fn test_get_partitions() {
548 let mut shell = TestSshShell::new();
549 let partitions = super::get_partitions(&mut shell, "/dev/foobar", false).unwrap();
550 expect_cmd_sequence! {
551 shell,
552 SshCommand::make_cmd("lsblk -o KNAME /dev/foobar", None, false, false, false, false),
553 }
554 assert_eq!(
555 {
556 let mut set = std::collections::HashSet::new();
557 set.insert("foo".into());
558 set.insert("bar".into());
559 set.insert("baz".into());
560 set
561 },
562 partitions
563 );
564 }
565
566 #[test]
567 fn test_get_unpartitioned_devices() {
568 let mut shell = TestSshShell::new();
569 let devs = super::get_unpartitioned_devs(&mut shell, false).unwrap();
570 expect_cmd_sequence! {
571 shell,
572 SshCommand::make_cmd("lsblk -o KNAME", None, false, false, false, false),
573 SshCommand::make_cmd("lsblk -o KNAME /dev/bar", None, false, false, false, false),
574 SshCommand::make_cmd("lsblk -o KNAME /dev/baz", None, false, false, false, false),
575 SshCommand::make_cmd("lsblk -o KNAME /dev/foo", None, false, false, false, false),
576 SshCommand::make_cmd("lsblk -o KNAME /dev/foobar", None, false, false, false, false),
577 SshCommand::make_cmd("lsblk -o KNAME /dev/sdb", None, false, false, false, false),
578 SshCommand::make_cmd("lsblk -o KNAME /dev/sdc", None, false, false, false, false),
579 }
580 assert_eq!(
581 {
582 let mut set = std::collections::HashSet::new();
583 set.insert("sdb".into());
584 set.insert("sdc".into());
585 set
586 },
587 devs
588 );
589 }
590
591 #[test]
592 fn test_get_mounted_devs() {
593 let mut shell = TestSshShell::new();
594 let devs = super::get_mounted_devs(&mut shell, false).unwrap();
595 expect_cmd_sequence! {
596 shell,
597 SshCommand::make_cmd("lsblk -o KNAME,MOUNTPOINT", None, false, false, false, false),
598 }
599 assert_eq!(
600 vec![
601 ("foo".to_owned(), "/mnt/foo".to_owned()),
602 ("bar".to_owned(), "/mnt/bar".to_owned())
603 ],
604 devs
605 );
606 }
607
608 #[test]
609 fn test_get_dev_sizes() {
610 let mut shell = TestSshShell::new();
611 let devs = super::get_dev_sizes(&mut shell, vec!["sda", "sdb", "sdc"], false).unwrap();
612 expect_cmd_sequence! {
613 shell,
614 SshCommand::make_cmd("lsblk -o SIZE /dev/sda", None, false, false, false, false),
615 SshCommand::make_cmd("lsblk -o SIZE /dev/sdb", None, false, false, false, false),
616 SshCommand::make_cmd("lsblk -o SIZE /dev/sdc", None, false, false, false, false),
617 }
618 assert_eq!(vec!["477G".to_owned(), "400G".into(), "500G".into()], devs);
619 }
620
621 mod test_escape_for_bash {
622 use super::super::escape_for_bash;
623
624 #[test]
625 fn simple() {
626 const TEST_STRING: &str = "ls";
627 assert_eq!(escape_for_bash(TEST_STRING), "ls");
628 }
629
630 #[test]
631 fn more_complex() {
632 use std::process::Command;
633
634 const TEST_STRING: &str =
635 r#""Bob?!", said she, "I though you said 'I can't be there'!""#;
636
637 let out = Command::new("bash")
638 .arg("-c")
639 .arg(&format!("echo {}", escape_for_bash(TEST_STRING)))
640 .output()
641 .unwrap();
642 let out = String::from_utf8(out.stdout).unwrap();
643
644 assert_eq!(out.trim(), TEST_STRING);
645 }
646 }
647
648 #[test]
649 fn test_get_host_ip() {
650 const TEST_ADDR: &str = "localhost:2303";
651 let (addr, port) = super::get_host_ip(TEST_ADDR);
652
653 assert_eq!(addr, "127.0.0.1".parse::<std::net::IpAddr>().unwrap());
654 assert_eq!(port, 2303);
655 }
656
657 #[test]
658 fn test_reboot() {
659 let mut shell = TestSshShell::new();
660 super::reboot(&mut shell, false).unwrap();
661 expect_cmd_sequence! {
662 shell,
663 SshCommand::make_cmd("sudo reboot", None, false, false, false, false),
664 SshCommand::make_cmd("whoami", None, false, false, false, false),
665 };
666 }
667}