diff --git a/crates/era-cheatcodes/src/cheatcodes.rs b/crates/era-cheatcodes/src/cheatcodes.rs index 85c76eb46..20810751f 100644 --- a/crates/era-cheatcodes/src/cheatcodes.rs +++ b/crates/era-cheatcodes/src/cheatcodes.rs @@ -134,6 +134,29 @@ const BROADCAST_IGNORED_CONTRACTS: [H160; 19] = [ H160::zero(), ]; +//same as above, except without +// CHEATCODE_ADDRESS +const PRANK_IGNORED_CONTRACTS: [H160; 18] = [ + zksync_types::BOOTLOADER_ADDRESS, + zksync_types::ACCOUNT_CODE_STORAGE_ADDRESS, + zksync_types::NONCE_HOLDER_ADDRESS, + zksync_types::KNOWN_CODES_STORAGE_ADDRESS, + zksync_types::IMMUTABLE_SIMULATOR_STORAGE_ADDRESS, + zksync_types::CONTRACT_FORCE_DEPLOYER_ADDRESS, + zksync_types::L1_MESSENGER_ADDRESS, + zksync_types::KECCAK256_PRECOMPILE_ADDRESS, + zksync_types::L2_ETH_TOKEN_ADDRESS, + zksync_types::SYSTEM_CONTEXT_ADDRESS, + zksync_types::BOOTLOADER_UTILITIES_ADDRESS, + zksync_types::EVENT_WRITER_ADDRESS, + zksync_types::COMPRESSOR_ADDRESS, + zksync_types::COMPLEX_UPGRADER_ADDRESS, + zksync_types::ECRECOVER_PRECOMPILE_ADDRESS, + zksync_types::SHA256_PRECOMPILE_ADDRESS, + zksync_types::MINT_AND_BURN_ADDRESS, + H160::zero(), +]; + #[derive(Debug, Clone)] struct EraEnv { l1_batch_env: L1BatchEnv, @@ -257,6 +280,7 @@ enum ActionOnReturn { prev_exception_handler_pc: Option, reason: Vec, }, + StopPrank, } #[derive(Debug, Default, Clone)] @@ -268,7 +292,9 @@ struct FinishCyclePermanentActions { #[derive(Debug, Clone)] struct StartPrankOpts { sender: H160, - origin: Option, + origin: Option<(H160, H160)>, + depth: usize, + prankster: H160, } /// Tracks the expected calls per address. @@ -432,7 +458,7 @@ impl } // Checks returns from contracts for expectRevert cheatcode - self.handle_return(&state, &data, memory); + self.handle_return(&state, &data, memory, &storage); // Checks contract calls for expectCall cheatcode if let Opcode::FarCall(_call) = data.opcode.variant.opcode { @@ -477,6 +503,44 @@ impl } } + // ---- START: prank origin only at desired depth + if let Opcode::FarCall(_) | Opcode::NearCall(_) = data.opcode.variant.opcode { + let current = state.vm_local_state.callstack.depth(); + + if !PRANK_IGNORED_CONTRACTS + .contains(&state.vm_local_state.callstack.current.this_address) + { + if let Some(StartPrankOpts { depth, prankster, .. }) = + self.permanent_actions.start_prank.as_ref() + { + if ¤t == depth && + &state.vm_local_state.callstack.inner.last().unwrap().this_address == + prankster + { + self.enable_prank(&mut storage.borrow_mut()); + } + } + } + } + + if let Opcode::Ret(_) = data.opcode.variant.opcode { + let current = state.vm_local_state.callstack.depth(); + if !PRANK_IGNORED_CONTRACTS + .contains(&state.vm_local_state.callstack.current.this_address) + { + if let Some(StartPrankOpts { depth, prankster, .. }) = + self.permanent_actions.start_prank.as_ref() + { + if *depth == current + 1 && + &state.vm_local_state.callstack.current.this_address == prankster + { + self.disable_prank(&mut storage.borrow_mut()); + } + } + } + } + // ---- END: prank origin only at desired depth + if self.return_data.is_some() { if let Opcode::Ret(_call) = data.opcode.variant.opcode { if self.near_calls == 0 { @@ -1148,10 +1212,12 @@ impl VmTracer { + tracing::info!("👷 Pranking to {msg_sender:?}"); + self.prank( + state.vm_local_state.callstack.depth(), + &storage, + msg_sender.to_h160(), + None, + state.vm_local_state.callstack.inner.last().unwrap().this_address, + ) + } + prank_1(prank_1Call { msgSender: msg_sender, txOrigin: tx_origin }) => { + tracing::info!("👷 Pranking to {msg_sender:?} with origin {tx_origin:?}"); + self.prank( + state.vm_local_state.callstack.depth(), + &storage, + msg_sender.to_h160(), + Some(tx_origin.to_h160()), + state.vm_local_state.callstack.inner.last().unwrap().this_address, + ) + } recordLogs(recordLogsCall {}) => { tracing::info!("👷 Recording logs"); tracing::info!( @@ -2032,11 +2118,25 @@ impl CheatcodeTracer { } startPrank_0(startPrank_0Call { msgSender: msg_sender }) => { tracing::info!("👷 Starting prank to {msg_sender:?}"); - self.start_prank(&storage, msg_sender.to_h160(), None); + let target_depth = state.vm_local_state.callstack.depth() - 1; + self.start_prank( + &storage, + msg_sender.to_h160(), + None, + target_depth, + state.vm_local_state.callstack.inner[target_depth - 1].this_address, + ); } startPrank_1(startPrank_1Call { msgSender: msg_sender, txOrigin: tx_origin }) => { tracing::info!("👷 Starting prank to {msg_sender:?} with origin {tx_origin:?}"); - self.start_prank(&storage, msg_sender.to_h160(), Some(tx_origin.to_h160())) + let target_depth = state.vm_local_state.callstack.depth() - 1; + self.start_prank( + &storage, + msg_sender.to_h160(), + Some(tx_origin.to_h160()), + target_depth, + state.vm_local_state.callstack.inner[target_depth - 1].this_address, + ); } stopBroadcast(stopBroadcastCall {}) => { tracing::info!("👷 Stopping broadcast"); @@ -2492,11 +2592,12 @@ impl CheatcodeTracer { } } - fn handle_return( + fn handle_return( &mut self, state: &VmLocalStateData<'_>, data: &AfterExecutionData, memory: &SimpleMemory, + storage: &StoragePtr>, ) { // Skip check if there are no expected actions let Some(action) = self.next_return_action.as_mut() else { return }; @@ -2508,6 +2609,11 @@ impl CheatcodeTracer { // Skip check if opcode is not Ret let Opcode::Ret(op) = data.opcode.variant.opcode else { return }; + // Check how many returns we need to skip before finding the actual one + if action.returns_to_skip != 0 { + action.returns_to_skip -= 1; + return + } // The desired return opcode was found match &action.action { @@ -2517,12 +2623,6 @@ impl CheatcodeTracer { prev_exception_handler_pc: exception_handler, prev_continue_pc: continue_pc, } => { - // Check how many returns we need to skip before finding the actual one - if action.returns_to_skip != 0 { - action.returns_to_skip -= 1; - return - } - match op { RetOpcode::Revert => { tracing::debug!(wanted = %depth, current_depth = %callstack_depth, opcode = ?data.opcode.variant.opcode, "expectRevert"); @@ -2578,7 +2678,57 @@ impl CheatcodeTracer { }); self.next_return_action = None; } + ActionOnReturn::StopPrank => { + tracing::debug!("Stopping prank"); + self.stop_prank(storage); + self.next_return_action = None; + } + } + } + + fn enable_prank(&mut self, storage: &mut RefMut) { + if let Some(StartPrankOpts { origin: Some((_, new_origin)), .. }) = + self.permanent_actions.start_prank.as_ref() + { + let key = StorageKey::new( + AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS), + zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION, + ); + self.write_storage(key, H256::from(*new_origin), storage) + } + } + + fn disable_prank(&mut self, storage: &mut RefMut) { + if let Some(StartPrankOpts { origin: Some((old_origin, _)), .. }) = + self.permanent_actions.start_prank.as_ref() + { + let key = StorageKey::new( + AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS), + zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION, + ); + self.write_storage(key, H256::from(*old_origin), storage) + } + } + + fn prank( + &mut self, + current_depth: usize, + storage: &StoragePtr>, + sender: H160, + origin: Option, + prankster: H160, + ) { + if self.permanent_actions.broadcast.is_some() { + tracing::error!("prank is incompatible with broadcast"); + return } + + self.start_prank(storage, sender, origin, current_depth, prankster); + self.next_return_action.replace(NextReturnAction { + target_depth: current_depth - 1, + action: ActionOnReturn::StopPrank, + returns_to_skip: 2, + }); } fn start_prank( @@ -2586,6 +2736,8 @@ impl CheatcodeTracer { storage: &StoragePtr>, sender: H160, origin: Option, + depth: usize, + prankster: H160, ) { if self.permanent_actions.broadcast.is_some() { tracing::error!("prank is incompatible with broadcast"); @@ -2594,7 +2746,12 @@ impl CheatcodeTracer { match origin { None => { - self.permanent_actions.start_prank.replace(StartPrankOpts { sender, origin: None }); + self.permanent_actions.start_prank.replace(StartPrankOpts { + sender, + origin: None, + depth, + prankster, + }); } Some(tx_origin) => { let key = StorageKey::new( @@ -2603,18 +2760,20 @@ impl CheatcodeTracer { ); let storage = &mut storage.borrow_mut(); let original_tx_origin = storage.read_value(&key); - self.write_storage(key, tx_origin.into(), storage); - self.permanent_actions - .start_prank - .replace(StartPrankOpts { sender, origin: Some(original_tx_origin.into()) }); + self.permanent_actions.start_prank.replace(StartPrankOpts { + sender, + origin: Some((original_tx_origin.into(), tx_origin)), + depth, + prankster, + }); } } } fn stop_prank(&mut self, storage: &StoragePtr>) { if let Some(original_tx_origin) = - self.permanent_actions.start_prank.take().and_then(|v| v.origin) + self.permanent_actions.start_prank.take().and_then(|v| v.origin).map(|origin| origin.0) { let key = StorageKey::new( AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS), diff --git a/crates/era-cheatcodes/tests/src/cheatcodes/Prank.t.sol b/crates/era-cheatcodes/tests/src/cheatcodes/Prank.t.sol new file mode 100644 index 000000000..ecbbebd1e --- /dev/null +++ b/crates/era-cheatcodes/tests/src/cheatcodes/Prank.t.sol @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +import {Test, console2 as console} from "../../lib/forge-std/src/Test.sol"; +import {Constants} from "./Constants.sol"; + +contract Victim { + function assertCallerAndOrigin( + address expectedSender, + string memory senderMessage, + address expectedOrigin, + string memory originMessage + ) public view { + require(msg.sender == expectedSender, senderMessage); + require(tx.origin == expectedOrigin, originMessage); + } +} + +contract ConstructorVictim is Victim { + constructor( + address expectedSender, + string memory senderMessage, + address expectedOrigin, + string memory originMessage + ) { + require(msg.sender == expectedSender, senderMessage); + require(tx.origin == expectedOrigin, originMessage); + } +} + +contract NestedVictim { + Victim innerVictim; + + constructor(Victim victim) { + innerVictim = victim; + } + + function assertCallerAndOrigin( + address expectedSender, + string memory senderMessage, + address expectedOrigin, + string memory originMessage + ) public view { + require(msg.sender == expectedSender, senderMessage); + require(tx.origin == expectedOrigin, originMessage); + innerVictim.assertCallerAndOrigin( + address(this), + "msg.sender was incorrectly set for nested victim", + expectedOrigin, + "tx.origin was incorrectly set for nested victim" + ); + } +} + +contract NestedPranker is Test { + address newSender; + address newOrigin; + address oldOrigin; + + constructor(address _newSender, address _newOrigin) { + newSender = _newSender; + newOrigin = _newOrigin; + oldOrigin = tx.origin; + } + + function incompletePrank() public { + vm.startPrank(newSender, newOrigin); + } + + function completePrank(NestedVictim victim) public { + victim.assertCallerAndOrigin( + newSender, "msg.sender was not set in nested prank", newOrigin, "tx.origin was not set in nested prank" + ); + vm.stopPrank(); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), + "msg.sender was not cleaned up in nested prank", + oldOrigin, + "tx.origin was not cleaned up in nested prank" + ); + } +} + +contract PrankTest is Test { + address constant TEST_ADDRESS = 0x6Eb28604685b1F182dAB800A1Bfa4BaFdBA8a79a; + address constant TEST_ORIGIN = 0xdEBe90b7BFD87Af696B1966082F6515a6E72F3d8; + + function testPrankSender() public { + // Perform the prank + Victim victim = new Victim(); + vm.prank(TEST_ADDRESS); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", tx.origin, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", tx.origin, "tx.origin invariant failed" + ); + } + + function testPrankOrigin() public { + address oldOrigin = tx.origin; + + // Perform the prank + Victim victim = new Victim(); + vm.prank(TEST_ADDRESS, TEST_ORIGIN); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", TEST_ORIGIN, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin was not cleaned up" + ); + } + + function testPrank1AfterPrank0() public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + vm.prank(TEST_ADDRESS); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", oldOrigin, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + + // Overwrite the prank + vm.prank(TEST_ADDRESS, TEST_ORIGIN); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", TEST_ORIGIN, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testPrank0AfterPrank1() public { + // Perform the prank + address oldOrigin = tx.origin; + Victim victim = new Victim(); + vm.prank(TEST_ADDRESS, TEST_ORIGIN); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", TEST_ORIGIN, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + + // Overwrite the prank + vm.prank(TEST_ADDRESS); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set during prank", oldOrigin, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", oldOrigin, "tx.origin invariant failed" + ); + } + + function testPrankConstructorSender() public { + vm.prank(TEST_ADDRESS); + ConstructorVictim victim = new ConstructorVictim( + TEST_ADDRESS, "msg.sender was not set during prank", tx.origin, "tx.origin invariant failed" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", tx.origin, "tx.origin invariant failed" + ); + } + + function testPrankConstructorOrigin() public { + // Perform the prank + vm.prank(TEST_ADDRESS, TEST_ORIGIN); + ConstructorVictim victim = new ConstructorVictim( + TEST_ADDRESS, "msg.sender was not set during prank", TEST_ORIGIN, "tx.origin was not set during prank" + ); + + // Ensure we cleaned up correctly + victim.assertCallerAndOrigin( + address(this), "msg.sender was not cleaned up", tx.origin, "tx.origin was not cleaned up" + ); + } + + /// Checks that `tx.origin` is set for all subcalls of a `prank`. + /// + /// Ref: issue #1210 + function testTxOriginInNestedPrank() public { + address oldSender = msg.sender; + address oldOrigin = tx.origin; + + Victim innerVictim = new Victim(); + NestedVictim victim = new NestedVictim(innerVictim); + + vm.prank(TEST_ADDRESS, TEST_ORIGIN); + victim.assertCallerAndOrigin( + TEST_ADDRESS, "msg.sender was not set correctly", TEST_ORIGIN, "tx.origin was not set correctly" + ); + } + + function testPrankComplex() public { + address oldOrigin = tx.origin; + + NestedPranker pranker = new NestedPranker(TEST_ADDRESS, TEST_ORIGIN); + Victim innerVictim = new Victim(); + NestedVictim victim = new NestedVictim(innerVictim); + + pranker.incompletePrank(); + victim.assertCallerAndOrigin( + address(this), + "msg.sender was altered at an incorrect depth", + oldOrigin, + "tx.origin was altered at an incorrect depth" + ); + + pranker.completePrank(victim); + } +} diff --git a/crates/era-cheatcodes/tests/src/cheatcodes/StartPrank.t.sol b/crates/era-cheatcodes/tests/src/cheatcodes/StartPrank.t.sol index 248f0ecbf..91d8e2c29 100644 --- a/crates/era-cheatcodes/tests/src/cheatcodes/StartPrank.t.sol +++ b/crates/era-cheatcodes/tests/src/cheatcodes/StartPrank.t.sol @@ -25,14 +25,6 @@ contract CheatcodeStartPrankTest is Test { // Start prank without tx.origin vm.startPrank(TEST_ADDRESS); - require( - msg.sender == TEST_ADDRESS, - "startPrank failed: msg.sender unchanged" - ); - require( - tx.origin == original_tx_origin, - "startPrank failed tx.origin changed" - ); victim.assertCallerAndOrigin( TEST_ADDRESS, "startPrank failed: victim.assertCallerAndOrigin failed", @@ -43,21 +35,12 @@ contract CheatcodeStartPrankTest is Test { // Stop prank vm.stopPrank(); - require( - msg.sender == original_msg_sender, - "stopPrank failed: msg.sender didn't return to original" - ); - require( - tx.origin == original_tx_origin, - "stopPrank failed tx.origin changed" - ); victim.assertCallerAndOrigin( address(this), "startPrank failed: victim.assertCallerAndOrigin failed", original_tx_origin, "startPrank failed: victim.assertCallerAndOrigin failed" ); - } function testStartPrankWithOrigin() external { @@ -77,14 +60,6 @@ contract CheatcodeStartPrankTest is Test { // Start prank with tx.origin vm.startPrank(TEST_ADDRESS, TEST_ORIGIN); - require( - msg.sender == TEST_ADDRESS, - "startPrank failed: msg.sender unchanged" - ); - require( - tx.origin == TEST_ORIGIN, - "startPrank failed: tx.origin unchanged" - ); victim.assertCallerAndOrigin( TEST_ADDRESS, "startPrank failed: victim.assertCallerAndOrigin failed", @@ -95,14 +70,6 @@ contract CheatcodeStartPrankTest is Test { // Stop prank vm.stopPrank(); - require( - msg.sender == original_msg_sender, - "stopPrank failed: msg.sender didn't return to original" - ); - require( - tx.origin == original_tx_origin, - "stopPrank failed: tx.origin didn't return to original" - ); victim.assertCallerAndOrigin( address(this), "startPrank failed: victim.assertCallerAndOrigin failed",