1use crate::RegContents;
6use std::{
7 mem::{align_of, size_of},
8 ptr, slice,
9};
10use swamp_vm_types::{VEC_HEADER_MAGIC_CODE, VEC_HEADER_PAYLOAD_OFFSET, VecHeader};
11
12pub struct HostArgs {
13 all_memory: *mut u8,
15 all_memory_len: usize,
16 registers: *mut RegContents,
17 register_count: usize,
18 stack_offset: usize,
19 pub function_id: u16,
20}
21
22impl HostArgs {
23 #[must_use]
24 pub unsafe fn new(
25 function_id: u16,
26 all_memory: *mut u8,
27 all_memory_len: usize,
28 stack_offset: usize,
29 registers: *mut RegContents,
30 register_count: usize,
31 ) -> Self {
32 debug_assert_eq!(
34 all_memory.addr() % align_of::<u64>(),
35 0,
36 "Unaligned frame pointer",
37 );
38
39 Self {
40 all_memory,
41 all_memory_len,
42 registers,
43 stack_offset,
44 register_count,
45 function_id,
46 }
47 }
48
49 pub fn ptr(&mut self, register: u8) -> *mut u8 {
51 let addr = unsafe { *self.registers.add(register as usize) };
52 unsafe { self.all_memory.add(addr as usize) }
53 }
54
55 pub fn print_bytes(label: &str, bytes: &[u8]) {
56 print!("{label}: [");
57 for (i, &b) in bytes.iter().enumerate() {
58 print!("{b:02X}");
59 if i < bytes.len() - 1 {
60 print!(" ");
61 }
62 }
63 println!("]");
64 }
65
66 pub unsafe fn ptr_to_slice<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
67 unsafe { std::slice::from_raw_parts(ptr, len) }
68 }
69
70 pub unsafe fn ptr_to_slice_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
71 unsafe { std::slice::from_raw_parts_mut(ptr, len) }
72 }
73
74 pub fn ptr_as<T>(&mut self, register_id: u8) -> *const T {
76 self.ptr(register_id) as *const T
77 }
78
79 pub fn get<T>(&self, register_id: u8) -> &T {
81 assert!(
82 (register_id as usize) < self.register_count,
83 "Host call register out of bounds: register {} requested, but only {} registers available",
84 register_id,
85 self.register_count
86 );
87
88 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
89 let size_of_t = size_of::<T>();
90
91 assert!(
93 addr + size_of_t <= self.all_memory_len,
94 "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
95 size_of_t,
96 addr,
97 self.all_memory_len
98 );
99
100 assert_eq!(
101 addr % align_of::<T>(),
102 0,
103 "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
104 addr,
105 std::any::type_name::<T>(),
106 align_of::<T>()
107 );
108
109 unsafe { &*(self.all_memory.add(addr) as *const T) }
110 }
111
112 pub unsafe fn get_unchecked<T>(&mut self, register_id: u8) -> &T {
114 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
115 unsafe { &*(self.all_memory.add(addr) as *const T) }
116 }
117
118 pub fn get_mut<T>(&mut self, register_id: u8) -> &mut T {
120 assert!(
121 (register_id as usize) < self.register_count,
122 "Host call register out of bounds: register {} requested, but only {} registers available",
123 register_id,
124 self.register_count
125 );
126
127 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
128 let size_of_t = size_of::<T>();
129
130 assert!(
132 addr + size_of_t <= self.all_memory_len,
133 "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
134 size_of_t,
135 addr,
136 self.all_memory_len
137 );
138
139 assert_eq!(
141 addr % align_of::<T>(),
142 0,
143 "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
144 addr,
145 std::any::type_name::<T>(),
146 align_of::<T>()
147 );
148
149 unsafe { &mut *self.all_memory.add(addr).cast::<T>() }
151 }
152
153 pub unsafe fn get_mut_unchecked<T>(&mut self, register_id: u8) -> &mut T {
155 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
156 unsafe { &mut *(self.all_memory.add(addr) as *mut T) }
157 }
158
159 pub fn register(&self, register_id: u8) -> u32 {
161 unsafe { *self.registers.add(register_id as usize) }
162 }
163
164 pub fn register_i32(&self, register_id: u8) -> i32 {
166 unsafe { *self.registers.add(register_id as usize) as i32 }
167 }
168
169 pub fn set_register(&mut self, register_id: u8, data: u32) {
171 unsafe {
172 *self.registers.add(register_id as usize) = data;
173 }
174 }
175
176 pub fn write<T>(&mut self, register_id: u8, data: &T) {
178 let dest_ptr = self.ptr(register_id) as *mut T;
179 let src_ptr = data as *const T;
180
181 unsafe {
182 ptr::copy_nonoverlapping(src_ptr, dest_ptr, 1);
183 }
184 }
185
186 pub fn string(&self, register_id: u8) -> &str {
188 let string_header_addr = unsafe { *self.registers.add(register_id as usize) };
189 unsafe {
190 let string_header =
191 *(self.all_memory.add(string_header_addr as usize) as *const VecHeader);
192
193 let string_data_ptr = self
195 .all_memory
196 .add(string_header_addr as usize + VEC_HEADER_PAYLOAD_OFFSET.0 as usize);
197 let string_byte_length = string_header.element_count as usize;
198
199 debug_assert!(
200 string_header_addr as usize + size_of::<VecHeader>() + string_byte_length
201 <= self.all_memory_len,
202 "String read out-of-bounds in memory"
203 );
204
205 let bytes = slice::from_raw_parts(string_data_ptr, string_byte_length);
206
207 std::str::from_utf8_unchecked(bytes)
208 }
209 }
210
211 pub fn write_to_vector_bulk<T>(&mut self, vec_register: u8, data_slice: &[T])
222 where
223 T: Copy,
224 {
225 let vec_header_addr = self.register(vec_register);
227 let vec_header_ptr =
228 unsafe { self.all_memory.add(vec_header_addr as usize) as *mut VecHeader };
229
230 let (element_count, capacity, element_size) = unsafe {
231 let header = &*vec_header_ptr;
232
233 if header.padding != VEC_HEADER_MAGIC_CODE {
235 panic!("Invalid vector header - memory corruption detected");
236 }
237
238 if header.capacity == 0 {
239 panic!("Vector was never initialized");
240 }
241
242 (header.element_count, header.capacity, header.element_size)
243 };
244
245 let required_capacity = element_count as usize + data_slice.len();
247 if required_capacity > capacity as usize {
248 panic!("Not enough capacity: need {required_capacity}, have {capacity}");
249 }
250
251 if size_of::<T>() != element_size as usize {
253 panic!(
254 "Element size mismatch: expected {}, got {}",
255 element_size,
256 size_of::<T>()
257 );
258 }
259
260 let start_addr =
262 vec_header_addr + VEC_HEADER_PAYLOAD_OFFSET.0 + (element_count as u32) * element_size;
263
264 unsafe {
266 let dest_ptr = self.all_memory.add(start_addr as usize) as *mut T;
267 ptr::copy_nonoverlapping(data_slice.as_ptr(), dest_ptr, data_slice.len());
268
269 (*vec_header_ptr).element_count += data_slice.len() as u16;
271 }
272 }
273
274 pub fn write_to_vector_at_index<T>(&mut self, vec_register: u8, index: u16, data: &T)
286 where
287 T: Copy,
288 {
289 let vec_header_addr = self.register(vec_register);
291 let vec_header =
292 unsafe { *(self.all_memory.add(vec_header_addr as usize) as *const VecHeader) };
293
294 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
296 panic!("Invalid vector header - memory corruption detected");
297 }
298
299 if vec_header.capacity == 0 {
300 panic!("Vector was never initialized");
301 }
302
303 if index >= vec_header.element_count {
305 panic!(
306 "Index {} out of bounds for vector of length {}",
307 index, vec_header.element_count
308 );
309 }
310
311 if size_of::<T>() != vec_header.element_size as usize {
313 panic!(
314 "Element size mismatch: expected {}, got {}",
315 vec_header.element_size,
316 size_of::<T>()
317 );
318 }
319
320 let element_addr = vec_header_addr
322 + VEC_HEADER_PAYLOAD_OFFSET.0
323 + (index as u32) * vec_header.element_size;
324
325 unsafe {
327 let element_ptr = self.all_memory.add(element_addr as usize) as *mut T;
328 ptr::write(element_ptr, *data);
329 }
330 }
331
332 pub fn push_to_vector<T>(&mut self, vec_register: u8, data: &T)
343 where
344 T: Copy,
345 {
346 self.write_to_vector_bulk(vec_register, slice::from_ref(data));
347 }
348
349 pub fn read_from_vector_at_index<T>(&self, vec_register: u8, index: u16) -> &T
363 where
364 T: Copy,
365 {
366 let vec_header = self.get::<VecHeader>(vec_register);
368
369 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
371 panic!("Invalid vector header - memory corruption detected");
372 }
373
374 if vec_header.capacity == 0 {
375 panic!("Vector was never initialized");
376 }
377
378 if index >= vec_header.element_count {
380 panic!(
381 "Index {} out of bounds for vector of length {}",
382 index, vec_header.element_count
383 );
384 }
385
386 if size_of::<T>() != vec_header.element_size as usize {
388 panic!(
389 "Element size mismatch: expected {}, got {}",
390 vec_header.element_size,
391 size_of::<T>()
392 );
393 }
394
395 let vec_header_addr = self.register(vec_register);
397 let element_addr = vec_header_addr
398 + VEC_HEADER_PAYLOAD_OFFSET.0
399 + (index as u32) * vec_header.element_size;
400
401 unsafe {
403 let element_ptr = self.all_memory.add(element_addr as usize) as *const T;
404 &*element_ptr
405 }
406 }
407
408 pub fn vector_len(&self, vec_register: u8) -> u16 {
419 let vec_header = self.get::<VecHeader>(vec_register);
420
421 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
423 panic!("Invalid vector header - memory corruption detected");
424 }
425
426 if vec_header.capacity == 0 {
427 panic!("Vector was never initialized");
428 }
429
430 vec_header.element_count
431 }
432
433 pub fn vector_capacity(&self, vec_register: u8) -> u16 {
444 let vec_header = self.get::<VecHeader>(vec_register);
445
446 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
448 panic!("Invalid vector header - memory corruption detected");
449 }
450
451 if vec_header.capacity == 0 {
452 panic!("Vector was never initialized");
453 }
454
455 vec_header.capacity
456 }
457
458 pub fn vector_is_empty(&self, vec_register: u8) -> bool {
466 self.vector_len(vec_register) == 0
467 }
468
469 pub fn vector_is_full(&self, vec_register: u8) -> bool {
477 let vec_header = self.get::<VecHeader>(vec_register);
478 vec_header.element_count >= vec_header.capacity
479 }
480}
481
482pub trait HostFunctionCallback {
483 fn dispatch_host_call(&mut self, args: HostArgs);
484}