Skip to content

Commit

Permalink
memmod: hook RtlPcToFileHeader's invocation from GetModuleHandleEx
Browse files Browse the repository at this point in the history
When GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) is called
by cfgmgr32.dll's SwCreateDevice on the DLL's callback, it expects to
get the module of the DLL. But of course memory loaded modules means
there is none. This causes SwCreateDevice to fail.

GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) internally
uses RtlPcToFileHeader. In turn, RtlPcToFileHeader looks things up in
the inverted function table, which has no stable interface across OS
releases. That means adding a proper module isn't going to work.

So instead we hook the IAT, so that we can intercept all calls to
RtlPcToFileHeader that come from GetModuleHandleEx's kernelbase.dll. If
the value to look up is within the range of a module we've memory
loaded, then we change the value to lookup to the hook function itself,
so that it winds up returning the main module.

Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
zx2c4 committed Oct 11, 2021
1 parent 5a02b10 commit afe8594
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
84 changes: 84 additions & 0 deletions driver/memmod/memmod_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package memmod
import (
"errors"
"fmt"
"strings"
"sync"
"syscall"
"unsafe"

Expand Down Expand Up @@ -382,6 +384,76 @@ func (module *Module) buildNameExports() error {
return nil
}

type addressRange struct {
start uintptr
end uintptr
}

var loadedAddressRanges []addressRange
var loadedAddressRangesMu sync.RWMutex
var haveHookedRtlPcToFileHeader sync.Once
var hookRtlPcToFileHeaderResult error

func hookRtlPcToFileHeader() error {
var kernelBase windows.Handle
err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase)
if err != nil {
return err
}
imageBase := unsafe.Pointer(kernelBase)
dosHeader := (*IMAGE_DOS_HEADER)(imageBase)
ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew))
importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress))
for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) {
libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name)))
if strings.EqualFold(libraryName, "ntdll.dll") {
break
}
}
if importDescriptor.Name == 0 {
return errors.New("ntdll.dll not found")
}
originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk()))
thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk))
for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) {
if *originalThunk&IMAGE_ORDINAL_FLAG == 0 {
function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk))
name := windows.BytePtrToString(&function.Name[0])
if name == "RtlPcToFileHeader" {
break
}
}
thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk)))
}
if *originalThunk == 0 {
return errors.New("RtlPcToFileHeader not found")
}
var oldProtect uint32
err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect)
if err != nil {
return err
}
originalRtlPcToFileHeader := *thunk
*thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr {
loadedAddressRangesMu.RLock()
for i := range loadedAddressRanges {
if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end {
pcValue = *thunk
break
}
}
loadedAddressRangesMu.RUnlock()
ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0)
return ret
})
err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect)
if err != nil {
return err
}
return nil
}

// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
Expand Down Expand Up @@ -513,6 +585,18 @@ func LoadLibrary(data []byte) (module *Module, err error) {
// Register exception tables, if they exist.
module.registerExceptionHandlers()

// Register function PCs.
loadedAddressRangesMu.Lock()
loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize})
loadedAddressRangesMu.Unlock()
haveHookedRtlPcToFileHeader.Do(func() {
hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader()
})
err = hookRtlPcToFileHeaderResult
if err != nil {
return
}

// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module golang.zx2c4.com/wireguard/windows

go 1.16
go 1.17

require (
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794
Expand Down

0 comments on commit afe8594

Please sign in to comment.