Skip to main content

zsync_rs/
assembly.rs

1use std::fs::{File, OpenOptions};
2use std::io::{Read, Seek, SeekFrom};
3use std::os::unix::fs::FileExt;
4use std::path::Path;
5
6use crate::checksum::calc_sha1;
7use crate::control::ControlFile;
8use crate::http::{
9    DEFAULT_RANGE_GAP_THRESHOLD, HttpClient, byte_ranges_from_block_ranges, merge_byte_ranges,
10};
11use crate::matcher::BlockMatcher;
12use crate::matcher::MatchError;
13
14#[derive(Debug, thiserror::Error)]
15pub enum AssemblyError {
16    #[error("IO error: {0}")]
17    Io(#[from] std::io::Error),
18    #[error("HTTP error: {0}")]
19    Http(#[from] crate::http::HttpError),
20    #[error("Matcher error: {0}")]
21    Matcher(#[from] MatchError),
22    #[error("Control file error: {0}")]
23    Control(String),
24    #[error("Checksum mismatch: expected {expected}, got {actual}")]
25    ChecksumMismatch { expected: String, actual: String },
26    #[error("No URLs available")]
27    NoUrls,
28}
29
30pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
31
32pub struct ZsyncAssembly {
33    control: ControlFile,
34    base_url: Option<String>,
35    matcher: BlockMatcher,
36    http: HttpClient,
37    output_path: std::path::PathBuf,
38    temp_path: std::path::PathBuf,
39    file: Option<File>,
40    range_gap_threshold: u64,
41    progress_callback: Option<ProgressCallback>,
42}
43
44impl ZsyncAssembly {
45    pub fn new(control: ControlFile, output_path: &Path) -> Result<Self, AssemblyError> {
46        Self::with_base_url(control, output_path, None)
47    }
48
49    pub fn with_base_url(
50        control: ControlFile,
51        output_path: &Path,
52        base_url: Option<&str>,
53    ) -> Result<Self, AssemblyError> {
54        let matcher = BlockMatcher::new(&control);
55        let http = HttpClient::new();
56        let temp_path = output_path.with_extension("zsync-tmp");
57
58        Ok(Self {
59            control,
60            base_url: base_url.map(|s| s.to_string()),
61            matcher,
62            http,
63            output_path: output_path.to_path_buf(),
64            temp_path,
65            file: None,
66            range_gap_threshold: DEFAULT_RANGE_GAP_THRESHOLD,
67            progress_callback: None,
68        })
69    }
70
71    pub fn from_url(control_url: &str, output_path: &Path) -> Result<Self, AssemblyError> {
72        let http = HttpClient::new();
73        let control = http.fetch_control_file(control_url)?;
74        let base_url = extract_base_url(control_url);
75        Self::with_base_url(control, output_path, Some(&base_url))
76    }
77
78    pub fn set_range_gap_threshold(&mut self, threshold: u64) {
79        self.range_gap_threshold = threshold;
80    }
81
82    pub fn set_progress_callback<F>(&mut self, callback: F)
83    where
84        F: Fn(u64, u64) + Send + Sync + 'static,
85    {
86        self.progress_callback = Some(Box::new(callback));
87    }
88
89    fn report_progress(&self) {
90        if let Some(ref cb) = self.progress_callback {
91            let (done, total) = self.progress();
92            cb(done, total);
93        }
94    }
95
96    pub fn progress(&self) -> (u64, u64) {
97        let total = self.control.length;
98        let got = self.matcher.blocks_todo();
99        let blocks_done = self.matcher.total_blocks() - got;
100        let done_bytes = (blocks_done * self.control.blocksize) as u64;
101        (done_bytes.min(total), total)
102    }
103
104    pub fn is_complete(&self) -> bool {
105        self.matcher.is_complete()
106    }
107
108    pub fn block_stats(&self) -> (usize, usize) {
109        let total = self.matcher.total_blocks();
110        let todo = self.matcher.blocks_todo();
111        (total - todo, total)
112    }
113
114    pub fn submit_source_file(&mut self, path: &Path) -> Result<usize, AssemblyError> {
115        let mut file = File::open(path)?;
116        let mut buf = Vec::new();
117        file.read_to_end(&mut buf)?;
118
119        let blocksize = self.control.blocksize;
120        let context = blocksize * self.control.hash_lengths.seq_matches as usize;
121
122        if buf.len() < context {
123            return Ok(0);
124        }
125
126        // Zero-pad to allow scanning the last context bytes of the source
127        let original_len = buf.len();
128        buf.resize(original_len + context, 0);
129
130        let matched_blocks = self.matcher.submit_source_data(&buf, 0);
131
132        for (block_id, source_offset) in &matched_blocks {
133            let file_handle = self.ensure_file()?;
134            let offset = (block_id * blocksize) as u64;
135            let block_data = &buf[*source_offset..source_offset + blocksize];
136            Self::write_at_offset(file_handle, block_data, offset)?;
137        }
138
139        Ok(matched_blocks.len())
140    }
141
142    /// Scan the partially-assembled output for duplicate target blocks.
143    /// If the target has the same content at positions A and B, and A was matched
144    /// from a seed file, this finds B without downloading it.
145    pub fn submit_self_referential(&mut self) -> Result<usize, AssemblyError> {
146        if self.file.is_none() {
147            return Ok(0);
148        }
149
150        // Flush writes and read the temp file
151        let file = self.file.as_mut().unwrap();
152        file.sync_all()?;
153
154        let mut buf = Vec::new();
155        file.seek(SeekFrom::Start(0))?;
156        file.read_to_end(&mut buf)?;
157
158        let blocksize = self.control.blocksize;
159        let context = blocksize * self.control.hash_lengths.seq_matches as usize;
160
161        if buf.len() < context {
162            return Ok(0);
163        }
164
165        let original_len = buf.len();
166        buf.resize(original_len + context, 0);
167
168        let matched_blocks = self.matcher.submit_source_data(&buf, 0);
169
170        for (block_id, source_offset) in &matched_blocks {
171            let offset = (block_id * blocksize) as u64;
172            let block_data = &buf[*source_offset..source_offset + blocksize];
173            Self::write_at_offset(self.file.as_ref().unwrap(), block_data, offset)?;
174        }
175
176        Ok(matched_blocks.len())
177    }
178
179    fn write_at_offset(file: &File, data: &[u8], offset: u64) -> Result<(), AssemblyError> {
180        file.write_all_at(data, offset)?;
181        Ok(())
182    }
183
184    fn ensure_file(&mut self) -> Result<&mut File, AssemblyError> {
185        if self.file.is_none() {
186            let file = OpenOptions::new()
187                .read(true)
188                .write(true)
189                .create(true)
190                .truncate(false)
191                .open(&self.temp_path)?;
192            self.file = Some(file);
193        }
194        Ok(self.file.as_mut().unwrap())
195    }
196
197    pub fn download_missing_blocks(&mut self) -> Result<usize, AssemblyError> {
198        let relative_url = self
199            .control
200            .urls
201            .first()
202            .ok_or(AssemblyError::NoUrls)?
203            .clone();
204
205        let url = self
206            .base_url
207            .as_ref()
208            .map(|base| resolve_url(base, &relative_url))
209            .unwrap_or(relative_url);
210
211        let block_ranges = self.matcher.needed_block_ranges();
212        if block_ranges.is_empty() {
213            return Ok(0);
214        }
215
216        let byte_ranges = byte_ranges_from_block_ranges(
217            &block_ranges,
218            self.control.blocksize,
219            self.control.length,
220        );
221        let merged_ranges = merge_byte_ranges(&byte_ranges, self.range_gap_threshold);
222        let mut downloaded_blocks = 0;
223        let blocksize = self.control.blocksize;
224        let mut padded_buf = vec![0u8; blocksize];
225
226        for (start, end) in merged_ranges {
227            let data = self.http.fetch_range(&url, start, end)?;
228
229            let block_start = (start / blocksize as u64) as usize;
230            let total_blocks = self.matcher.total_blocks();
231            let num_blocks = data.len().div_ceil(blocksize);
232
233            for i in 0..num_blocks {
234                let block_id = block_start + i;
235                if block_id >= total_blocks {
236                    break;
237                }
238
239                if self.matcher.is_block_known(block_id) {
240                    continue;
241                }
242
243                let block_offset = i * blocksize;
244                let block_end = std::cmp::min(block_offset + blocksize, data.len());
245                let block_data = &data[block_offset..block_end];
246
247                padded_buf[..block_data.len()].copy_from_slice(block_data);
248                if block_data.len() < blocksize {
249                    padded_buf[block_data.len()..].fill(0);
250                }
251
252                if self.matcher.submit_blocks(&padded_buf, block_id)? {
253                    let file = self.ensure_file()?;
254                    let offset = (block_id * blocksize) as u64;
255                    Self::write_at_offset(file, block_data, offset)?;
256                    downloaded_blocks += 1;
257                    self.report_progress();
258                }
259            }
260        }
261
262        Ok(downloaded_blocks)
263    }
264
265    pub fn complete(mut self) -> Result<(), AssemblyError> {
266        if !self.matcher.is_complete() {
267            return Err(AssemblyError::Control(
268                "Not all blocks downloaded".to_string(),
269            ));
270        }
271
272        let file_length = self.control.length;
273        let expected_sha1 = self.control.sha1.clone();
274
275        let file = self.ensure_file()?;
276        file.set_len(file_length)?;
277
278        if let Some(ref expected) = expected_sha1 {
279            file.seek(SeekFrom::Start(0))?;
280            let mut buf = Vec::new();
281            file.read_to_end(&mut buf)?;
282
283            let actual_checksum = calc_sha1(&buf);
284            let actual_hex = hex_encode(&actual_checksum);
285
286            if !actual_hex.eq_ignore_ascii_case(expected) {
287                return Err(AssemblyError::ChecksumMismatch {
288                    expected: expected.clone(),
289                    actual: actual_hex,
290                });
291            }
292        }
293
294        drop(self.file);
295        std::fs::rename(&self.temp_path, &self.output_path)?;
296
297        Ok(())
298    }
299
300    pub fn abort(self) {
301        let _ = std::fs::remove_file(&self.temp_path);
302    }
303}
304
305fn hex_encode(bytes: &[u8]) -> String {
306    bytes.iter().map(|b| format!("{:02x}", b)).collect()
307}
308
309fn extract_base_url(url: &str) -> String {
310    url.rfind('/')
311        .map(|i| url[..=i].to_string())
312        .unwrap_or_default()
313}
314
315fn resolve_url(base: &str, relative: &str) -> String {
316    if relative.contains("://") {
317        return relative.to_string();
318    }
319    if relative.starts_with('/') {
320        let scheme_end = base.find("://").map(|i| i + 3).unwrap_or(0);
321        let host_end = base[scheme_end..]
322            .find('/')
323            .map(|i| scheme_end + i)
324            .unwrap_or(base.len());
325        format!("{}{}", &base[..host_end], relative)
326    } else {
327        format!("{}{}", base, relative)
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_hex_encode() {
337        assert_eq!(hex_encode(&[0x00, 0xff, 0x10]), "00ff10");
338    }
339
340    #[test]
341    fn test_extract_base_url() {
342        assert_eq!(
343            extract_base_url("https://example.com/path/file.zsync"),
344            "https://example.com/path/"
345        );
346        assert_eq!(
347            extract_base_url("https://example.com/file.zsync"),
348            "https://example.com/"
349        );
350    }
351
352    #[test]
353    fn test_resolve_url() {
354        assert_eq!(
355            resolve_url("https://example.com/path/", "file.bin"),
356            "https://example.com/path/file.bin"
357        );
358        assert_eq!(
359            resolve_url("https://example.com/path/", "/file.bin"),
360            "https://example.com/file.bin"
361        );
362        assert_eq!(
363            resolve_url("https://example.com/path/", "https://other.com/file.bin"),
364            "https://other.com/file.bin"
365        );
366    }
367}