1use crate::{Result, ZiPatchError};
36use std::fs::File;
37use std::io::{Read, Seek, SeekFrom};
38use std::path::Path;
39
40pub trait PatchSource {
49 fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()>;
60}
61
62#[derive(Debug)]
68pub struct FilePatchSource {
69 files: Vec<File>,
70}
71
72impl FilePatchSource {
73 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
81 let file = File::open(path)?;
82 Ok(Self { files: vec![file] })
83 }
84
85 pub fn open_chain<I, P>(paths: I) -> Result<Self>
93 where
94 I: IntoIterator<Item = P>,
95 P: AsRef<Path>,
96 {
97 let iter = paths.into_iter();
98 let mut files = Vec::with_capacity(iter.size_hint().0);
102 for p in iter {
103 files.push(File::open(p).map_err(ZiPatchError::Io)?);
104 }
105 Ok(Self { files })
106 }
107
108 #[must_use]
110 pub fn from_file(file: File) -> Self {
111 Self { files: vec![file] }
112 }
113
114 #[must_use]
116 pub fn patch_count(&self) -> usize {
117 self.files.len()
118 }
119}
120
121impl PatchSource for FilePatchSource {
122 fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()> {
123 let count = self.files.len();
124 let file = self
125 .files
126 .get_mut(patch as usize)
127 .ok_or(ZiPatchError::PatchIndexOutOfRange { patch, count })?;
128 file.seek(SeekFrom::Start(offset))?;
129 match file.read_exact(dst) {
130 Ok(()) => Ok(()),
131 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
132 Err(ZiPatchError::PatchSourceTooShort {
133 offset,
134 requested: dst.len(),
135 })
136 }
137 Err(e) => Err(ZiPatchError::Io(e)),
138 }
139 }
140}
141
142#[cfg(any(test, feature = "test-utils"))]
147#[derive(Debug, Clone)]
148pub struct MemoryPatchSource {
149 bufs: Vec<std::sync::Arc<[u8]>>,
150}
151
152#[cfg(any(test, feature = "test-utils"))]
153impl MemoryPatchSource {
154 #[must_use]
156 pub fn new(buf: Vec<u8>) -> Self {
157 Self {
158 bufs: vec![buf.into()],
159 }
160 }
161
162 #[must_use]
164 pub fn from_slice(buf: &[u8]) -> Self {
165 Self {
166 bufs: vec![Vec::from(buf).into()],
167 }
168 }
169
170 #[must_use]
172 pub fn new_chain(bufs: Vec<Vec<u8>>) -> Self {
173 Self {
174 bufs: bufs.into_iter().map(Into::into).collect(),
175 }
176 }
177
178 #[must_use]
180 pub fn from_slices(bufs: &[&[u8]]) -> Self {
181 Self {
182 bufs: bufs.iter().map(|b| Vec::from(*b).into()).collect(),
183 }
184 }
185
186 #[must_use]
188 pub fn patch_count(&self) -> usize {
189 self.bufs.len()
190 }
191}
192
193#[cfg(any(test, feature = "test-utils"))]
194impl PatchSource for MemoryPatchSource {
195 fn read(&mut self, patch: u32, offset: u64, dst: &mut [u8]) -> Result<()> {
196 let count = self.bufs.len();
197 let buf = self
198 .bufs
199 .get(patch as usize)
200 .ok_or(ZiPatchError::PatchIndexOutOfRange { patch, count })?;
201 let start = usize::try_from(offset).map_err(|_| ZiPatchError::PatchSourceTooShort {
202 offset,
203 requested: dst.len(),
204 })?;
205 let end = start
206 .checked_add(dst.len())
207 .ok_or(ZiPatchError::PatchSourceTooShort {
208 offset,
209 requested: dst.len(),
210 })?;
211 if end > buf.len() {
212 return Err(ZiPatchError::PatchSourceTooShort {
213 offset,
214 requested: dst.len(),
215 });
216 }
217 dst.copy_from_slice(&buf[start..end]);
218 Ok(())
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn memory_source_round_trips_arbitrary_ranges() {
228 let bytes: Vec<u8> = (0..=255u8).collect();
229 let mut src = MemoryPatchSource::new(bytes.clone());
230
231 let mut head = [0u8; 16];
232 src.read(0, 0, &mut head).unwrap();
233 assert_eq!(&head, &bytes[..16]);
234
235 let mut mid = [0u8; 32];
236 src.read(0, 100, &mut mid).unwrap();
237 assert_eq!(&mid, &bytes[100..132]);
238
239 let mut tail = [0u8; 16];
240 src.read(0, 240, &mut tail).unwrap();
241 assert_eq!(&tail, &bytes[240..256]);
242
243 let mut empty = [0u8; 0];
245 src.read(0, 0, &mut empty).unwrap();
246 src.read(0, 256, &mut empty).unwrap();
247 }
248
249 #[test]
250 fn memory_source_out_of_range_returns_too_short() {
251 let mut src = MemoryPatchSource::new(vec![0u8; 16]);
252 let mut buf = [0u8; 4];
253
254 let err = src
255 .read(0, 15, &mut buf)
256 .expect_err("read past end must fail");
257 match err {
258 ZiPatchError::PatchSourceTooShort { offset, requested } => {
259 assert_eq!(offset, 15);
260 assert_eq!(requested, 4);
261 }
262 other => panic!("expected PatchSourceTooShort, got {other:?}"),
263 }
264
265 let err = src
266 .read(0, 1_000_000, &mut buf)
267 .expect_err("read far past end must fail");
268 assert!(matches!(err, ZiPatchError::PatchSourceTooShort { .. }));
269 }
270
271 #[test]
272 fn memory_source_chain_indexes_each_patch() {
273 let p0: Vec<u8> = (0..16u8).map(|i| 0xA0 | i).collect();
275 let p1: Vec<u8> = (0..16u8).map(|i| 0xB0 | i).collect();
276 let mut src = MemoryPatchSource::new_chain(vec![p0.clone(), p1.clone()]);
277
278 let mut buf = [0u8; 4];
279 src.read(0, 0, &mut buf).unwrap();
280 assert_eq!(&buf, &p0[..4]);
281 src.read(1, 0, &mut buf).unwrap();
282 assert_eq!(&buf, &p1[..4]);
283 src.read(0, 12, &mut buf).unwrap();
284 assert_eq!(&buf, &p0[12..16]);
285 }
286
287 #[test]
288 fn memory_source_chain_rejects_out_of_range_patch() {
289 let mut src = MemoryPatchSource::new_chain(vec![vec![0u8; 16]]);
290 let mut buf = [0u8; 4];
291 let err = src
292 .read(1, 0, &mut buf)
293 .expect_err("patch 1 must be out of range");
294 match err {
295 ZiPatchError::PatchIndexOutOfRange { patch, count } => {
296 assert_eq!(patch, 1);
297 assert_eq!(count, 1);
298 }
299 other => panic!("expected PatchIndexOutOfRange, got {other:?}"),
300 }
301 }
302
303 #[test]
304 fn file_source_round_trips_arbitrary_ranges() {
305 let bytes: Vec<u8> = (0..=255u8).collect();
306 let tmp = tempfile::tempdir().unwrap();
307 let path = tmp.path().join("source.bin");
308 std::fs::write(&path, &bytes).unwrap();
309
310 let mut src = FilePatchSource::open(&path).unwrap();
311
312 let mut head = [0u8; 16];
313 src.read(0, 0, &mut head).unwrap();
314 assert_eq!(&head, &bytes[..16]);
315
316 let mut mid = [0u8; 32];
317 src.read(0, 100, &mut mid).unwrap();
318 assert_eq!(&mid, &bytes[100..132]);
319 }
320
321 #[test]
322 fn file_source_short_returns_too_short() {
323 let tmp = tempfile::tempdir().unwrap();
324 let path = tmp.path().join("source.bin");
325 std::fs::write(&path, [0u8; 16]).unwrap();
326
327 let mut src = FilePatchSource::open(&path).unwrap();
328 let mut buf = [0u8; 32];
329 let err = src
330 .read(0, 0, &mut buf)
331 .expect_err("read past end must fail");
332 assert!(matches!(err, ZiPatchError::PatchSourceTooShort { .. }));
333 }
334
335 #[test]
336 fn file_source_chain_indexes_each_file() {
337 let tmp = tempfile::tempdir().unwrap();
338 let p0 = tmp.path().join("p0.bin");
339 let p1 = tmp.path().join("p1.bin");
340 std::fs::write(&p0, b"AAAAAAAA").unwrap();
341 std::fs::write(&p1, b"BBBBBBBB").unwrap();
342
343 let mut src = FilePatchSource::open_chain([&p0, &p1]).unwrap();
344 assert_eq!(src.patch_count(), 2);
345
346 let mut buf = [0u8; 4];
347 src.read(0, 0, &mut buf).unwrap();
348 assert_eq!(&buf, b"AAAA");
349 src.read(1, 4, &mut buf).unwrap();
350 assert_eq!(&buf, b"BBBB");
351 }
352}