Skip to content

Commit

Permalink
Merge branch 'main' into array-rework
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso authored Nov 24, 2023
2 parents 3d7f855 + 0fe43fc commit 062a8bc
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 13 deletions.
23 changes: 19 additions & 4 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)
renameObjects(typ, body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)

// var _f{n} F = coroutine.Push[F](&_c.Stack)
gen.List = append(gen.List, &ast.DeclStmt{Decl: &ast.GenDecl{
Expand Down Expand Up @@ -632,22 +632,37 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
compiledBody := compileDispatch(body, frameName, spans, mayYield).(*ast.BlockStmt)
gen.List = append(gen.List, compiledBody.List...)

// If the function returns one or more values, it must end with a return statement;
// we inject it if the function body does not already has one.
// If the function returns one or more values, it must end with a return
// statement. Since the input Go code is valid, the last entry in the
// dispatch table should already contain a return statement. We inject a
// panic at the end of the function in case this invariant does not hold
// anymore.
if typ.Results != nil && len(typ.Results.List) > 0 {
needsReturn := len(gen.List) == 0
if !needsReturn {
_, endsWithReturn := gen.List[len(gen.List)-1].(*ast.ReturnStmt)
needsReturn = !endsWithReturn
}
if needsReturn {
gen.List = append(gen.List, &ast.ReturnStmt{})
gen.List = append(gen.List, &ast.ExprStmt{X: panicCall("unreachable")})
}
}

return gen
}

func panicCall(s string) ast.Expr {
return &ast.CallExpr{
Fun: &ast.Ident{Name: "panic"},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: "\"" + s + "\"",
},
},
}
}

// This function returns true if a function body is composed of at most one
// expression.
func isExpr(body *ast.BlockStmt) bool {
Expand Down
7 changes: 7 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18, 3, 6, 9, 6, 12, 18, 9, 18, 27},
result: 27,
},

{
name: "return named values",
coroR: func() int { return ReturnNamedValue() },
yields: []int{11},
result: 42,
},
}

// This emulates the installation of function type information by the
Expand Down
55 changes: 52 additions & 3 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
// renameObjects renames types, constants and variables declared within
// a function. Each is given a unique name, so that declarations are safe
// to hoist into the function prologue.
func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
func renameObjects(fntype *ast.FuncType, tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
// Scan decls to find objects, giving each new object a unique name.
names := make(map[types.Object]*ast.Ident, len(decls))
selectors := make(map[types.Object]*ast.SelectorExpr, len(frameType.Fields.List))
Expand Down Expand Up @@ -238,7 +238,7 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
// replacing if they are removed from the tree too early.
//
// Note that replacing identifiers is a recursive operation which traverses
// function literls.
// function literals.

astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
Expand Down Expand Up @@ -326,9 +326,55 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
},
nil,
)

// Perform a last pass to assigned named results before unnamed. It cannot
// be done in the renaming pass because it should not recurse into function
// literals, which the renaming pass does.
if hasNamedResults(fntype) {
astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *ast.FuncLit:
return false
case *ast.ReturnStmt:
if len(n.Results) > 0 {
return true
}

// Transform
// return
// into
// return (selector1), (selector2)...
for _, t := range fntype.Results.List {
ident := t.Names[0]
obj := info.ObjectOf(ident)
n.Results = append(n.Results, selectors[obj])
}
}

return true
}, nil)
}
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) {
func hasNamedResults(t *ast.FuncType) bool {
if t.Results == nil || len(t.Results.List) == 0 {
return false
}

for _, result := range t.Results.List {
for _, name := range result.Names {
if name == nil || name.Name == "" || name.Name == "_" {
continue
}
return true
}
}
return false
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) []*ast.Ident {
var namedResults []*ast.Ident
names := map[types.Object]*ast.Ident{}

fieldLists := []*ast.FieldList{recv, typ.Params, typ.Results}
Expand All @@ -345,6 +391,7 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
newIdent := ast.NewIdent("_fn" + strconv.Itoa(len(names)))
names[obj] = newIdent
info.Defs[newIdent] = obj
namedResults = append(namedResults, newIdent)
}
}
}
Expand All @@ -366,4 +413,6 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
}
return true
}, nil)

return namedResults
}
7 changes: 7 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,10 @@ func varArgs(args ...int) {
coroutine.Yield[int, any](arg)
}
}

func ReturnNamedValue() (out int) {
out = 5
coroutine.Yield[int, any](11)
out = 42
return
}
53 changes: 47 additions & 6 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func NestedLoops(_fn0 int) (_ int) {

return _f0.X1
}
return
panic("unreachable")
}

//go:noinline
Expand Down Expand Up @@ -1616,7 +1616,7 @@ func Range10ClosureCapturingValues() {

return false
}
return
panic("unreachable")
}
_f1.IP = 4
fallthrough
Expand Down Expand Up @@ -1733,7 +1733,7 @@ func Range10ClosureCapturingPointers() {

return false
}
return
panic("unreachable")
}
_f1.IP = 5
fallthrough
Expand Down Expand Up @@ -1953,7 +1953,7 @@ func Range10ClosureHeterogenousCapture() {
case _f0.IP < 16:
return _f1.X10 < 10
}
return
panic("unreachable")
}
_f1.IP = 13
fallthrough
Expand Down Expand Up @@ -2837,7 +2837,7 @@ func a(_fn0 int) (_ int) {
case _f0.IP < 3:
return _f0.X0
}
return
panic("unreachable")
}

//go:noinline
Expand Down Expand Up @@ -2869,7 +2869,7 @@ func b(_fn0 int) (_ int) {
case _f0.IP < 3:
return _f0.X0
}
return
panic("unreachable")
}

//go:noinline
Expand Down Expand Up @@ -3196,6 +3196,46 @@ func varArgs(_fn0 ...int) {
}
}
}

//go:noinline
func ReturnNamedValue() (_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
} = coroutine.Push[struct {
IP int
X0 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
}{}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f0.IP < 2:
_f0.X0 = 5
_f0.IP = 2
fallthrough
case _f0.IP < 3:
coroutine.Yield[int, any](11)
_f0.IP = 3
fallthrough
case _f0.IP < 4:
_f0.X0 = 42
_f0.IP = 4
fallthrough
case _f0.IP < 5:
return _f0.X0
}
panic("unreachable")
}
func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
Expand Down Expand Up @@ -3292,6 +3332,7 @@ func init() {
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue")
_types.RegisterFunc[func(i int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue.func2")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeYieldAndDeferAssign")
_types.RegisterFunc[func() (_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.ReturnNamedValue")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.SomeFunctionThatShouldExistInTheCompiledFile")
Expand Down

0 comments on commit 062a8bc

Please sign in to comment.