sqlx_core_oldapi/postgres/options/
pgpass.rs

1use std::borrow::Cow;
2use std::env::var_os;
3use std::fs::File;
4use std::io::{BufRead, BufReader};
5use std::path::PathBuf;
6
7/// try to load a password from the various pgpass file locations
8pub fn load_password(
9    host: &str,
10    port: u16,
11    username: &str,
12    database: Option<&str>,
13) -> Option<String> {
14    let custom_file = var_os("PGPASSFILE");
15    if let Some(file) = custom_file {
16        if let Some(password) =
17            load_password_from_file(PathBuf::from(file), host, port, username, database)
18        {
19            return Some(password);
20        }
21    }
22
23    #[cfg(not(target_os = "windows"))]
24    let default_file = dirs::home_dir().map(|path| path.join(".pgpass"));
25    #[cfg(target_os = "windows")]
26    let default_file = dirs::data_dir().map(|path| path.join("postgres").join("pgpass.conf"));
27    load_password_from_file(default_file?, host, port, username, database)
28}
29
30/// try to extract a password from a pgpass file
31fn load_password_from_file(
32    path: PathBuf,
33    host: &str,
34    port: u16,
35    username: &str,
36    database: Option<&str>,
37) -> Option<String> {
38    let file = File::open(&path).ok()?;
39
40    #[cfg(target_os = "linux")]
41    {
42        use std::os::unix::fs::PermissionsExt;
43
44        // check file permissions on linux
45
46        let metadata = file.metadata().ok()?;
47        let permissions = metadata.permissions();
48        let mode = permissions.mode();
49        if mode & 0o77 != 0 {
50            log::warn!(
51                "ignoring {}: permissions for not strict enough: {:o}",
52                path.to_string_lossy(),
53                mode
54            );
55            return None;
56        }
57    }
58
59    let reader = BufReader::new(file);
60    load_password_from_reader(reader, host, port, username, database)
61}
62
63fn load_password_from_reader(
64    mut reader: impl BufRead,
65    host: &str,
66    port: u16,
67    username: &str,
68    database: Option<&str>,
69) -> Option<String> {
70    let mut line = String::new();
71
72    // https://stackoverflow.com/a/55041833
73    fn trim_newline(s: &mut String) {
74        if s.ends_with('\n') {
75            s.pop();
76            if s.ends_with('\r') {
77                s.pop();
78            }
79        }
80    }
81
82    while let Ok(n) = reader.read_line(&mut line) {
83        if n == 0 {
84            break;
85        }
86
87        if line.starts_with('#') {
88            // comment, do nothing
89        } else {
90            // try to load password from line
91            trim_newline(&mut line);
92            if let Some(password) = load_password_from_line(&line, host, port, username, database) {
93                return Some(password);
94            }
95        }
96
97        line.clear();
98    }
99
100    None
101}
102
103/// try to check all fields & extract the password
104fn load_password_from_line(
105    mut line: &str,
106    host: &str,
107    port: u16,
108    username: &str,
109    database: Option<&str>,
110) -> Option<String> {
111    let whole_line = line;
112
113    // Pgpass line ordering: hostname, port, database, username, password
114    // See: https://www.postgresql.org/docs/9.3/libpq-pgpass.html
115    match line.trim_start().chars().next() {
116        None | Some('#') => None,
117        _ => {
118            matches_next_field(whole_line, &mut line, host)?;
119            matches_next_field(whole_line, &mut line, &port.to_string())?;
120            matches_next_field(whole_line, &mut line, database.unwrap_or_default())?;
121            matches_next_field(whole_line, &mut line, username)?;
122            Some(line.to_owned())
123        }
124    }
125}
126
127/// check if the next field matches the provided value
128fn matches_next_field(whole_line: &str, line: &mut &str, value: &str) -> Option<()> {
129    let field = find_next_field(line);
130    match field {
131        Some(field) => {
132            if field == "*" || field == value {
133                Some(())
134            } else {
135                None
136            }
137        }
138        None => {
139            log::warn!("Malformed line in pgpass file: {}", whole_line);
140            None
141        }
142    }
143}
144
145/// extract the next value from a line in a pgpass file
146///
147/// `line` will get updated to point behind the field and delimiter
148fn find_next_field<'a>(line: &mut &'a str) -> Option<Cow<'a, str>> {
149    let mut escaping = false;
150    let mut escaped_string = None;
151    let mut last_added = 0;
152
153    let char_indicies = line.char_indices();
154    for (idx, c) in char_indicies {
155        if c == ':' && !escaping {
156            let (field, rest) = line.split_at(idx);
157            *line = &rest[1..];
158
159            if let Some(mut escaped_string) = escaped_string {
160                escaped_string += &field[last_added..];
161                return Some(Cow::Owned(escaped_string));
162            } else {
163                return Some(Cow::Borrowed(field));
164            }
165        } else if c == '\\' {
166            let s = escaped_string.get_or_insert_with(String::new);
167
168            if escaping {
169                s.push('\\');
170            } else {
171                *s += &line[last_added..idx];
172            }
173
174            escaping = !escaping;
175            last_added = idx + 1;
176        } else {
177            escaping = false;
178        }
179    }
180
181    None
182}
183
184#[cfg(test)]
185mod tests {
186    use super::{find_next_field, load_password_from_line, load_password_from_reader};
187    use std::borrow::Cow;
188
189    #[test]
190    fn test_find_next_field() {
191        fn test_case<'a>(mut input: &'a str, result: Option<Cow<'a, str>>, rest: &str) {
192            assert_eq!(find_next_field(&mut input), result);
193            assert_eq!(input, rest);
194        }
195
196        // normal field
197        test_case("foo:bar:baz", Some(Cow::Borrowed("foo")), "bar:baz");
198        // \ escaped
199        test_case(
200            "foo\\\\:bar:baz",
201            Some(Cow::Owned("foo\\".to_owned())),
202            "bar:baz",
203        );
204        // : escaped
205        test_case(
206            "foo\\::bar:baz",
207            Some(Cow::Owned("foo:".to_owned())),
208            "bar:baz",
209        );
210        // unnecessary escape
211        test_case(
212            "foo\\a:bar:baz",
213            Some(Cow::Owned("fooa".to_owned())),
214            "bar:baz",
215        );
216        // other text after escape
217        test_case(
218            "foo\\\\a:bar:baz",
219            Some(Cow::Owned("foo\\a".to_owned())),
220            "bar:baz",
221        );
222        // double escape
223        test_case(
224            "foo\\\\\\\\a:bar:baz",
225            Some(Cow::Owned("foo\\\\a".to_owned())),
226            "bar:baz",
227        );
228        // utf8 support
229        test_case("🦀:bar:baz", Some(Cow::Borrowed("🦀")), "bar:baz");
230
231        // missing delimiter (eof)
232        test_case("foo", None, "foo");
233        // missing delimiter after escape
234        test_case("foo\\:", None, "foo\\:");
235        // missing delimiter after unused trailing escape
236        test_case("foo\\", None, "foo\\");
237    }
238
239    #[test]
240    fn test_load_password_from_line() {
241        // normal
242        assert_eq!(
243            load_password_from_line(
244                "localhost:5432:bar:foo:baz",
245                "localhost",
246                5432,
247                "foo",
248                Some("bar")
249            ),
250            Some("baz".to_owned())
251        );
252        // wildcard
253        assert_eq!(
254            load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")),
255            Some("baz".to_owned())
256        );
257        // accept wildcard with missing db
258        assert_eq!(
259            load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None),
260            Some("baz".to_owned())
261        );
262
263        // doesn't match
264        assert_eq!(
265            load_password_from_line(
266                "thishost:5432:bar:foo:baz",
267                "thathost",
268                5432,
269                "foo",
270                Some("bar")
271            ),
272            None
273        );
274        // malformed entry
275        assert_eq!(
276            load_password_from_line(
277                "localhost:5432:bar:foo",
278                "localhost",
279                5432,
280                "foo",
281                Some("bar")
282            ),
283            None
284        );
285    }
286
287    #[test]
288    fn test_load_password_from_reader() {
289        let file = b"\
290            localhost:5432:bar:foo:baz\n\
291            # mixed line endings (also a comment!)\n\
292            *:5432:bar:foo:baz\r\n\
293            # trailing space, comment with CRLF! \r\n\
294            thishost:5432:bar:foo:baz \n\
295            # malformed line \n\
296            thathost:5432:foobar:foo\n\
297            # missing trailing newline\n\
298            localhost:5432:*:foo:baz
299        ";
300
301        // normal
302        assert_eq!(
303            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")),
304            Some("baz".to_owned())
305        );
306        // wildcard
307        assert_eq!(
308            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")),
309            Some("baz".to_owned())
310        );
311        // accept wildcard with missing db
312        assert_eq!(
313            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None),
314            Some("baz".to_owned())
315        );
316
317        // doesn't match
318        assert_eq!(
319            load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
320            None
321        );
322        // malformed entry
323        assert_eq!(
324            load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
325            None
326        );
327    }
328}