swamp_vm/
host.rs

1/*
2 * Copyright (c) Peter Bjorklund. All rights reserved. https://github.com/swamp/swamp
3 * Licensed under the MIT License. See LICENSE in the project root for license information.
4 */
5use crate::RegContents;
6use std::{
7    mem::{align_of, size_of},
8    ptr, slice,
9};
10use swamp_vm_isa::{AnyHeader, VEC_HEADER_MAGIC_CODE, VEC_HEADER_PAYLOAD_OFFSET, VecHeader};
11
12#[derive(Debug)]
13pub struct AnyValue {
14    pub bytes: Vec<u8>,
15    pub type_hash: u32,
16}
17
18#[derive(Debug)]
19pub struct AnyValueMut {
20    pub data_ptr: *mut u8,
21    pub size: usize,
22    pub type_hash: u32,
23}
24
25pub struct HostArgs {
26    // references into the Vm
27    all_memory: *mut u8,
28    all_memory_len: usize,
29    registers: *mut RegContents,
30    register_count: usize,
31    stack_offset: usize,
32    pub function_id: u16,
33}
34
35impl HostArgs {
36    #[must_use]
37    pub unsafe fn new(
38        function_id: u16,
39        all_memory: *mut u8,
40        all_memory_len: usize,
41        stack_offset: usize,
42        registers: *mut RegContents,
43        register_count: usize,
44    ) -> Self {
45        // Ensure alignment
46        debug_assert_eq!(
47            all_memory.addr() % align_of::<u64>(),
48            0,
49            "Unaligned frame pointer",
50        );
51
52        Self {
53            all_memory,
54            all_memory_len,
55            registers,
56            stack_offset,
57            register_count,
58            function_id,
59        }
60    }
61
62    /// Get a raw pointer from a register value
63    pub fn ptr(&mut self, register: u8) -> *mut u8 {
64        let addr = unsafe { *self.registers.add(register as usize) };
65        unsafe { self.all_memory.add(addr as usize) }
66    }
67
68    pub fn print_bytes(label: &str, bytes: &[u8]) {
69        print!("{label}: [");
70        for (i, &b) in bytes.iter().enumerate() {
71            print!("{b:02X}");
72            if i < bytes.len() - 1 {
73                print!(" ");
74            }
75        }
76        println!("]");
77    }
78
79    pub unsafe fn ptr_to_slice<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
80        unsafe { std::slice::from_raw_parts(ptr, len) }
81    }
82
83    pub unsafe fn ptr_to_slice_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
84        unsafe { std::slice::from_raw_parts_mut(ptr, len) }
85    }
86
87    /// Get a typed pointer from a register
88    pub fn ptr_as<T>(&mut self, register_id: u8) -> *const T {
89        self.ptr(register_id) as *const T
90    }
91
92    /// Get a safe reference to T from a register with bounds and alignment checks
93    pub fn get<T>(&self, register_id: u8) -> &T {
94        assert!(
95            (register_id as usize) < self.register_count,
96            "Host call register out of bounds: register {} requested, but only {} registers available",
97            register_id,
98            self.register_count
99        );
100
101        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
102        let size_of_t = size_of::<T>();
103
104        // Bounds check: ensure the entire T fits within memory
105        assert!(
106            addr + size_of_t <= self.all_memory_len,
107            "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
108            size_of_t,
109            addr,
110            self.all_memory_len
111        );
112
113        assert_eq!(
114            addr % align_of::<T>(),
115            0,
116            "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
117            addr,
118            std::any::type_name::<T>(),
119            align_of::<T>()
120        );
121
122        unsafe { &*(self.all_memory.add(addr) as *const T) }
123    }
124
125    /// Get a reference to T from a register without safety checks (for performance)
126    pub unsafe fn get_unchecked<T>(&mut self, register_id: u8) -> &T {
127        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
128        unsafe { &*(self.all_memory.add(addr) as *const T) }
129    }
130
131    /// Get a mutable reference to T from a register with bounds and alignment checks
132    pub fn get_mut<T>(&mut self, register_id: u8) -> &mut T {
133        assert!(
134            (register_id as usize) < self.register_count,
135            "Host call register out of bounds: register {} requested, but only {} registers available",
136            register_id,
137            self.register_count
138        );
139
140        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
141        let size_of_t = size_of::<T>();
142
143        // Bounds check: ensure the entire T fits within memory
144        assert!(
145            addr + size_of_t <= self.all_memory_len,
146            "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
147            size_of_t,
148            addr,
149            self.all_memory_len
150        );
151
152        // Alignment check: ensure T is properly aligned
153        assert_eq!(
154            addr % align_of::<T>(),
155            0,
156            "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
157            addr,
158            std::any::type_name::<T>(),
159            align_of::<T>()
160        );
161
162        // Safe to create mutable reference - all checks passed
163        unsafe { &mut *self.all_memory.add(addr).cast::<T>() }
164    }
165
166    /// Get a mutable reference to T from a register without safety checks (for performance)
167    pub unsafe fn get_mut_unchecked<T>(&mut self, register_id: u8) -> &mut T {
168        let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
169        unsafe { &mut *(self.all_memory.add(addr) as *mut T) }
170    }
171
172    /// Get the raw register value as u32
173    pub fn register(&self, register_id: u8) -> u32 {
174        unsafe { *self.registers.add(register_id as usize) }
175    }
176
177    /// Get the register value as i32
178    pub fn register_i32(&self, register_id: u8) -> i32 {
179        unsafe { *self.registers.add(register_id as usize) as i32 }
180    }
181
182    /// Set a register to a u32 value
183    pub fn set_register(&mut self, register_id: u8, data: u32) {
184        unsafe {
185            *self.registers.add(register_id as usize) = data;
186        }
187    }
188
189    /// Write data to the memory location pointed to by a register
190    pub fn write<T>(&mut self, register_id: u8, data: &T) {
191        let dest_ptr = self.ptr(register_id) as *mut T;
192        let src_ptr = data as *const T;
193
194        unsafe {
195            ptr::copy_nonoverlapping(src_ptr, dest_ptr, 1);
196        }
197    }
198
199    /// Get a string from a register
200    pub fn string(&self, register_id: u8) -> &str {
201        let string_header_addr = unsafe { *self.registers.add(register_id as usize) };
202        unsafe {
203            let string_header =
204                *(self.all_memory.add(string_header_addr as usize) as *const VecHeader);
205
206            // String data follows directly after the header
207            let string_data_ptr = self
208                .all_memory
209                .add(string_header_addr as usize + VEC_HEADER_PAYLOAD_OFFSET.0 as usize);
210            let string_byte_length = string_header.element_count as usize;
211
212            debug_assert!(
213                string_header_addr as usize + size_of::<VecHeader>() + string_byte_length
214                    <= self.all_memory_len,
215                "String read out-of-bounds in memory"
216            );
217
218            let bytes = slice::from_raw_parts(string_data_ptr, string_byte_length);
219
220            std::str::from_utf8_unchecked(bytes)
221        }
222    }
223
224    pub fn any(&self, register_id: u8) -> AnyValue {
225        let any_header = self.get::<AnyHeader>(register_id);
226
227        unsafe {
228            let any_data_ptr = self.all_memory.add(any_header.data_ptr as usize);
229
230            let data_slice = slice::from_raw_parts(any_data_ptr, any_header.size as usize);
231
232            AnyValue {
233                bytes: data_slice.to_vec(),
234                type_hash: any_header.type_hash,
235            }
236        }
237    }
238
239    pub fn any_mut(&self, register_id: u8) -> AnyValueMut {
240        let any_header = self.get::<AnyHeader>(register_id);
241        let any_data_ptr = unsafe { self.all_memory.add(any_header.data_ptr as usize) };
242        AnyValueMut {
243            data_ptr: any_data_ptr,
244            size: any_header.size as usize,
245            type_hash: any_header.type_hash,
246        }
247    }
248
249    /// Write multiple elements to the end of a vector (bulk append)
250    ///
251    /// # Arguments
252    /// * `vec_register` - Register containing the vector header address, usually r0
253    /// * `data_slice` - Slice of data to append to the vector
254    ///
255    /// # Panics
256    /// * If the vector header is invalid or corrupted
257    /// * If there's insufficient capacity for all elements
258    /// * If the element size doesn't match the vector's element size
259    pub fn write_to_vector_bulk<T>(&mut self, vec_register: u8, data_slice: &[T])
260    where
261        T: Copy,
262    {
263        // Get the vector header address and read the header
264        let vec_header_addr = self.register(vec_register);
265        let vec_header_ptr =
266            unsafe { self.all_memory.add(vec_header_addr as usize) as *mut VecHeader };
267
268        let (element_count, capacity, element_size) = unsafe {
269            let header = &*vec_header_ptr;
270
271            // Validate the vector header
272            if header.padding != VEC_HEADER_MAGIC_CODE {
273                panic!("Invalid vector header - memory corruption detected");
274            }
275
276            (header.element_count, header.capacity, header.element_size)
277        };
278
279        // Check if there's enough capacity
280        let required_capacity = element_count as usize + data_slice.len();
281        if required_capacity > capacity as usize {
282            panic!("Not enough capacity: need {required_capacity}, have {capacity}");
283        }
284
285        // Verify element size matches
286        if size_of::<T>() != element_size as usize {
287            panic!(
288                "Element size mismatch: expected {}, got {}",
289                element_size,
290                size_of::<T>()
291            );
292        }
293
294        // Calculate the starting address for the new elements
295        let start_addr =
296            vec_header_addr + VEC_HEADER_PAYLOAD_OFFSET.0 + (element_count as u32) * element_size;
297
298        // Write all elements
299        unsafe {
300            let dest_ptr = self.all_memory.add(start_addr as usize) as *mut T;
301            ptr::copy_nonoverlapping(data_slice.as_ptr(), dest_ptr, data_slice.len());
302
303            // Update the element count
304            (*vec_header_ptr).element_count += data_slice.len() as u16;
305        }
306    }
307
308    /// Write a single element to a specific index in a vector
309    ///
310    /// # Arguments
311    /// * `vec_register` - Register containing the vector header address, usually r0
312    /// * `index` - Index where to write the element (must be `< element_count`)
313    /// * `data` - Data to write at the specified index
314    ///
315    /// # Panics
316    /// * If the vector header is invalid or corrupted
317    /// * If the index is out of bounds
318    /// * If the element size doesn't match the vector's element size
319    pub fn write_to_vector_at_index<T>(&mut self, vec_register: u8, index: u16, data: &T)
320    where
321        T: Copy,
322    {
323        // Get the vector header address and read the header
324        let vec_header_addr = self.register(vec_register);
325        let vec_header =
326            unsafe { *(self.all_memory.add(vec_header_addr as usize) as *const VecHeader) };
327
328        // Validate the vector header
329        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
330            panic!("Invalid vector header - memory corruption detected");
331        }
332
333        // Bounds check
334        if index >= vec_header.element_count {
335            panic!(
336                "Index {} out of bounds for vector of length {}",
337                index, vec_header.element_count
338            );
339        }
340
341        // Verify element size matches
342        if size_of::<T>() != vec_header.element_size as usize {
343            panic!(
344                "Element size mismatch: expected {}, got {}",
345                vec_header.element_size,
346                size_of::<T>()
347            );
348        }
349
350        // Calculate the address of the element to write
351        let element_addr = vec_header_addr
352            + VEC_HEADER_PAYLOAD_OFFSET.0
353            + (index as u32) * vec_header.element_size;
354
355        // Write the data
356        unsafe {
357            let element_ptr = self.all_memory.add(element_addr as usize) as *mut T;
358            ptr::write(element_ptr, *data);
359        }
360    }
361
362    /// Append a single element to the end of a vector
363    ///
364    /// # Arguments
365    /// * `vec_register` - Register containing the vector header address, usually r0
366    /// * `data` - Data to append to the vector
367    ///
368    /// # Panics
369    /// * If the vector header is invalid or corrupted
370    /// * If there's insufficient capacity for the new element
371    /// * If the element size doesn't match the vector's element size
372    pub fn push_to_vector<T>(&mut self, vec_register: u8, data: &T)
373    where
374        T: Copy,
375    {
376        self.write_to_vector_bulk(vec_register, slice::from_ref(data));
377    }
378
379    /// Read an element from a vector at a specific index
380    ///
381    /// # Arguments
382    /// * `vec_register` - Register containing the vector header address, usually r0
383    /// * `index` - Index of the element to read
384    ///
385    /// # Returns
386    /// A reference to the element at the specified index
387    ///
388    /// # Panics
389    /// * If the vector header is invalid or corrupted
390    /// * If the index is out of bounds
391    /// * If the element size doesn't match the requested type size
392    pub fn read_from_vector_at_index<T>(&self, vec_register: u8, index: u16) -> &T
393    where
394        T: Copy,
395    {
396        // Get the vector header
397        let vec_header = self.get::<VecHeader>(vec_register);
398
399        // Validate the vector header
400        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
401            panic!("Invalid vector header - memory corruption detected");
402        }
403
404        // Bounds check
405        if index >= vec_header.element_count {
406            panic!(
407                "Index {} out of bounds for vector of length {}",
408                index, vec_header.element_count
409            );
410        }
411
412        // Verify element size matches
413        if size_of::<T>() != vec_header.element_size as usize {
414            panic!(
415                "Element size mismatch: expected {}, got {}",
416                vec_header.element_size,
417                size_of::<T>()
418            );
419        }
420
421        // Calculate the address of the element to read
422        let vec_header_addr = self.register(vec_register);
423        let element_addr = vec_header_addr
424            + VEC_HEADER_PAYLOAD_OFFSET.0
425            + (index as u32) * vec_header.element_size;
426
427        // Read the data
428        unsafe {
429            let element_ptr = self.all_memory.add(element_addr as usize) as *const T;
430            &*element_ptr
431        }
432    }
433
434    /// Get the current length (element count) of a vector
435    ///
436    /// # Arguments
437    /// * `vec_register` - Register containing the vector header address
438    ///
439    /// # Returns
440    /// The number of elements currently in the vector
441    ///
442    /// # Panics
443    /// * If the vector header is invalid or corrupted
444    pub fn vector_len(&self, vec_register: u8) -> u16 {
445        let vec_header = self.get::<VecHeader>(vec_register);
446
447        // Validate the vector header
448        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
449            panic!("Invalid vector header - memory corruption detected");
450        }
451
452        vec_header.element_count
453    }
454
455    /// Get the capacity of a vector
456    ///
457    /// # Arguments
458    /// * `vec_register` - Register containing the vector header address
459    ///
460    /// # Returns
461    /// The maximum number of elements the vector can hold
462    ///
463    /// # Panics
464    /// * If the vector header is invalid or corrupted
465    pub fn vector_capacity(&self, vec_register: u8) -> u16 {
466        let vec_header = self.get::<VecHeader>(vec_register);
467
468        // Validate the vector header
469        if vec_header.padding != VEC_HEADER_MAGIC_CODE {
470            panic!("Invalid vector header - memory corruption detected");
471        }
472
473        vec_header.capacity
474    }
475
476    /// Check if a vector is empty
477    ///
478    /// # Arguments
479    /// * `vec_register` - Register containing the vector header address
480    ///
481    /// # Returns
482    /// True if the vector has no elements, false otherwise
483    pub fn vector_is_empty(&self, vec_register: u8) -> bool {
484        self.vector_len(vec_register) == 0
485    }
486
487    /// Check if a vector is full (at capacity)
488    ///
489    /// # Arguments
490    /// * `vec_register` - Register containing the vector header address
491    ///
492    /// # Returns
493    /// True if the vector is at capacity, false otherwise
494    pub fn vector_is_full(&self, vec_register: u8) -> bool {
495        let vec_header = self.get::<VecHeader>(vec_register);
496        vec_header.element_count >= vec_header.capacity
497    }
498}
499
500pub trait HostFunctionCallback {
501    fn dispatch_host_call(&mut self, args: HostArgs);
502}