1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//! A couple crates that depend on tx5-core need to be able to write/verify
//! files on system. Enable this `file_check` feature to provide that ability.

use crate::{Error, Result};

/// A handle to a verified system file. Keep this instance in memory as
/// long as you intend to keep using the validated file.
pub struct FileCheck {
    path: std::path::PathBuf,
    _file: Option<std::fs::File>,
}

impl FileCheck {
    /// Get the path of the validated FileCheck file.
    pub fn path(&self) -> &std::path::Path {
        &self.path
    }
}

/// Write a file if needed, verify the file, and return a handle to that file.
pub fn file_check(
    file_data: &[u8],
    file_hash: &str,
    file_name_prefix: &str,
    file_name_ext: &str,
) -> Result<FileCheck> {
    let file_name = format!("{file_name_prefix}-{file_hash}{file_name_ext}");

    let mut pref_path =
        dirs::data_local_dir().expect("failed to get data_local_dir");
    pref_path.push(&file_name);

    if let Ok(file) = validate(&pref_path, file_hash) {
        return Ok(FileCheck {
            path: pref_path,
            _file: Some(file),
        });
    }

    let tmp = write(file_data)?;

    // NOTE: This is NOT atomic, nor secure, but being able to validate the
    //       file hash post-op mitigates this a bit. And we can let the os
    //       clean up a dangling tmp file if it failed to unlink.
    match tmp.persist_noclobber(&pref_path) {
        Ok(mut file) => {
            set_perms(&mut file)?;

            drop(file);

            let file = validate(&pref_path, file_hash)?;

            Ok(FileCheck {
                path: pref_path,
                _file: Some(file),
            })
        }
        Err(err) => {
            let tempfile::PersistError { file: tmp, .. } = err;

            // First, check to see if a different process wrote correctly
            if let Ok(file) = validate(&pref_path, file_hash) {
                // we no longer need the tmp file, clean it up
                let _ = tmp.close();

                return Ok(FileCheck {
                    path: pref_path,
                    _file: Some(file),
                });
            }

            // we're just going to use the tmp file, do what we need to
            // do to make sure it isn't deleted when the handle drops.

            let path = tmp.path().to_owned();
            let tmp = tmp.into_temp_path();

            // This seems wrong, but it is how tempfile internally goes
            // about doing persist/keep, so we're using it already,
            // and it's only once-ish per process...
            std::mem::forget(tmp);

            let file = validate(&path, file_hash)?;

            Ok(FileCheck {
                path,
                _file: Some(file),
            })
        }
    }
}

/// Validate a file.
fn validate(path: &std::path::Path, hash: &str) -> Result<std::fs::File> {
    use std::io::Read;

    let mut file = std::fs::OpenOptions::new().read(true).open(path)?;

    let mut data = Vec::new();
    file.read_to_end(&mut data).expect("failed to read lib");

    use sha2::Digest;
    let mut hasher = sha2::Sha256::new();
    hasher.update(data);
    let on_disk_hash =
        base64::encode_config(hasher.finalize(), base64::URL_SAFE_NO_PAD);

    if on_disk_hash != hash {
        return Err(Error::err(format!("FileCheckHashMiss({path:?})")));
    }

    let perms = file
        .metadata()
        .expect("failed to get lib metadata")
        .permissions();

    if !perms.readonly() {
        return Err(Error::err(format!("FileCheckNotReadonly({path:?})")));
    }

    tracing::trace!("success correct file_check: {path:?}");

    Ok(file)
}

/// Write a temp file.
fn write(file_data: &[u8]) -> Result<tempfile::NamedTempFile> {
    use std::io::Write;

    let mut tmp = tempfile::NamedTempFile::new()?;

    tmp.as_file_mut().write_all(file_data)?;
    tmp.as_file_mut().flush()?;

    set_perms(tmp.as_file_mut())?;

    Ok(tmp)
}

/// Set file permissions.
fn set_perms(file: &mut std::fs::File) -> Result<()> {
    let mut perms = file.metadata()?.permissions();

    perms.set_readonly(true);
    #[cfg(unix)]
    std::os::unix::fs::PermissionsExt::set_mode(&mut perms, 0o500);

    file.set_permissions(perms)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;

    #[tokio::test(flavor = "multi_thread")]
    async fn file_check_stress() {
        use rand::Rng;
        let mut data = vec![0; 1024 * 1024 * 10]; // 10 MiB
        rand::thread_rng().fill(&mut data[..]);
        let data = Arc::new(data);

        use sha2::Digest;
        let mut hasher = sha2::Sha256::new();
        hasher.update(&data[..]);
        let hash =
            base64::encode_config(hasher.finalize(), base64::URL_SAFE_NO_PAD);

        let mut task_list = Vec::new();

        const COUNT: usize = 3;

        let barrier = Arc::new(std::sync::Barrier::new(COUNT));

        for _ in 0..3 {
            let data = data.clone();
            let hash = hash.clone();
            let barrier = barrier.clone();
            task_list.push(tokio::task::spawn_blocking(move || {
                barrier.wait();

                file_check(
                    data.as_slice(),
                    &hash,
                    "tx5-core-file-check-test",
                    ".data",
                )
            }));
        }

        // make sure they're not dropped until the test is over
        let mut tmp = Vec::new();
        for task in task_list {
            tmp.push(task.await.unwrap().unwrap());
        }

        // cleanup
        for tmp in tmp {
            let path = tmp.path().to_owned();
            drop(tmp);
            let _ = std::fs::remove_file(&path);
        }
    }
}