1use std::fs::{File, OpenOptions};
2use std::io::{Read, Seek, SeekFrom};
3use std::os::unix::fs::FileExt;
4use std::path::Path;
5
6use crate::checksum::calc_sha1_stream;
7use crate::control::ControlFile;
8use crate::http::{
9 DEFAULT_RANGE_GAP_THRESHOLD, HttpClient, byte_ranges_from_block_ranges, merge_byte_ranges,
10};
11use crate::matcher::BlockMatcher;
12use crate::matcher::MatchError;
13
14const STREAM_CHUNK_SIZE: usize = 1024 * 1024;
15
16#[derive(Debug, thiserror::Error)]
17pub enum AssemblyError {
18 #[error("IO error: {0}")]
19 Io(#[from] std::io::Error),
20 #[error("HTTP error: {0}")]
21 Http(#[from] crate::http::HttpError),
22 #[error("Matcher error: {0}")]
23 Matcher(#[from] MatchError),
24 #[error("Control file error: {0}")]
25 Control(String),
26 #[error("Checksum mismatch: expected {expected}, got {actual}")]
27 ChecksumMismatch { expected: String, actual: String },
28 #[error("No URLs available")]
29 NoUrls,
30}
31
32pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
33
34pub struct ZsyncAssembly {
35 control: ControlFile,
36 base_url: Option<String>,
37 matcher: BlockMatcher,
38 http: HttpClient,
39 output_path: std::path::PathBuf,
40 temp_path: std::path::PathBuf,
41 file: Option<File>,
42 range_gap_threshold: u64,
43 progress_callback: Option<ProgressCallback>,
44}
45
46impl ZsyncAssembly {
47 pub fn new(control: ControlFile, output_path: &Path) -> Result<Self, AssemblyError> {
48 Self::with_base_url(control, output_path, None)
49 }
50
51 pub fn with_base_url(
52 control: ControlFile,
53 output_path: &Path,
54 base_url: Option<&str>,
55 ) -> Result<Self, AssemblyError> {
56 let matcher = BlockMatcher::new(&control);
57 let http = HttpClient::new();
58 let temp_path = output_path.with_extension("zsync-tmp");
59
60 Ok(Self {
61 control,
62 base_url: base_url.map(|s| s.to_string()),
63 matcher,
64 http,
65 output_path: output_path.to_path_buf(),
66 temp_path,
67 file: None,
68 range_gap_threshold: DEFAULT_RANGE_GAP_THRESHOLD,
69 progress_callback: None,
70 })
71 }
72
73 pub fn from_url(control_url: &str, output_path: &Path) -> Result<Self, AssemblyError> {
74 let http = HttpClient::new();
75 let control = http.fetch_control_file(control_url)?;
76 let base_url = extract_base_url(control_url);
77 Self::with_base_url(control, output_path, Some(&base_url))
78 }
79
80 pub fn set_range_gap_threshold(&mut self, threshold: u64) {
81 self.range_gap_threshold = threshold;
82 }
83
84 pub fn set_progress_callback<F>(&mut self, callback: F)
85 where
86 F: Fn(u64, u64) + Send + Sync + 'static,
87 {
88 self.progress_callback = Some(Box::new(callback));
89 }
90
91 fn report_progress(&self) {
92 if let Some(ref cb) = self.progress_callback {
93 let (done, total) = self.progress();
94 cb(done, total);
95 }
96 }
97
98 pub fn progress(&self) -> (u64, u64) {
99 let total = self.control.length;
100 let got = self.matcher.blocks_todo();
101 let blocks_done = self.matcher.total_blocks() - got;
102 let done_bytes = (blocks_done * self.control.blocksize) as u64;
103 (done_bytes.min(total), total)
104 }
105
106 pub fn is_complete(&self) -> bool {
107 self.matcher.is_complete()
108 }
109
110 pub fn block_stats(&self) -> (usize, usize) {
111 let total = self.matcher.total_blocks();
112 let todo = self.matcher.blocks_todo();
113 (total - todo, total)
114 }
115
116 pub fn submit_source_file(&mut self, path: &Path) -> Result<usize, AssemblyError> {
117 let file = File::open(path)?;
118 let file_size = file.metadata()?.len() as usize;
119
120 let blocksize = self.control.blocksize;
121 let context = blocksize * self.control.hash_lengths.seq_matches as usize;
122
123 if file_size < context {
124 return Ok(0);
125 }
126
127 let chunk_size = STREAM_CHUNK_SIZE.max(context * 2);
128 let mut total_matched = 0;
129 let mut buf = vec![0u8; chunk_size + 2 * context];
130 let mut file_offset = 0usize;
131
132 loop {
133 let overlap_start = file_offset.saturating_sub(context);
134 let overlap_len = file_offset - overlap_start;
135
136 if overlap_len > 0 {
137 file.read_at(&mut buf[..overlap_len], overlap_start as u64)?;
138 }
139
140 let read_start = overlap_len;
141 let read_len = chunk_size;
142
143 let bytes_read = file.read_at(
144 &mut buf[read_start..read_start + read_len],
145 file_offset as u64,
146 )?;
147 if bytes_read == 0 {
148 break;
149 }
150
151 let data_len = read_start + bytes_read;
152 let chunk_context = if file_offset + bytes_read < file_size {
153 let context_start = file_offset + bytes_read;
154 let context_available = file_size.saturating_sub(context_start).min(context);
155 file.read_at(
156 &mut buf[data_len..data_len + context_available],
157 context_start as u64,
158 )?;
159 if context_available < context {
160 buf[data_len + context_available..data_len + context].fill(0);
161 }
162 data_len + context
163 } else {
164 buf[data_len..data_len + context].fill(0);
165 data_len + context
166 };
167
168 let matched_blocks = self
169 .matcher
170 .submit_source_data(&buf[..chunk_context], overlap_start as u64);
171
172 for (block_id, source_offset) in &matched_blocks {
173 let file_handle = self.ensure_file()?;
174 let offset = (block_id * blocksize) as u64;
175 let buf_offset = source_offset.saturating_sub(overlap_start);
176 debug_assert!(
177 buf_offset + blocksize <= chunk_context,
178 "buf_offset {} + blocksize {} > chunk_context {} (source_offset={}, overlap_start={})",
179 buf_offset,
180 blocksize,
181 chunk_context,
182 source_offset,
183 overlap_start
184 );
185 let block_data = &buf[buf_offset..buf_offset + blocksize];
186 Self::write_at_offset(file_handle, block_data, offset)?;
187 }
188
189 total_matched += matched_blocks.len();
190 file_offset += bytes_read;
191
192 if bytes_read < read_len {
193 break;
194 }
195 }
196
197 Ok(total_matched)
198 }
199
200 pub fn submit_self_referential(&mut self) -> Result<usize, AssemblyError> {
201 if self.file.is_none() {
202 return Ok(0);
203 }
204
205 let file = self.file.as_mut().unwrap();
206 file.sync_all()?;
207
208 let file_size = file.metadata()?.len() as usize;
209
210 let blocksize = self.control.blocksize;
211 let context = blocksize * self.control.hash_lengths.seq_matches as usize;
212
213 if file_size < context {
214 return Ok(0);
215 }
216
217 let chunk_size = STREAM_CHUNK_SIZE.max(context * 2);
218 let mut total_matched = 0;
219 let mut buf = vec![0u8; chunk_size + 2 * context];
220 let mut file_offset = 0usize;
221
222 loop {
223 let overlap_start = file_offset.saturating_sub(context);
224 let overlap_len = file_offset - overlap_start;
225
226 if overlap_len > 0 {
227 file.read_at(&mut buf[..overlap_len], overlap_start as u64)?;
228 }
229
230 let read_start = overlap_len;
231 let read_len = chunk_size;
232
233 let bytes_read = file.read_at(
234 &mut buf[read_start..read_start + read_len],
235 file_offset as u64,
236 )?;
237 if bytes_read == 0 {
238 break;
239 }
240
241 let data_len = read_start + bytes_read;
242 let chunk_context = if file_offset + bytes_read < file_size {
243 let context_start = file_offset + bytes_read;
244 let context_available = file_size.saturating_sub(context_start).min(context);
245 file.read_at(
246 &mut buf[data_len..data_len + context_available],
247 context_start as u64,
248 )?;
249 if context_available < context {
250 buf[data_len + context_available..data_len + context].fill(0);
251 }
252 data_len + context
253 } else {
254 buf[data_len..data_len + context].fill(0);
255 data_len + context
256 };
257
258 let matched_blocks = self
259 .matcher
260 .submit_source_data(&buf[..chunk_context], overlap_start as u64);
261
262 for (block_id, source_offset) in &matched_blocks {
263 let offset = (block_id * blocksize) as u64;
264 let buf_offset = source_offset.saturating_sub(overlap_start);
265 debug_assert!(
266 buf_offset + blocksize <= chunk_context,
267 "buf_offset {} + blocksize {} > chunk_context {} (source_offset={}, overlap_start={})",
268 buf_offset,
269 blocksize,
270 chunk_context,
271 source_offset,
272 overlap_start
273 );
274 let block_data = &buf[buf_offset..buf_offset + blocksize];
275 Self::write_at_offset(file, block_data, offset)?;
276 }
277
278 total_matched += matched_blocks.len();
279 file_offset += bytes_read;
280
281 if bytes_read < read_len {
282 break;
283 }
284 }
285
286 Ok(total_matched)
287 }
288
289 fn write_at_offset(file: &File, data: &[u8], offset: u64) -> Result<(), AssemblyError> {
290 file.write_all_at(data, offset)?;
291 Ok(())
292 }
293
294 fn ensure_file(&mut self) -> Result<&mut File, AssemblyError> {
295 if self.file.is_none() {
296 let file = OpenOptions::new()
297 .read(true)
298 .write(true)
299 .create(true)
300 .truncate(false)
301 .open(&self.temp_path)?;
302 self.file = Some(file);
303 }
304 Ok(self.file.as_mut().unwrap())
305 }
306
307 pub fn download_missing_blocks(&mut self) -> Result<usize, AssemblyError> {
308 let relative_url = self
309 .control
310 .urls
311 .first()
312 .ok_or(AssemblyError::NoUrls)?
313 .clone();
314
315 let url = self
316 .base_url
317 .as_ref()
318 .map(|base| resolve_url(base, &relative_url))
319 .unwrap_or(relative_url);
320
321 let block_ranges = self.matcher.needed_block_ranges();
322 if block_ranges.is_empty() {
323 return Ok(0);
324 }
325
326 let byte_ranges = byte_ranges_from_block_ranges(
327 &block_ranges,
328 self.control.blocksize,
329 self.control.length,
330 );
331 let merged_ranges = merge_byte_ranges(&byte_ranges, self.range_gap_threshold);
332 let mut downloaded_blocks = 0;
333 let blocksize = self.control.blocksize;
334 let total_blocks = self.matcher.total_blocks();
335 let mut padded_buf = vec![0u8; blocksize];
336
337 for (range_start, range_end) in merged_ranges {
338 let mut reader = self.http.fetch_range_reader(&url, range_start, range_end)?;
339 let block_start = (range_start / blocksize as u64) as usize;
340 let initial_offset = (range_start % blocksize as u64) as usize;
341
342 let mut buf = vec![0u8; blocksize + 64 * 1024];
343 buf[..initial_offset].fill(0);
344 let mut buf_len = initial_offset;
345 let mut current_block_id = block_start;
346
347 let mut read_buf = [0u8; 64 * 1024];
348 loop {
349 let n = reader.read(&mut read_buf)?;
350 if n == 0 {
351 break;
352 }
353
354 if buf_len + n > buf.len() {
355 buf.resize(buf_len + n, 0);
356 }
357 buf[buf_len..buf_len + n].copy_from_slice(&read_buf[..n]);
358 buf_len += n;
359
360 while buf_len >= blocksize {
361 if current_block_id >= total_blocks {
362 break;
363 }
364
365 if !self.matcher.is_block_known(current_block_id) {
366 let block_data_end = if current_block_id == total_blocks - 1 {
367 let last_block_size = (self.control.length as usize) % blocksize;
368 if last_block_size == 0 {
369 blocksize
370 } else {
371 last_block_size
372 }
373 } else {
374 blocksize
375 };
376
377 let block_data = &buf[..block_data_end];
378 padded_buf[..block_data_end].copy_from_slice(block_data);
379 if block_data_end < blocksize {
380 padded_buf[block_data_end..].fill(0);
381 }
382
383 if self.matcher.submit_blocks(&padded_buf, current_block_id)? {
384 let file = self.ensure_file()?;
385 let file_offset = (current_block_id * blocksize) as u64;
386 Self::write_at_offset(file, block_data, file_offset)?;
387 downloaded_blocks += 1;
388 self.report_progress();
389 }
390 }
391
392 current_block_id += 1;
393 buf.copy_within(blocksize..buf_len, 0);
394 buf_len -= blocksize;
395 }
396 }
397
398 if buf_len > 0
399 && current_block_id < total_blocks
400 && !self.matcher.is_block_known(current_block_id)
401 {
402 let block_data = &buf[..buf_len];
403 padded_buf[..buf_len].copy_from_slice(block_data);
404 padded_buf[buf_len..].fill(0);
405
406 if self.matcher.submit_blocks(&padded_buf, current_block_id)? {
407 let file = self.ensure_file()?;
408 let file_offset = (current_block_id * blocksize) as u64;
409 Self::write_at_offset(file, block_data, file_offset)?;
410 downloaded_blocks += 1;
411 self.report_progress();
412 }
413 }
414 }
415
416 Ok(downloaded_blocks)
417 }
418
419 pub fn complete(mut self) -> Result<(), AssemblyError> {
420 if !self.matcher.is_complete() {
421 return Err(AssemblyError::Control(
422 "Not all blocks downloaded".to_string(),
423 ));
424 }
425
426 let file_length = self.control.length;
427 let expected_sha1 = self.control.sha1.clone();
428
429 let file = self.ensure_file()?;
430 file.set_len(file_length)?;
431
432 if let Some(ref expected) = expected_sha1 {
433 file.seek(SeekFrom::Start(0))?;
434 let actual_checksum = calc_sha1_stream(file)?;
435 let actual_hex = hex_encode(&actual_checksum);
436
437 if !actual_hex.eq_ignore_ascii_case(expected) {
438 return Err(AssemblyError::ChecksumMismatch {
439 expected: expected.clone(),
440 actual: actual_hex,
441 });
442 }
443 }
444
445 drop(self.file);
446 std::fs::rename(&self.temp_path, &self.output_path)?;
447
448 Ok(())
449 }
450
451 pub fn abort(self) {
452 let _ = std::fs::remove_file(&self.temp_path);
453 }
454}
455
456fn hex_encode(bytes: &[u8]) -> String {
457 bytes.iter().map(|b| format!("{:02x}", b)).collect()
458}
459
460fn extract_base_url(url: &str) -> String {
461 url.rfind('/')
462 .map(|i| url[..=i].to_string())
463 .unwrap_or_default()
464}
465
466fn resolve_url(base: &str, relative: &str) -> String {
467 if relative.contains("://") {
468 return relative.to_string();
469 }
470 if relative.starts_with('/') {
471 let scheme_end = base.find("://").map(|i| i + 3).unwrap_or(0);
472 let host_end = base[scheme_end..]
473 .find('/')
474 .map(|i| scheme_end + i)
475 .unwrap_or(base.len());
476 format!("{}{}", &base[..host_end], relative)
477 } else {
478 format!("{}{}", base, relative)
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_hex_encode() {
488 assert_eq!(hex_encode(&[0x00, 0xff, 0x10]), "00ff10");
489 }
490
491 #[test]
492 fn test_extract_base_url() {
493 assert_eq!(
494 extract_base_url("https://example.com/path/file.zsync"),
495 "https://example.com/path/"
496 );
497 assert_eq!(
498 extract_base_url("https://example.com/file.zsync"),
499 "https://example.com/"
500 );
501 }
502
503 #[test]
504 fn test_resolve_url() {
505 assert_eq!(
506 resolve_url("https://example.com/path/", "file.bin"),
507 "https://example.com/path/file.bin"
508 );
509 assert_eq!(
510 resolve_url("https://example.com/path/", "/file.bin"),
511 "https://example.com/file.bin"
512 );
513 assert_eq!(
514 resolve_url("https://example.com/path/", "https://other.com/file.bin"),
515 "https://other.com/file.bin"
516 );
517 }
518}