Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add slogcontext codemod #1

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions analyzers/slogcontext/analyzer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package slogcontext

import (
"fmt"
"go/ast"

"github.com/gostaticanalysis/analysisutil"
"github.com/seatgeek/sgmods-go/pkg/util"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)

const slogPackage = "log/slog"
const analyzerName = "slogcontext"

var SlogContextAnalyzer = &analysis.Analyzer{
Name: analyzerName,
Doc: "check that context is passed to all slog calls",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: func(pass *analysis.Pass) (interface{}, error) {
if !util.Imports(pass.Pkg, slogPackage) {
return nil, nil
}

nodeFilter := []ast.Node{
(*ast.CallExpr)(nil),
}

inspector := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
inspector.WithStack(nodeFilter, func(node ast.Node, push bool, stack []ast.Node) bool {
callExpr, ok := node.(*ast.CallExpr)
if !ok {
panic(fmt.Sprintf("unexpected node type %T", node))
}

slogCall, ok := determineSlogCall(callExpr)
if !ok {
return false
}

switch slogCall {
case "Debug", "Info", "Warn", "Error":
break
default:
return false
}

containingFunc := containingFunc(stack)
availableCtx := availableContext(containingFunc)
newCallExpr := *callExpr
newCallExpr.Fun.(*ast.SelectorExpr).Sel.Name += "Context"
switch availableCtx {
case "":
newCallExpr.Args = append([]ast.Expr{ast.NewIdent("context.TODO()")}, newCallExpr.Args...)
case "_":
newCallExpr.Args = append([]ast.Expr{ast.NewIdent("ctx")}, newCallExpr.Args...)

// rename our blank context to ctx
ctxParam := containingFunc.Type.Params.List[0].Names[0]
pos := ctxParam.Pos()
end := ctxParam.End()
ctxParam.Name = "ctx"

analysisutil.ReportWithoutIgnore(pass, analyzerName)(analysis.Diagnostic{
Pos: containingFunc.Pos(),
End: containingFunc.End(),
Message: "context needed by slog call is blank",
SuggestedFixes: []analysis.SuggestedFix{
{
Message: "Rename blank context to 'ctx'",
TextEdits: []analysis.TextEdit{
{
Pos: pos,
End: end,
NewText: []byte(util.Render(ctxParam, pass.Fset)),
},
},
},
},
})
default:
newCallExpr.Args = append([]ast.Expr{ast.NewIdent(availableCtx)}, newCallExpr.Args...)
}
newText := util.Render(&newCallExpr, pass.Fset)

analysisutil.ReportWithoutIgnore(pass, analyzerName)(analysis.Diagnostic{
Pos: node.Pos(),
End: node.End(),
Message: "context not passed to slog call",
SuggestedFixes: []analysis.SuggestedFix{
{
Message: "Add context to slog call",
TextEdits: []analysis.TextEdit{
{
Pos: node.Pos(),
End: node.End(),
NewText: []byte(newText),
},
},
},
},
})

return false
})

return nil, nil
},
}

func availableContext(fn *ast.FuncDecl) string {
if fn == nil || fn.Type.Params.NumFields() == 0 {
return ""
}

// first arg is context
firstArg := fn.Type.Params.List[0]
if selectorMatches(firstArg.Type, "context", "Context") {
return firstArg.Names[0].Name
}

// any arg is a pointer to http.Request
for _, arg := range fn.Type.Params.List {
if starExpr, ok := arg.Type.(*ast.StarExpr); ok {
if selectorMatches(starExpr.X, "http", "Request") {
return fmt.Sprintf("%s.Context()", arg.Names[0].Name)
}
}
}

return ""
}

func containingFunc(stack []ast.Node) *ast.FuncDecl {
for i := 0; i < len(stack); i++ {
if fn, ok := stack[i].(*ast.FuncDecl); ok {
return fn
}
}

return nil
}

// returns the slog call name, false if not an slog call
func determineSlogCall(callExpr *ast.CallExpr) (string, bool) {
selector, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
return "", false
}

x, ok := selector.X.(*ast.Ident)
if !ok {
return "", false
}

if x.Name != "slog" {
return "", false
}

return selector.Sel.Name, true
}

func selectorMatches(node ast.Node, x string, y string) bool {
selector, ok := node.(*ast.SelectorExpr)
if !ok {
return false
}

xIdent, ok := selector.X.(*ast.Ident)
if !ok {
return false
}

if xIdent.Name != x {
return false
}

if selector.Sel.Name != y {
return false
}

return true
}
16 changes: 16 additions & 0 deletions analyzers/slogcontext/analyzer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package slogcontext

import (
"testing"

"golang.org/x/tools/go/analysis/analysistest"
)

func TestWrapErrorAnalyzer(t *testing.T) {
analysistest.RunWithSuggestedFixes(t, analysistest.TestData(), SlogContextAnalyzer,
"basic",
"ignore",
"blankcontext",
"httpreq",
)
}
26 changes: 26 additions & 0 deletions analyzers/slogcontext/testdata/src/basic/basic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import (
"context"
"log/slog"
)

