Skip to content

Commit

Permalink
w3vm: fixed call of state-change hooks (#205)
Browse files Browse the repository at this point in the history
* w3vm: fix call of state-change hooks

* w3vm: dropped unused tracing hooks

---------

Co-authored-by: lmittmann <[email protected]>
  • Loading branch information
lmittmann and lmittmann authored Jan 6, 2025
1 parent 6a25c3c commit 9e1e566
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 23 deletions.
11 changes: 11 additions & 0 deletions w3vm/testdata/w3vm/1_19999999.json

Large diffs are not rendered by default.

22 changes: 0 additions & 22 deletions w3vm/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ func joinHooks(hooks []*tracing.Hooks) *tracing.Hooks {
}

// vm hooks
var onTxStarts []tracing.TxStartHook
var onTxEnds []tracing.TxEndHook
var onEnters []tracing.EnterHook
var onExits []tracing.ExitHook
var onOpcodes []tracing.OpcodeHook
Expand All @@ -147,12 +145,6 @@ func joinHooks(hooks []*tracing.Hooks) *tracing.Hooks {
continue
}
// vm hooks
if h.OnTxStart != nil {
onTxStarts = append(onTxStarts, h.OnTxStart)
}
if h.OnTxEnd != nil {
onTxEnds = append(onTxEnds, h.OnTxEnd)
}
if h.OnEnter != nil {
onEnters = append(onEnters, h.OnEnter)
}
Expand Down Expand Up @@ -188,20 +180,6 @@ func joinHooks(hooks []*tracing.Hooks) *tracing.Hooks {

hook := new(tracing.Hooks)
// vm hooks
if len(onTxStarts) > 0 {
hook.OnTxStart = func(vm *tracing.VMContext, tx *types.Transaction, from common.Address) {
for _, h := range onTxStarts {
h(vm, tx, from)
}
}
}
if len(onTxEnds) > 0 {
hook.OnTxEnd = func(receipt *types.Receipt, err error) {
for _, h := range onTxEnds {
h(receipt, err)
}
}
}
if len(onEnters) > 0 {
hook.OnEnter = func(depth int, typ byte, from, to common.Address, input []byte, gas uint64, value *big.Int) {
for _, h := range onEnters {
Expand Down
9 changes: 8 additions & 1 deletion w3vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re
return nil, ErrFetch
}

var db vm.StateDB
if hooks != nil {
db = state.NewHookedState(v.db, hooks)
} else {
db = v.db
}

coreMsg, txCtx, err := v.buildMessage(msg, isCall)
if err != nil {
return nil, err
Expand All @@ -107,7 +114,7 @@ func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Re
v.txIndex++

gp := new(core.GasPool).AddGas(coreMsg.GasLimit)
evm := vm.NewEVM(*v.opts.blockCtx, *txCtx, v.db, v.opts.chainConfig, vm.Config{
evm := vm.NewEVM(*v.opts.blockCtx, *txCtx, db, v.opts.chainConfig, vm.Config{
Tracer: hooks,
NoBaseFee: v.opts.noBaseFee || isCall,
})
Expand Down
43 changes: 43 additions & 0 deletions w3vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -285,6 +286,48 @@ func TestVMApply(t *testing.T) {
}
}

func TestVMApply_Hook(t *testing.T) {
vm, err := w3vm.New(
w3vm.WithNoBaseFee(),
w3vm.WithFork(client, big.NewInt(20_000_000)),
w3vm.WithTB(t),
)
if err != nil {
t.Fatalf("Failed to create VM: %v", err)
}

// setup hook
var hookCount [10]uint
hook := &tracing.Hooks{
// vm event hooks
OnEnter: func(int, byte, common.Address, common.Address, []byte, uint64, *big.Int) { hookCount[0]++ },
OnExit: func(int, []byte, uint64, error, bool) { hookCount[1]++ },
OnOpcode: func(uint64, byte, uint64, uint64, tracing.OpContext, []byte, int, error) { hookCount[2]++ },
OnFault: func(uint64, byte, uint64, uint64, tracing.OpContext, int, error) { hookCount[3]++ },
OnGasChange: func(uint64, uint64, tracing.GasChangeReason) { hookCount[4]++ },
// state hooks
OnBalanceChange: func(common.Address, *big.Int, *big.Int, tracing.BalanceChangeReason) { hookCount[5]++ },
OnNonceChange: func(addr common.Address, prev, new uint64) { hookCount[6]++ },
OnCodeChange: func(common.Address, common.Hash, []byte, common.Hash, []byte) { hookCount[7]++ },
OnStorageChange: func(addr common.Address, slot, prev, new common.Hash) { hookCount[8]++ },
OnLog: func(*types.Log) { hookCount[9]++ },
}

vm.Apply(&w3types.Message{To: &addrWETH, Value: w3.Big1}, hook)
vm.Apply(&w3types.Message{To: nil, Input: w3.B("0xfe")}, hook) // fault
vm.Apply(&w3types.Message{To: nil, Input: w3.B("0x5f5ff3")}, hook) // deploy empty contract

for i, field := range []string{
"OnEnter", "OnExit", "OnOpcode", "OnFault", "OnGasChange", // vm event hooks
"OnBalanceChange", "OnNonceChange", "OnCodeChange", "OnStorageChange", "OnLog", // state hooks
} {
if hookCount[i] > 0 {
continue
}
t.Fatalf("Hook %q was not triggered", field)
}
}

func TestVMSnapshot(t *testing.T) {
vm, _ := w3vm.New(
w3vm.WithState(w3types.State{
Expand Down

0 comments on commit 9e1e566

Please sign in to comment.