Skip to content

Commit

Permalink
feat: prank cheatcode (#265)
Browse files Browse the repository at this point in the history
Co-authored-by: Dustin Brickwood <[email protected]>
  • Loading branch information
Karrq and dutterbutter authored Mar 18, 2024
1 parent fd46281 commit 4307bfb
Show file tree
Hide file tree
Showing 3 changed files with 409 additions and 53 deletions.
199 changes: 179 additions & 20 deletions crates/era-cheatcodes/src/cheatcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ const BROADCAST_IGNORED_CONTRACTS: [H160; 19] = [
H160::zero(),
];

//same as above, except without
// CHEATCODE_ADDRESS
const PRANK_IGNORED_CONTRACTS: [H160; 18] = [
zksync_types::BOOTLOADER_ADDRESS,
zksync_types::ACCOUNT_CODE_STORAGE_ADDRESS,
zksync_types::NONCE_HOLDER_ADDRESS,
zksync_types::KNOWN_CODES_STORAGE_ADDRESS,
zksync_types::IMMUTABLE_SIMULATOR_STORAGE_ADDRESS,
zksync_types::CONTRACT_FORCE_DEPLOYER_ADDRESS,
zksync_types::L1_MESSENGER_ADDRESS,
zksync_types::KECCAK256_PRECOMPILE_ADDRESS,
zksync_types::L2_ETH_TOKEN_ADDRESS,
zksync_types::SYSTEM_CONTEXT_ADDRESS,
zksync_types::BOOTLOADER_UTILITIES_ADDRESS,
zksync_types::EVENT_WRITER_ADDRESS,
zksync_types::COMPRESSOR_ADDRESS,
zksync_types::COMPLEX_UPGRADER_ADDRESS,
zksync_types::ECRECOVER_PRECOMPILE_ADDRESS,
zksync_types::SHA256_PRECOMPILE_ADDRESS,
zksync_types::MINT_AND_BURN_ADDRESS,
H160::zero(),
];

#[derive(Debug, Clone)]
struct EraEnv {
l1_batch_env: L1BatchEnv,
Expand Down Expand Up @@ -257,6 +280,7 @@ enum ActionOnReturn {
prev_exception_handler_pc: Option<PcOrImm>,
reason: Vec<u8>,
},
StopPrank,
}

#[derive(Debug, Default, Clone)]
Expand All @@ -268,7 +292,9 @@ struct FinishCyclePermanentActions {
#[derive(Debug, Clone)]
struct StartPrankOpts {
sender: H160,
origin: Option<H160>,
origin: Option<(H160, H160)>,
depth: usize,
prankster: H160,
}

/// Tracks the expected calls per address.
Expand Down Expand Up @@ -432,7 +458,7 @@ impl<S: DatabaseExt + revm::DatabaseCommit + Send, H: HistoryMode>
}

// Checks returns from contracts for expectRevert cheatcode
self.handle_return(&state, &data, memory);
self.handle_return(&state, &data, memory, &storage);

// Checks contract calls for expectCall cheatcode
if let Opcode::FarCall(_call) = data.opcode.variant.opcode {
Expand Down Expand Up @@ -477,6 +503,44 @@ impl<S: DatabaseExt + revm::DatabaseCommit + Send, H: HistoryMode>
}
}

// ---- START: prank origin only at desired depth
if let Opcode::FarCall(_) | Opcode::NearCall(_) = data.opcode.variant.opcode {
let current = state.vm_local_state.callstack.depth();

if !PRANK_IGNORED_CONTRACTS
.contains(&state.vm_local_state.callstack.current.this_address)
{
if let Some(StartPrankOpts { depth, prankster, .. }) =
self.permanent_actions.start_prank.as_ref()
{
if &current == depth &&
&state.vm_local_state.callstack.inner.last().unwrap().this_address ==
prankster
{
self.enable_prank(&mut storage.borrow_mut());
}
}
}
}

