swamp_vm/
host.rs

1/*
2 * Copyright (c) Peter Bjorklund. All rights reserved. https://github.com/swamp/swamp
3 * Licensed under the MIT License. See LICENSE in the project root for license information.
4 */
5use crate::RegContents;
6use std::{
7    mem::{align_of, size_of},
8    ptr, slice,
9};
10use swamp_vm_types::{VEC_HEADER_MAGIC_CODE, VEC_HEADER_PAYLOAD_OFFSET, VecHeader};
11
12pub struct HostArgs {
13    // references into the Vm
14    all_memory: *mut u8,
15    all_memory_len: usize,
16    registers: *mut RegContents,
17    register_count: usize,
18    stack_offset: usize,
19    pub function_id: u16,
20}
21
22impl HostArgs {
23    #[must_use]
24    pub unsafe fn new(
25        function_id: u16,
26        all_memory: *mut u8,
27        all_memory_len: usize,
28        stack_offset: usize,
29        registers: *mut RegContents,
30        register_count: usize,
31    ) -> Self {
32        // Ensure alignment
33        debug_assert_eq!(
34            all_memory.addr() % align_of::<u64>(),
35            0,
36            "Unaligned frame pointer",
37        );
38
39        Self {
40            all_memory,
41            all_memory_len,
42            registers,
43            stack_offset,
44            register_count,
45            function_id,
46        }
47    }
48
49    /// Get a raw pointer from a register value
50    pub fn ptr(&mut self, register: u8) -> *mut u8 {
51        let addr = unsafe { *self.registers.add(register as usize) };
52        unsafe { self.all_memory.add(addr as usize) }
53    }
54
55    pub fn print_bytes(label: &str, bytes: &[u8]) {
56        print!("{label}: [");
57        for (i, &b) in bytes.iter().enumerate() {
58            print!("{b:02X}");
59            if i < bytes.len() - 1 {
60                print!(" ");
61            }
62        }
63        println!("]");
64    }
65
66    pub unsafe fn ptr_to_slice<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
67        unsafe { std::slice::from_raw_parts(ptr, len) }
68    }
69
70    pub unsafe fn ptr_to_slice_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
71        unsafe { std::slice::from_raw_parts_mut(ptr, len) }
72    }
73
74    /// Get a typed pointer from a register
75    pub fn ptr_as<T>(&mut self, register_id: u8) -> *const T {
76        self.ptr(register_id) as *const T
77    }
78
79    /// Get a safe reference to T from a register with bounds and alignment checks
80    pub fn get<T>(&self, register_id: u8) -> &T {
81        assert!(
82            (register_id as usize) < self.register_count,
83            "Host call register out of bounds: register {} requested, but only {} registers available",
84            register_id,
85            self.register_count
86        );
87
88        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
89        let size_of_t = size_of::<T>();
90
91        // Bounds check: ensure the entire T fits within memory
92        assert!(
93            addr + size_of_t <= self.all_memory_len,
94            "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
95            size_of_t,
96            addr,
97            self.all_memory_len
98        );
99
100        assert_eq!(
101            addr % align_of::<T>(),
102            0,
103            "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
104            addr,
105            std::any::type_name::<T>(),
106            align_of::<T>()
107        );
108
109        unsafe { &*(self.all_memory.add(addr) as *const T) }
110    }
111
112    /// Get a reference to T from a register without safety checks (for performance)
113    pub unsafe fn get_unchecked<T>(&mut self, register_id: u8) -> &T {
114        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
115        unsafe { &*(self.all_memory.add(addr) as *const T) }
116    }
117
118    /// Get a mutable reference to T from a register with bounds and alignment checks
119    pub fn get_mut<T>(&mut self, register_id: u8) -> &mut T {
120        assert!(
121            (register_id as usize) < self.register_count,
122            "Host call register out of bounds: register {} requested, but only {} registers available",
123            register_id,
124            self.register_count
125        );
126
127        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
128        let size_of_t = size_of::<T>();
129
130        // Bounds check: ensure the entire T fits within memory
131        assert!(
132            addr + size_of_t <= self.all_memory_len,
133            "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
134            size_of_t,
135            addr,
136            self.all_memory_len
137        );
138
139        // Alignment check: ensure T is properly aligned
140        assert_eq!(
141            addr % align_of::<T>(),
142            0,
143            "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
144            addr,
145            std::any::type_name::<T>(),
146            align_of::<T>()
147        );
148
149        // Safe to create mutable reference - all checks passed
150        unsafe { &mut *self.all_memory.add(addr).cast::<T>() }
151    }
152
153    /// Get a mutable reference to T from a register without safety checks (for performance)
154    pub unsafe fn get_mut_unchecked<T>(&mut self, register_id: u8) -> &mut T {
155        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
156        unsafe { &mut *(self.all_memory.add(addr) as *mut T) }
157    }
158
159    /// Get the raw register value as u32
160    pub fn register(&self, register_id: u8) -> u32 {
161        unsafe { *self.registers.add(register_id as usize) }
162    }
163
164    /// Get the register value as i32
165    pub fn register_i32(&self, register_id: u8) -> i32 {
166        unsafe { *self.registers.add(register_id as usize) as i32 }
167    }
168
169    /// Set a register to a u32 value
170    pub fn set_register(&mut self, register_id: u8, data: u32) {
171        unsafe {
172            *self.registers.add(register_id as usize) = data;
173        }
174    }
175
176    /// Write data to the memory location pointed to by a register
177    pub fn write<T>(&mut self, register_id: u8, data: &T) {
178        let dest_ptr = self.ptr(register_id) as *mut T;
179        let src_ptr = data as *const T;
180
181        unsafe {
182            ptr::copy_nonoverlapping(src_ptr, dest_ptr, 1);
183        }
184    }
185
186    /// Get a string from a register
187    pub fn string(&self, register_id: u8) -> &str {
188        let string_header_addr = unsafe { *self.registers.add(register_id as usize) };
189        unsafe {
190            let string_header =
191                *(self.all_memory.add(string_header_addr as usize) as *const VecHeader);
192
193            // String data follows directly after the header
194            let string_data_ptr = self
195                .all_memory
196                .add(string_header_addr as usize + VEC_HEADER_PAYLOAD_OFFSET.0 as usize);
197            let string_byte_length = string_header.element_count as usize;
198
199            debug_assert!(
200                string_header_addr as usize + size_of::<VecHeader>() + string_byte_length
201                    <= self.all_memory_len,
202                "String read out-of-bounds in memory"
203            );
204
205            let bytes = slice::from_raw_parts(string_data_ptr, string_byte_length);
206
207            std::str::from_utf8_unchecked(bytes)
208        }
209    }
210
211    /// Write multiple elements to the end of a vector (bulk append)
212    ///
213    /// # Arguments
214    /// * `vec_register` - Register containing the vector header address, usually r0
215    /// * `data_slice` - Slice of data to append to the vector
216    ///
217    /// # Panics
218    /// * If the vector header is invalid or corrupted
219    /// * If there's insufficient capacity for all elements
220    /// * If the element size doesn't match the vector's element size
221    pub fn write_to_vector_bulk<T>(&mut self, vec_register: u8, data_slice: &[T])
222    where
223        T: Copy,
224    {
225        // Get the vector header address and read the header
226        let vec_header_addr = self.register(vec_register);
227        let vec_header_ptr =
228            unsafe { self.all_memory.add(vec_header_addr as usize) as *mut VecHeader };
229
230        let (element_count, capacity, element_size) = unsafe {
231            let header = &*vec_header_ptr;
232
233            // Validate the vector header
234            if header.padding != VEC_HEADER_MAGIC_CODE {
235                panic!("Invalid vector header - memory corruption detected");
236            }
237
238            if header.capacity == 0 {
239                panic!("Vector was never initialized");
240            }
241
242            (header.element_count, header.capacity, header.element_size)
243        };
244
245        // Check if there's enough capacity
246        let required_capacity = element_count as usize + data_slice.len();
247        if required_capacity > capacity as usize {
248            panic!("Not enough capacity: need {required_capacity}, have {capacity}");
249        }
250
251        // Verify element size matches
252        if size_of::<T>() != element_size as usize {
253            panic!(
254                "Element size mismatch: expected {}, got {}",
255                element_size,
256                size_of::<T>()
257            );
258        }
259
260        // Calculate the starting address for the new elements
261        let start_addr =
262            vec_header_addr + VEC_HEADER_PAYLOAD_OFFSET.0 + (element_count as u32) * element_size;
263
264        // Write all elements
265        unsafe {
266            let dest_ptr = self.all_memory.add(start_addr as usize) as *mut T;
267            ptr::copy_nonoverlapping(data_slice.as_ptr(), dest_ptr, data_slice.len());
268
269            // Update the element count
270            (*vec_header_ptr).element_count += data_slice.len() as u16;
271        }
272    }
273
274    /// Write a single element to a specific index in a vector
275    ///
276    /// # Arguments
277    /// * `vec_register` - Register containing the vector header address, usually r0
278    /// * `index` - Index where to write the element (must be `< element_count`)
279    /// * `data` - Data to write at the specified index
280    ///
281    /// # Panics
282    /// * If the vector header is invalid or corrupted
283    /// * If the index is out of bounds
284    /// * If the element size doesn't match the vector's element size
285    pub fn write_to_vector_at_index<T>(&mut self, vec_register: u8, index: u16, data: &T)
286    where
287        T: Copy,
288    {
289        // Get the vector header address and read the header
290        let vec_header_addr = self.register(vec_register);
291        let vec_header =
292            unsafe { *(self.all_memory.add(vec_header_addr as usize) as *const VecHeader) };
293
294        // Validate the vector header
295        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
296            panic!("Invalid vector header - memory corruption detected");
297        }
298
299        if vec_header.capacity == 0 {
300            panic!("Vector was never initialized");
301        }
302
303        // Bounds check
304        if index >= vec_header.element_count {
305            panic!(
306                "Index {} out of bounds for vector of length {}",
307                index, vec_header.element_count
308            );
309        }
310
311        // Verify element size matches
312        if size_of::<T>() != vec_header.element_size as usize {
313            panic!(
314                "Element size mismatch: expected {}, got {}",
315                vec_header.element_size,
316                size_of::<T>()
317            );
318        }
319
320        // Calculate the address of the element to write
321        let element_addr = vec_header_addr
322            + VEC_HEADER_PAYLOAD_OFFSET.0
323            + (index as u32) * vec_header.element_size;
324
325        // Write the data
326        unsafe {
327            let element_ptr = self.all_memory.add(element_addr as usize) as *mut T;
328            ptr::write(element_ptr, *data);
329        }
330    }
331
332    /// Append a single element to the end of a vector
333    ///
334    /// # Arguments
335    /// * `vec_register` - Register containing the vector header address, usually r0
336    /// * `data` - Data to append to the vector
337    ///
338    /// # Panics
339    /// * If the vector header is invalid or corrupted
340    /// * If there's insufficient capacity for the new element
341    /// * If the element size doesn't match the vector's element size
342    pub fn push_to_vector<T>(&mut self, vec_register: u8, data: &T)
343    where
344        T: Copy,
345    {
346        self.write_to_vector_bulk(vec_register, slice::from_ref(data));
347    }
348
349    /// Read an element from a vector at a specific index
350    ///
351    /// # Arguments
352    /// * `vec_register` - Register containing the vector header address, usually r0
353    /// * `index` - Index of the element to read
354    ///
355    /// # Returns
356    /// A reference to the element at the specified index
357    ///
358    /// # Panics
359    /// * If the vector header is invalid or corrupted
360    /// * If the index is out of bounds
361    /// * If the element size doesn't match the requested type size
362    pub fn read_from_vector_at_index<T>(&self, vec_register: u8, index: u16) -> &T
363    where
364        T: Copy,
365    {
366        // Get the vector header
367        let vec_header = self.get::<VecHeader>(vec_register);
368
369        // Validate the vector header
370        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
371            panic!("Invalid vector header - memory corruption detected");
372        }
373
374        if vec_header.capacity == 0 {
375            panic!("Vector was never initialized");
376        }
377
378        // Bounds check
379        if index >= vec_header.element_count {
380            panic!(
381                "Index {} out of bounds for vector of length {}",
382                index, vec_header.element_count
383            );
384        }
385
386        // Verify element size matches
387        if size_of::<T>() != vec_header.element_size as usize {
388            panic!(
389                "Element size mismatch: expected {}, got {}",
390                vec_header.element_size,
391                size_of::<T>()
392            );
393        }
394
395        // Calculate the address of the element to read
396        let vec_header_addr = self.register(vec_register);
397        let element_addr = vec_header_addr
398            + VEC_HEADER_PAYLOAD_OFFSET.0
399            + (index as u32) * vec_header.element_size;
400
401        // Read the data
402        unsafe {
403            let element_ptr = self.all_memory.add(element_addr as usize) as *const T;
404            &*element_ptr
405        }
406    }
407
408    /// Get the current length (element count) of a vector
409    ///
410    /// # Arguments
411    /// * `vec_register` - Register containing the vector header address
412    ///
413    /// # Returns
414    /// The number of elements currently in the vector
415    ///
416    /// # Panics
417    /// * If the vector header is invalid or corrupted
418    pub fn vector_len(&self, vec_register: u8) -> u16 {
419        let vec_header = self.get::<VecHeader>(vec_register);
420
421        // Validate the vector header
422        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
423            panic!("Invalid vector header - memory corruption detected");
424        }
425
426        if vec_header.capacity == 0 {
427            panic!("Vector was never initialized");
428        }
429
430        vec_header.element_count
431    }
432
433    /// Get the capacity of a vector
434    ///
435    /// # Arguments
436    /// * `vec_register` - Register containing the vector header address
437    ///
438    /// # Returns
439    /// The maximum number of elements the vector can hold
440    ///
441    /// # Panics
442    /// * If the vector header is invalid or corrupted
443    pub fn vector_capacity(&self, vec_register: u8) -> u16 {
444        let vec_header = self.get::<VecHeader>(vec_register);
445
446        // Validate the vector header
447        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
448            panic!("Invalid vector header - memory corruption detected");
449        }
450
451        if vec_header.capacity == 0 {
452            panic!("Vector was never initialized");
453        }
454
455        vec_header.capacity
456    }
457
458    /// Check if a vector is empty
459    ///
460    /// # Arguments
461    /// * `vec_register` - Register containing the vector header address
462    ///
463    /// # Returns
464    /// True if the vector has no elements, false otherwise
465    pub fn vector_is_empty(&self, vec_register: u8) -> bool {
466        self.vector_len(vec_register) == 0
467    }
468
469    /// Check if a vector is full (at capacity)
470    ///
471    /// # Arguments
472    /// * `vec_register` - Register containing the vector header address
473    ///
474    /// # Returns
475    /// True if the vector is at capacity, false otherwise
476    pub fn vector_is_full(&self, vec_register: u8) -> bool {
477        let vec_header = self.get::<VecHeader>(vec_register);
478        vec_header.element_count >= vec_header.capacity
479    }
480}
481
482pub trait HostFunctionCallback {
483    fn dispatch_host_call(&mut self, args: HostArgs);
484}