spurs_util/
lib.rs

1//! This is a library containing a bunch of routines that I have found useful for setting up and
2//! running experiments remotely.
3//!
4//! Some of these utilities execute a sequence of steps. They require a shell as input and actually
5//! run a command remotely.
6//!
7//! The rest only construct a command that can be executed and return it to the caller _without
8//! executing anything_.
9//!
10//! There are also some utilities that don't construct or run commands. They are just useful
11//! functions for constructing commands.
12//!
13//! The `centos` and `ubuntu` submodules contain routines specifically useful for those platforms.
14
15#![doc(html_root_url = "https://docs.rs/spurs-util/0.3.1")]
16
17pub mod centos;
18pub mod ubuntu;
19
20use std::{
21    collections::{BTreeSet, HashMap, HashSet},
22    net::{IpAddr, ToSocketAddrs},
23};
24
25use spurs::{cmd, Execute, SshCommand, SshError};
26
27///////////////////////////////////////////////////////////////////////////////
28// Common useful routines
29///////////////////////////////////////////////////////////////////////////////
30
31/// Given a string, properly escape the string so that it can be passed as a command line argument
32/// to bash.
33///
34/// This is useful for passing commands to `bash -c` (e.g. through ssh).
35pub fn escape_for_bash(s: &str) -> String {
36    let mut new = String::with_capacity(s.len());
37
38    // Escape every non-alphanumeric character.
39    for c in s.chars() {
40        if c.is_ascii_alphanumeric() {
41            new.push(c);
42        } else {
43            new.push('\\');
44            new.push(c);
45        }
46    }
47
48    new
49}
50
51/// Given a host:ip address, return `(host, ip)`.
52pub fn get_host_ip<A: ToSocketAddrs>(addr: A) -> (IpAddr, u16) {
53    let addr = addr.to_socket_addrs().unwrap().next().unwrap();
54    let ip = addr.ip();
55    let port = addr.port();
56    (ip, port)
57}
58
59///////////////////////////////////////////////////////////////////////////////
60// Below are utilies that just construct (but don't run) a command.
61///////////////////////////////////////////////////////////////////////////////
62
63/// Sets the CPU scaling governor to the given governor. This requires
64/// - `cpupower` to be installed,
65/// - `sudo` priveleges,
66/// - the necessary Linux kernel modules.
67pub fn set_cpu_scaling_governor(gov: &str) -> SshCommand {
68    cmd!("sudo cpupower frequency-set -g {}", gov)
69}
70
71/// Turn off the swap device. Requires `sudo` permissions.
72pub fn swapoff(device: &str) -> SshCommand {
73    cmd!("sudo swapoff {}", device)
74}
75
76/// Turn on the swap device. Requires `sudo` permissions. Assumes the device is already formatted
77/// as a swap device (i.e. with `mkswap`).
78pub fn swapon(device: &str) -> SshCommand {
79    cmd!("sudo swapon {}", device)
80}
81
82/// Add the executing user to the given group. Requires `sudo` permissions.
83pub fn add_to_group(group: &str) -> SshCommand {
84    cmd!("sudo usermod -aG {} `whoami`", group).use_bash()
85}
86
87/// Write a new general partition table (GPT) on the given device. Requires `sudo` permissions.
88///
89/// **NOTE**: this will destroy any data on the partition!
90pub fn write_gpt(device: &str) -> SshCommand {
91    cmd!("sudo parted -a optimal {} -s -- mklabel gpt", device)
92}
93
94/// Create a new partition on the given device. Requires `sudo` permissions.
95pub fn create_partition(device: &str) -> SshCommand {
96    cmd!(
97        "sudo parted -a optimal {} -s -- mkpart primary 0% 100%",
98        device
99    )
100}
101
102///////////////////////////////////////////////////////////////////////////////
103// Below are utilies that actually run a command. These require a shell as input.
104///////////////////////////////////////////////////////////////////////////////
105
106/// Formats and mounts the given device as ext4 at the given mountpoint owned by the given user.
107/// The given partition and mountpoint are assumed to be valid (we don't check).  We will assume
108/// quite a few things for simplicity:
109/// - the disk _IS_ partitioned, but the partition is not formatted
110/// - the disk should be mounted at the mountpoint, which is a valid directory
111/// - you have `sudo` permissions
112/// - `owner` is a valid username
113///
114/// We need to be careful not to mess up the ssh keys, so we will first mount the
115/// new FS somewhere, copy over dotfiles, then unmount and mount to users...
116///
117/// In particular, this is useful for mounting a new partition as a home directory.
118///
119/// # Warning!
120///
121/// This can cause data loss and seriously mess up your system. **BE VERY CAREFUL**. Make sure you
122/// are formatting the write partition.
123///
124/// # Example
125///
126/// ```rust,ignore
127/// format_partition_as_ext4(root_shell, "/dev/sda4", "/home/foouser/")?;
128/// ```
129pub fn format_partition_as_ext4<P: AsRef<std::path::Path>>(
130    shell: &impl Execute,
131    dry_run: bool,
132    partition: &str,
133    mount: P,
134    owner: &str,
135) -> Result<(), SshError> {
136    shell.run(cmd!("lsblk").dry_run(dry_run))?;
137
138    // Make a filesystem on the first partition
139    shell.run(cmd!("sudo mkfs.ext4 {}", partition).dry_run(dry_run))?;
140
141    // Mount the FS in tmp
142    shell.run(cmd!("mkdir -p /tmp/tmp_mnt").dry_run(dry_run))?;
143    shell.run(cmd!("sudo mount -t ext4 {} /tmp/tmp_mnt", partition).dry_run(dry_run))?;
144    shell.run(cmd!("sudo chown {} /tmp/tmp_mnt", owner).dry_run(dry_run))?;
145
146    // Copy all existing files
147    shell.run(cmd!("rsync -a {}/ /tmp/tmp_mnt/", mount.as_ref().display()).dry_run(dry_run))?;
148
149    // Unmount from tmp
150    shell.run(cmd!("sync").dry_run(dry_run))?;
151    shell.run(cmd!("sudo umount /tmp/tmp_mnt").dry_run(dry_run))?;
152
153    // Mount the FS at `mount`
154    shell.run(
155        cmd!(
156            "sudo mount -t ext4 {} {}",
157            partition,
158            mount.as_ref().display()
159        )
160        .dry_run(dry_run),
161    )?;
162    shell.run(cmd!("sudo chown {} {}", owner, mount.as_ref().display()).dry_run(dry_run))?;
163
164    // Add to /etc/fstab
165    let uuid = shell
166        .run(
167            cmd!("sudo blkid -o export {} | grep '^UUID='", partition)
168                .use_bash()
169                .dry_run(dry_run),
170        )?
171        .stdout;
172    let uuid = uuid.trim();
173    shell.run(
174        cmd!(
175            r#"echo "{}    {}    ext4    defaults    0    1" | sudo tee -a /etc/fstab"#,
176            uuid,
177            mount.as_ref().display()
178        )
179        .dry_run(dry_run),
180    )?;
181
182    // Print for info
183    shell.run(cmd!("lsblk").dry_run(dry_run))?;
184
185    Ok(())
186}
187
188/// Returns a list of partitions of the given device. For example, `["sda1", "sda2"]`.
189pub fn get_partitions(
190    shell: &impl Execute,
191    device: &str,
192    dry_run: bool,
193) -> Result<HashSet<String>, SshError> {
194    Ok(shell
195        .run(cmd!("lsblk -o KNAME {}", device).dry_run(dry_run))?
196        .stdout
197        .lines()
198        .map(|line| line.trim().to_owned())
199        .skip(2)
200        .collect())
201}
202
203/// Returns a list of devices with no partitions. For example, `["sda", "sdb"]`.
204pub fn get_unpartitioned_devs(
205    shell: &impl Execute,
206    dry_run: bool,
207) -> Result<HashSet<String>, SshError> {
208    // List all devs
209    let lsblk = shell.run(cmd!("lsblk -o KNAME").dry_run(dry_run))?.stdout;
210    let mut devices: BTreeSet<&str> = lsblk.lines().map(|line| line.trim()).skip(1).collect();
211
212    // Get the partitions of each device.
213    let partitions: HashMap<_, _> = devices
214        .iter()
215        .map(|&dev| {
216            (
217                dev,
218                get_partitions(shell, &format!("/dev/{}", dev), dry_run),
219            )
220        })
221        .collect();
222
223    // Remove partitions and partitioned devices from the list of devices
224    for (dev, parts) in partitions.into_iter() {
225        let parts = parts?;
226        if !parts.is_empty() {
227            devices.remove(dev);
228            for part in parts {
229                devices.remove(part.as_str());
230            }
231        }
232    }
233
234    Ok(devices.iter().map(|&dev| dev.to_owned()).collect())
235}
236
237/// Returns the list of devices mounted and their mountpoints. For example, `[("sda2", "/")]`.
238pub fn get_mounted_devs(
239    shell: &impl Execute,
240    dry_run: bool,
241) -> Result<Vec<(String, String)>, SshError> {
242    let devices = shell
243        .run(cmd!("lsblk -o KNAME,MOUNTPOINT").dry_run(dry_run))?
244        .stdout;
245    let devices = devices.lines().skip(1);
246    let mut mounted = vec![];
247    for line in devices {
248        let split: Vec<_> = line
249            .split(char::is_whitespace)
250            .filter(|s| !s.is_empty())
251            .collect();
252
253        // Need to make sure there are no duplicates (which can happen with LVM)
254        if split.len() > 1 && mounted.iter().all(|(d, _)| d != split[0]) {
255            mounted.push((split[0].to_owned(), split[1].to_owned()));
256        }
257    }
258    Ok(mounted)
259}
260
261/// Returns the human-readable size of the devices `devs`. For example, `["477G", "500M"]`.
262pub fn get_dev_sizes(
263    shell: &impl Execute,
264    devs: Vec<&str>,
265    dry_run: bool,
266) -> Result<Vec<String>, SshError> {
267    let per_dev = devs
268        .iter()
269        .map(|dev| shell.run(cmd!("lsblk -o SIZE /dev/{}", dev).dry_run(dry_run)));
270
271    let mut sizes = vec![];
272    for size in per_dev {
273        sizes.push(size?.stdout.lines().nth(1).unwrap().trim().to_owned());
274    }
275
276    Ok(sizes)
277}
278
279/// Reboot and wait for the remote machine to come back up again. Requires `sudo`.
280pub fn reboot(shell: &mut impl Execute, dry_run: bool) -> Result<(), SshError> {
281    let _ = shell.run(cmd!("sudo reboot").dry_run(dry_run));
282
283    if !dry_run {
284        // If we try to reconnect immediately, the machine will not have gone down yet.
285        std::thread::sleep(std::time::Duration::from_secs(10));
286
287        // Attempt to reconnect.
288        shell.reconnect()?;
289    }
290
291    // Make sure it worked.
292    shell.run(cmd!("whoami").dry_run(dry_run))?;
293
294    Ok(())
295}
296
297///////////////////////////////////////////////////////////////////////////////
298// Tests
299///////////////////////////////////////////////////////////////////////////////
300
301#[cfg(test)]
302mod test {
303    use log::info;
304
305    use spurs::{Execute, SshCommand, SshError, SshOutput};
306
307    /// An `Execute` implementation for use in tests.
308    #[derive(Clone, Debug)]
309    pub struct TestSshShell {
310        pub commands: std::sync::Arc<std::sync::Mutex<Vec<SshCommand>>>,
311    }
312
313    impl TestSshShell {
314        pub fn new() -> Self {
315            // init logging if never done before...
316            use std::sync::Once;
317            static START: Once = Once::new();
318            START.call_once(|| {
319                env_logger::init();
320            });
321
322            Self {
323                commands: std::sync::Arc::new(std::sync::Mutex::new(vec![])),
324            }
325        }
326    }
327
328    impl Execute for TestSshShell {
329        fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError> {
330            info!("Test run({:#?})", cmd);
331
332            enum FakeCommand {
333                Blkid,
334                Kname1,
335                Kname2,
336                Kname3,
337                Kname4,
338                KnameMountpoint,
339                Size1,
340                Size2,
341                Size3,
342                Unknown,
343            }
344
345            let short_cmd = {
346                if cmd.cmd().contains("blkid") {
347                    FakeCommand::Blkid
348                } else if cmd.cmd().contains("KNAME /dev/foobar") {
349                    FakeCommand::Kname1
350                } else if cmd.cmd().contains("KNAME /dev/sd") {
351                    FakeCommand::Kname3
352                } else if cmd.cmd().contains("KNAME /dev/") {
353                    FakeCommand::Kname4
354                } else if cmd.cmd().contains("KNAME,MOUNTPOINT") {
355                    FakeCommand::KnameMountpoint
356                } else if cmd.cmd().contains("KNAME") {
357                    FakeCommand::Kname2
358                } else if cmd.cmd().contains("SIZE /dev/sda") {
359                    FakeCommand::Size1
360                } else if cmd.cmd().contains("SIZE /dev/sdb") {
361                    FakeCommand::Size2
362                } else if cmd.cmd().contains("SIZE /dev/sdc") {
363                    FakeCommand::Size3
364                } else {
365                    FakeCommand::Unknown
366                }
367            };
368
369            self.commands.lock().unwrap().push(cmd);
370
371            let stdout = match short_cmd {
372                FakeCommand::Blkid => "UUID=1fb958bf-de7e-428a-a0b7-a598f22e96fa\n".into(),
373                FakeCommand::Kname1 => "KNAME\nfoobar\nfoo\nbar\nbaz\n".into(),
374                FakeCommand::Kname2 => "KNAME\nfoobar\nfoo\nbar\nbaz\nsdb\nsdc".into(),
375                FakeCommand::Kname3 => "KNAME\nsdb".into(),
376                FakeCommand::Kname4 => "KNAME\nfoo".into(),
377                FakeCommand::KnameMountpoint => {
378                    "KNAME MOUNTPOINT\nfoobar\nfoo  /mnt/foo\nbar  /mnt/bar\nbaz\nsdb\nsdc".into()
379                }
380                FakeCommand::Size1 => "SIZE\n477G".into(),
381                FakeCommand::Size2 => "SIZE\n400G".into(),
382                FakeCommand::Size3 => "SIZE\n500G".into(),
383                FakeCommand::Unknown => String::new(),
384            };
385
386            info!("Output: {}", stdout);
387
388            Ok(SshOutput {
389                stdout,
390                stderr: String::new(),
391            })
392        }
393
394        fn duplicate(&self) -> Result<Self, SshError> {
395            Ok(self.clone())
396        }
397
398        fn reconnect(&mut self) -> Result<(), SshError> {
399            info!("Test reconnect");
400
401            Ok(())
402        }
403    }
404
405    macro_rules! expect_cmd_sequence {
406        ($shell:expr) => {
407            assert!($shell.commands.is_empty());
408        };
409        ($shell:expr, $($cmd:expr),+ $(,)?) => {
410            let expected: &[SshCommand] = &[$($cmd),+];
411            let locked = $shell.commands.lock().unwrap();
412
413            if locked.len() != expected.len() {
414                panic!("Number of commands run does not match expected number: \n Expected: {:#?}\nActual:  {:#?}====\n", expected, locked);
415            }
416
417            let mut fail = false;
418            let mut message = "Actual commands did not match expected commands: \n".to_owned();
419
420            for (expected, actual) in expected.iter().zip(locked.iter()) {
421                if expected != actual {
422                    fail = true;
423                    message.push_str(&format!("\nExpected: {:#?}\nActual:  {:#?}\n=====\n", expected, actual));
424                }
425            }
426
427            if fail {
428                panic!("{}", message);
429            }
430        };
431    }
432
433    #[test]
434    fn test_set_cpu_scaling_governor() {
435        assert_eq!(
436            super::set_cpu_scaling_governor("foobar"),
437            SshCommand::make_cmd(
438                "sudo cpupower frequency-set -g foobar".into(),
439                None,
440                false,
441                false,
442                false,
443                false,
444            )
445        );
446    }
447
448    #[test]
449    fn test_swapoff() {
450        assert_eq!(
451            super::swapoff("foobar"),
452            SshCommand::make_cmd(
453                "sudo swapoff foobar".into(),
454                None,
455                false,
456                false,
457                false,
458                false,
459            )
460        );
461    }
462
463    #[test]
464    fn test_swapon() {
465        assert_eq!(
466            super::swapon("foobar"),
467            SshCommand::make_cmd(
468                "sudo swapon foobar".into(),
469                None,
470                false,
471                false,
472                false,
473                false,
474            )
475        );
476    }
477
478    #[test]
479    fn test_add_to_group() {
480        assert_eq!(
481            super::add_to_group("foobar"),
482            SshCommand::make_cmd(
483                "sudo usermod -aG foobar `whoami`".into(),
484                None,
485                true, // use_bash
486                false,
487                false,
488                false,
489            )
490        );
491    }
492
493    #[test]
494    fn test_write_gpt() {
495        assert_eq!(
496            super::write_gpt("foobar"),
497            SshCommand::make_cmd(
498                "sudo parted -a optimal foobar -s -- mklabel gpt".into(),
499                None,
500                false,
501                false,
502                false,
503                false,
504            )
505        );
506    }
507
508    #[test]
509    fn test_create_partition() {
510        assert_eq!(
511            super::create_partition("foobar"),
512            SshCommand::make_cmd(
513                "sudo parted -a optimal foobar -s -- mkpart primary 0% 100%".into(),
514                None,
515                false,
516                false,
517                false,
518                false,
519            )
520        );
521    }
522
523    #[test]
524    fn test_format_partition_as_ext4() {
525        let mut shell = TestSshShell::new();
526        super::format_partition_as_ext4(&mut shell, false, "/dev/foobar", "/mnt/point/", "me")
527            .unwrap();
528        expect_cmd_sequence! {
529            shell,
530            SshCommand::make_cmd("lsblk", None, false, false, false, false),
531            SshCommand::make_cmd("sudo mkfs.ext4 /dev/foobar", None, false, false, false, false),
532            SshCommand::make_cmd("mkdir -p /tmp/tmp_mnt", None, false, false, false, false),
533            SshCommand::make_cmd("sudo mount -t ext4 /dev/foobar /tmp/tmp_mnt", None, false, false, false, false),
534            SshCommand::make_cmd("sudo chown me /tmp/tmp_mnt", None, false, false, false, false),
535            SshCommand::make_cmd("rsync -a /mnt/point// /tmp/tmp_mnt/", None, false, false, false, false),
536            SshCommand::make_cmd("sync", None, false, false, false, false),
537            SshCommand::make_cmd("sudo umount /tmp/tmp_mnt", None, false, false, false, false),
538            SshCommand::make_cmd("sudo mount -t ext4 /dev/foobar /mnt/point/", None, false, false, false, false),
539            SshCommand::make_cmd("sudo chown me /mnt/point/", None, false, false, false, false),
540            SshCommand::make_cmd("sudo blkid -o export /dev/foobar | grep '^UUID='", None, /* use_bash = */ true, false, false, false),
541            SshCommand::make_cmd(r#"echo "UUID=1fb958bf-de7e-428a-a0b7-a598f22e96fa    /mnt/point/    ext4    defaults    0    1" | sudo tee -a /etc/fstab"#, None, false, false, false, false),
542            SshCommand::make_cmd("lsblk", None, false, false, false, false),
543        };
544    }
545
546    #[test]
547    fn test_get_partitions() {
548        let mut shell = TestSshShell::new();
549        let partitions = super::get_partitions(&mut shell, "/dev/foobar", false).unwrap();
550        expect_cmd_sequence! {
551            shell,
552            SshCommand::make_cmd("lsblk -o KNAME /dev/foobar", None, false, false, false, false),
553        }
554        assert_eq!(
555            {
556                let mut set = std::collections::HashSet::new();
557                set.insert("foo".into());
558                set.insert("bar".into());
559                set.insert("baz".into());
560                set
561            },
562            partitions
563        );
564    }
565
566    #[test]
567    fn test_get_unpartitioned_devices() {
568        let mut shell = TestSshShell::new();
569        let devs = super::get_unpartitioned_devs(&mut shell, false).unwrap();
570        expect_cmd_sequence! {
571            shell,
572            SshCommand::make_cmd("lsblk -o KNAME", None, false, false, false, false),
573            SshCommand::make_cmd("lsblk -o KNAME /dev/bar", None, false, false, false, false),
574            SshCommand::make_cmd("lsblk -o KNAME /dev/baz", None, false, false, false, false),
575            SshCommand::make_cmd("lsblk -o KNAME /dev/foo", None, false, false, false, false),
576            SshCommand::make_cmd("lsblk -o KNAME /dev/foobar", None, false, false, false, false),
577            SshCommand::make_cmd("lsblk -o KNAME /dev/sdb", None, false, false, false, false),
578            SshCommand::make_cmd("lsblk -o KNAME /dev/sdc", None, false, false, false, false),
579        }
580        assert_eq!(
581            {
582                let mut set = std::collections::HashSet::new();
583                set.insert("sdb".into());
584                set.insert("sdc".into());
585                set
586            },
587            devs
588        );
589    }
590
591    #[test]
592    fn test_get_mounted_devs() {
593        let mut shell = TestSshShell::new();
594        let devs = super::get_mounted_devs(&mut shell, false).unwrap();
595        expect_cmd_sequence! {
596            shell,
597            SshCommand::make_cmd("lsblk -o KNAME,MOUNTPOINT", None, false, false, false, false),
598        }
599        assert_eq!(
600            vec![
601                ("foo".to_owned(), "/mnt/foo".to_owned()),
602                ("bar".to_owned(), "/mnt/bar".to_owned())
603            ],
604            devs
605        );
606    }
607
608    #[test]
609    fn test_get_dev_sizes() {
610        let mut shell = TestSshShell::new();
611        let devs = super::get_dev_sizes(&mut shell, vec!["sda", "sdb", "sdc"], false).unwrap();
612        expect_cmd_sequence! {
613            shell,
614            SshCommand::make_cmd("lsblk -o SIZE /dev/sda", None, false, false, false, false),
615            SshCommand::make_cmd("lsblk -o SIZE /dev/sdb", None, false, false, false, false),
616            SshCommand::make_cmd("lsblk -o SIZE /dev/sdc", None, false, false, false, false),
617        }
618        assert_eq!(vec!["477G".to_owned(), "400G".into(), "500G".into()], devs);
619    }
620
621    mod test_escape_for_bash {
622        use super::super::escape_for_bash;
623
624        #[test]
625        fn simple() {
626            const TEST_STRING: &str = "ls";
627            assert_eq!(escape_for_bash(TEST_STRING), "ls");
628        }
629
630        #[test]
631        fn more_complex() {
632            use std::process::Command;
633
634            const TEST_STRING: &str =
635                r#""Bob?!", said she, "I though you said 'I can't be there'!""#;
636
637            let out = Command::new("bash")
638                .arg("-c")
639                .arg(&format!("echo {}", escape_for_bash(TEST_STRING)))
640                .output()
641                .unwrap();
642            let out = String::from_utf8(out.stdout).unwrap();
643
644            assert_eq!(out.trim(), TEST_STRING);
645        }
646    }
647
648    #[test]
649    fn test_get_host_ip() {
650        const TEST_ADDR: &str = "localhost:2303";
651        let (addr, port) = super::get_host_ip(TEST_ADDR);
652
653        assert_eq!(addr, "127.0.0.1".parse::<std::net::IpAddr>().unwrap());
654        assert_eq!(port, 2303);
655    }
656
657    #[test]
658    fn test_reboot() {
659        let mut shell = TestSshShell::new();
660        super::reboot(&mut shell, false).unwrap();
661        expect_cmd_sequence! {
662            shell,
663            SshCommand::make_cmd("sudo reboot", None, false, false, false, false),
664            SshCommand::make_cmd("whoami", None, false, false, false, false),
665        };
666    }
667}