Skip to main content

zsync_rs/
assembly.rs

1use std::fs::{File, OpenOptions};
2use std::io::{Read, Seek, SeekFrom};
3use std::os::unix::fs::FileExt;
4use std::path::Path;
5
6use crate::checksum::calc_sha1_stream;
7use crate::control::ControlFile;
8use crate::http::{
9    DEFAULT_RANGE_GAP_THRESHOLD, HttpClient, byte_ranges_from_block_ranges, merge_byte_ranges,
10};
11use crate::matcher::BlockMatcher;
12use crate::matcher::MatchError;
13
14const STREAM_CHUNK_SIZE: usize = 1024 * 1024;
15
16#[derive(Debug, thiserror::Error)]
17pub enum AssemblyError {
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20    #[error("HTTP error: {0}")]
21    Http(#[from] crate::http::HttpError),
22    #[error("Matcher error: {0}")]
23    Matcher(#[from] MatchError),
24    #[error("Control file error: {0}")]
25    Control(String),
26    #[error("Checksum mismatch: expected {expected}, got {actual}")]
27    ChecksumMismatch { expected: String, actual: String },
28    #[error("No URLs available")]
29    NoUrls,
30}
31
32pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
33
34pub struct ZsyncAssembly {
35    control: ControlFile,
36    base_url: Option<String>,
37    matcher: BlockMatcher,
38    http: HttpClient,
39    output_path: std::path::PathBuf,
40    temp_path: std::path::PathBuf,
41    file: Option<File>,
42    range_gap_threshold: u64,
43    progress_callback: Option<ProgressCallback>,
44}
45
46impl ZsyncAssembly {
47    pub fn new(control: ControlFile, output_path: &Path) -> Result<Self, AssemblyError> {
48        Self::with_base_url(control, output_path, None)
49    }
50
51    pub fn with_base_url(
52        control: ControlFile,
53        output_path: &Path,
54        base_url: Option<&str>,
55    ) -> Result<Self, AssemblyError> {
56        let matcher = BlockMatcher::new(&control);
57        let http = HttpClient::new();
58        let temp_path = output_path.with_extension("zsync-tmp");
59
60        Ok(Self {
61            control,
62            base_url: base_url.map(|s| s.to_string()),
63            matcher,
64            http,
65            output_path: output_path.to_path_buf(),
66            temp_path,
67            file: None,
68            range_gap_threshold: DEFAULT_RANGE_GAP_THRESHOLD,
69            progress_callback: None,
70        })
71    }
72
73    pub fn from_url(control_url: &str, output_path: &Path) -> Result<Self, AssemblyError> {
74        let http = HttpClient::new();
75        let control = http.fetch_control_file(control_url)?;
76        let base_url = extract_base_url(control_url);
77        Self::with_base_url(control, output_path, Some(&base_url))
78    }
79
80    pub fn set_range_gap_threshold(&mut self, threshold: u64) {
81        self.range_gap_threshold = threshold;
82    }
83
84    pub fn set_progress_callback<F>(&mut self, callback: F)
85    where
86        F: Fn(u64, u64) + Send + Sync + 'static,
87    {
88        self.progress_callback = Some(Box::new(callback));
89    }
90
91    fn report_progress(&self) {
92        if let Some(ref cb) = self.progress_callback {
93            let (done, total) = self.progress();
94            cb(done, total);
95        }
96    }
97
98    pub fn progress(&self) -> (u64, u64) {
99        let total = self.control.length;
100        let got = self.matcher.blocks_todo();
101        let blocks_done = self.matcher.total_blocks() - got;
102        let done_bytes = (blocks_done * self.control.blocksize) as u64;
103        (done_bytes.min(total), total)
104    }
105
106    pub fn is_complete(&self) -> bool {
107        self.matcher.is_complete()
108    }
109
110    pub fn block_stats(&self) -> (usize, usize) {
111        let total = self.matcher.total_blocks();
112        let todo = self.matcher.blocks_todo();
113        (total - todo, total)
114    }
115
116    pub fn submit_source_file(&mut self, path: &Path) -> Result<usize, AssemblyError> {
117        let file = File::open(path)?;
118        let file_size = file.metadata()?.len() as usize;
119
120        let blocksize = self.control.blocksize;
121        let context = blocksize * self.control.hash_lengths.seq_matches as usize;
122
123        if file_size < context {
124            return Ok(0);
125        }
126
127        let chunk_size = STREAM_CHUNK_SIZE.max(context * 2);
128        let mut total_matched = 0;
129        let mut buf = vec![0u8; chunk_size + 2 * context];
130        let mut file_offset = 0usize;
131
132        loop {
133            let overlap_start = file_offset.saturating_sub(context);
134            let overlap_len = file_offset - overlap_start;
135
136            if overlap_len > 0 {
137                file.read_at(&mut buf[..overlap_len], overlap_start as u64)?;
138            }
139
140            let read_start = overlap_len;
141            let read_len = chunk_size;
142
143            let bytes_read = file.read_at(
144                &mut buf[read_start..read_start + read_len],
145                file_offset as u64,
146            )?;
147            if bytes_read == 0 {
148                break;
149            }
150
151            let data_len = read_start + bytes_read;
152            let chunk_context = if file_offset + bytes_read < file_size {
153                let context_start = file_offset + bytes_read;
154                let context_available = file_size.saturating_sub(context_start).min(context);
155                file.read_at(
156                    &mut buf[data_len..data_len + context_available],
157                    context_start as u64,
158                )?;
159                if context_available < context {
160                    buf[data_len + context_available..data_len + context].fill(0);
161                }
162                data_len + context
163            } else {
164                buf[data_len..data_len + context].fill(0);
165                data_len + context
166            };
167
168            let matched_blocks = self
169                .matcher
170                .submit_source_data(&buf[..chunk_context], overlap_start as u64);
171
172            for (block_id, source_offset) in &matched_blocks {
173                let file_handle = self.ensure_file()?;
174                let offset = (block_id * blocksize) as u64;
175                let buf_offset = source_offset.saturating_sub(overlap_start);
176                debug_assert!(
177                    buf_offset + blocksize <= chunk_context,
178                    "buf_offset {} + blocksize {} > chunk_context {} (source_offset={}, overlap_start={})",
179                    buf_offset,
180                    blocksize,
181                    chunk_context,
182                    source_offset,
183                    overlap_start
184                );
185                let block_data = &buf[buf_offset..buf_offset + blocksize];
186                Self::write_at_offset(file_handle, block_data, offset)?;
187            }
188
189            total_matched += matched_blocks.len();
190            file_offset += bytes_read;
191
192            if bytes_read < read_len {
193                break;
194            }
195        }
196
197        Ok(total_matched)
198    }
199
200    pub fn submit_self_referential(&mut self) -> Result<usize, AssemblyError> {
201        if self.file.is_none() {
202            return Ok(0);
203        }
204
205        let file = self.file.as_mut().unwrap();
206        file.sync_all()?;
207
208        let file_size = file.metadata()?.len() as usize;
209
210        let blocksize = self.control.blocksize;
211        let context = blocksize * self.control.hash_lengths.seq_matches as usize;
212
213        if file_size < context {
214            return Ok(0);
215        }
216
217        let chunk_size = STREAM_CHUNK_SIZE.max(context * 2);
218        let mut total_matched = 0;
219        let mut buf = vec![0u8; chunk_size + 2 * context];
220        let mut file_offset = 0usize;
221
222        loop {
223            let overlap_start = file_offset.saturating_sub(context);
224            let overlap_len = file_offset - overlap_start;
225
226            if overlap_len > 0 {
227                file.read_at(&mut buf[..overlap_len], overlap_start as u64)?;
228            }
229
230            let read_start = overlap_len;
231            let read_len = chunk_size;
232
233            let bytes_read = file.read_at(
234                &mut buf[read_start..read_start + read_len],
235                file_offset as u64,
236            )?;
237            if bytes_read == 0 {
238                break;
239            }
240
241            let data_len = read_start + bytes_read;
242            let chunk_context = if file_offset + bytes_read < file_size {
243                let context_start = file_offset + bytes_read;
244                let context_available = file_size.saturating_sub(context_start).min(context);
245                file.read_at(
246                    &mut buf[data_len..data_len + context_available],
247                    context_start as u64,
248                )?;
249                if context_available < context {
250                    buf[data_len + context_available..data_len + context].fill(0);
251                }
252                data_len + context
253            } else {
254                buf[data_len..data_len + context].fill(0);
255                data_len + context
256            };
257
258            let matched_blocks = self
259                .matcher
260                .submit_source_data(&buf[..chunk_context], overlap_start as u64);
261
262            for (block_id, source_offset) in &matched_blocks {
263                let offset = (block_id * blocksize) as u64;
264                let buf_offset = source_offset.saturating_sub(overlap_start);
265                debug_assert!(
266                    buf_offset + blocksize <= chunk_context,
267                    "buf_offset {} + blocksize {} > chunk_context {} (source_offset={}, overlap_start={})",
268                    buf_offset,
269                    blocksize,
270                    chunk_context,
271                    source_offset,
272                    overlap_start
273                );
274                let block_data = &buf[buf_offset..buf_offset + blocksize];
275                Self::write_at_offset(file, block_data, offset)?;
276            }
277
278            total_matched += matched_blocks.len();
279            file_offset += bytes_read;
280
281            if bytes_read < read_len {
282                break;
283            }
284        }
285
286        Ok(total_matched)
287    }
288
289    fn write_at_offset(file: &File, data: &[u8], offset: u64) -> Result<(), AssemblyError> {
290        file.write_all_at(data, offset)?;
291        Ok(())
292    }
293
294    fn ensure_file(&mut self) -> Result<&mut File, AssemblyError> {
295        if self.file.is_none() {
296            let file = OpenOptions::new()
297                .read(true)
298                .write(true)
299                .create(true)
300                .truncate(false)
301                .open(&self.temp_path)?;
302            self.file = Some(file);
303        }
304        Ok(self.file.as_mut().unwrap())
305    }
306
307    pub fn download_missing_blocks(&mut self) -> Result<usize, AssemblyError> {
308        let relative_url = self
309            .control
310            .urls
311            .first()
312            .ok_or(AssemblyError::NoUrls)?
313            .clone();
314
315        let url = self
316            .base_url
317            .as_ref()
318            .map(|base| resolve_url(base, &relative_url))
319            .unwrap_or(relative_url);
320
321        let block_ranges = self.matcher.needed_block_ranges();
322        if block_ranges.is_empty() {
323            return Ok(0);
324        }
325
326        let byte_ranges = byte_ranges_from_block_ranges(
327            &block_ranges,
328            self.control.blocksize,
329            self.control.length,
330        );
331        let merged_ranges = merge_byte_ranges(&byte_ranges, self.range_gap_threshold);
332        let mut downloaded_blocks = 0;
333        let blocksize = self.control.blocksize;
334        let total_blocks = self.matcher.total_blocks();
335        let mut padded_buf = vec![0u8; blocksize];
336
337        for (range_start, range_end) in merged_ranges {
338            let mut reader = self.http.fetch_range_reader(&url, range_start, range_end)?;
339            let block_start = (range_start / blocksize as u64) as usize;
340            let initial_offset = (range_start % blocksize as u64) as usize;
341
342            let mut buf = vec![0u8; blocksize + 64 * 1024];
343            buf[..initial_offset].fill(0);
344            let mut buf_len = initial_offset;
345            let mut current_block_id = block_start;
346
347            let mut read_buf = [0u8; 64 * 1024];
348            loop {
349                let n = reader.read(&mut read_buf)?;
350                if n == 0 {
351                    break;
352                }
353
354                if buf_len + n > buf.len() {
355                    buf.resize(buf_len + n, 0);
356                }
357                buf[buf_len..buf_len + n].copy_from_slice(&read_buf[..n]);
358                buf_len += n;
359
360                while buf_len >= blocksize {
361                    if current_block_id >= total_blocks {
362                        break;
363                    }
364
365                    if !self.matcher.is_block_known(current_block_id) {
366                        let block_data_end = if current_block_id == total_blocks - 1 {
367                            let last_block_size = (self.control.length as usize) % blocksize;
368                            if last_block_size == 0 {
369                                blocksize
370                            } else {
371                                last_block_size
372                            }
373                        } else {
374                            blocksize
375                        };
376
377                        let block_data = &buf[..block_data_end];
378                        padded_buf[..block_data_end].copy_from_slice(block_data);
379                        if block_data_end < blocksize {
380                            padded_buf[block_data_end..].fill(0);
381                        }
382
383                        if self.matcher.submit_blocks(&padded_buf, current_block_id)? {
384                            let file = self.ensure_file()?;
385                            let file_offset = (current_block_id * blocksize) as u64;
386                            Self::write_at_offset(file, block_data, file_offset)?;
387                            downloaded_blocks += 1;
388                            self.report_progress();
389                        }
390                    }
391
392                    current_block_id += 1;
393                    buf.copy_within(blocksize..buf_len, 0);
394                    buf_len -= blocksize;
395                }
396            }
397
398            if buf_len > 0
399                && current_block_id < total_blocks
400                && !self.matcher.is_block_known(current_block_id)
401            {
402                let block_data = &buf[..buf_len];
403                padded_buf[..buf_len].copy_from_slice(block_data);
404                padded_buf[buf_len..].fill(0);
405
406                if self.matcher.submit_blocks(&padded_buf, current_block_id)? {
407                    let file = self.ensure_file()?;
408                    let file_offset = (current_block_id * blocksize) as u64;
409                    Self::write_at_offset(file, block_data, file_offset)?;
410                    downloaded_blocks += 1;
411                    self.report_progress();
412                }
413            }
414        }
415
416        Ok(downloaded_blocks)
417    }
418
419    pub fn complete(mut self) -> Result<(), AssemblyError> {
420        if !self.matcher.is_complete() {
421            return Err(AssemblyError::Control(
422                "Not all blocks downloaded".to_string(),
423            ));
424        }
425
426        let file_length = self.control.length;
427        let expected_sha1 = self.control.sha1.clone();
428
429        let file = self.ensure_file()?;
430        file.set_len(file_length)?;
431
432        if let Some(ref expected) = expected_sha1 {
433            file.seek(SeekFrom::Start(0))?;
434            let actual_checksum = calc_sha1_stream(file)?;
435            let actual_hex = hex_encode(&actual_checksum);
436
437            if !actual_hex.eq_ignore_ascii_case(expected) {
438                return Err(AssemblyError::ChecksumMismatch {
439                    expected: expected.clone(),
440                    actual: actual_hex,
441                });
442            }
443        }
444
445        drop(self.file);
446        std::fs::rename(&self.temp_path, &self.output_path)?;
447
448        Ok(())
449    }
450
451    pub fn abort(self) {
452        let _ = std::fs::remove_file(&self.temp_path);
453    }
454}
455
456fn hex_encode(bytes: &[u8]) -> String {
457    bytes.iter().map(|b| format!("{:02x}", b)).collect()
458}
459
460fn extract_base_url(url: &str) -> String {
461    url.rfind('/')
462        .map(|i| url[..=i].to_string())
463        .unwrap_or_default()
464}
465
466fn resolve_url(base: &str, relative: &str) -> String {
467    if relative.contains("://") {
468        return relative.to_string();
469    }
470    if relative.starts_with('/') {
471        let scheme_end = base.find("://").map(|i| i + 3).unwrap_or(0);
472        let host_end = base[scheme_end..]
473            .find('/')
474            .map(|i| scheme_end + i)
475            .unwrap_or(base.len());
476        format!("{}{}", &base[..host_end], relative)
477    } else {
478        format!("{}{}", base, relative)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_hex_encode() {
488        assert_eq!(hex_encode(&[0x00, 0xff, 0x10]), "00ff10");
489    }
490
491    #[test]
492    fn test_extract_base_url() {
493        assert_eq!(
494            extract_base_url("https://example.com/path/file.zsync"),
495            "https://example.com/path/"
496        );
497        assert_eq!(
498            extract_base_url("https://example.com/file.zsync"),
499            "https://example.com/"
500        );
501    }
502
503    #[test]
504    fn test_resolve_url() {
505        assert_eq!(
506            resolve_url("https://example.com/path/", "file.bin"),
507            "https://example.com/path/file.bin"
508        );
509        assert_eq!(
510            resolve_url("https://example.com/path/", "/file.bin"),
511            "https://example.com/file.bin"
512        );
513        assert_eq!(
514            resolve_url("https://example.com/path/", "https://other.com/file.bin"),
515            "https://other.com/file.bin"
516        );
517    }
518}