sqlx_etorreborre_postgres/options/
pgpass.rs

1use std::borrow::Cow;
2use std::env::var_os;
3use std::fs::File;
4use std::io::{BufRead, BufReader};
5use std::path::PathBuf;
6
7/// try to load a password from the various pgpass file locations
8pub fn load_password(
9    host: &str,
10    port: u16,
11    username: &str,
12    database: Option<&str>,
13) -> Option<String> {
14    let custom_file = var_os("PGPASSFILE");
15    if let Some(file) = custom_file {
16        if let Some(password) =
17            load_password_from_file(PathBuf::from(file), host, port, username, database)
18        {
19            return Some(password);
20        }
21    }
22
23    #[cfg(not(target_os = "windows"))]
24    let default_file = home::home_dir().map(|path| path.join(".pgpass"));
25    #[cfg(target_os = "windows")]
26    let default_file = {
27        use etcetera::BaseStrategy;
28
29        etcetera::base_strategy::Windows::new()
30            .ok()
31            .map(|basedirs| basedirs.data_dir().join("postgres").join("pgpass.conf"))
32    };
33    load_password_from_file(default_file?, host, port, username, database)
34}
35
36/// try to extract a password from a pgpass file
37fn load_password_from_file(
38    path: PathBuf,
39    host: &str,
40    port: u16,
41    username: &str,
42    database: Option<&str>,
43) -> Option<String> {
44    let file = File::open(&path).ok()?;
45
46    #[cfg(target_os = "linux")]
47    {
48        use std::os::unix::fs::PermissionsExt;
49
50        // check file permissions on linux
51
52        let metadata = file.metadata().ok()?;
53        let permissions = metadata.permissions();
54        let mode = permissions.mode();
55        if mode & 0o77 != 0 {
56            tracing::warn!(
57                path = %path.to_string_lossy(),
58                permissions = format!("{mode:o}"),
59                "Ignoring path. Permissions are not strict enough",
60            );
61            return None;
62        }
63    }
64
65    let reader = BufReader::new(file);
66    load_password_from_reader(reader, host, port, username, database)
67}
68
69fn load_password_from_reader(
70    mut reader: impl BufRead,
71    host: &str,
72    port: u16,
73    username: &str,
74    database: Option<&str>,
75) -> Option<String> {
76    let mut line = String::new();
77
78    // https://stackoverflow.com/a/55041833
79    fn trim_newline(s: &mut String) {
80        if s.ends_with('\n') {
81            s.pop();
82            if s.ends_with('\r') {
83                s.pop();
84            }
85        }
86    }
87
88    while let Ok(n) = reader.read_line(&mut line) {
89        if n == 0 {
90            break;
91        }
92
93        if line.starts_with('#') {
94            // comment, do nothing
95        } else {
96            // try to load password from line
97            trim_newline(&mut line);
98            if let Some(password) = load_password_from_line(&line, host, port, username, database) {
99                return Some(password);
100            }
101        }
102
103        line.clear();
104    }
105
106    None
107}
108
109/// try to check all fields & extract the password
110fn load_password_from_line(
111    mut line: &str,
112    host: &str,
113    port: u16,
114    username: &str,
115    database: Option<&str>,
116) -> Option<String> {
117    let whole_line = line;
118
119    // Pgpass line ordering: hostname, port, database, username, password
120    // See: https://www.postgresql.org/docs/9.3/libpq-pgpass.html
121    match line.trim_start().chars().next() {
122        None | Some('#') => None,
123        _ => {
124            matches_next_field(whole_line, &mut line, host)?;
125            matches_next_field(whole_line, &mut line, &port.to_string())?;
126            matches_next_field(whole_line, &mut line, database.unwrap_or_default())?;
127            matches_next_field(whole_line, &mut line, username)?;
128            Some(line.to_owned())
129        }
130    }
131}
132
133/// check if the next field matches the provided value
134fn matches_next_field(whole_line: &str, line: &mut &str, value: &str) -> Option<()> {
135    let field = find_next_field(line);
136    match field {
137        Some(field) => {
138            if field == "*" || field == value {
139                Some(())
140            } else {
141                None
142            }
143        }
144        None => {
145            tracing::warn!(line = whole_line, "Malformed line in pgpass file");
146            None
147        }
148    }
149}
150
151/// extract the next value from a line in a pgpass file
152///
153/// `line` will get updated to point behind the field and delimiter
154fn find_next_field<'a>(line: &mut &'a str) -> Option<Cow<'a, str>> {
155    let mut escaping = false;
156    let mut escaped_string = None;
157    let mut last_added = 0;
158
159    let char_indicies = line.char_indices();
160    for (idx, c) in char_indicies {
161        if c == ':' && !escaping {
162            let (field, rest) = line.split_at(idx);
163            *line = &rest[1..];
164
165            if let Some(mut escaped_string) = escaped_string {
166                escaped_string += &field[last_added..];
167                return Some(Cow::Owned(escaped_string));
168            } else {
169                return Some(Cow::Borrowed(field));
170            }
171        } else if c == '\\' {
172            let s = escaped_string.get_or_insert_with(String::new);
173
174            if escaping {
175                s.push('\\');
176            } else {
177                *s += &line[last_added..idx];
178            }
179
180            escaping = !escaping;
181            last_added = idx + 1;
182        } else {
183            escaping = false;
184        }
185    }
186
187    return None;
188}
189
190#[cfg(test)]
191mod tests {
192    use super::{find_next_field, load_password_from_line, load_password_from_reader};
193    use std::borrow::Cow;
194
195    #[test]
196    fn test_find_next_field() {
197        fn test_case<'a>(mut input: &'a str, result: Option<Cow<'a, str>>, rest: &str) {
198            assert_eq!(find_next_field(&mut input), result);
199            assert_eq!(input, rest);
200        }
201
202        // normal field
203        test_case("foo:bar:baz", Some(Cow::Borrowed("foo")), "bar:baz");
204        // \ escaped
205        test_case(
206            "foo\\\\:bar:baz",
207            Some(Cow::Owned("foo\\".to_owned())),
208            "bar:baz",
209        );
210        // : escaped
211        test_case(
212            "foo\\::bar:baz",
213            Some(Cow::Owned("foo:".to_owned())),
214            "bar:baz",
215        );
216        // unnecessary escape
217        test_case(
218            "foo\\a:bar:baz",
219            Some(Cow::Owned("fooa".to_owned())),
220            "bar:baz",
221        );
222        // other text after escape
223        test_case(
224            "foo\\\\a:bar:baz",
225            Some(Cow::Owned("foo\\a".to_owned())),
226            "bar:baz",
227        );
228        // double escape
229        test_case(
230            "foo\\\\\\\\a:bar:baz",
231            Some(Cow::Owned("foo\\\\a".to_owned())),
232            "bar:baz",
233        );
234        // utf8 support
235        test_case("🦀:bar:baz", Some(Cow::Borrowed("🦀")), "bar:baz");
236
237        // missing delimiter (eof)
238        test_case("foo", None, "foo");
239        // missing delimiter after escape
240        test_case("foo\\:", None, "foo\\:");
241        // missing delimiter after unused trailing escape
242        test_case("foo\\", None, "foo\\");
243    }
244
245    #[test]
246    fn test_load_password_from_line() {
247        // normal
248        assert_eq!(
249            load_password_from_line(
250                "localhost:5432:bar:foo:baz",
251                "localhost",
252                5432,
253                "foo",
254                Some("bar")
255            ),
256            Some("baz".to_owned())
257        );
258        // wildcard
259        assert_eq!(
260            load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")),
261            Some("baz".to_owned())
262        );
263        // accept wildcard with missing db
264        assert_eq!(
265            load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None),
266            Some("baz".to_owned())
267        );
268
269        // doesn't match
270        assert_eq!(
271            load_password_from_line(
272                "thishost:5432:bar:foo:baz",
273                "thathost",
274                5432,
275                "foo",
276                Some("bar")
277            ),
278            None
279        );
280        // malformed entry
281        assert_eq!(
282            load_password_from_line(
283                "localhost:5432:bar:foo",
284                "localhost",
285                5432,
286                "foo",
287                Some("bar")
288            ),
289            None
290        );
291    }
292
293    #[test]
294    fn test_load_password_from_reader() {
295        let file = b"\
296            localhost:5432:bar:foo:baz\n\
297            # mixed line endings (also a comment!)\n\
298            *:5432:bar:foo:baz\r\n\
299            # trailing space, comment with CRLF! \r\n\
300            thishost:5432:bar:foo:baz \n\
301            # malformed line \n\
302            thathost:5432:foobar:foo\n\
303            # missing trailing newline\n\
304            localhost:5432:*:foo:baz
305        ";
306
307        // normal
308        assert_eq!(
309            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")),
310            Some("baz".to_owned())
311        );
312        // wildcard
313        assert_eq!(
314            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")),
315            Some("baz".to_owned())
316        );
317        // accept wildcard with missing db
318        assert_eq!(
319            load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None),
320            Some("baz".to_owned())
321        );
322
323        // doesn't match
324        assert_eq!(
325            load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
326            None
327        );
328        // malformed entry
329        assert_eq!(
330            load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
331            None
332        );
333    }
334}