func main() {
run(context.Background())

slog.Debug("main debug msg") // want `context not passed to slog call`

slog.DebugContext(context.Background(), "main debug msg with ctx")
}

func run(ctx context.Context) {
slog.Debug("debug msg") // want `context not passed to slog call`
slog.Info("info msg") // want `context not passed to slog call`
slog.Warn("warn msg") // want `context not passed to slog call`
slog.Error("error msg") // want `context not passed to slog call`

slog.DebugContext(ctx, "debug msg with ctx")
slog.InfoContext(ctx, "info msg with ctx")
slog.WarnContext(ctx, "warn msg with ctx")
slog.ErrorContext(ctx, "error msg with ctx")
}
26 changes: 26 additions & 0 deletions analyzers/slogcontext/testdata/src/basic/basic.go.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import (
"context"
"log/slog"
)

func main() {
run(context.Background())

slog.DebugContext(context.TODO(), "main debug msg") // want `context not passed to slog call`

slog.DebugContext(context.Background(), "main debug msg with ctx")
}

func run(ctx context.Context) {
slog.DebugContext(ctx, "debug msg") // want `context not passed to slog call`
slog.InfoContext(ctx, "info msg") // want `context not passed to slog call`
slog.WarnContext(ctx, "warn msg") // want `context not passed to slog call`
slog.ErrorContext(ctx, "error msg") // want `context not passed to slog call`

slog.DebugContext(ctx, "debug msg with ctx")
slog.InfoContext(ctx, "info msg with ctx")
slog.WarnContext(ctx, "warn msg with ctx")
slog.ErrorContext(ctx, "error msg with ctx")
}
14 changes: 14 additions & 0 deletions analyzers/slogcontext/testdata/src/blankcontext/blankcontext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

import (
"context"
"log/slog"
)

func main() {
run(context.Background())
}

func run(_ context.Context) { // want `context needed by slog call is blank`
slog.Info("hello world") // want `context not passed to slog call`
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

import (
"context"
"log/slog"
)

func main() {
run(context.Background())
}

func run(ctx context.Context) { // want `context needed by slog call is blank`
slog.InfoContext(ctx, "hello world") // want `context not passed to slog call`
}
19 changes: 19 additions & 0 deletions analyzers/slogcontext/testdata/src/httpreq/httpreq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

import (
"log/slog"
"net/http"
)

func main() {
run(nil, &http.Request{})
run2(nil, &http.Request{})
}

func run(w http.ResponseWriter, r *http.Request) {
slog.Info("hello world") // want `context not passed to slog call`
}

func run2(w http.ResponseWriter, req *http.Request) {
slog.Info("hello world") // want `context not passed to slog call`
}
19 changes: 19 additions & 0 deletions analyzers/slogcontext/testdata/src/httpreq/httpreq.go.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

import (
"log/slog"
"net/http"
)

func main() {
run(nil, &http.Request{})
run2(nil, &http.Request{})
}

func run(w http.ResponseWriter, r *http.Request) {
slog.InfoContext(r.Context(), "hello world") // want `context not passed to slog call`
}

func run2(w http.ResponseWriter, req *http.Request) {
slog.InfoContext(req.Context(), "hello world") // want `context not passed to slog call`
}
8 changes: 8 additions & 0 deletions analyzers/slogcontext/testdata/src/ignore/ignore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package main

import "log/slog"

func main() {
slog.Info("hello world!") //lint:ignore slogcontext "no context in main"
slog.Info("hello world!") //lint:ignore otherlinter "another reason" // want `context not passed to slog call`
}
8 changes: 8 additions & 0 deletions analyzers/slogcontext/testdata/src/ignore/ignore.go.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package main

import "log/slog"

func main() {
slog.Info("hello world!") //lint:ignore slogcontext "no context in main"
slog.InfoContext(context.TODO(), "hello world!") //lint:ignore otherlinter "another reason" // want `context not passed to slog call`
}
15 changes: 4 additions & 11 deletions analyzers/wrap_error/analyzer.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package wrap_error

import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"strings"

"github.com/seatgeek/sgmods-go/analyzers"
"github.com/seatgeek/sgmods-go/pkg/util"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
Expand Down Expand Up @@ -82,7 +81,7 @@ var WrapErrorAnalyzer = &analysis.Analyzer{
{
Pos: pos,
End: end,
NewText: []byte(render(importGenDecl, pass.Fset)),
NewText: []byte(util.Render(importGenDecl, pass.Fset)),
},
},
},
Expand Down Expand Up @@ -139,8 +138,8 @@ var WrapErrorAnalyzer = &analysis.Analyzer{
},
}

old := render(returnStmt, pass.Fset)
new := render(suggested, pass.Fset)
old := util.Render(returnStmt, pass.Fset)
new := util.Render(suggested, pass.Fset)

pass.Report(analysis.Diagnostic{
Pos: returnStmt.Pos(),
Expand Down Expand Up @@ -213,9 +212,3 @@ func findImportGenDecl(node *ast.File) (*ast.GenDecl, bool) {

return nil, false
}

func render(node interface{}, fset *token.FileSet) string {
buf := bytes.Buffer{}
printer.Fprint(&buf, fset, node)
return buf.String()
}
Loading