ssh_vault/vault/
dio.rs

1use std::fs::{File, OpenOptions};
2use std::io::{self, IsTerminal, Read, Write};
3
4pub enum InputSource {
5    Stdin,
6    File(File),
7}
8
9impl InputSource {
10    pub fn new(input: Option<String>) -> io::Result<Self> {
11        if let Some(filename) = input {
12            // Use a file if the filename is not "-" (stdin)
13            if filename != "-" {
14                return Ok(Self::File(File::open(filename)?));
15            }
16        }
17
18        Ok(Self::Stdin)
19    }
20
21    #[must_use]
22    pub fn is_terminal(&self) -> bool {
23        matches!(self, Self::Stdin) && io::stdin().is_terminal()
24    }
25}
26
27impl Read for InputSource {
28    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
29        match self {
30            Self::Stdin => io::stdin().read(buf),
31            Self::File(file) => file.read(buf),
32        }
33    }
34}
35
36// OutputDestination is a wrapper around stdout or a temporary file
37pub enum OutputDestination {
38    Stdout,
39    File(File),
40}
41
42impl OutputDestination {
43    #[allow(clippy::suspicious_open_options)]
44    pub fn new(output: Option<String>) -> io::Result<Self> {
45        if let Some(filename) = output {
46            // Use a file if the filename is not "-" (stdout)
47            if filename != "-" {
48                return Ok(Self::File(
49                    OpenOptions::new().write(true).create(true).open(filename)?,
50                ));
51            }
52        }
53
54        Ok(Self::Stdout)
55    }
56
57    pub fn truncate(&self) -> io::Result<()> {
58        match self {
59            Self::File(file) => file.set_len(0),
60            Self::Stdout => Ok(()), // Do nothing for stdout
61        }
62    }
63
64    // Check if the output is empty, preventing overwriting a non-empty file
65    pub fn is_empty(&self) -> io::Result<bool> {
66        match self {
67            Self::File(file) => Ok(file.metadata().map(|m| m.len() == 0).unwrap_or(false)),
68            Self::Stdout => Ok(true), // Do nothing for stdout
69        }
70    }
71}
72
73impl Write for OutputDestination {
74    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
75        match self {
76            Self::Stdout => io::stdout().write(buf),
77            Self::File(file) => file.write(buf),
78        }
79    }
80
81    fn flush(&mut self) -> io::Result<()> {
82        match self {
83            Self::Stdout => io::stdout().flush(),
84            Self::File(file) => file.flush(),
85        }
86    }
87}
88
89pub fn setup_io(
90    input: Option<String>,
91    output: Option<String>,
92) -> io::Result<(InputSource, OutputDestination)> {
93    let input = InputSource::new(input)?;
94    let output = OutputDestination::new(output)?;
95
96    Ok((input, output))
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use tempfile::NamedTempFile;
103
104    #[test]
105    fn test_setup_io() {
106        if std::env::var("GITHUB_ACTIONS").is_ok() {
107            return;
108        }
109        let (input, output) = setup_io(None, None).unwrap();
110        assert!(input.is_terminal());
111        assert!(matches!(output, OutputDestination::Stdout));
112
113        let (input, output) = setup_io(Some("-".to_string()), None).unwrap();
114        assert!(input.is_terminal());
115        assert!(matches!(output, OutputDestination::Stdout));
116
117        let rs = setup_io(Some("noneexistent".to_string()), None);
118        assert!(rs.is_err());
119    }
120
121    #[test]
122    fn test_setup_io_file() {
123        let output_file = NamedTempFile::new().unwrap();
124
125        let (input, output) = setup_io(Some("Cargo.toml".to_string()), None).unwrap();
126        assert!(!input.is_terminal());
127        assert!(matches!(output, OutputDestination::Stdout));
128
129        let (input, output) =
130            setup_io(Some("Cargo.toml".to_string()), Some("-".to_string())).unwrap();
131        assert!(!input.is_terminal());
132        assert!(matches!(output, OutputDestination::Stdout));
133
134        let (input, output) = setup_io(
135            Some("Cargo.toml".to_string()),
136            Some(output_file.path().to_str().unwrap().to_string()),
137        )
138        .unwrap();
139        assert!(!input.is_terminal());
140        assert!(matches!(output, OutputDestination::File(_)));
141
142        // File is directory
143        let rs = setup_io(Some("Cargo.toml".to_string()), Some("/".to_string()));
144        assert!(rs.is_err());
145    }
146
147    #[test]
148    fn test_input_source() {
149        let mut input = InputSource::new(Some("Cargo.toml".to_string())).unwrap();
150        let mut buf = [0; 1024];
151        let n = input.read(&mut buf).unwrap();
152        assert!(n > 0);
153
154        let rs = InputSource::new(Some("noneexistent".to_string()));
155        assert!(rs.is_err());
156    }
157
158    #[test]
159    fn test_output_destination() {
160        let mut output = OutputDestination::new(Some("-".to_string())).unwrap();
161        let n = output.write(b"test").unwrap();
162        assert_eq!(n, 4);
163
164        let mut output = OutputDestination::new(None).unwrap();
165        let n = output.write(b"test").unwrap();
166        assert_eq!(n, 4);
167
168        let output_file = NamedTempFile::new().unwrap();
169        let mut output =
170            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
171        let n = output.write(b"test").unwrap();
172        assert_eq!(n, 4);
173    }
174
175    #[test]
176    fn test_output_destination_truncate() {
177        let mut output_file = NamedTempFile::new().unwrap();
178        let mut output =
179            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
180        let n = output.write(b"test").unwrap();
181        assert_eq!(n, 4);
182
183        output.truncate().unwrap();
184        let mut buf = [0; 1024];
185        let n = output_file.read(&mut buf).unwrap();
186        assert_eq!(n, 0);
187    }
188
189    #[test]
190    fn test_output_destination_is_empty() {
191        let output_file = NamedTempFile::new().unwrap();
192        let mut output =
193            OutputDestination::new(Some(output_file.path().to_str().unwrap().to_string())).unwrap();
194        let n = output.write(b"test").unwrap();
195        assert_eq!(n, 4);
196
197        let is_empty = output.is_empty().unwrap();
198        assert!(!is_empty);
199
200        output.truncate().unwrap();
201        let is_empty = output.is_empty().unwrap();
202        assert!(is_empty);
203    }
204}