1use alloy_primitives::{Address, B256, U256};
42use alloy_provider::Provider;
43pub use calls::{errors::Error, CallAccess, MutatingCallContext, StaticCallContext, ValueTransfer};
44use deploy::DeploymentAccess;
45use std::cell::RefCell;
46use std::rc::Rc;
47use tokio::runtime::Runtime;
48
49pub use stylus_core::*;
50
51use crate::state::VMState;
52
53#[derive(Clone)]
76pub struct TestVM {
77 state: Rc<RefCell<VMState>>,
78}
79
80impl Default for TestVM {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl From<VMState> for TestVM {
87 fn from(state: VMState) -> Self {
88 Self {
89 state: Rc::new(RefCell::new(state)),
90 }
91 }
92}
93
94impl TestVM {
95 pub fn new() -> Self {
103 Self {
104 state: Rc::new(RefCell::new(VMState::default())),
105 }
106 }
107
108 pub fn snapshot(&self) -> VMState {
112 self.state.borrow().clone()
113 }
114
115 pub fn set_block_number(&self, block_number: u64) {
124 self.state.borrow_mut().block_number = block_number;
125 }
126
127 pub fn set_block_timestamp(&self, timestamp: u64) {
136 self.state.borrow_mut().block_timestamp = timestamp;
137 }
138
139 pub fn set_tx_origin(&self, origin: Address) {
141 self.state.borrow_mut().tx_origin = origin;
142 }
143
144 pub fn set_balance(&self, address: Address, balance: U256) {
155 self.state.borrow_mut().balances.insert(address, balance);
156 }
157
158 pub fn set_contract_address(&self, address: Address) {
160 self.state.borrow_mut().contract_address = address;
161 }
162
163 pub fn set_code(&self, address: Address, code: Vec<u8>) {
165 self.state.borrow_mut().code_storage.insert(address, code);
166 }
167
168 pub fn set_gas_left(&self, gas: u64) {
170 self.state.borrow_mut().gas_left = gas;
171 }
172
173 pub fn set_ink_left(&self, ink: u64) {
175 self.state.borrow_mut().ink_left = ink;
176 }
177
178 pub fn set_sender(&self, sender: Address) {
180 self.state.borrow_mut().msg_sender = sender;
181 }
182
183 pub fn set_value(&self, value: U256) {
185 self.state.borrow_mut().msg_value = value;
186 }
187
188 pub fn get_storage(&self, key: U256) -> B256 {
200 self.state
201 .borrow()
202 .storage
203 .get(&key)
204 .copied()
205 .unwrap_or_default()
206 }
207
208 pub fn set_storage(&self, key: U256, value: B256) {
210 self.state.borrow_mut().storage.insert(key, value);
211 }
212
213 pub fn clear_storage(&self) {
215 self.state.borrow_mut().storage.clear();
216 }
217
218 pub fn mock_call(&self, to: Address, data: Vec<u8>, return_data: Result<Vec<u8>, Vec<u8>>) {
235 self.state
236 .borrow_mut()
237 .call_returns
238 .insert((to, data), return_data);
239 }
240
241 pub fn mock_delegate_call(
243 &self,
244 to: Address,
245 data: Vec<u8>,
246 return_data: Result<Vec<u8>, Vec<u8>>,
247 ) {
248 self.state
249 .borrow_mut()
250 .delegate_call_returns
251 .insert((to, data), return_data);
252 }
253
254 pub fn mock_static_call(
256 &self,
257 to: Address,
258 data: Vec<u8>,
259 return_data: Result<Vec<u8>, Vec<u8>>,
260 ) {
261 self.state
262 .borrow_mut()
263 .static_call_returns
264 .insert((to, data), return_data);
265 }
266
267 pub fn mock_deploy(&self, code: Vec<u8>, salt: Option<B256>, result: Result<Address, Vec<u8>>) {
281 self.state
282 .borrow_mut()
283 .deploy_returns
284 .insert((code, salt), result);
285 }
286
287 pub fn get_emitted_logs(&self) -> Vec<(Vec<B256>, Vec<u8>)> {
289 self.state.borrow().emitted_logs.clone()
290 }
291
292 pub fn clear_mocks(&self) {
294 let mut state = self.state.borrow_mut();
295 state.call_returns.clear();
296 state.delegate_call_returns.clear();
297 state.static_call_returns.clear();
298 state.deploy_returns.clear();
299 state.emitted_logs.clear();
300 }
301}
302
303impl Host for TestVM {}
304
305impl CryptographyAccess for TestVM {
306 fn native_keccak256(&self, input: &[u8]) -> B256 {
307 alloy_primitives::keccak256(input)
308 }
309}
310
311impl CalldataAccess for TestVM {
312 fn read_args(&self, _len: usize) -> Vec<u8> {
313 unimplemented!("read_args not yet implemented for TestVM")
314 }
315 fn read_return_data(&self, _offset: usize, _size: Option<usize>) -> Vec<u8> {
316 unimplemented!("read_return_data not yet implemented for TestVM")
317 }
318 fn return_data_size(&self) -> usize {
319 unimplemented!("return_data_size not yet implemented for TestVM")
320 }
321 fn write_result(&self, _data: &[u8]) {
322 unimplemented!("write_result not yet implemented for TestVM")
323 }
324}
325
326unsafe impl UnsafeDeploymentAccess for TestVM {
327 unsafe fn create1(
328 &self,
329 _code: *const u8,
330 _code_len: usize,
331 _endowment: *const u8,
332 _contract: *mut u8,
333 _revert_data_len: *mut usize,
334 ) {
335 unimplemented!("unsafe create1 not yet implemented for TestVM")
336 }
337 unsafe fn create2(
338 &self,
339 _code: *const u8,
340 _code_len: usize,
341 _endowment: *const u8,
342 _salt: *const u8,
343 _contract: *mut u8,
344 _revert_data_len: *mut usize,
345 ) {
346 unimplemented!("unsafe create2 not yet implemented for TestVM")
347 }
348}
349
350impl StorageAccess for TestVM {
351 unsafe fn storage_cache_bytes32(&self, key: U256, value: B256) {
352 self.state.borrow_mut().storage.insert(key, value);
353 }
354
355 fn flush_cache(&self, _clear: bool) {}
356 fn storage_load_bytes32(&self, key: U256) -> B256 {
357 if let Some(provider) = self.state.borrow().provider.clone() {
358 let rt = Runtime::new().expect("Failed to create runtime");
359 let addr = self.state.borrow().contract_address;
360 let storage = rt
361 .block_on(async { provider.get_storage_at(addr, key).await })
362 .unwrap_or_default();
363 return B256::from(storage);
364 }
365 self.state
366 .borrow()
367 .storage
368 .get(&key)
369 .copied()
370 .unwrap_or(B256::ZERO)
371 }
372}
373
374unsafe impl UnsafeCallAccess for TestVM {
375 unsafe fn call_contract(
376 &self,
377 _to: *const u8,
378 _data: *const u8,
379 _data_len: usize,
380 _value: *const u8,
381 _gas: u64,
382 _outs_len: &mut usize,
383 ) -> u8 {
384 unimplemented!("unsafe call_contract not yet implemented for TestVM")
385 }
386 unsafe fn delegate_call_contract(
387 &self,
388 _to: *const u8,
389 _data: *const u8,
390 _data_len: usize,
391 _gas: u64,
392 _outs_len: &mut usize,
393 ) -> u8 {
394 unimplemented!("unsafe delegate_call_contract not yet implemented for TestVM")
395 }
396 unsafe fn static_call_contract(
397 &self,
398 _to: *const u8,
399 _data: *const u8,
400 _data_len: usize,
401 _gas: u64,
402 _outs_len: &mut usize,
403 ) -> u8 {
404 unimplemented!("unsafe static_call_contract not yet implemented for TestVM")
405 }
406}
407
408impl BlockAccess for TestVM {
409 fn block_basefee(&self) -> U256 {
410 self.state.borrow().block_basefee
411 }
412
413 fn block_coinbase(&self) -> Address {
414 self.state.borrow().coinbase
415 }
416
417 fn block_gas_limit(&self) -> u64 {
418 self.state.borrow().block_gas_limit
419 }
420
421 fn block_number(&self) -> u64 {
422 self.state.borrow().block_number
423 }
424
425 fn block_timestamp(&self) -> u64 {
426 self.state.borrow().block_timestamp
427 }
428}
429
430impl ChainAccess for TestVM {
431 fn chain_id(&self) -> u64 {
432 self.state.borrow().chain_id
433 }
434}
435
436impl AccountAccess for TestVM {
437 fn balance(&self, account: Address) -> U256 {
438 self.state
439 .borrow()
440 .balances
441 .get(&account)
442 .copied()
443 .unwrap_or_default()
444 }
445
446 fn code(&self, account: Address) -> Vec<u8> {
447 self.state
448 .borrow()
449 .code_storage
450 .get(&account)
451 .cloned()
452 .unwrap_or_default()
453 }
454
455 fn code_hash(&self, account: Address) -> B256 {
456 if let Some(code) = self.state.borrow().code_storage.get(&account) {
457 alloy_primitives::keccak256(code)
458 } else {
459 B256::ZERO
460 }
461 }
462
463 fn code_size(&self, account: Address) -> usize {
464 self.state
465 .borrow()
466 .code_storage
467 .get(&account)
468 .map_or(0, |code| code.len())
469 }
470
471 fn contract_address(&self) -> Address {
472 self.state.borrow().contract_address
473 }
474}
475
476impl MemoryAccess for TestVM {
477 fn pay_for_memory_grow(&self, _pages: u16) {}
478}
479
480impl MessageAccess for TestVM {
481 fn msg_reentrant(&self) -> bool {
482 self.state.borrow().reentrant
483 }
484
485 fn msg_sender(&self) -> Address {
486 self.state.borrow().msg_sender
487 }
488
489 fn msg_value(&self) -> U256 {
490 self.state.borrow().msg_value
491 }
492
493 fn tx_origin(&self) -> Address {
494 self.state.borrow().tx_origin
495 }
496}
497
498impl MeteringAccess for TestVM {
499 fn evm_gas_left(&self) -> u64 {
500 self.state.borrow().gas_left
501 }
502
503 fn evm_ink_left(&self) -> u64 {
504 self.state.borrow().ink_left
505 }
506
507 fn tx_gas_price(&self) -> U256 {
508 self.state.borrow().tx_gas_price
509 }
510
511 fn tx_ink_price(&self) -> u32 {
512 self.state.borrow().tx_ink_price
513 }
514}
515
516impl CallAccess for TestVM {
517 fn call(
518 &self,
519 _context: &dyn MutatingCallContext,
520 to: Address,
521 data: &[u8],
522 ) -> Result<Vec<u8>, Error> {
523 self.state
524 .borrow()
525 .call_returns
526 .get(&(to, data.to_vec()))
527 .cloned()
528 .map(|opt| match opt {
529 Ok(data) => Ok(data),
530 Err(data) => Err(Error::Revert(data)),
531 })
532 .unwrap_or(Ok(Vec::new()))
533 }
534
535 unsafe fn delegate_call(
536 &self,
537 _context: &dyn MutatingCallContext,
538 to: Address,
539 data: &[u8],
540 ) -> Result<Vec<u8>, Error> {
541 self.state
542 .borrow()
543 .delegate_call_returns
544 .get(&(to, data.to_vec()))
545 .cloned()
546 .map(|opt| match opt {
547 Ok(data) => Ok(data),
548 Err(data) => Err(Error::Revert(data)),
549 })
550 .unwrap_or(Ok(Vec::new()))
551 }
552
553 fn static_call(
554 &self,
555 _context: &dyn StaticCallContext,
556 to: Address,
557 data: &[u8],
558 ) -> Result<Vec<u8>, Error> {
559 self.state
560 .borrow()
561 .static_call_returns
562 .get(&(to, data.to_vec()))
563 .cloned()
564 .map(|opt| match opt {
565 Ok(data) => Ok(data),
566 Err(data) => Err(Error::Revert(data)),
567 })
568 .unwrap_or(Ok(Vec::new()))
569 }
570}
571
572impl ValueTransfer for TestVM {
573 #[cfg(feature = "reentrant")]
574 fn transfer_eth(
575 &self,
576 _storage: &mut dyn stylus_core::storage::TopLevelStorage,
577 to: Address,
578 amount: U256,
579 ) -> Result<(), Vec<u8>> {
580 let mut state = self.state.borrow_mut();
581 let from = state.contract_address;
582
583 let from_balance = state.balances.get(&from).copied().unwrap_or_default();
584 let to_balance = state.balances.get(&to).copied().unwrap_or_default();
585
586 if from_balance < amount {
587 return Err(b"insufficient funds for transfer".to_vec());
588 }
589
590 let new_to_balance = to_balance
591 .checked_add(amount)
592 .ok_or_else(|| b"balance overflow".to_vec())?;
593
594 state.balances.insert(from, from_balance - amount);
595 state.balances.insert(to, new_to_balance);
596
597 Ok(())
598 }
599
600 #[cfg(not(feature = "reentrant"))]
601 fn transfer_eth(&self, to: Address, amount: U256) -> Result<(), Vec<u8>> {
602 let mut state = self.state.borrow_mut();
603 let from = state.contract_address;
604
605 let from_balance = state.balances.get(&from).copied().unwrap_or_default();
606 let to_balance = state.balances.get(&to).copied().unwrap_or_default();
607
608 if from_balance < amount {
609 return Err(b"insufficient funds for transfer".to_vec());
610 }
611
612 let new_to_balance = to_balance
613 .checked_add(amount)
614 .ok_or_else(|| b"balance overflow".to_vec())?;
615
616 state.balances.insert(from, from_balance - amount);
617 state.balances.insert(to, new_to_balance);
618
619 Ok(())
620 }
621}
622
623impl DeploymentAccess for TestVM {
624 #[cfg(feature = "reentrant")]
625 unsafe fn deploy(
626 &self,
627 code: &[u8],
628 _endowment: U256,
629 salt: Option<B256>,
630 _cache_policy: stylus_core::deploy::CachePolicy,
631 ) -> Result<Address, Vec<u8>> {
632 self.state
633 .borrow()
634 .deploy_returns
635 .get(&(code.to_vec(), salt))
636 .cloned()
637 .unwrap_or(Ok(Address::ZERO))
638 }
639
640 #[cfg(not(feature = "reentrant"))]
641 unsafe fn deploy(
642 &self,
643 code: &[u8],
644 _endowment: U256,
645 salt: Option<B256>,
646 ) -> Result<Address, Vec<u8>> {
647 self.state
648 .borrow()
649 .deploy_returns
650 .get(&(code.to_vec(), salt))
651 .cloned()
652 .unwrap_or(Ok(Address::ZERO))
653 }
654}
655
656impl LogAccess for TestVM {
657 fn emit_log(&self, input: &[u8], num_topics: usize) {
658 let (topics_data, data) = input.split_at(num_topics * 32);
659 let mut topics = Vec::with_capacity(num_topics);
660
661 for chunk in topics_data.chunks(32) {
662 let mut bytes = [0u8; 32];
663 bytes.copy_from_slice(chunk);
664 topics.push(B256::from(bytes));
665 }
666
667 self.state
668 .borrow_mut()
669 .emitted_logs
670 .push((topics, data.to_vec()));
671 }
672
673 fn raw_log(&self, topics: &[B256], data: &[u8]) -> Result<(), &'static str> {
674 self.state
675 .borrow_mut()
676 .emitted_logs
677 .push((topics.to_vec(), data.to_vec()));
678 Ok(())
679 }
680}
681
682#[cfg(all(test, not(feature = "reentrant")))]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_basic_vm_operations() {
688 let vm = TestVM::new();
689
690 vm.set_block_number(12345);
691 assert_eq!(vm.block_number(), 12345);
692
693 let address = Address::from([1u8; 20]);
694 let balance = U256::from(1000);
695 vm.set_balance(address, balance);
696 assert_eq!(vm.balance(address), balance);
697
698 let key = U256::from(1);
699 let value = B256::new([1u8; 32]);
700 vm.set_storage(key, value);
701 assert_eq!(vm.get_storage(key), value);
702 }
703
704 #[test]
705 fn test_mock_calls() {
706 let vm = TestVM::new();
707 let target = Address::from([2u8; 20]);
708 let data = vec![1, 2, 3, 4];
709 let expected_return = vec![5, 6, 7, 8];
710
711 vm.mock_call(target, data.clone(), Ok(expected_return.clone()));
713
714 let ctx = stylus_core::calls::context::Call::new();
715 let result = vm.call(&ctx, target, &data).unwrap();
716 assert_eq!(result, expected_return);
717
718 let error_data = vec![9, 9, 9];
720 vm.mock_call(target, data.clone(), Err(error_data.clone()));
721
722 match vm.call(&ctx, target, &data) {
723 Err(Error::Revert(returned_data)) => assert_eq!(returned_data, error_data),
724 _ => panic!("Expected revert error"),
725 }
726 }
727
728 #[test]
729 fn test_mock_deploys() {
730 let vm = TestVM::new();
731 let code = vec![1, 2, 3, 4];
732 let expected_address = Address::from([3u8; 20]);
733
734 vm.mock_deploy(code.clone(), None, Ok(expected_address));
736
737 unsafe {
738 let result = vm.deploy(&code, U256::ZERO, None).unwrap();
739 assert_eq!(result, expected_address);
740 }
741
742 let error_data = vec![9, 9, 9];
744 vm.mock_deploy(code.clone(), None, Err(error_data.clone()));
745
746 unsafe {
747 match vm.deploy(&code, U256::ZERO, None) {
748 Err(returned_data) => assert_eq!(returned_data, error_data),
749 _ => panic!("Expected deployment error"),
750 }
751 }
752 }
753
754 #[test]
755 fn test_logs() {
756 let vm = TestVM::new();
757 let topic1 = B256::from([1u8; 32]);
758 let topic2 = B256::from([2u8; 32]);
759 let data = vec![3, 4, 5];
760
761 vm.raw_log(&[topic1, topic2], &data).unwrap();
762
763 let logs = vm.get_emitted_logs();
764 assert_eq!(logs.len(), 1);
765 assert_eq!(logs[0].0, vec![topic1, topic2]);
766 assert_eq!(logs[0].1, data);
767 }
768
769 #[test]
770 fn test_transfer_eth_success() {
771 let vm = TestVM::new();
772 let from = vm.state.borrow().contract_address;
773 let to = Address::from([1u8; 20]);
774 let initial_balance = U256::from(1000);
775 let transfer_amount = U256::from(300);
776
777 vm.set_balance(from, initial_balance);
778
779 let result = vm.transfer_eth(to, transfer_amount);
780 assert!(result.is_ok());
781
782 assert_eq!(vm.balance(from), initial_balance - transfer_amount);
783 assert_eq!(vm.balance(to), transfer_amount);
784 }
785
786 #[test]
787 fn test_transfer_eth_insufficient_funds() {
788 let vm = TestVM::new();
789 let from = vm.state.borrow().contract_address;
790 let to = Address::from([1u8; 20]);
791 let initial_balance = U256::from(100);
792 let transfer_amount = U256::from(200);
793
794 vm.set_balance(from, initial_balance);
795
796 let result = vm.transfer_eth(to, transfer_amount);
797 assert!(result.is_err());
798
799 assert_eq!(vm.balance(from), initial_balance);
801 assert_eq!(vm.balance(to), U256::ZERO);
802 }
803
804 #[test]
805 fn test_transfer_eth_overflow() {
806 let vm = TestVM::new();
807 let from = vm.state.borrow().contract_address;
808 let to = Address::from([1u8; 20]);
809
810 vm.set_balance(from, U256::MAX);
811 vm.set_balance(to, U256::MAX);
812
813 let result = vm.transfer_eth(to, U256::from(1));
814 assert!(result.is_err());
815
816 assert_eq!(vm.balance(from), U256::MAX);
817 assert_eq!(vm.balance(to), U256::MAX);
818 }
819}