if let Opcode::Ret(_) = data.opcode.variant.opcode {
let current = state.vm_local_state.callstack.depth();
if !PRANK_IGNORED_CONTRACTS
.contains(&state.vm_local_state.callstack.current.this_address)
{
if let Some(StartPrankOpts { depth, prankster, .. }) =
self.permanent_actions.start_prank.as_ref()
{
if *depth == current + 1 &&
&state.vm_local_state.callstack.current.this_address == prankster
{
self.disable_prank(&mut storage.borrow_mut());
}
}
}
}
// ---- END: prank origin only at desired depth

if self.return_data.is_some() {
if let Opcode::Ret(_call) = data.opcode.variant.opcode {
if self.near_calls == 0 {
Expand Down Expand Up @@ -1148,10 +1212,12 @@ impl<S: DatabaseExt + revm::DatabaseCommit + Send, H: HistoryMode> VmTracer<EraD
}

// Sets the sender address for startPrank cheatcode
if let Some(start_prank_call) = &self.permanent_actions.start_prank {
if let Some(start_prank_opts) = &self.permanent_actions.start_prank {
let this_address = state.local_state.callstack.current.this_address;
if !INTERNAL_CONTRACT_ADDRESSES.contains(&this_address) {
state.local_state.callstack.current.msg_sender = start_prank_call.sender;
if !INTERNAL_CONTRACT_ADDRESSES.contains(&this_address) &&
state.local_state.callstack.current.msg_sender == start_prank_opts.prankster
{
state.local_state.callstack.current.msg_sender = start_prank_opts.sender;
}
}
TracerExecutionStatus::Continue
Expand Down Expand Up @@ -1737,6 +1803,26 @@ impl CheatcodeTracer {
tracing::info!("👷 Clearing all mocked calls");
self.mocked_calls.clear();
}
prank_0(prank_0Call { msgSender: msg_sender }) => {
tracing::info!("👷 Pranking to {msg_sender:?}");
self.prank(
state.vm_local_state.callstack.depth(),
&storage,
msg_sender.to_h160(),
None,
state.vm_local_state.callstack.inner.last().unwrap().this_address,
)
}
prank_1(prank_1Call { msgSender: msg_sender, txOrigin: tx_origin }) => {
tracing::info!("👷 Pranking to {msg_sender:?} with origin {tx_origin:?}");
self.prank(
state.vm_local_state.callstack.depth(),
&storage,
msg_sender.to_h160(),
Some(tx_origin.to_h160()),
state.vm_local_state.callstack.inner.last().unwrap().this_address,
)
}
recordLogs(recordLogsCall {}) => {
tracing::info!("👷 Recording logs");
tracing::info!(
Expand Down Expand Up @@ -2032,11 +2118,25 @@ impl CheatcodeTracer {
}
startPrank_0(startPrank_0Call { msgSender: msg_sender }) => {
tracing::info!("👷 Starting prank to {msg_sender:?}");
self.start_prank(&storage, msg_sender.to_h160(), None);
let target_depth = state.vm_local_state.callstack.depth() - 1;
self.start_prank(
&storage,
msg_sender.to_h160(),
None,
target_depth,
state.vm_local_state.callstack.inner[target_depth - 1].this_address,
);
}
startPrank_1(startPrank_1Call { msgSender: msg_sender, txOrigin: tx_origin }) => {
tracing::info!("👷 Starting prank to {msg_sender:?} with origin {tx_origin:?}");
self.start_prank(&storage, msg_sender.to_h160(), Some(tx_origin.to_h160()))
let target_depth = state.vm_local_state.callstack.depth() - 1;
self.start_prank(
&storage,
msg_sender.to_h160(),
Some(tx_origin.to_h160()),
target_depth,
state.vm_local_state.callstack.inner[target_depth - 1].this_address,
);
}
stopBroadcast(stopBroadcastCall {}) => {
tracing::info!("👷 Stopping broadcast");
Expand Down Expand Up @@ -2492,11 +2592,12 @@ impl CheatcodeTracer {
}
}

fn handle_return<H: HistoryMode>(
fn handle_return<H: HistoryMode, S: DatabaseExt + Send>(
&mut self,
state: &VmLocalStateData<'_>,
data: &AfterExecutionData,
memory: &SimpleMemory<H>,
storage: &StoragePtr<EraDb<S>>,
) {
// Skip check if there are no expected actions
let Some(action) = self.next_return_action.as_mut() else { return };
Expand All @@ -2508,6 +2609,11 @@ impl CheatcodeTracer {

// Skip check if opcode is not Ret
let Opcode::Ret(op) = data.opcode.variant.opcode else { return };
// Check how many returns we need to skip before finding the actual one
if action.returns_to_skip != 0 {
action.returns_to_skip -= 1;
return
}

// The desired return opcode was found
match &action.action {
Expand All @@ -2517,12 +2623,6 @@ impl CheatcodeTracer {
prev_exception_handler_pc: exception_handler,
prev_continue_pc: continue_pc,
} => {
// Check how many returns we need to skip before finding the actual one
if action.returns_to_skip != 0 {
action.returns_to_skip -= 1;
return
}

match op {
RetOpcode::Revert => {
tracing::debug!(wanted = %depth, current_depth = %callstack_depth, opcode = ?data.opcode.variant.opcode, "expectRevert");
Expand Down Expand Up @@ -2578,14 +2678,66 @@ impl CheatcodeTracer {
});
self.next_return_action = None;
}
ActionOnReturn::StopPrank => {
tracing::debug!("Stopping prank");
self.stop_prank(storage);
self.next_return_action = None;
}
}
}

fn enable_prank<S: WriteStorage>(&mut self, storage: &mut RefMut<S>) {
if let Some(StartPrankOpts { origin: Some((_, new_origin)), .. }) =
self.permanent_actions.start_prank.as_ref()
{
let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION,
);
self.write_storage(key, H256::from(*new_origin), storage)
}
}

fn disable_prank<S: WriteStorage>(&mut self, storage: &mut RefMut<S>) {
if let Some(StartPrankOpts { origin: Some((old_origin, _)), .. }) =
self.permanent_actions.start_prank.as_ref()
{
let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION,
);
self.write_storage(key, H256::from(*old_origin), storage)
}
}

fn prank<S: DatabaseExt + Send>(
&mut self,
current_depth: usize,
storage: &StoragePtr<EraDb<S>>,
sender: H160,
origin: Option<H160>,
prankster: H160,
) {
if self.permanent_actions.broadcast.is_some() {
tracing::error!("prank is incompatible with broadcast");
return
}

self.start_prank(storage, sender, origin, current_depth, prankster);
self.next_return_action.replace(NextReturnAction {
target_depth: current_depth - 1,
action: ActionOnReturn::StopPrank,
returns_to_skip: 2,
});
}

fn start_prank<S: DatabaseExt + Send>(
&mut self,
storage: &StoragePtr<EraDb<S>>,
sender: H160,
origin: Option<H160>,
depth: usize,
prankster: H160,
) {
if self.permanent_actions.broadcast.is_some() {
tracing::error!("prank is incompatible with broadcast");
Expand All @@ -2594,7 +2746,12 @@ impl CheatcodeTracer {

match origin {
None => {
self.permanent_actions.start_prank.replace(StartPrankOpts { sender, origin: None });
self.permanent_actions.start_prank.replace(StartPrankOpts {
sender,
origin: None,
depth,
prankster,
});
}
Some(tx_origin) => {
let key = StorageKey::new(
Expand All @@ -2603,18 +2760,20 @@ impl CheatcodeTracer {
);
let storage = &mut storage.borrow_mut();
let original_tx_origin = storage.read_value(&key);
self.write_storage(key, tx_origin.into(), storage);

self.permanent_actions
.start_prank
.replace(StartPrankOpts { sender, origin: Some(original_tx_origin.into()) });
self.permanent_actions.start_prank.replace(StartPrankOpts {
sender,
origin: Some((original_tx_origin.into(), tx_origin)),
depth,
prankster,
});
}
}
}

fn stop_prank<S: DatabaseExt + Send>(&mut self, storage: &StoragePtr<EraDb<S>>) {
if let Some(original_tx_origin) =
self.permanent_actions.start_prank.take().and_then(|v| v.origin)
self.permanent_actions.start_prank.take().and_then(|v| v.origin).map(|origin| origin.0)
{
let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
Expand Down
Loading

0 comments on commit 4307bfb

Please sign in to comment.