1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use crate::ffi::{close, dup, dup2, STDOUT_FILENO};

use std::fs::File;
use std::io;
use std::os::unix::io::IntoRawFd;
use std::os::unix::io::RawFd;
use std::os::unix::prelude::*;
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};

static IS_REPLACED: AtomicBool = AtomicBool::new(false);

const ORDERING: Ordering = Ordering::SeqCst;

/// A Guard over the Stdout change.
/// when this guard is dropped stdout will go back to the original,
/// and the file will be closed.
pub struct StdoutOverrideGuard {
    stdout_fd: RawFd,
    file_fd: RawFd,
}

/// Override the Stdout File Descriptor safely.
///
pub struct StdoutOverride;

impl StdoutOverride {
    /// Override the stdout by providing a path.
    /// This uses [`File::create`] so it will fail/succeed accordingly.
    ///
    /// [`File::create`]: https://doc.rust-lang.org/stable/std/fs/struct.File.html#method.create
    pub fn override_file<P: AsRef<Path>>(p: P) -> io::Result<StdoutOverrideGuard> {
        Self::check_override();

        let file = File::create(p)?;
        let file_fd = file.into_raw_fd();
        Self::override_fd(file_fd)
    }

    /// Override the stdout by providing something that can be turned into a file descriptor.
    /// This will accept Sockets, Files, and even Stdio's. [`AsRawFd`]
    ///
    /// [`AsRawFd`]: https://doc.rust-lang.org/stable/std/os/unix/io/trait.AsRawFd.html
    pub fn override_raw<FD: AsRawFd>(fd: FD) -> io::Result<StdoutOverrideGuard> {
        Self::check_override();

        let file_fd = fd.as_raw_fd();
        Self::override_fd(file_fd)
    }

    fn override_fd(file_fd: RawFd) -> io::Result<StdoutOverrideGuard> {
        let stdout_fd = unsafe { dup(STDOUT_FILENO) }?;
        let _ = unsafe { dup2(file_fd, STDOUT_FILENO) }?;

        IS_REPLACED.store(true, ORDERING);

        Ok(StdoutOverrideGuard { stdout_fd, file_fd })
    }

    fn check_override() {
        if IS_REPLACED.load(ORDERING) {
            panic!("Tried to override Stdout twice");
        }
    }
}

impl Drop for StdoutOverrideGuard {
    fn drop(&mut self) {
        // Ignoring syscalls errors seems to be the most sensible thing to do in a Drop impl
        // https://github.com/rust-lang/rust/blob/bd177f3e/src/libstd/sys/unix/fd.rs#L293-L302
        let _ = unsafe { dup2(self.stdout_fd, STDOUT_FILENO) };
        let _ = unsafe { close(self.file_fd) };
        IS_REPLACED.store(false, ORDERING);
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use std::{
        fs::{read_to_string, remove_file, File},
        io::{stdout, Read, Write},
        mem,
    };

    #[test]
    fn test_stdout() {
        let file_name = "./test1.txt";
        let data = "12345";
        let _ = remove_file(file_name);

        let guard = StdoutOverride::override_file(file_name).unwrap();
        print!("{}", data);
        stdout().flush().unwrap();
        mem::drop(guard);

        let contents = read_to_string(file_name).unwrap();
        assert_eq!(data, contents);
        println!("Outside!");

        remove_file(file_name).unwrap();
    }

    #[test]
    fn test_original() {
        let file_name = "./test2.txt";
        let _ = remove_file(file_name);

        let file = File::create(file_name).unwrap();
        let file = file.into_raw_fd();

        let real_stdout = unsafe { dup(STDOUT_FILENO) }.unwrap();

        unsafe { dup2(file, STDOUT_FILENO) }.unwrap();

        println!("Let's see where it's saved");
        let mut file = File::open(file_name).unwrap();

        let mut contents = String::new();
        file.read_to_string(&mut contents).unwrap();
        stdout().lock().flush().unwrap();
        unsafe { dup2(real_stdout, STDOUT_FILENO) }.unwrap();
        assert_eq!("Let\'s see where it\'s saved\n", contents);

        println!("got back");
        remove_file(file_name).unwrap();
    }
}