spurs/
lib.rs

1//! `spurs` is a library for executing commands remotely over SSH. I created it in an effort to
2//! automate setup and experimentation on a cluster of machines.
3//!
4//! `spurs` prioritizes ergonomics over performance. It is _not_ a high-performance way of getting
5//! stuff done in a cluster.
6//!
7//! `spurs` takes heavy inspiration from the python
8//! [spur.py](https://github.com/mwilliamson/spur.py) library, which is amazing. At some point,
9//! though, my scripts were so big that python was getting in my way, so I created `spurs` to allow
10//! me to build my cluster setup/experiments scripts/framework in rust, with much greater
11//! productivity and refactorability.
12
13#![doc(html_root_url = "https://docs.rs/spurs/0.9.2")]
14
15use std::{
16    io::Read,
17    net::{SocketAddr, TcpStream, ToSocketAddrs},
18    path::{Path, PathBuf},
19    sync::{Arc, Mutex},
20    thread::JoinHandle,
21    time::Duration,
22};
23
24use log::{debug, info, trace};
25
26use ssh2::Session;
27
28/// The default timeout for the TCP stream of a SSH connection.
29const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
30
31#[derive(Debug, PartialEq, Eq)]
32pub struct SshCommand {
33    cmd: String,
34    cwd: Option<PathBuf>,
35    use_bash: bool,
36    allow_error: bool,
37    dry_run: bool,
38    no_pty: bool,
39}
40
41#[derive(Debug)]
42pub struct SshOutput {
43    pub stdout: String,
44    pub stderr: String,
45}
46
47/// An error type representing things that could possibly go wrong when using an SshShell.
48#[derive(Debug)]
49pub enum SshError {
50    /// Unable to find the private key at the given path.
51    KeyNotFound { file: String },
52
53    /// SSH authentication failed.
54    AuthFailed { key: std::path::PathBuf },
55
56    /// The comand run over SSH returned with a non-zero exit code.
57    NonZeroExit { cmd: String, exit: i32 },
58
59    /// An SSH error occurred.
60    SshError { error: ssh2::Error },
61
62    /// An I/O error occurred.
63    IoError { error: std::io::Error },
64}
65
66/// Represents a connection via SSH to a particular source.
67pub struct SshShell {
68    // The TCP stream needs to be in the struct to keep it alive while the session is active.
69    tcp: TcpStream,
70    username: String,
71    key: PathBuf,
72    remote_name: String, // used for printing
73    remote: SocketAddr,
74    sess: Arc<Mutex<Session>>,
75    dry_run_mode: bool,
76}
77
78/// A handle for a spawned remote command.
79pub struct SshSpawnHandle {
80    thread_handle: JoinHandle<(SshShell, Result<SshOutput, SshError>)>,
81}
82
83/// A trait representing types that can run an `SshCommand`.
84pub trait Execute: Sized {
85    /// Run a command on the remote machine, blocking until the command completes.
86    ///
87    /// Note that command using `sudo` will hang indefinitely if `sudo` asks for a password.
88    fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError>;
89
90    /// Attempts to create a new `Self` with similar credentials to `self` but using an independent
91    /// connection. This is useful for running multiple commands in parallel without needing to
92    /// pass around the parameters everywhere.
93    fn duplicate(&self) -> Result<Self, SshError>;
94
95    /// Attempt to reconnect to the remote until it reconnects (possibly indefinitely).
96    fn reconnect(&mut self) -> Result<(), SshError>;
97}
98
99impl std::fmt::Display for SshError {
100    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
101        match self {
102            SshError::KeyNotFound { file } => write!(f, "no such key: {}", file),
103            SshError::AuthFailed { key } => {
104                write!(f, "authentication failed with private key: {:?}", key)
105            }
106            SshError::NonZeroExit { cmd, exit } => {
107                write!(f, "non-zero exit ({}) for command: {}", exit, cmd)
108            }
109            SshError::SshError { error } => write!(f, "{}", error),
110            SshError::IoError { error } => write!(f, "{}", error),
111        }
112    }
113}
114
115impl std::error::Error for SshError {}
116
117impl std::convert::From<ssh2::Error> for SshError {
118    fn from(error: ssh2::Error) -> Self {
119        SshError::SshError { error }
120    }
121}
122
123impl std::convert::From<std::io::Error> for SshError {
124    fn from(error: std::io::Error) -> Self {
125        SshError::IoError { error }
126    }
127}
128
129impl SshCommand {
130    /// Create a new builder for the given command with default options.
131    pub fn new(cmd: &str) -> Self {
132        SshCommand {
133            cmd: cmd.to_owned(),
134            cwd: None,
135            use_bash: false,
136            allow_error: false,
137            dry_run: false,
138            no_pty: false,
139        }
140    }
141
142    /// Change the current working directory to `cwd` before executing.
143    pub fn cwd<P: AsRef<Path>>(self, cwd: P) -> Self {
144        SshCommand {
145            cwd: Some(cwd.as_ref().to_owned()),
146            ..self
147        }
148    }
149
150    /// Execute using bash.
151    pub fn use_bash(self) -> Self {
152        SshCommand {
153            use_bash: true,
154            ..self
155        }
156    }
157
158    /// Allow a non-zero exit code. Normally, an error would occur and we would return early.
159    pub fn allow_error(self) -> Self {
160        SshCommand {
161            allow_error: true,
162            ..self
163        }
164    }
165
166    /// Don't actually execute any command remotely. Just print the command that would be executed
167    /// and return success. Note that we still connect to the remote. This is useful for debugging.
168    pub fn dry_run(self, is_dry: bool) -> Self {
169        SshCommand {
170            dry_run: is_dry,
171            ..self
172        }
173    }
174
175    /// Don't request a psuedo-terminal (pty). It turns out that some commands behave differently
176    /// with a pty. I'm not really sure what causes this.
177    ///
178    /// NOTE: You need a pty for `sudo`.
179    pub fn no_pty(self) -> Self {
180        SshCommand {
181            no_pty: true,
182            ..self
183        }
184    }
185
186    /// Helper for tests that makes a `SshCommand` with the given values.
187    #[cfg(any(test, feature = "test"))]
188    pub fn make_cmd(
189        cmd: &str,
190        cwd: Option<PathBuf>,
191        use_bash: bool,
192        allow_error: bool,
193        dry_run: bool,
194        no_pty: bool,
195    ) -> Self {
196        SshCommand {
197            cmd: cmd.into(),
198            cwd,
199            use_bash,
200            allow_error,
201            dry_run,
202            no_pty,
203        }
204    }
205
206    /// Helper for tests to get the command from this `SshCommand`.
207    #[cfg(any(test, feature = "test"))]
208    pub fn cmd(&self) -> &str {
209        &self.cmd
210    }
211}
212
213impl SshShell {
214    /// Returns a shell connected via the default private key at `$HOME/.ssh/id_rsa` to the given
215    /// SSH server as the given user.
216    ///
217    /// ```rust,ignore
218    /// SshShell::with_default_key("markm", "myhost:22")?;
219    /// ```
220    pub fn with_default_key<A: ToSocketAddrs + std::fmt::Debug>(
221        username: &str,
222        remote: A,
223    ) -> Result<Self, SshError> {
224        const DEFAULT_KEY_SUFFIX: &str = ".ssh/id_rsa";
225        let home = if let Some(home) = dirs::home_dir() {
226            home
227        } else {
228            return Err(SshError::KeyNotFound {
229                file: DEFAULT_KEY_SUFFIX.into(),
230            }
231            .into());
232        };
233
234        SshShell::with_key(username, remote, home.join(DEFAULT_KEY_SUFFIX))
235    }
236
237    /// Returns a shell connected via the first private key found at `$HOME/.ssh/` to the given
238    /// SSH server as the given user.
239    ///
240    /// ```rust,ignore
241    /// SshShell::with_any_key("markm", "myhost:22")?;
242    /// ```
243    pub fn with_any_key<A: Copy + ToSocketAddrs + std::fmt::Debug>(
244        username: &str,
245        remote: A,
246    ) -> Result<Self, SshError> {
247        const DEFAULT_KEY_DIR: &str = ".ssh/";
248        let home = if let Some(home) = dirs::home_dir() {
249            home
250        } else {
251            return Err(SshError::KeyNotFound {
252                file: DEFAULT_KEY_DIR.into(),
253            });
254        };
255        let key_dir = home.join(DEFAULT_KEY_DIR);
256
257        for entry in std::fs::read_dir(&key_dir)? {
258            let entry = entry?;
259            let name = entry.file_name().into_string().unwrap();
260
261            // To find the private keys, find the public keys then chop off ".pub"
262            if !name.ends_with(".pub") {
263                continue;
264            }
265
266            let (priv_key, _) = name.split_at(name.len() - 4);
267            let shell = SshShell::with_key(username, remote, key_dir.join(priv_key));
268
269            if shell.is_ok() {
270                return shell;
271            }
272        }
273
274        Err(SshError::KeyNotFound {
275            file: DEFAULT_KEY_DIR.into(),
276        })
277    }
278
279    /// Returns a shell connected via private key file `key` to the given SSH server as the given
280    /// user.
281    ///
282    /// ```rust,ignore
283    /// SshShell::with_key("markm", "myhost:22", "/home/foo/.ssh/id_rsa")?;
284    /// ```
285    pub fn with_key<A: ToSocketAddrs + std::fmt::Debug, P: AsRef<Path>>(
286        username: &str,
287        remote: A,
288        key: P,
289    ) -> Result<Self, SshError> {
290        info!("New SSH shell: {}@{:?}", username, remote);
291        debug!("Using key: {:?}", key.as_ref());
292
293        debug!("Create new TCP stream...");
294
295        // Create a TCP connection
296        let tcp = TcpStream::connect(&remote)?;
297        tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
298        tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
299        let remote_name = format!("{:?}", remote);
300        let remote = remote.to_socket_addrs().unwrap().next().unwrap();
301
302        debug!("Create new SSH session...");
303
304        // Start an SSH session
305        let mut sess = Session::new().unwrap();
306        sess.handshake(&tcp)?;
307        trace!("SSH session handshook.");
308        sess.userauth_pubkey_file(username, None, key.as_ref(), None)?;
309        if !sess.authenticated() {
310            return Err(SshError::AuthFailed {
311                key: key.as_ref().to_path_buf(),
312            }
313            .into());
314        }
315        trace!("SSH session authenticated.");
316
317        println!(
318            "{}",
319            console::style(format!("{}@{} ({})", username, remote_name, remote))
320                .green()
321                .bold()
322        );
323
324        Ok(SshShell {
325            tcp,
326            username: username.to_owned(),
327            key: key.as_ref().to_owned(),
328            remote_name,
329            remote,
330            sess: Arc::new(Mutex::new(sess)),
331            dry_run_mode: false,
332        })
333    }
334
335    /// Returns a new shell connected via the same credentials as the given existing host.
336    ///
337    /// ```rust,ignore
338    /// SshShell::from_existing(&existing_ssh_shell)?;
339    /// ```
340    pub fn from_existing(shell: &SshShell) -> Result<Self, SshError> {
341        info!("New SSH shell: {}@{:?}", shell.username, shell.remote);
342        debug!("Using key: {:?}", shell.key);
343
344        debug!("Create new TCP stream...");
345
346        // Create a TCP connection
347        let tcp = TcpStream::connect(&shell.remote)?;
348        tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
349        tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
350        let remote = shell.remote.clone();
351
352        debug!("Create new SSH session...");
353
354        // Start an SSH session
355        let mut sess = Session::new().unwrap();
356        sess.handshake(&tcp)?;
357        trace!("SSH session handshook.");
358        sess.userauth_pubkey_file(&shell.username, None, shell.key.as_ref(), None)?;
359        if !sess.authenticated() {
360            return Err(SshError::AuthFailed {
361                key: shell.key.clone(),
362            }
363            .into());
364        }
365        trace!("SSH session authenticated.");
366
367        println!(
368            "{}",
369            console::style(format!(
370                "{}@{} ({})",
371                shell.username, shell.remote_name, remote
372            ))
373            .green()
374            .bold()
375        );
376
377        Ok(SshShell {
378            tcp,
379            username: shell.username.clone(),
380            key: shell.key.clone(),
381            remote_name: shell.remote_name.clone(),
382            remote,
383            sess: Arc::new(Mutex::new(sess)),
384            dry_run_mode: false,
385        })
386    }
387
388    /// Toggles _dry run mode_. In dry run mode, commands are not executed remotely; we only print
389    /// what commands we would execute. Note that we do connect remotely, though. This is off by
390    /// default: we default to actually running the commands.
391    pub fn set_dry_run(&mut self, on: bool) {
392        self.dry_run_mode = on;
393        info!(
394            "Toggled dry run mode: {}",
395            if self.dry_run_mode { "on" } else { "off" }
396        );
397    }
398
399    pub fn spawn(&self, cmd: SshCommand) -> Result<SshSpawnHandle, SshError> {
400        debug!("spawn({:?})", cmd);
401        let shell = Self::from_existing(self)?;
402        let cmd = if self.dry_run_mode {
403            cmd.dry_run(true)
404        } else {
405            cmd
406        };
407
408        let thread_handle = std::thread::spawn(move || {
409            let result = shell.run(cmd);
410            (shell, result)
411        });
412
413        debug!("spawned thread for command.");
414
415        Ok(SshSpawnHandle { thread_handle })
416    }
417
418    fn run_with_chan_and_opts(
419        host_and_username: String, // for printing
420        mut chan: ssh2::Channel,
421        cmd_opts: SshCommand,
422    ) -> Result<SshOutput, SshError> {
423        debug!("run_with_chan_and_opts({:?})", cmd_opts);
424
425        let SshCommand {
426            cwd,
427            cmd,
428            use_bash,
429            allow_error,
430            dry_run,
431            no_pty,
432        } = cmd_opts;
433
434        // Print the raw command. We are going to modify it slightly before executing (e.g. to
435        // switch directories)
436        let msg = cmd.clone();
437
438        // Construct the commmand in the right directory and using bash if needed.
439        let cmd = if use_bash {
440            format!("bash -c {}", escape_for_bash(&cmd))
441        } else {
442            cmd
443        };
444
445        debug!("After shell escaping: {:?}", cmd);
446
447        let cmd = if let Some(cwd) = &cwd {
448            format!("cd {} ; {}", cwd.display(), cmd)
449        } else {
450            cmd
451        };
452
453        debug!("After cwd: {:?}", cmd);
454
455        // print message
456        if let Some(cwd) = cwd {
457            println!(
458                "{:-<80}\n{}\n{}\n{}",
459                "",
460                console::style(host_and_username).blue(),
461                console::style(cwd.display()).blue(),
462                console::style(msg).yellow().bold()
463            );
464        } else {
465            println!(
466                "{:-<80}\n{}\n{}",
467                "",
468                console::style(host_and_username).blue(),
469                console::style(msg).yellow().bold()
470            );
471        }
472
473        let mut stdout = String::new();
474        let mut stderr = String::new();
475
476        // If dry run, close and return early without actually doing anything.
477        if dry_run {
478            chan.close()?;
479            chan.wait_close()?;
480
481            debug!("Closed channel after dry run.");
482
483            return Ok(SshOutput { stdout, stderr });
484        }
485
486        // request a pty so that `sudo` commands work fine
487        if !no_pty {
488            chan.request_pty("vt100", None, None)?;
489            debug!("Requested pty.");
490        }
491
492        // execute cmd remotely
493        debug!("Execute command remotely (asynchronous)...");
494        chan.exec(&cmd)?;
495
496        trace!("Read stdout...");
497
498        // print stdout
499        let mut buf = [0; 256];
500        while chan.read(&mut buf)? > 0 {
501            let out = String::from_utf8_lossy(&buf);
502            let out = out.trim_end_matches('\u{0}');
503            print!("{}", out);
504            stdout.push_str(out);
505
506            // clear buf
507            buf.iter_mut().for_each(|x| *x = 0);
508        }
509
510        trace!("No more stdout.");
511
512        // close and wait for remote to close
513        chan.close()?;
514        chan.wait_close()?;
515
516        debug!("Command completed remotely.");
517
518        // clear buf
519        buf.iter_mut().for_each(|x| *x = 0);
520
521        trace!("Read stderr...");
522
523        // print stderr
524        while chan.stderr().read(&mut buf)? > 0 {
525            let err = String::from_utf8_lossy(&buf);
526            let err = err.trim_end_matches('\u{0}');
527            print!("{}", err);
528            stderr.push_str(err);
529
530            // clear buf
531            buf.iter_mut().for_each(|x| *x = 0);
532        }
533
534        trace!("No more stderr.");
535        debug!("Checking exit status.");
536
537        // check the exit status
538        let exit = chan.exit_status()?;
539        debug!("Exit status: {}", exit);
540        if exit != 0 && !allow_error {
541            return Err(SshError::NonZeroExit { cmd, exit }.into());
542        }
543
544        trace!("Done with command.");
545
546        // return output
547        Ok(SshOutput { stdout, stderr })
548    }
549}
550
551impl Execute for SshShell {
552    fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError> {
553        debug!("run(cmd)");
554        let sess = self.sess.lock().unwrap();
555        debug!("Attempt to crate channel...");
556        let chan = sess.channel_session()?;
557        debug!("Channel created.");
558        let host_and_username = format!("{}@{}", self.username, self.remote_name);
559        let cmd = if self.dry_run_mode {
560            cmd.dry_run(true)
561        } else {
562            cmd
563        };
564        Self::run_with_chan_and_opts(host_and_username, chan, cmd)
565    }
566
567    fn duplicate(&self) -> Result<Self, SshError> {
568        Self::from_existing(self)
569    }
570
571    fn reconnect(&mut self) -> Result<(), SshError> {
572        info!("Reconnect attempt.");
573
574        trace!("Attempt to create new TCP stream...");
575        loop {
576            print!("{}", console::style("Attempt Reconnect ... ").red());
577            match TcpStream::connect_timeout(&self.remote, DEFAULT_TIMEOUT / 2) {
578                Ok(tcp) => {
579                    self.tcp = tcp;
580                    break;
581                }
582                Err(e) => {
583                    trace!("{:?}", e);
584                    println!("{}", console::style("failed, retrying").red());
585                    std::thread::sleep(DEFAULT_TIMEOUT / 2);
586                }
587            }
588        }
589
590        println!(
591            "{}",
592            console::style("TCP connected, doing SSH handshake").red()
593        );
594
595        // Start an SSH session
596        debug!("Attempt to create new SSH session...");
597        let mut sess = Session::new().unwrap();
598        sess.handshake(&self.tcp)?;
599        trace!("Handshook!");
600        sess.userauth_pubkey_file(&self.username, None, self.key.as_ref(), None)?;
601        if !sess.authenticated() {
602            return Err(SshError::AuthFailed {
603                key: self.key.clone(),
604            }
605            .into());
606        }
607        trace!("authenticated!");
608
609        // It should be safe to `Arc::get_mut` here. `reconnect` takes `self` by mutable reference,
610        // so no other thread should have access (even immutably) to `self.sess`.
611        let self_sess = Arc::get_mut(&mut self.sess).unwrap().get_mut().unwrap();
612        let _old_sess = std::mem::replace(self_sess, sess);
613
614        println!(
615            "{}",
616            console::style(format!("{}@{}", self.username, self.remote))
617                .green()
618                .bold()
619        );
620
621        Ok(())
622    }
623}
624
625impl std::fmt::Debug for SshShell {
626    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
627        write!(
628            f,
629            "SshShell {{ {}@{:?} dry_run={} key={:?} }}",
630            self.username, self.remote, self.dry_run_mode, self.key
631        )
632    }
633}
634
635impl SshSpawnHandle {
636    /// Block until the remote command completes.
637    pub fn join(self) -> (SshShell, Result<SshOutput, SshError>) {
638        debug!("Blocking on spawned commmand.");
639        let ret = self.thread_handle.join().unwrap();
640        debug!("Spawned commmand complete.");
641        ret
642    }
643}
644
645impl std::fmt::Debug for SshSpawnHandle {
646    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
647        write!(f, "SshSpawnHandle {{ running }}")
648    }
649}
650
651/// A useful macro that allows creating commands with format strings and arguments.
652///
653/// ```rust,ignore
654/// cmd!("ls {}", "foo")
655/// ```
656///
657/// is equivalent to the expression
658///
659/// ```rust,ignore
660/// SshCommand::new(&format!("ls {}", "foo"))
661/// ```
662#[macro_export]
663macro_rules! cmd {
664    ($fmt:expr) => {
665        $crate::SshCommand::new(&format!($fmt))
666    };
667    ($fmt:expr, $($arg:tt)*) => {
668        $crate::SshCommand::new(&format!($fmt, $($arg)*))
669    };
670}
671
672/// Given a string, properly escape the string so that it can be passed as a command line argument
673/// to bash.
674///
675/// This is useful for passing commands to `bash -c` (e.g. through ssh).
676fn escape_for_bash(s: &str) -> String {
677    let mut new = String::with_capacity(s.len());
678
679    // Escape every non-alphanumeric character.
680    for c in s.chars() {
681        if c.is_ascii_alphanumeric() {
682            new.push(c);
683        } else {
684            new.push('\\');
685            new.push(c);
686        }
687    }
688
689    new
690}
691
692///////////////////////////////////////////////////////////////////////////////
693// Tests
694///////////////////////////////////////////////////////////////////////////////
695
696#[cfg(test)]
697mod test {
698    use crate::{cmd, SshCommand};
699
700    #[test]
701    fn test_cmd_macro() {
702        assert_eq!(cmd!("{} {}", "ls", 3), SshCommand::new("ls 3"));
703    }
704
705    mod test_escape_for_bash {
706        use super::super::escape_for_bash;
707
708        #[test]
709        fn simple() {
710            const TEST_STRING: &str = "ls";
711            assert_eq!(escape_for_bash(TEST_STRING), "ls");
712        }
713
714        #[test]
715        fn more_complex() {
716            use std::process::Command;
717
718            const TEST_STRING: &str =
719                r#""Bob?!", said she, "I though you said 'I can't be there'!""#;
720
721            let out = Command::new("bash")
722                .arg("-c")
723                .arg(&format!("echo {}", escape_for_bash(TEST_STRING)))
724                .output()
725                .unwrap();
726            let out = String::from_utf8(out.stdout).unwrap();
727
728            assert_eq!(out.trim(), TEST_STRING);
729        }
730    }
731}