1use crate::RegContents;
6use std::{
7 mem::{align_of, size_of},
8 ptr, slice,
9};
10use swamp_vm_isa::{AnyHeader, VEC_HEADER_MAGIC_CODE, VEC_HEADER_PAYLOAD_OFFSET, VecHeader};
11
12#[derive(Debug)]
13pub struct AnyValue {
14 pub bytes: Vec<u8>,
15 pub type_hash: u32,
16}
17
18#[derive(Debug)]
19pub struct AnyValueMut {
20 pub data_ptr: *mut u8,
21 pub size: usize,
22 pub type_hash: u32,
23}
24
25pub struct HostArgs {
26 all_memory: *mut u8,
28 all_memory_len: usize,
29 registers: *mut RegContents,
30 register_count: usize,
31 stack_offset: usize,
32 pub function_id: u16,
33}
34
35impl HostArgs {
36 #[must_use]
37 pub unsafe fn new(
38 function_id: u16,
39 all_memory: *mut u8,
40 all_memory_len: usize,
41 stack_offset: usize,
42 registers: *mut RegContents,
43 register_count: usize,
44 ) -> Self {
45 debug_assert_eq!(
47 all_memory.addr() % align_of::<u64>(),
48 0,
49 "Unaligned frame pointer",
50 );
51
52 Self {
53 all_memory,
54 all_memory_len,
55 registers,
56 stack_offset,
57 register_count,
58 function_id,
59 }
60 }
61
62 pub fn ptr(&mut self, register: u8) -> *mut u8 {
64 let addr = unsafe { *self.registers.add(register as usize) };
65 unsafe { self.all_memory.add(addr as usize) }
66 }
67
68 pub fn print_bytes(label: &str, bytes: &[u8]) {
69 print!("{label}: [");
70 for (i, &b) in bytes.iter().enumerate() {
71 print!("{b:02X}");
72 if i < bytes.len() - 1 {
73 print!(" ");
74 }
75 }
76 println!("]");
77 }
78
79 pub unsafe fn ptr_to_slice<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
80 unsafe { std::slice::from_raw_parts(ptr, len) }
81 }
82
83 pub unsafe fn ptr_to_slice_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
84 unsafe { std::slice::from_raw_parts_mut(ptr, len) }
85 }
86
87 pub fn ptr_as<T>(&mut self, register_id: u8) -> *const T {
89 self.ptr(register_id) as *const T
90 }
91
92 pub fn get<T>(&self, register_id: u8) -> &T {
94 assert!(
95 (register_id as usize) < self.register_count,
96 "Host call register out of bounds: register {} requested, but only {} registers available",
97 register_id,
98 self.register_count
99 );
100
101 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
102 let size_of_t = size_of::<T>();
103
104 assert!(
106 addr + size_of_t <= self.all_memory_len,
107 "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
108 size_of_t,
109 addr,
110 self.all_memory_len
111 );
112
113 assert_eq!(
114 addr % align_of::<T>(),
115 0,
116 "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
117 addr,
118 std::any::type_name::<T>(),
119 align_of::<T>()
120 );
121
122 unsafe { &*(self.all_memory.add(addr) as *const T) }
123 }
124
125 pub unsafe fn get_unchecked<T>(&mut self, register_id: u8) -> &T {
127 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
128 unsafe { &*(self.all_memory.add(addr) as *const T) }
129 }
130
131 pub fn get_mut<T>(&mut self, register_id: u8) -> &mut T {
133 assert!(
134 (register_id as usize) < self.register_count,
135 "Host call register out of bounds: register {} requested, but only {} registers available",
136 register_id,
137 self.register_count
138 );
139
140 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
141 let size_of_t = size_of::<T>();
142
143 assert!(
145 addr + size_of_t <= self.all_memory_len,
146 "Host call bounds violation: trying to read {} bytes at address {:#x}, but memory size is {:#x}",
147 size_of_t,
148 addr,
149 self.all_memory_len
150 );
151
152 assert_eq!(
154 addr % align_of::<T>(),
155 0,
156 "Host call alignment violation: address {:#x} is not aligned for type {} (requires {}-byte alignment)",
157 addr,
158 std::any::type_name::<T>(),
159 align_of::<T>()
160 );
161
162 unsafe { &mut *self.all_memory.add(addr).cast::<T>() }
164 }
165
166 pub unsafe fn get_mut_unchecked<T>(&mut self, register_id: u8) -> &mut T {
168 let addr = unsafe { *self.registers.add(register_id as usize) } as usize;
169 unsafe { &mut *(self.all_memory.add(addr) as *mut T) }
170 }
171
172 pub fn register(&self, register_id: u8) -> u32 {
174 unsafe { *self.registers.add(register_id as usize) }
175 }
176
177 pub fn register_i32(&self, register_id: u8) -> i32 {
179 unsafe { *self.registers.add(register_id as usize) as i32 }
180 }
181
182 pub fn set_register(&mut self, register_id: u8, data: u32) {
184 unsafe {
185 *self.registers.add(register_id as usize) = data;
186 }
187 }
188
189 pub fn write<T>(&mut self, register_id: u8, data: &T) {
191 let dest_ptr = self.ptr(register_id) as *mut T;
192 let src_ptr = data as *const T;
193
194 unsafe {
195 ptr::copy_nonoverlapping(src_ptr, dest_ptr, 1);
196 }
197 }
198
199 pub fn string(&self, register_id: u8) -> &str {
201 let string_header_addr = unsafe { *self.registers.add(register_id as usize) };
202 unsafe {
203 let string_header =
204 *(self.all_memory.add(string_header_addr as usize) as *const VecHeader);
205
206 let string_data_ptr = self
208 .all_memory
209 .add(string_header_addr as usize + VEC_HEADER_PAYLOAD_OFFSET.0 as usize);
210 let string_byte_length = string_header.element_count as usize;
211
212 debug_assert!(
213 string_header_addr as usize + size_of::<VecHeader>() + string_byte_length
214 <= self.all_memory_len,
215 "String read out-of-bounds in memory"
216 );
217
218 let bytes = slice::from_raw_parts(string_data_ptr, string_byte_length);
219
220 std::str::from_utf8_unchecked(bytes)
221 }
222 }
223
224 pub fn any(&self, register_id: u8) -> AnyValue {
225 let any_header = self.get::<AnyHeader>(register_id);
226
227 unsafe {
228 let any_data_ptr = self.all_memory.add(any_header.data_ptr as usize);
229
230 let data_slice = slice::from_raw_parts(any_data_ptr, any_header.size as usize);
231
232 AnyValue {
233 bytes: data_slice.to_vec(),
234 type_hash: any_header.type_hash,
235 }
236 }
237 }
238
239 pub fn any_mut(&self, register_id: u8) -> AnyValueMut {
240 let any_header = self.get::<AnyHeader>(register_id);
241 let any_data_ptr = unsafe { self.all_memory.add(any_header.data_ptr as usize) };
242 AnyValueMut {
243 data_ptr: any_data_ptr,
244 size: any_header.size as usize,
245 type_hash: any_header.type_hash,
246 }
247 }
248
249 pub fn write_to_vector_bulk<T>(&mut self, vec_register: u8, data_slice: &[T])
260 where
261 T: Copy,
262 {
263 let vec_header_addr = self.register(vec_register);
265 let vec_header_ptr =
266 unsafe { self.all_memory.add(vec_header_addr as usize) as *mut VecHeader };
267
268 let (element_count, capacity, element_size) = unsafe {
269 let header = &*vec_header_ptr;
270
271 if header.padding != VEC_HEADER_MAGIC_CODE {
273 panic!("Invalid vector header - memory corruption detected");
274 }
275
276 (header.element_count, header.capacity, header.element_size)
277 };
278
279 let required_capacity = element_count as usize + data_slice.len();
281 if required_capacity > capacity as usize {
282 panic!("Not enough capacity: need {required_capacity}, have {capacity}");
283 }
284
285 if size_of::<T>() != element_size as usize {
287 panic!(
288 "Element size mismatch: expected {}, got {}",
289 element_size,
290 size_of::<T>()
291 );
292 }
293
294 let start_addr =
296 vec_header_addr + VEC_HEADER_PAYLOAD_OFFSET.0 + (element_count as u32) * element_size;
297
298 unsafe {
300 let dest_ptr = self.all_memory.add(start_addr as usize) as *mut T;
301 ptr::copy_nonoverlapping(data_slice.as_ptr(), dest_ptr, data_slice.len());
302
303 (*vec_header_ptr).element_count += data_slice.len() as u16;
305 }
306 }
307
308 pub fn write_to_vector_at_index<T>(&mut self, vec_register: u8, index: u16, data: &T)
320 where
321 T: Copy,
322 {
323 let vec_header_addr = self.register(vec_register);
325 let vec_header =
326 unsafe { *(self.all_memory.add(vec_header_addr as usize) as *const VecHeader) };
327
328 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
330 panic!("Invalid vector header - memory corruption detected");
331 }
332
333 if index >= vec_header.element_count {
335 panic!(
336 "Index {} out of bounds for vector of length {}",
337 index, vec_header.element_count
338 );
339 }
340
341 if size_of::<T>() != vec_header.element_size as usize {
343 panic!(
344 "Element size mismatch: expected {}, got {}",
345 vec_header.element_size,
346 size_of::<T>()
347 );
348 }
349
350 let element_addr = vec_header_addr
352 + VEC_HEADER_PAYLOAD_OFFSET.0
353 + (index as u32) * vec_header.element_size;
354
355 unsafe {
357 let element_ptr = self.all_memory.add(element_addr as usize) as *mut T;
358 ptr::write(element_ptr, *data);
359 }
360 }
361
362 pub fn push_to_vector<T>(&mut self, vec_register: u8, data: &T)
373 where
374 T: Copy,
375 {
376 self.write_to_vector_bulk(vec_register, slice::from_ref(data));
377 }
378
379 pub fn read_from_vector_at_index<T>(&self, vec_register: u8, index: u16) -> &T
393 where
394 T: Copy,
395 {
396 let vec_header = self.get::<VecHeader>(vec_register);
398
399 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
401 panic!("Invalid vector header - memory corruption detected");
402 }
403
404 if index >= vec_header.element_count {
406 panic!(
407 "Index {} out of bounds for vector of length {}",
408 index, vec_header.element_count
409 );
410 }
411
412 if size_of::<T>() != vec_header.element_size as usize {
414 panic!(
415 "Element size mismatch: expected {}, got {}",
416 vec_header.element_size,
417 size_of::<T>()
418 );
419 }
420
421 let vec_header_addr = self.register(vec_register);
423 let element_addr = vec_header_addr
424 + VEC_HEADER_PAYLOAD_OFFSET.0
425 + (index as u32) * vec_header.element_size;
426
427 unsafe {
429 let element_ptr = self.all_memory.add(element_addr as usize) as *const T;
430 &*element_ptr
431 }
432 }
433
434 pub fn vector_len(&self, vec_register: u8) -> u16 {
445 let vec_header = self.get::<VecHeader>(vec_register);
446
447 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
449 panic!("Invalid vector header - memory corruption detected");
450 }
451
452 vec_header.element_count
453 }
454
455 pub fn vector_capacity(&self, vec_register: u8) -> u16 {
466 let vec_header = self.get::<VecHeader>(vec_register);
467
468 if vec_header.padding != VEC_HEADER_MAGIC_CODE {
470 panic!("Invalid vector header - memory corruption detected");
471 }
472
473 vec_header.capacity
474 }
475
476 pub fn vector_is_empty(&self, vec_register: u8) -> bool {
484 self.vector_len(vec_register) == 0
485 }
486
487 pub fn vector_is_full(&self, vec_register: u8) -> bool {
495 let vec_header = self.get::<VecHeader>(vec_register);
496 vec_header.element_count >= vec_header.capacity
497 }
498}
499
500pub trait HostFunctionCallback {
501 fn dispatch_host_call(&mut self, args: HostArgs);
502}