1use std::fs::{File, OpenOptions};
2use std::io::{Read, Seek, SeekFrom};
3use std::os::unix::fs::FileExt;
4use std::path::Path;
5
6use crate::checksum::calc_sha1;
7use crate::control::ControlFile;
8use crate::http::{
9 DEFAULT_RANGE_GAP_THRESHOLD, HttpClient, byte_ranges_from_block_ranges, merge_byte_ranges,
10};
11use crate::matcher::BlockMatcher;
12use crate::matcher::MatchError;
13
14#[derive(Debug, thiserror::Error)]
15pub enum AssemblyError {
16 #[error("IO error: {0}")]
17 Io(#[from] std::io::Error),
18 #[error("HTTP error: {0}")]
19 Http(#[from] crate::http::HttpError),
20 #[error("Matcher error: {0}")]
21 Matcher(#[from] MatchError),
22 #[error("Control file error: {0}")]
23 Control(String),
24 #[error("Checksum mismatch: expected {expected}, got {actual}")]
25 ChecksumMismatch { expected: String, actual: String },
26 #[error("No URLs available")]
27 NoUrls,
28}
29
30pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
31
32pub struct ZsyncAssembly {
33 control: ControlFile,
34 base_url: Option<String>,
35 matcher: BlockMatcher,
36 http: HttpClient,
37 output_path: std::path::PathBuf,
38 temp_path: std::path::PathBuf,
39 file: Option<File>,
40 range_gap_threshold: u64,
41 progress_callback: Option<ProgressCallback>,
42}
43
44impl ZsyncAssembly {
45 pub fn new(control: ControlFile, output_path: &Path) -> Result<Self, AssemblyError> {
46 Self::with_base_url(control, output_path, None)
47 }
48
49 pub fn with_base_url(
50 control: ControlFile,
51 output_path: &Path,
52 base_url: Option<&str>,
53 ) -> Result<Self, AssemblyError> {
54 let matcher = BlockMatcher::new(&control);
55 let http = HttpClient::new();
56 let temp_path = output_path.with_extension("zsync-tmp");
57
58 Ok(Self {
59 control,
60 base_url: base_url.map(|s| s.to_string()),
61 matcher,
62 http,
63 output_path: output_path.to_path_buf(),
64 temp_path,
65 file: None,
66 range_gap_threshold: DEFAULT_RANGE_GAP_THRESHOLD,
67 progress_callback: None,
68 })
69 }
70
71 pub fn from_url(control_url: &str, output_path: &Path) -> Result<Self, AssemblyError> {
72 let http = HttpClient::new();
73 let control = http.fetch_control_file(control_url)?;
74 let base_url = extract_base_url(control_url);
75 Self::with_base_url(control, output_path, Some(&base_url))
76 }
77
78 pub fn set_range_gap_threshold(&mut self, threshold: u64) {
79 self.range_gap_threshold = threshold;
80 }
81
82 pub fn set_progress_callback<F>(&mut self, callback: F)
83 where
84 F: Fn(u64, u64) + Send + Sync + 'static,
85 {
86 self.progress_callback = Some(Box::new(callback));
87 }
88
89 fn report_progress(&self) {
90 if let Some(ref cb) = self.progress_callback {
91 let (done, total) = self.progress();
92 cb(done, total);
93 }
94 }
95
96 pub fn progress(&self) -> (u64, u64) {
97 let total = self.control.length;
98 let got = self.matcher.blocks_todo();
99 let blocks_done = self.matcher.total_blocks() - got;
100 let done_bytes = (blocks_done * self.control.blocksize) as u64;
101 (done_bytes.min(total), total)
102 }
103
104 pub fn is_complete(&self) -> bool {
105 self.matcher.is_complete()
106 }
107
108 pub fn block_stats(&self) -> (usize, usize) {
109 let total = self.matcher.total_blocks();
110 let todo = self.matcher.blocks_todo();
111 (total - todo, total)
112 }
113
114 pub fn submit_source_file(&mut self, path: &Path) -> Result<usize, AssemblyError> {
115 let mut file = File::open(path)?;
116 let mut buf = Vec::new();
117 file.read_to_end(&mut buf)?;
118
119 let blocksize = self.control.blocksize;
120 let context = blocksize * self.control.hash_lengths.seq_matches as usize;
121
122 if buf.len() < context {
123 return Ok(0);
124 }
125
126 let original_len = buf.len();
128 buf.resize(original_len + context, 0);
129
130 let matched_blocks = self.matcher.submit_source_data(&buf, 0);
131
132 for (block_id, source_offset) in &matched_blocks {
133 let file_handle = self.ensure_file()?;
134 let offset = (block_id * blocksize) as u64;
135 let block_data = &buf[*source_offset..source_offset + blocksize];
136 Self::write_at_offset(file_handle, block_data, offset)?;
137 }
138
139 Ok(matched_blocks.len())
140 }
141
142 pub fn submit_self_referential(&mut self) -> Result<usize, AssemblyError> {
146 if self.file.is_none() {
147 return Ok(0);
148 }
149
150 let file = self.file.as_mut().unwrap();
152 file.sync_all()?;
153
154 let mut buf = Vec::new();
155 file.seek(SeekFrom::Start(0))?;
156 file.read_to_end(&mut buf)?;
157
158 let blocksize = self.control.blocksize;
159 let context = blocksize * self.control.hash_lengths.seq_matches as usize;
160
161 if buf.len() < context {
162 return Ok(0);
163 }
164
165 let original_len = buf.len();
166 buf.resize(original_len + context, 0);
167
168 let matched_blocks = self.matcher.submit_source_data(&buf, 0);
169
170 for (block_id, source_offset) in &matched_blocks {
171 let offset = (block_id * blocksize) as u64;
172 let block_data = &buf[*source_offset..source_offset + blocksize];
173 Self::write_at_offset(self.file.as_ref().unwrap(), block_data, offset)?;
174 }
175
176 Ok(matched_blocks.len())
177 }
178
179 fn write_at_offset(file: &File, data: &[u8], offset: u64) -> Result<(), AssemblyError> {
180 file.write_all_at(data, offset)?;
181 Ok(())
182 }
183
184 fn ensure_file(&mut self) -> Result<&mut File, AssemblyError> {
185 if self.file.is_none() {
186 let file = OpenOptions::new()
187 .read(true)
188 .write(true)
189 .create(true)
190 .truncate(false)
191 .open(&self.temp_path)?;
192 self.file = Some(file);
193 }
194 Ok(self.file.as_mut().unwrap())
195 }
196
197 pub fn download_missing_blocks(&mut self) -> Result<usize, AssemblyError> {
198 let relative_url = self
199 .control
200 .urls
201 .first()
202 .ok_or(AssemblyError::NoUrls)?
203 .clone();
204
205 let url = self
206 .base_url
207 .as_ref()
208 .map(|base| resolve_url(base, &relative_url))
209 .unwrap_or(relative_url);
210
211 let block_ranges = self.matcher.needed_block_ranges();
212 if block_ranges.is_empty() {
213 return Ok(0);
214 }
215
216 let byte_ranges = byte_ranges_from_block_ranges(
217 &block_ranges,
218 self.control.blocksize,
219 self.control.length,
220 );
221 let merged_ranges = merge_byte_ranges(&byte_ranges, self.range_gap_threshold);
222 let mut downloaded_blocks = 0;
223 let blocksize = self.control.blocksize;
224 let mut padded_buf = vec![0u8; blocksize];
225
226 for (start, end) in merged_ranges {
227 let data = self.http.fetch_range(&url, start, end)?;
228
229 let block_start = (start / blocksize as u64) as usize;
230 let total_blocks = self.matcher.total_blocks();
231 let num_blocks = data.len().div_ceil(blocksize);
232
233 for i in 0..num_blocks {
234 let block_id = block_start + i;
235 if block_id >= total_blocks {
236 break;
237 }
238
239 if self.matcher.is_block_known(block_id) {
240 continue;
241 }
242
243 let block_offset = i * blocksize;
244 let block_end = std::cmp::min(block_offset + blocksize, data.len());
245 let block_data = &data[block_offset..block_end];
246
247 padded_buf[..block_data.len()].copy_from_slice(block_data);
248 if block_data.len() < blocksize {
249 padded_buf[block_data.len()..].fill(0);
250 }
251
252 if self.matcher.submit_blocks(&padded_buf, block_id)? {
253 let file = self.ensure_file()?;
254 let offset = (block_id * blocksize) as u64;
255 Self::write_at_offset(file, block_data, offset)?;
256 downloaded_blocks += 1;
257 self.report_progress();
258 }
259 }
260 }
261
262 Ok(downloaded_blocks)
263 }
264
265 pub fn complete(mut self) -> Result<(), AssemblyError> {
266 if !self.matcher.is_complete() {
267 return Err(AssemblyError::Control(
268 "Not all blocks downloaded".to_string(),
269 ));
270 }
271
272 let file_length = self.control.length;
273 let expected_sha1 = self.control.sha1.clone();
274
275 let file = self.ensure_file()?;
276 file.set_len(file_length)?;
277
278 if let Some(ref expected) = expected_sha1 {
279 file.seek(SeekFrom::Start(0))?;
280 let mut buf = Vec::new();
281 file.read_to_end(&mut buf)?;
282
283 let actual_checksum = calc_sha1(&buf);
284 let actual_hex = hex_encode(&actual_checksum);
285
286 if !actual_hex.eq_ignore_ascii_case(expected) {
287 return Err(AssemblyError::ChecksumMismatch {
288 expected: expected.clone(),
289 actual: actual_hex,
290 });
291 }
292 }
293
294 drop(self.file);
295 std::fs::rename(&self.temp_path, &self.output_path)?;
296
297 Ok(())
298 }
299
300 pub fn abort(self) {
301 let _ = std::fs::remove_file(&self.temp_path);
302 }
303}
304
305fn hex_encode(bytes: &[u8]) -> String {
306 bytes.iter().map(|b| format!("{:02x}", b)).collect()
307}
308
309fn extract_base_url(url: &str) -> String {
310 url.rfind('/')
311 .map(|i| url[..=i].to_string())
312 .unwrap_or_default()
313}
314
315fn resolve_url(base: &str, relative: &str) -> String {
316 if relative.contains("://") {
317 return relative.to_string();
318 }
319 if relative.starts_with('/') {
320 let scheme_end = base.find("://").map(|i| i + 3).unwrap_or(0);
321 let host_end = base[scheme_end..]
322 .find('/')
323 .map(|i| scheme_end + i)
324 .unwrap_or(base.len());
325 format!("{}{}", &base[..host_end], relative)
326 } else {
327 format!("{}{}", base, relative)
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_hex_encode() {
337 assert_eq!(hex_encode(&[0x00, 0xff, 0x10]), "00ff10");
338 }
339
340 #[test]
341 fn test_extract_base_url() {
342 assert_eq!(
343 extract_base_url("https://example.com/path/file.zsync"),
344 "https://example.com/path/"
345 );
346 assert_eq!(
347 extract_base_url("https://example.com/file.zsync"),
348 "https://example.com/"
349 );
350 }
351
352 #[test]
353 fn test_resolve_url() {
354 assert_eq!(
355 resolve_url("https://example.com/path/", "file.bin"),
356 "https://example.com/path/file.bin"
357 );
358 assert_eq!(
359 resolve_url("https://example.com/path/", "/file.bin"),
360 "https://example.com/file.bin"
361 );
362 assert_eq!(
363 resolve_url("https://example.com/path/", "https://other.com/file.bin"),
364 "https://other.com/file.bin"
365 );
366 }
367}