solana_bpf_loader_program/syscalls/
mem_ops.rs

1use {
2    super::*,
3    crate::translate_mut,
4    solana_program_runtime::invoke_context::SerializedAccountMetadata,
5    solana_sbpf::{error::EbpfError, memory_region::MemoryRegion},
6    std::slice,
7};
8
9fn mem_op_consume(invoke_context: &mut InvokeContext, n: u64) -> Result<(), Error> {
10    let compute_cost = invoke_context.get_execution_cost();
11    let cost = compute_cost.mem_op_base_cost.max(
12        n.checked_div(compute_cost.cpi_bytes_per_unit)
13            .unwrap_or(u64::MAX),
14    );
15    consume_compute_meter(invoke_context, cost)
16}
17
18/// Check that two regions do not overlap.
19pub(crate) fn is_nonoverlapping<N>(src: N, src_len: N, dst: N, dst_len: N) -> bool
20where
21    N: Ord + num_traits::SaturatingSub,
22{
23    // If the absolute distance between the ptrs is at least as big as the size of the other,
24    // they do not overlap.
25    if src > dst {
26        src.saturating_sub(&dst) >= dst_len
27    } else {
28        dst.saturating_sub(&src) >= src_len
29    }
30}
31
32declare_builtin_function!(
33    /// memcpy
34    SyscallMemcpy,
35    fn rust(
36        invoke_context: &mut InvokeContext,
37        dst_addr: u64,
38        src_addr: u64,
39        n: u64,
40        _arg4: u64,
41        _arg5: u64,
42        memory_mapping: &mut MemoryMapping,
43    ) -> Result<u64, Error> {
44        mem_op_consume(invoke_context, n)?;
45
46        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
47            return Err(SyscallError::CopyOverlapping.into());
48        }
49
50        // host addresses can overlap so we always invoke memmove
51        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
52    }
53);
54
55declare_builtin_function!(
56    /// memmove
57    SyscallMemmove,
58    fn rust(
59        invoke_context: &mut InvokeContext,
60        dst_addr: u64,
61        src_addr: u64,
62        n: u64,
63        _arg4: u64,
64        _arg5: u64,
65        memory_mapping: &mut MemoryMapping,
66    ) -> Result<u64, Error> {
67        mem_op_consume(invoke_context, n)?;
68
69        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
70    }
71);
72
73declare_builtin_function!(
74    /// memcmp
75    SyscallMemcmp,
76    fn rust(
77        invoke_context: &mut InvokeContext,
78        s1_addr: u64,
79        s2_addr: u64,
80        n: u64,
81        cmp_result_addr: u64,
82        _arg5: u64,
83        memory_mapping: &mut MemoryMapping,
84    ) -> Result<u64, Error> {
85        mem_op_consume(invoke_context, n)?;
86
87        if invoke_context
88            .get_feature_set()
89            .bpf_account_data_direct_mapping
90        {
91            translate_mut!(
92                memory_mapping,
93                invoke_context.get_check_aligned(),
94                let cmp_result_ref_mut: &mut i32 = map(cmp_result_addr)?;
95            );
96            let syscall_context = invoke_context.get_syscall_context()?;
97
98            *cmp_result_ref_mut = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())?;
99        } else {
100            let s1 = translate_slice::<u8>(
101                memory_mapping,
102                s1_addr,
103                n,
104                invoke_context.get_check_aligned(),
105            )?;
106            let s2 = translate_slice::<u8>(
107                memory_mapping,
108                s2_addr,
109                n,
110                invoke_context.get_check_aligned(),
111            )?;
112
113            debug_assert_eq!(s1.len(), n as usize);
114            debug_assert_eq!(s2.len(), n as usize);
115            // Safety:
116            // memcmp is marked unsafe since it assumes that the inputs are at least
117            // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
118            // long because `translate_slice` would have failed otherwise.
119            let result = unsafe { memcmp(s1, s2, n as usize) };
120
121            translate_mut!(
122                memory_mapping,
123                invoke_context.get_check_aligned(),
124                let cmp_result_ref_mut: &mut i32 = map(cmp_result_addr)?;
125            );
126            *cmp_result_ref_mut = result;
127        }
128
129        Ok(0)
130    }
131);
132
133declare_builtin_function!(
134    /// memset
135    SyscallMemset,
136    fn rust(
137        invoke_context: &mut InvokeContext,
138        dst_addr: u64,
139        c: u64,
140        n: u64,
141        _arg4: u64,
142        _arg5: u64,
143        memory_mapping: &mut MemoryMapping,
144    ) -> Result<u64, Error> {
145        mem_op_consume(invoke_context, n)?;
146
147        if invoke_context
148            .get_feature_set()
149            .bpf_account_data_direct_mapping
150        {
151            let syscall_context = invoke_context.get_syscall_context()?;
152
153            memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())
154        } else {
155            translate_mut!(
156                memory_mapping,
157                invoke_context.get_check_aligned(),
158                let s: &mut [u8] = map(dst_addr, n)?;
159            );
160            s.fill(c as u8);
161            Ok(0)
162        }
163    }
164);
165
166fn memmove(
167    invoke_context: &mut InvokeContext,
168    dst_addr: u64,
169    src_addr: u64,
170    n: u64,
171    memory_mapping: &MemoryMapping,
172) -> Result<u64, Error> {
173    if invoke_context
174        .get_feature_set()
175        .bpf_account_data_direct_mapping
176    {
177        let syscall_context = invoke_context.get_syscall_context()?;
178
179        memmove_non_contiguous(
180            dst_addr,
181            src_addr,
182            n,
183            &syscall_context.accounts_metadata,
184            memory_mapping,
185            invoke_context.get_check_aligned(),
186        )
187    } else {
188        translate_mut!(
189            memory_mapping,
190            invoke_context.get_check_aligned(),
191            let dst_ref_mut: &mut [u8] = map(dst_addr, n)?;
192        );
193        let dst_ptr = dst_ref_mut.as_mut_ptr();
194        let src_ptr = translate_slice::<u8>(
195            memory_mapping,
196            src_addr,
197            n,
198            invoke_context.get_check_aligned(),
199        )?
200        .as_ptr();
201
202        unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
203        Ok(0)
204    }
205}
206
207fn memmove_non_contiguous(
208    dst_addr: u64,
209    src_addr: u64,
210    n: u64,
211    accounts: &[SerializedAccountMetadata],
212    memory_mapping: &MemoryMapping,
213    resize_area: bool,
214) -> Result<u64, Error> {
215    let reverse = dst_addr.wrapping_sub(src_addr) < n;
216    iter_memory_pair_chunks(
217        AccessType::Load,
218        src_addr,
219        AccessType::Store,
220        dst_addr,
221        n,
222        accounts,
223        memory_mapping,
224        reverse,
225        resize_area,
226        |src_host_addr, dst_host_addr, chunk_len| {
227            unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
228            Ok(0)
229        },
230    )
231}
232
233// Marked unsafe since it assumes that the slices are at least `n` bytes long.
234unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
235    for i in 0..n {
236        let a = *s1.get_unchecked(i);
237        let b = *s2.get_unchecked(i);
238        if a != b {
239            return (a as i32).saturating_sub(b as i32);
240        };
241    }
242
243    0
244}
245
246fn memcmp_non_contiguous(
247    src_addr: u64,
248    dst_addr: u64,
249    n: u64,
250    accounts: &[SerializedAccountMetadata],
251    memory_mapping: &MemoryMapping,
252    resize_area: bool,
253) -> Result<i32, Error> {
254    let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
255        let res = unsafe {
256            let s1 = slice::from_raw_parts(s1_addr, chunk_len);
257            let s2 = slice::from_raw_parts(s2_addr, chunk_len);
258            // Safety:
259            // memcmp is marked unsafe since it assumes that s1 and s2 are exactly chunk_len
260            // long. The whole point of iter_memory_pair_chunks is to find same length chunks
261            // across two memory regions.
262            memcmp(s1, s2, chunk_len)
263        };
264        if res != 0 {
265            return Err(MemcmpError::Diff(res).into());
266        }
267        Ok(0)
268    };
269    match iter_memory_pair_chunks(
270        AccessType::Load,
271        src_addr,
272        AccessType::Load,
273        dst_addr,
274        n,
275        accounts,
276        memory_mapping,
277        false,
278        resize_area,
279        memcmp_chunk,
280    ) {
281        Ok(res) => Ok(res),
282        Err(error) => match error.downcast_ref() {
283            Some(MemcmpError::Diff(diff)) => Ok(*diff),
284            _ => Err(error),
285        },
286    }
287}
288
289#[derive(Debug)]
290enum MemcmpError {
291    Diff(i32),
292}
293
294impl std::fmt::Display for MemcmpError {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        match self {
297            MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
298        }
299    }
300}
301
302impl std::error::Error for MemcmpError {
303    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
304        match self {
305            MemcmpError::Diff(_) => None,
306        }
307    }
308}
309
310fn memset_non_contiguous(
311    dst_addr: u64,
312    c: u8,
313    n: u64,
314    accounts: &[SerializedAccountMetadata],
315    memory_mapping: &MemoryMapping,
316    check_aligned: bool,
317) -> Result<u64, Error> {
318    let dst_chunk_iter = MemoryChunkIterator::new(
319        memory_mapping,
320        accounts,
321        AccessType::Store,
322        dst_addr,
323        n,
324        check_aligned,
325    )?;
326    for item in dst_chunk_iter {
327        let (dst_region, dst_vm_addr, dst_len) = item?;
328        let dst_host_addr = dst_region
329            .vm_to_host(dst_vm_addr, dst_len as u64)
330            .ok_or_else(|| {
331                EbpfError::AccessViolation(AccessType::Store, dst_vm_addr, dst_len as u64, "")
332            })?;
333        unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
334    }
335
336    Ok(0)
337}
338
339#[allow(clippy::too_many_arguments)]
340fn iter_memory_pair_chunks<T, F>(
341    src_access: AccessType,
342    src_addr: u64,
343    dst_access: AccessType,
344    dst_addr: u64,
345    n_bytes: u64,
346    accounts: &[SerializedAccountMetadata],
347    memory_mapping: &MemoryMapping,
348    reverse: bool,
349    resize_area: bool,
350    mut fun: F,
351) -> Result<T, Error>
352where
353    T: Default,
354    F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
355{
356    let mut src_chunk_iter = MemoryChunkIterator::new(
357        memory_mapping,
358        accounts,
359        src_access,
360        src_addr,
361        n_bytes,
362        resize_area,
363    )?;
364    let mut dst_chunk_iter = MemoryChunkIterator::new(
365        memory_mapping,
366        accounts,
367        dst_access,
368        dst_addr,
369        n_bytes,
370        resize_area,
371    )?;
372
373    let mut src_chunk = None;
374    let mut dst_chunk = None;
375
376    macro_rules! memory_chunk {
377        ($chunk_iter:ident, $chunk:ident) => {
378            if let Some($chunk) = &mut $chunk {
379                // Keep processing the current chunk
380                $chunk
381            } else {
382                // This is either the first call or we've processed all the bytes in the current
383                // chunk. Move to the next one.
384                let chunk = match if reverse {
385                    $chunk_iter.next_back()
386                } else {
387                    $chunk_iter.next()
388                } {
389                    Some(item) => item?,
390                    None => break,
391                };
392                $chunk.insert(chunk)
393            }
394        };
395    }
396
397    loop {
398        let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
399        let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
400
401        // We always process same-length pairs
402        let chunk_len = *src_remaining.min(dst_remaining);
403
404        let (src_host_addr, dst_host_addr) = {
405            let (src_addr, dst_addr) = if reverse {
406                // When scanning backwards not only we want to scan regions from the end,
407                // we want to process the memory within regions backwards as well.
408                (
409                    src_chunk_addr
410                        .saturating_add(*src_remaining as u64)
411                        .saturating_sub(chunk_len as u64),
412                    dst_chunk_addr
413                        .saturating_add(*dst_remaining as u64)
414                        .saturating_sub(chunk_len as u64),
415                )
416            } else {
417                (*src_chunk_addr, *dst_chunk_addr)
418            };
419
420            (
421                src_region
422                    .vm_to_host(src_addr, chunk_len as u64)
423                    .ok_or_else(|| {
424                        EbpfError::AccessViolation(AccessType::Load, src_addr, chunk_len as u64, "")
425                    })?,
426                dst_region
427                    .vm_to_host(dst_addr, chunk_len as u64)
428                    .ok_or_else(|| {
429                        EbpfError::AccessViolation(
430                            AccessType::Store,
431                            dst_addr,
432                            chunk_len as u64,
433                            "",
434                        )
435                    })?,
436            )
437        };
438
439        fun(
440            src_host_addr as *const u8,
441            dst_host_addr as *const u8,
442            chunk_len,
443        )?;
444
445        // Update how many bytes we have left to scan in each chunk
446        *src_remaining = src_remaining.saturating_sub(chunk_len);
447        *dst_remaining = dst_remaining.saturating_sub(chunk_len);
448
449        if !reverse {
450            // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
451            // mode we don't do this since we make progress by decreasing src_len and
452            // dst_len.
453            *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
454            *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
455        }
456
457        if *src_remaining == 0 {
458            src_chunk = None;
459        }
460
461        if *dst_remaining == 0 {
462            dst_chunk = None;
463        }
464    }
465
466    Ok(T::default())
467}
468
469struct MemoryChunkIterator<'a> {
470    memory_mapping: &'a MemoryMapping<'a>,
471    accounts: &'a [SerializedAccountMetadata],
472    access_type: AccessType,
473    initial_vm_addr: u64,
474    vm_addr_start: u64,
475    // exclusive end index (start + len, so one past the last valid address)
476    vm_addr_end: u64,
477    len: u64,
478    account_index: Option<usize>,
479    is_account: Option<bool>,
480    resize_area: bool,
481}
482
483impl<'a> MemoryChunkIterator<'a> {
484    fn new(
485        memory_mapping: &'a MemoryMapping,
486        accounts: &'a [SerializedAccountMetadata],
487        access_type: AccessType,
488        vm_addr: u64,
489        len: u64,
490        resize_area: bool,
491    ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
492        let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
493            access_type,
494            vm_addr,
495            len,
496            "unknown",
497        ))?;
498
499        Ok(MemoryChunkIterator {
500            memory_mapping,
501            accounts,
502            access_type,
503            initial_vm_addr: vm_addr,
504            len,
505            vm_addr_start: vm_addr,
506            vm_addr_end,
507            account_index: None,
508            is_account: None,
509            resize_area,
510        })
511    }
512
513    fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
514        match self.memory_mapping.region(self.access_type, vm_addr) {
515            Ok((_region_index, region)) => Ok(region),
516            Err(error) => match error {
517                EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
518                    EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
519                )),
520                EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
521                    Err(Box::new(EbpfError::StackAccessViolation(
522                        access_type,
523                        self.initial_vm_addr,
524                        self.len,
525                        frame,
526                    )))
527                }
528                _ => Err(error.into()),
529            },
530        }
531    }
532}
533
534impl<'a> Iterator for MemoryChunkIterator<'a> {
535    type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
536
537    fn next(&mut self) -> Option<Self::Item> {
538        if self.vm_addr_start == self.vm_addr_end {
539            return None;
540        }
541
542        let region = match self.region(self.vm_addr_start) {
543            Ok(region) => region,
544            Err(e) => {
545                self.vm_addr_start = self.vm_addr_end;
546                return Some(Err(e));
547            }
548        };
549
550        let region_is_account;
551
552        let mut account_index = self.account_index.unwrap_or_default();
553        self.account_index = Some(account_index);
554
555        loop {
556            if let Some(account) = self.accounts.get(account_index) {
557                let account_addr = account.vm_data_addr;
558                let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
559
560                if resize_addr < region.vm_addr {
561                    // region is after this account, move on next one
562                    account_index = account_index.saturating_add(1);
563                    self.account_index = Some(account_index);
564                } else {
565                    region_is_account = (account.original_data_len != 0 && region.vm_addr == account_addr)
566                        // unaligned programs do not have a resize area
567                        || (self.resize_area && region.vm_addr == resize_addr);
568                    break;
569                }
570            } else {
571                // address is after all the accounts
572                region_is_account = false;
573                break;
574            }
575        }
576
577        if let Some(is_account) = self.is_account {
578            if is_account != region_is_account {
579                return Some(Err(SyscallError::InvalidLength.into()));
580            }
581        } else {
582            self.is_account = Some(region_is_account);
583        }
584
585        let vm_addr = self.vm_addr_start;
586
587        let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
588            // consume the whole region
589            let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
590            self.vm_addr_start = region.vm_addr_end;
591            len
592        } else {
593            // consume part of the region
594            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
595            self.vm_addr_start = self.vm_addr_end;
596            len
597        };
598
599        Some(Ok((region, vm_addr, chunk_len as usize)))
600    }
601}
602
603impl DoubleEndedIterator for MemoryChunkIterator<'_> {
604    fn next_back(&mut self) -> Option<Self::Item> {
605        if self.vm_addr_start == self.vm_addr_end {
606            return None;
607        }
608
609        let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
610            Ok(region) => region,
611            Err(e) => {
612                self.vm_addr_start = self.vm_addr_end;
613                return Some(Err(e));
614            }
615        };
616
617        let region_is_account;
618
619        let mut account_index = self
620            .account_index
621            .unwrap_or_else(|| self.accounts.len().saturating_sub(1));
622        self.account_index = Some(account_index);
623
624        loop {
625            let Some(account) = self.accounts.get(account_index) else {
626                // address is after all the accounts
627                region_is_account = false;
628                break;
629            };
630
631            let account_addr = account.vm_data_addr;
632            let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
633
634            if account_index > 0 && account_addr > region.vm_addr {
635                account_index = account_index.saturating_sub(1);
636
637                self.account_index = Some(account_index);
638            } else {
639                region_is_account = (account.original_data_len != 0 && region.vm_addr == account_addr)
640                    // unaligned programs do not have a resize area
641                    || (self.resize_area && region.vm_addr == resize_addr);
642                break;
643            }
644        }
645
646        if let Some(is_account) = self.is_account {
647            if is_account != region_is_account {
648                return Some(Err(SyscallError::InvalidLength.into()));
649            }
650        } else {
651            self.is_account = Some(region_is_account);
652        }
653
654        let chunk_len = if region.vm_addr >= self.vm_addr_start {
655            // consume the whole region
656            let len = self.vm_addr_end.saturating_sub(region.vm_addr);
657            self.vm_addr_end = region.vm_addr;
658            len
659        } else {
660            // consume part of the region
661            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
662            self.vm_addr_end = self.vm_addr_start;
663            len
664        };
665
666        Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
667    }
668}
669
670#[cfg(test)]
671#[allow(clippy::indexing_slicing)]
672#[allow(clippy::arithmetic_side_effects)]
673mod tests {
674    use {
675        super::*,
676        assert_matches::assert_matches,
677        solana_sbpf::{ebpf::MM_RODATA_START, program::SBPFVersion},
678        test_case::test_case,
679    };
680
681    fn to_chunk_vec<'a>(
682        iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
683    ) -> Vec<(u64, usize)> {
684        iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
685            .collect::<Vec<_>>()
686    }
687
688    #[test]
689    #[should_panic(expected = "AccessViolation")]
690    fn test_memory_chunk_iterator_no_regions() {
691        let config = Config {
692            aligned_memory_mapping: false,
693            ..Config::default()
694        };
695        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
696
697        let mut src_chunk_iter =
698            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1, true).unwrap();
699        src_chunk_iter.next().unwrap().unwrap();
700    }
701
702    #[test]
703    #[should_panic(expected = "AccessViolation")]
704    fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
705        let config = Config {
706            aligned_memory_mapping: false,
707            ..Config::default()
708        };
709        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
710
711        let mut src_chunk_iter =
712            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1, true)
713                .unwrap();
714        src_chunk_iter.next().unwrap().unwrap();
715    }
716
717    #[test]
718    fn test_memory_chunk_iterator_out_of_bounds() {
719        let config = Config {
720            aligned_memory_mapping: false,
721            ..Config::default()
722        };
723        let mem1 = vec![0xFF; 42];
724        let memory_mapping = MemoryMapping::new(
725            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
726            &config,
727            SBPFVersion::V3,
728        )
729        .unwrap();
730
731        // check oob at the lower bound on the first next()
732        let mut src_chunk_iter = MemoryChunkIterator::new(
733            &memory_mapping,
734            &[],
735            AccessType::Load,
736            MM_RODATA_START - 1,
737            42,
738            true,
739        )
740        .unwrap();
741        assert_matches!(
742            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
743            EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_RODATA_START - 1
744        );
745
746        // check oob at the upper bound. Since the memory mapping isn't empty,
747        // this always happens on the second next().
748        let mut src_chunk_iter = MemoryChunkIterator::new(
749            &memory_mapping,
750            &[],
751            AccessType::Load,
752            MM_RODATA_START,
753            43,
754            true,
755        )
756        .unwrap();
757        assert!(src_chunk_iter.next().unwrap().is_ok());
758        assert_matches!(
759            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
760            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
761        );
762
763        // check oob at the upper bound on the first next_back()
764        let mut src_chunk_iter = MemoryChunkIterator::new(
765            &memory_mapping,
766            &[],
767            AccessType::Load,
768            MM_RODATA_START,
769            43,
770            true,
771        )
772        .unwrap()
773        .rev();
774        assert_matches!(
775            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
776            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
777        );
778
779        // check oob at the upper bound on the 2nd next_back()
780        let mut src_chunk_iter = MemoryChunkIterator::new(
781            &memory_mapping,
782            &[],
783            AccessType::Load,
784            MM_RODATA_START - 1,
785            43,
786            true,
787        )
788        .unwrap()
789        .rev();
790        assert!(src_chunk_iter.next().unwrap().is_ok());
791        assert_matches!(
792            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
793            EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_RODATA_START - 1
794        );
795    }
796
797    #[test]
798    fn test_memory_chunk_iterator_one() {
799        let config = Config {
800            aligned_memory_mapping: false,
801            ..Config::default()
802        };
803        let mem1 = vec![0xFF; 42];
804        let memory_mapping = MemoryMapping::new(
805            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
806            &config,
807            SBPFVersion::V3,
808        )
809        .unwrap();
810
811        // check lower bound
812        let mut src_chunk_iter = MemoryChunkIterator::new(
813            &memory_mapping,
814            &[],
815            AccessType::Load,
816            MM_RODATA_START - 1,
817            1,
818            true,
819        )
820        .unwrap();
821        assert!(src_chunk_iter.next().unwrap().is_err());
822
823        // check upper bound
824        let mut src_chunk_iter = MemoryChunkIterator::new(
825            &memory_mapping,
826            &[],
827            AccessType::Load,
828            MM_RODATA_START + 42,
829            1,
830            true,
831        )
832        .unwrap();
833        assert!(src_chunk_iter.next().unwrap().is_err());
834
835        for (vm_addr, len) in [
836            (MM_RODATA_START, 0),
837            (MM_RODATA_START + 42, 0),
838            (MM_RODATA_START, 1),
839            (MM_RODATA_START, 42),
840            (MM_RODATA_START + 41, 1),
841        ] {
842            for rev in [true, false] {
843                let iter = MemoryChunkIterator::new(
844                    &memory_mapping,
845                    &[],
846                    AccessType::Load,
847                    vm_addr,
848                    len,
849                    true,
850                )
851                .unwrap();
852                let res = if rev {
853                    to_chunk_vec(iter.rev())
854                } else {
855                    to_chunk_vec(iter)
856                };
857                if len == 0 {
858                    assert_eq!(res, &[]);
859                } else {
860                    assert_eq!(res, &[(vm_addr, len as usize)]);
861                }
862            }
863        }
864    }
865
866    #[test]
867    fn test_memory_chunk_iterator_two() {
868        let config = Config {
869            aligned_memory_mapping: false,
870            ..Config::default()
871        };
872        let mem1 = vec![0x11; 8];
873        let mem2 = vec![0x22; 4];
874        let memory_mapping = MemoryMapping::new(
875            vec![
876                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
877                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
878            ],
879            &config,
880            SBPFVersion::V3,
881        )
882        .unwrap();
883
884        for (vm_addr, len, mut expected) in [
885            (MM_RODATA_START, 8, vec![(MM_RODATA_START, 8)]),
886            (
887                MM_RODATA_START + 7,
888                2,
889                vec![(MM_RODATA_START + 7, 1), (MM_RODATA_START + 8, 1)],
890            ),
891            (MM_RODATA_START + 8, 4, vec![(MM_RODATA_START + 8, 4)]),
892        ] {
893            for rev in [false, true] {
894                let iter = MemoryChunkIterator::new(
895                    &memory_mapping,
896                    &[],
897                    AccessType::Load,
898                    vm_addr,
899                    len,
900                    true,
901                )
902                .unwrap();
903                let res = if rev {
904                    expected.reverse();
905                    to_chunk_vec(iter.rev())
906                } else {
907                    to_chunk_vec(iter)
908                };
909
910                assert_eq!(res, expected);
911            }
912        }
913    }
914
915    #[test]
916    fn test_iter_memory_pair_chunks_short() {
917        let config = Config {
918            aligned_memory_mapping: false,
919            ..Config::default()
920        };
921        let mem1 = vec![0x11; 8];
922        let mem2 = vec![0x22; 4];
923        let memory_mapping = MemoryMapping::new(
924            vec![
925                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
926                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
927            ],
928            &config,
929            SBPFVersion::V3,
930        )
931        .unwrap();
932
933        // dst is shorter than src
934        assert_matches!(
935            iter_memory_pair_chunks(
936                AccessType::Load,
937                MM_RODATA_START,
938                AccessType::Load,
939                MM_RODATA_START + 8,
940                8,
941                &[],
942                &memory_mapping,
943                false,
944                true,
945                |_src, _dst, _len| Ok::<_, Error>(0),
946            ).unwrap_err().downcast_ref().unwrap(),
947            EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_RODATA_START + 8
948        );
949
950        // src is shorter than dst
951        assert_matches!(
952            iter_memory_pair_chunks(
953                AccessType::Load,
954                MM_RODATA_START + 10,
955                AccessType::Load,
956                MM_RODATA_START + 2,
957                3,
958                &[],
959                &memory_mapping,
960                false,
961                true,
962                |_src, _dst, _len| Ok::<_, Error>(0),
963            ).unwrap_err().downcast_ref().unwrap(),
964            EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_RODATA_START + 10
965        );
966    }
967
968    #[test]
969    #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
970    fn test_memmove_non_contiguous_readonly() {
971        let config = Config {
972            aligned_memory_mapping: false,
973            ..Config::default()
974        };
975        let mem1 = vec![0x11; 8];
976        let mem2 = vec![0x22; 4];
977        let memory_mapping = MemoryMapping::new(
978            vec![
979                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
980                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
981            ],
982            &config,
983            SBPFVersion::V3,
984        )
985        .unwrap();
986
987        memmove_non_contiguous(
988            MM_RODATA_START,
989            MM_RODATA_START + 8,
990            4,
991            &[],
992            &memory_mapping,
993            true,
994        )
995        .unwrap();
996    }
997
998    #[test_case(&[], (0, 0, 0); "no regions")]
999    #[test_case(&[10], (1, 10, 0); "single region 0 len")]
1000    #[test_case(&[10], (0, 5, 5); "single region no overlap")]
1001    #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
1002    #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
1003    #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
1004    #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
1005    #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
1006    #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
1007    #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
1008    #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
1009    #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
1010    #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
1011    #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
1012    #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
1013    #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
1014    #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
1015    #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
1016    #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
1017    fn test_memmove_non_contiguous(
1018        regions: &[usize],
1019        (src_offset, dst_offset, len): (usize, usize, usize),
1020    ) {
1021        let config = Config {
1022            aligned_memory_mapping: false,
1023            ..Config::default()
1024        };
1025        let (mem, memory_mapping) = build_memory_mapping(regions, &config);
1026
1027        // flatten the memory so we can memmove it with ptr::copy
1028        let mut expected_memory = flatten_memory(&mem);
1029        unsafe {
1030            std::ptr::copy(
1031                expected_memory.as_ptr().add(src_offset),
1032                expected_memory.as_mut_ptr().add(dst_offset),
1033                len,
1034            )
1035        };
1036
1037        // do our memmove
1038        memmove_non_contiguous(
1039            MM_RODATA_START + dst_offset as u64,
1040            MM_RODATA_START + src_offset as u64,
1041            len as u64,
1042            &[],
1043            &memory_mapping,
1044            true,
1045        )
1046        .unwrap();
1047
1048        // flatten memory post our memmove
1049        let memory = flatten_memory(&mem);
1050
1051        // compare libc's memmove with ours
1052        assert_eq!(expected_memory, memory);
1053    }
1054
1055    #[test]
1056    #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
1057    fn test_memset_non_contiguous_readonly() {
1058        let config = Config {
1059            aligned_memory_mapping: false,
1060            ..Config::default()
1061        };
1062        let mut mem1 = vec![0x11; 8];
1063        let mem2 = vec![0x22; 4];
1064        let memory_mapping = MemoryMapping::new(
1065            vec![
1066                MemoryRegion::new_writable(&mut mem1, MM_RODATA_START),
1067                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
1068            ],
1069            &config,
1070            SBPFVersion::V3,
1071        )
1072        .unwrap();
1073
1074        assert_eq!(
1075            memset_non_contiguous(MM_RODATA_START, 0x33, 9, &[], &memory_mapping, true).unwrap(),
1076            0
1077        );
1078    }
1079
1080    #[test]
1081    fn test_memset_non_contiguous() {
1082        let config = Config {
1083            aligned_memory_mapping: false,
1084            ..Config::default()
1085        };
1086        let mem1 = vec![0x11; 1];
1087        let mut mem2 = vec![0x22; 2];
1088        let mut mem3 = vec![0x33; 3];
1089        let mut mem4 = vec![0x44; 4];
1090        let memory_mapping = MemoryMapping::new(
1091            vec![
1092                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1093                MemoryRegion::new_writable(&mut mem2, MM_RODATA_START + 1),
1094                MemoryRegion::new_writable(&mut mem3, MM_RODATA_START + 3),
1095                MemoryRegion::new_writable(&mut mem4, MM_RODATA_START + 6),
1096            ],
1097            &config,
1098            SBPFVersion::V3,
1099        )
1100        .unwrap();
1101
1102        assert_eq!(
1103            memset_non_contiguous(MM_RODATA_START + 1, 0x55, 7, &[], &memory_mapping, true)
1104                .unwrap(),
1105            0
1106        );
1107        assert_eq!(&mem1, &[0x11]);
1108        assert_eq!(&mem2, &[0x55, 0x55]);
1109        assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
1110        assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
1111    }
1112
1113    #[test]
1114    fn test_memcmp_non_contiguous() {
1115        let config = Config {
1116            aligned_memory_mapping: false,
1117            ..Config::default()
1118        };
1119        let mem1 = b"foo".to_vec();
1120        let mem2 = b"barbad".to_vec();
1121        let mem3 = b"foobarbad".to_vec();
1122        let memory_mapping = MemoryMapping::new(
1123            vec![
1124                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1125                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 3),
1126                MemoryRegion::new_readonly(&mem3, MM_RODATA_START + 9),
1127            ],
1128            &config,
1129            SBPFVersion::V3,
1130        )
1131        .unwrap();
1132
1133        // non contiguous src
1134        assert_eq!(
1135            memcmp_non_contiguous(
1136                MM_RODATA_START,
1137                MM_RODATA_START + 9,
1138                9,
1139                &[],
1140                &memory_mapping,
1141                true
1142            )
1143            .unwrap(),
1144            0
1145        );
1146
1147        // non contiguous dst
1148        assert_eq!(
1149            memcmp_non_contiguous(
1150                MM_RODATA_START + 10,
1151                MM_RODATA_START + 1,
1152                8,
1153                &[],
1154                &memory_mapping,
1155                true
1156            )
1157            .unwrap(),
1158            0
1159        );
1160
1161        // diff
1162        assert_eq!(
1163            memcmp_non_contiguous(
1164                MM_RODATA_START + 1,
1165                MM_RODATA_START + 11,
1166                5,
1167                &[],
1168                &memory_mapping,
1169                true
1170            )
1171            .unwrap(),
1172            unsafe { memcmp(b"oobar", b"obarb", 5) }
1173        );
1174    }
1175
1176    fn build_memory_mapping<'a>(
1177        regions: &[usize],
1178        config: &'a Config,
1179    ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
1180        let mut regs = vec![];
1181        let mut mem = Vec::new();
1182        let mut offset = 0;
1183        for (i, region_len) in regions.iter().enumerate() {
1184            mem.push(
1185                (0..*region_len)
1186                    .map(|x| (i * 10 + x) as u8)
1187                    .collect::<Vec<_>>(),
1188            );
1189            regs.push(MemoryRegion::new_writable(
1190                &mut mem[i],
1191                MM_RODATA_START + offset as u64,
1192            ));
1193            offset += *region_len;
1194        }
1195
1196        let memory_mapping = MemoryMapping::new(regs, config, SBPFVersion::V3).unwrap();
1197
1198        (mem, memory_mapping)
1199    }
1200
1201    fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
1202        mem.iter().flatten().copied().collect()
1203    }
1204
1205    #[test]
1206    fn test_is_nonoverlapping() {
1207        for dst in 0..8 {
1208            assert!(is_nonoverlapping(10, 3, dst, 3));
1209        }
1210        for dst in 8..13 {
1211            assert!(!is_nonoverlapping(10, 3, dst, 3));
1212        }
1213        for dst in 13..20 {
1214            assert!(is_nonoverlapping(10, 3, dst, 3));
1215        }
1216        assert!(is_nonoverlapping::<u8>(255, 3, 254, 1));
1217        assert!(!is_nonoverlapping::<u8>(255, 2, 254, 3));
1218    }
1219}