From 197f87e73750e5452e88bf295fdbaa8f66d91b04 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 11:53:30 +0200 Subject: [PATCH 1/6] Journal mismatch via panics --- internal/state/awakeable.go | 41 +++++++++++++++++--------- internal/state/call.go | 34 ++++++++++++++------- internal/state/state.go | 38 ++++++++++++++++++++++-- internal/state/sys.go | 59 +++++++++++++++++++++++++------------ 4 files changed, 126 insertions(+), 46 deletions(-) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 5396615..d8114ca 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -17,8 +17,8 @@ type indexedEntry struct { func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) { indexedEntry, err := replayOrNew( c, - func(entry *wire.AwakeableEntryMessage) (indexedEntry, error) { - return indexedEntry{entry, c.entryIndex}, nil + func(entry *wire.AwakeableEntryMessage) indexedEntry { + return indexedEntry{entry, c.entryIndex} }, c._awakeable, ) @@ -37,18 +37,23 @@ func (c *Machine) _awakeable() (indexedEntry, error) { return indexedEntry{msg, c.entryIndex}, nil } -func (c *Machine) resolveAwakeable(id string, value []byte) error { +func (m *Machine) resolveAwakeable(id string, value []byte) error { _, err := replayOrNew( - c, - func(entry *wire.CompleteAwakeableEntryMessage) (restate.Void, error) { + m, + func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageValue, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Value) if entry.Id != id || !ok || !bytes.Equal(messageValue.Value, value) { - return restate.Void{}, errEntryMismatch + panic(m.newEntryMismatch(&wire.CompleteAwakeableEntryMessage{ + CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ + Id: id, + Result: &protocol.CompleteAwakeableEntryMessage_Value{Value: value}, + }, + }, entry)) } - return restate.Void{}, nil + return restate.Void{} }, func() (restate.Void, error) { - if err := c._resolveAwakeable(id, value); err != nil { + if err := m._resolveAwakeable(id, value); err != nil { return restate.Void{}, err } return restate.Void{}, nil @@ -69,18 +74,26 @@ func (c *Machine) _resolveAwakeable(id string, value []byte) error { return nil } -func (c *Machine) rejectAwakeable(id string, reason error) error { +func (m *Machine) rejectAwakeable(id string, reason error) error { _, err := replayOrNew( - c, - func(entry *wire.CompleteAwakeableEntryMessage) (restate.Void, error) { + m, + func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageFailure, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Failure) if entry.Id != id || !ok || messageFailure.Failure.Code != uint32(restate.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() { - return restate.Void{}, errEntryMismatch + panic(m.newEntryMismatch(&wire.CompleteAwakeableEntryMessage{ + CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ + Id: id, + Result: &protocol.CompleteAwakeableEntryMessage_Failure{Failure: &protocol.Failure{ + Code: uint32(restate.ErrorCode(reason)), + Message: reason.Error(), + }}, + }, + }, entry)) } - return restate.Void{}, nil + return restate.Void{} }, func() (restate.Void, error) { - if err := c._rejectAwakeable(id, reason); err != nil { + if err := m._rejectAwakeable(id, reason); err != nil { return restate.Void{}, err } return restate.Void{}, nil diff --git a/internal/state/call.go b/internal/state/call.go index faa544d..c6d46cb 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -94,15 +94,22 @@ func (m *Machine) doCall(service, key, method string, params []byte) (*wire.Call return replayOrNew( m, - func(entry *wire.CallEntryMessage) (*wire.CallEntryMessage, error) { + func(entry *wire.CallEntryMessage) *wire.CallEntryMessage { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || !bytes.Equal(entry.Parameter, params) { - return nil, errEntryMismatch + panic(m.newEntryMismatch(&wire.CallEntryMessage{ + CallEntryMessage: protocol.CallEntryMessage{ + ServiceName: service, + HandlerName: method, + Parameter: params, + Key: key, + }, + }, entry)) } - return entry, nil + return entry }, func() (*wire.CallEntryMessage, error) { return m._doCall(service, key, method, params) }) @@ -124,8 +131,8 @@ func (m *Machine) _doCall(service, key, method string, params []byte) (*wire.Cal return msg, nil } -func (c *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { - c.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing async call") +func (m *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { + m.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing async call") params, err := json.Marshal(body) if err != nil { @@ -133,19 +140,26 @@ func (c *Machine) sendCall(service, key, method string, body any, delay time.Dur } _, err = replayOrNew( - c, - func(entry *wire.OneWayCallEntryMessage) (restate.Void, error) { + m, + func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || !bytes.Equal(entry.Parameter, params) { - return restate.Void{}, errEntryMismatch + panic(m.newEntryMismatch(&wire.OneWayCallEntryMessage{ + OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ + ServiceName: service, + HandlerName: method, + Parameter: params, + Key: key, + }, + }, entry)) } - return restate.Void{}, nil + return restate.Void{} }, func() (restate.Void, error) { - return restate.Void{}, c._sendCall(service, key, method, params, delay) + return restate.Void{}, m._sendCall(service, key, method, params, delay) }, ) diff --git a/internal/state/state.go b/internal/state/state.go index d32d88a..89f2e16 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -2,6 +2,7 @@ package state import ( "context" + "encoding/json" "fmt" "io" "runtime/debug" @@ -10,6 +11,7 @@ import ( restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" + "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/wire" "github.com/rs/zerolog" @@ -164,6 +166,8 @@ type Machine struct { pendingCompletions map[uint32]wire.CompleteableMessage pendingAcks map[uint32]wire.AckableMessage pendingMutex sync.RWMutex + + failure any } func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { @@ -221,6 +225,27 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { case nil: // nothing to do, just exit return + case *entryMismatch: + expected, _ := json.Marshal(typ.expectedEntry) + actual, _ := json.Marshal(typ.actualEntry) + msg := fmt.Sprintf(`Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! +The journal entry at position %d was: +- In the user code: type: %T, message: %s +- In the replayed messages: type: %T, message %s`, + typ.entryIndex, typ.expectedEntry, string(expected), typ.actualEntry, string(actual)) + + m.log.Error().Msg(msg) + + // journal entry mismatch + if err := m.protocol.Write(&wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(errors.ErrJournalMismatch), + Message: msg, + Description: string(debug.Stack()), + }, + }); err != nil { + m.log.Error().Err(err).Msg("error sending failure message") + } default: // unknown panic! // send an error message (retryable) @@ -355,13 +380,18 @@ func (c *Machine) currentEntry() (wire.Message, bool) { // by sending the proper runtime messages func replayOrNew[M wire.Message, O any]( m *Machine, - replay func(msg M) (O, error), + replay func(msg M) O, new func() (O, error), ) (output O, err error) { // lock around preparing the entry, but we would never await an ack or completion with this held. m.entryMutex.Lock() defer m.entryMutex.Unlock() + if m.failure != nil { + // maybe the user will try to catch our panics, but we will just keep producing them + panic(m.failure) + } + m.entryIndex += 1 // check if there is an entry as this index @@ -371,9 +401,11 @@ func replayOrNew[M wire.Message, O any]( // by calling the replay function if ok { if entry, ok := entry.(M); !ok { - return output, errEntryMismatch + // will be eg *wire.CallEntryMessage(nil) + var expectedEntry M + panic(m.newEntryMismatch(expectedEntry, entry)) } else { - return replay(entry) + return replay(entry), nil } } diff --git a/internal/state/sys.go b/internal/state/sys.go index 6ceb40c..eec15a2 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -14,18 +14,31 @@ import ( "google.golang.org/protobuf/proto" ) -var ( - errEntryMismatch = restate.WithErrorCode(fmt.Errorf("log entry mismatch"), 32) -) +type entryMismatch struct { + entryIndex uint32 + // this can be satisfied by a nil pointer in the case that there is an entry type mismatch + expectedEntry wire.Message + actualEntry wire.Message +} + +func (m *Machine) newEntryMismatch(expectedEntry wire.Message, actualEntry wire.Message) *entryMismatch { + e := &entryMismatch{m.entryIndex, expectedEntry, actualEntry} + m.failure = e + return e +} func (m *Machine) set(key string, value []byte) error { _, err := replayOrNew( m, - func(entry *wire.SetStateEntryMessage) (void restate.Void, err error) { + func(entry *wire.SetStateEntryMessage) (void restate.Void) { if string(entry.Key) != key || !bytes.Equal(entry.Value, value) { - return void, errEntryMismatch + panic(m.newEntryMismatch(&wire.SetStateEntryMessage{ + SetStateEntryMessage: protocol.SetStateEntryMessage{ + Key: []byte(key), + Value: value, + }, + }, entry)) } - return }, func() (void restate.Void, err error) { return void, m._set(key, value) @@ -52,12 +65,16 @@ func (m *Machine) _set(key string, value []byte) error { func (m *Machine) clear(key string) error { _, err := replayOrNew( m, - func(entry *wire.ClearStateEntryMessage) (void restate.Void, err error) { + func(entry *wire.ClearStateEntryMessage) (void restate.Void) { if string(entry.Key) != key { - return void, errEntryMismatch + panic(m.newEntryMismatch(&wire.ClearStateEntryMessage{ + ClearStateEntryMessage: protocol.ClearStateEntryMessage{ + Key: []byte(key), + }, + }, entry)) } - return void, nil + return }, func() (restate.Void, error) { return restate.Void{}, m._clear(key) }, @@ -85,7 +102,7 @@ func (m *Machine) _clear(key string) error { func (m *Machine) clearAll() error { _, err := replayOrNew( m, - func(entry *wire.ClearAllStateEntryMessage) (void restate.Void, err error) { + func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { return }, func() (restate.Void, error) { return restate.Void{}, m._clearAll() @@ -111,11 +128,15 @@ func (m *Machine) _clearAll() error { func (m *Machine) get(key string) ([]byte, error) { entry, err := replayOrNew( m, - func(entry *wire.GetStateEntryMessage) (*wire.GetStateEntryMessage, error) { + func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { if string(entry.Key) != key { - return nil, errEntryMismatch + panic(m.newEntryMismatch(&wire.GetStateEntryMessage{ + GetStateEntryMessage: protocol.GetStateEntryMessage{ + Key: []byte(key), + }, + }, entry)) } - return entry, nil + return entry }, func() (*wire.GetStateEntryMessage, error) { return m._get(key) }) @@ -189,8 +210,8 @@ func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { func (m *Machine) keys() ([]string, error) { entry, err := replayOrNew( m, - func(entry *wire.GetStateKeysEntryMessage) (*wire.GetStateKeysEntryMessage, error) { - return entry, nil + func(entry *wire.GetStateKeysEntryMessage) *wire.GetStateKeysEntryMessage { + return entry }, m._keys, ) @@ -263,9 +284,9 @@ func (m *Machine) _keys() (*wire.GetStateKeysEntryMessage, error) { func (m *Machine) after(d time.Duration) (restate.After, error) { entry, err := replayOrNew( m, - func(entry *wire.SleepEntryMessage) (*wire.SleepEntryMessage, error) { + func(entry *wire.SleepEntryMessage) *wire.SleepEntryMessage { // we shouldn't verify the time because this would be different every time - return entry, nil + return entry }, func() (*wire.SleepEntryMessage, error) { return m._sleep(d) }, @@ -303,8 +324,8 @@ func (m *Machine) _sleep(d time.Duration) (*wire.SleepEntryMessage, error) { func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { entry, err := replayOrNew( m, - func(entry *wire.RunEntryMessage) (*wire.RunEntryMessage, error) { - return entry, nil + func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { + return entry }, func() (*wire.RunEntryMessage, error) { return m._sideEffect(fn) From 628dc8ff025df9fc083aa3275e811627ed2aba5c Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 12:13:59 +0200 Subject: [PATCH 2/6] Failed writes should also panic --- internal/state/awakeable.go | 16 ++++------- internal/state/call.go | 11 ++------ internal/state/completion.go | 17 +++++++++-- internal/state/state.go | 34 ++++++++++++++++------ internal/state/sys.go | 55 +++++++++++++----------------------- 5 files changed, 67 insertions(+), 66 deletions(-) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index d8114ca..0b0e1c0 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -31,9 +31,7 @@ func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) { func (c *Machine) _awakeable() (indexedEntry, error) { msg := &wire.AwakeableEntryMessage{} - if err := c.Write(msg); err != nil { - return indexedEntry{}, err - } + c.Write(msg) return indexedEntry{msg, c.entryIndex}, nil } @@ -63,14 +61,12 @@ func (m *Machine) resolveAwakeable(id string, value []byte) error { } func (c *Machine) _resolveAwakeable(id string, value []byte) error { - if err := c.Write(&wire.CompleteAwakeableEntryMessage{ + c.Write(&wire.CompleteAwakeableEntryMessage{ CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, Result: &protocol.CompleteAwakeableEntryMessage_Value{Value: value}, }, - }); err != nil { - return err - } + }) return nil } @@ -103,7 +99,7 @@ func (m *Machine) rejectAwakeable(id string, reason error) error { } func (c *Machine) _rejectAwakeable(id string, reason error) error { - if err := c.Write(&wire.CompleteAwakeableEntryMessage{ + c.Write(&wire.CompleteAwakeableEntryMessage{ CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, Result: &protocol.CompleteAwakeableEntryMessage_Failure{Failure: &protocol.Failure{ @@ -111,8 +107,6 @@ func (c *Machine) _rejectAwakeable(id string, reason error) error { Message: reason.Error(), }}, }, - }); err != nil { - return err - } + }) return nil } diff --git a/internal/state/call.go b/internal/state/call.go index c6d46cb..d0a3cc4 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -3,7 +3,6 @@ package state import ( "bytes" "encoding/json" - "fmt" "time" restate "github.com/restatedev/sdk-go" @@ -124,9 +123,7 @@ func (m *Machine) _doCall(service, key, method string, params []byte) (*wire.Cal Key: key, }, } - if err := m.Write(msg); err != nil { - return nil, fmt.Errorf("failed to send request message: %w", err) - } + m.Write(msg) return msg, nil } @@ -172,7 +169,7 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti invokeTime = uint64(time.Now().Add(delay).UnixMilli()) } - err := c.Write(&wire.OneWayCallEntryMessage{ + c.Write(&wire.OneWayCallEntryMessage{ OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ ServiceName: service, HandlerName: method, @@ -182,9 +179,5 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti }, }) - if err != nil { - return fmt.Errorf("failed to send request message: %w", err) - } - return nil } diff --git a/internal/state/completion.go b/internal/state/completion.go index 8f5ee09..d45b46c 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -26,7 +26,7 @@ func (m *Machine) ackable(entryIndex uint32) wire.AckableMessage { return m.pendingAcks[entryIndex] } -func (m *Machine) Write(message wire.Message) error { +func (m *Machine) Write(message wire.Message) { if message, ok := message.(wire.CompleteableMessage); ok && !message.Completed() { m.pendingMutex.Lock() m.pendingCompletions[m.entryIndex] = message @@ -37,7 +37,20 @@ func (m *Machine) Write(message wire.Message) error { m.pendingAcks[m.entryIndex] = message m.pendingMutex.Unlock() } - return m.protocol.Write(message) + if err := m.protocol.Write(message); err != nil { + panic(m.newWriteError(message, err)) + } +} + +type writeError struct { + entry wire.Message + err error +} + +func (m *Machine) newWriteError(entry wire.Message, err error) *writeError { + w := &writeError{entry, err} + m.failure = w + return w } func (m *Machine) handleCompletionsAcks() { diff --git a/internal/state/state.go b/internal/state/state.go index 89f2e16..332726c 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -228,24 +228,42 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { case *entryMismatch: expected, _ := json.Marshal(typ.expectedEntry) actual, _ := json.Marshal(typ.actualEntry) - msg := fmt.Sprintf(`Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! -The journal entry at position %d was: -- In the user code: type: %T, message: %s -- In the replayed messages: type: %T, message %s`, - typ.entryIndex, typ.expectedEntry, string(expected), typ.actualEntry, string(actual)) - m.log.Error().Msg(msg) + m.log.Error(). + Type("expectedType", typ.expectedEntry). + RawJSON("expectedMessage", expected). + Type("actualType", typ.actualEntry). + RawJSON("actualMessage", actual). + Msg("Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic!") // journal entry mismatch if err := m.protocol.Write(&wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(errors.ErrJournalMismatch), - Message: msg, + Code: uint32(errors.ErrJournalMismatch), + Message: fmt.Sprintf(`Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! +The journal entry at position %d was: +- In the user code: type: %T, message: %s +- In the replayed messages: type: %T, message %s`, + typ.entryIndex, typ.expectedEntry, string(expected), typ.actualEntry, string(actual)), Description: string(debug.Stack()), }, }); err != nil { m.log.Error().Err(err).Msg("error sending failure message") } + + return + case *writeError: + m.log.Error().Err(typ.err).Msg("Failed to write entry to Restate, shutting down state machine") + // don't even check for failure here because most likely the http2 conn is closed anyhow + _ = m.protocol.Write(&wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(errors.ErrProtocolViolation), + Message: typ.err.Error(), + Description: string(debug.Stack()), + }, + }) + + return default: // unknown panic! // send an error message (retryable) diff --git a/internal/state/sys.go b/internal/state/sys.go index eec15a2..6869ba0 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -41,7 +41,8 @@ func (m *Machine) set(key string, value []byte) error { } return }, func() (void restate.Void, err error) { - return void, m._set(key, value) + m._set(key, value) + return void, nil }) if err != nil { return err @@ -52,8 +53,8 @@ func (m *Machine) set(key string, value []byte) error { return nil } -func (m *Machine) _set(key string, value []byte) error { - return m.Write( +func (m *Machine) _set(key string, value []byte) { + m.Write( &wire.SetStateEntryMessage{ SetStateEntryMessage: protocol.SetStateEntryMessage{ Key: []byte(key), @@ -76,7 +77,8 @@ func (m *Machine) clear(key string) error { return }, func() (restate.Void, error) { - return restate.Void{}, m._clear(key) + m._clear(key) + return restate.Void{}, nil }, ) @@ -89,8 +91,8 @@ func (m *Machine) clear(key string) error { return err } -func (m *Machine) _clear(key string) error { - return m.Write( +func (m *Machine) _clear(key string) { + m.Write( &wire.ClearStateEntryMessage{ ClearStateEntryMessage: protocol.ClearStateEntryMessage{ Key: []byte(key), @@ -105,7 +107,8 @@ func (m *Machine) clearAll() error { func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { return }, func() (restate.Void, error) { - return restate.Void{}, m._clearAll() + m._clearAll() + return restate.Void{}, nil }, ) if err != nil { @@ -119,8 +122,8 @@ func (m *Machine) clearAll() error { } // clearAll drops all associated keys -func (m *Machine) _clearAll() error { - return m.Write( +func (m *Machine) _clearAll() { + m.Write( &wire.ClearAllStateEntryMessage{}, ) } @@ -178,9 +181,7 @@ func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { // value to the runtime msg.Complete(&protocol.CompletionMessage{Result: &protocol.CompletionMessage_Value{Value: value}}) - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } @@ -191,18 +192,14 @@ func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { // but also send an empty get state entry message msg.Complete(&protocol.CompletionMessage{Result: &protocol.CompletionMessage_Empty{Empty: &protocol.Empty{}}}) - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } // we didn't see the value and we don't know for sure there isn't one; ask the runtime for it - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } @@ -266,17 +263,9 @@ func (m *Machine) _keys() (*wire.GetStateKeysEntryMessage, error) { msg.Complete(&protocol.CompletionMessage{Result: &protocol.CompletionMessage_Value{ Value: value, }}) - - if err := m.Write(msg); err != nil { - return nil, err - } - - return nil, nil } - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } @@ -314,9 +303,7 @@ func (m *Machine) _sleep(d time.Duration) (*wire.SleepEntryMessage, error) { WakeUpTime: uint64(time.Now().Add(d).UnixMilli()), }, } - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } @@ -369,9 +356,7 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) (*wire.RunEntryMessage, }, }, } - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) // don't return the original error, we will turn the entry back into an error later // that way its not different replay vs non-replay @@ -399,9 +384,7 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) (*wire.RunEntryMessage, }, }, } - if err := m.Write(msg); err != nil { - return nil, err - } + m.Write(msg) return msg, nil } From 271a955ec566766487e41e362933b2b3511572d9 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 12:38:52 +0200 Subject: [PATCH 3/6] Creating journal entries should panic, not return error --- internal/state/awakeable.go | 23 ++++------ internal/state/call.go | 11 ++--- internal/state/completion.go | 7 +-- internal/state/state.go | 32 +++++++++++--- internal/state/sys.go | 69 ++++++++++++++--------------- internal/wire/wire.go | 86 +++++++++++++++++++----------------- 6 files changed, 124 insertions(+), 104 deletions(-) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 0b0e1c0..98454c9 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -29,10 +29,10 @@ func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) { return futures.NewAwakeable(c.ctx, c.id, indexedEntry.entryIndex, indexedEntry.entry), nil } -func (c *Machine) _awakeable() (indexedEntry, error) { +func (c *Machine) _awakeable() indexedEntry { msg := &wire.AwakeableEntryMessage{} c.Write(msg) - return indexedEntry{msg, c.entryIndex}, nil + return indexedEntry{msg, c.entryIndex} } func (m *Machine) resolveAwakeable(id string, value []byte) error { @@ -50,24 +50,21 @@ func (m *Machine) resolveAwakeable(id string, value []byte) error { } return restate.Void{} }, - func() (restate.Void, error) { - if err := m._resolveAwakeable(id, value); err != nil { - return restate.Void{}, err - } - return restate.Void{}, nil + func() restate.Void { + m._resolveAwakeable(id, value) + return restate.Void{} }, ) return err } -func (c *Machine) _resolveAwakeable(id string, value []byte) error { +func (c *Machine) _resolveAwakeable(id string, value []byte) { c.Write(&wire.CompleteAwakeableEntryMessage{ CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, Result: &protocol.CompleteAwakeableEntryMessage_Value{Value: value}, }, }) - return nil } func (m *Machine) rejectAwakeable(id string, reason error) error { @@ -88,11 +85,9 @@ func (m *Machine) rejectAwakeable(id string, reason error) error { } return restate.Void{} }, - func() (restate.Void, error) { - if err := m._rejectAwakeable(id, reason); err != nil { - return restate.Void{}, err - } - return restate.Void{}, nil + func() restate.Void { + m._rejectAwakeable(id, reason) + return restate.Void{} }, ) return err diff --git a/internal/state/call.go b/internal/state/call.go index d0a3cc4..fd25a12 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -109,12 +109,12 @@ func (m *Machine) doCall(service, key, method string, params []byte) (*wire.Call } return entry - }, func() (*wire.CallEntryMessage, error) { + }, func() *wire.CallEntryMessage { return m._doCall(service, key, method, params) }) } -func (m *Machine) _doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, error) { +func (m *Machine) _doCall(service, key, method string, params []byte) *wire.CallEntryMessage { msg := &wire.CallEntryMessage{ CallEntryMessage: protocol.CallEntryMessage{ ServiceName: service, @@ -125,7 +125,7 @@ func (m *Machine) _doCall(service, key, method string, params []byte) (*wire.Cal } m.Write(msg) - return msg, nil + return msg } func (m *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { @@ -155,8 +155,9 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur return restate.Void{} }, - func() (restate.Void, error) { - return restate.Void{}, m._sendCall(service, key, method, params, delay) + func() restate.Void { + m._sendCall(service, key, method, params, delay) + return restate.Void{} }, ) diff --git a/internal/state/completion.go b/internal/state/completion.go index d45b46c..138d0d2 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -43,12 +43,13 @@ func (m *Machine) Write(message wire.Message) { } type writeError struct { - entry wire.Message - err error + entryIndex uint32 + entry wire.Message + err error } func (m *Machine) newWriteError(entry wire.Message, err error) *writeError { - w := &writeError{entry, err} + w := &writeError{m.entryIndex, entry, err} m.failure = w return w } diff --git a/internal/state/state.go b/internal/state/state.go index 332726c..86995e9 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -245,7 +245,9 @@ The journal entry at position %d was: - In the user code: type: %T, message: %s - In the replayed messages: type: %T, message %s`, typ.entryIndex, typ.expectedEntry, string(expected), typ.actualEntry, string(actual)), - Description: string(debug.Stack()), + Description: string(debug.Stack()), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.MessageType(typ.actualEntry).UInt32(), }, }); err != nil { m.log.Error().Err(err).Msg("error sending failure message") @@ -257,12 +259,30 @@ The journal entry at position %d was: // don't even check for failure here because most likely the http2 conn is closed anyhow _ = m.protocol.Write(&wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(errors.ErrProtocolViolation), - Message: typ.err.Error(), - Description: string(debug.Stack()), + Code: uint32(errors.ErrProtocolViolation), + Message: typ.err.Error(), + Description: string(debug.Stack()), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.MessageType(typ.entry).UInt32(), }, }) + return + case *sideEffectFailure: + m.log.Error().Err(typ.err).Msg("Side effect returned a failure, returning error to Restate") + + if err := m.protocol.Write(&wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(restate.ErrorCode(typ.err)), + Message: typ.err.Error(), + Description: string(debug.Stack()), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), + }, + }); err != nil { + m.log.Error().Err(err).Msg("error sending failure message") + } + return default: // unknown panic! @@ -399,7 +419,7 @@ func (c *Machine) currentEntry() (wire.Message, bool) { func replayOrNew[M wire.Message, O any]( m *Machine, replay func(msg M) O, - new func() (O, error), + new func() O, ) (output O, err error) { // lock around preparing the entry, but we would never await an ack or completion with this held. m.entryMutex.Lock() @@ -428,5 +448,5 @@ func replayOrNew[M wire.Message, O any]( } // other wise call the new function - return new() + return new(), nil } diff --git a/internal/state/sys.go b/internal/state/sys.go index 6869ba0..4d18828 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -40,9 +40,9 @@ func (m *Machine) set(key string, value []byte) error { }, entry)) } return - }, func() (void restate.Void, err error) { + }, func() (void restate.Void) { m._set(key, value) - return void, nil + return void }) if err != nil { return err @@ -76,9 +76,9 @@ func (m *Machine) clear(key string) error { } return - }, func() (restate.Void, error) { + }, func() restate.Void { m._clear(key) - return restate.Void{}, nil + return restate.Void{} }, ) @@ -106,9 +106,9 @@ func (m *Machine) clearAll() error { m, func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { return - }, func() (restate.Void, error) { + }, func() restate.Void { m._clearAll() - return restate.Void{}, nil + return restate.Void{} }, ) if err != nil { @@ -140,7 +140,7 @@ func (m *Machine) get(key string) ([]byte, error) { }, entry)) } return entry - }, func() (*wire.GetStateEntryMessage, error) { + }, func() *wire.GetStateEntryMessage { return m._get(key) }) if err != nil { @@ -167,7 +167,7 @@ func (m *Machine) get(key string) ([]byte, error) { return nil, restate.TerminalError(fmt.Errorf("get state had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } -func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { +func (m *Machine) _get(key string) *wire.GetStateEntryMessage { msg := &wire.GetStateEntryMessage{ GetStateEntryMessage: protocol.GetStateEntryMessage{ Key: []byte(key), @@ -183,7 +183,7 @@ func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { m.Write(msg) - return msg, nil + return msg } // key is not in map! there are 2 cases. @@ -194,14 +194,14 @@ func (m *Machine) _get(key string) (*wire.GetStateEntryMessage, error) { m.Write(msg) - return msg, nil + return msg } // we didn't see the value and we don't know for sure there isn't one; ask the runtime for it m.Write(msg) - return msg, nil + return msg } func (m *Machine) keys() ([]string, error) { @@ -237,7 +237,7 @@ func (m *Machine) keys() ([]string, error) { return nil, nil } -func (m *Machine) _keys() (*wire.GetStateKeysEntryMessage, error) { +func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { msg := &wire.GetStateKeysEntryMessage{} if !m.partial { keys := make([]string, 0, len(m.current)) @@ -254,7 +254,7 @@ func (m *Machine) _keys() (*wire.GetStateKeysEntryMessage, error) { stateKeys := &protocol.GetStateKeysEntryMessage_StateKeys{Keys: byteKeys} value, err := proto.Marshal(stateKeys) if err != nil { - return nil, err + panic(err) // this is pretty much impossible } // we can return keys entirely from cache @@ -267,7 +267,7 @@ func (m *Machine) _keys() (*wire.GetStateKeysEntryMessage, error) { m.Write(msg) - return msg, nil + return msg } func (m *Machine) after(d time.Duration) (restate.After, error) { @@ -276,7 +276,7 @@ func (m *Machine) after(d time.Duration) (restate.After, error) { func(entry *wire.SleepEntryMessage) *wire.SleepEntryMessage { // we shouldn't verify the time because this would be different every time return entry - }, func() (*wire.SleepEntryMessage, error) { + }, func() *wire.SleepEntryMessage { return m._sleep(d) }, ) @@ -297,7 +297,7 @@ func (m *Machine) sleep(d time.Duration) error { } // _sleep creating a new sleep entry. -func (m *Machine) _sleep(d time.Duration) (*wire.SleepEntryMessage, error) { +func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { msg := &wire.SleepEntryMessage{ SleepEntryMessage: protocol.SleepEntryMessage{ WakeUpTime: uint64(time.Now().Add(d).UnixMilli()), @@ -305,7 +305,7 @@ func (m *Machine) _sleep(d time.Duration) (*wire.SleepEntryMessage, error) { } m.Write(msg) - return msg, nil + return msg } func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { @@ -314,7 +314,7 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { return entry }, - func() (*wire.RunEntryMessage, error) { + func() *wire.RunEntryMessage { return m._sideEffect(fn) }, ) @@ -341,7 +341,7 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { return nil, restate.TerminalError(fmt.Errorf("side effect entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } -func (m *Machine) _sideEffect(fn func() ([]byte, error)) (*wire.RunEntryMessage, error) { +func (m *Machine) _sideEffect(fn func() ([]byte, error)) *wire.RunEntryMessage { bytes, err := fn() if err != nil { @@ -358,23 +358,9 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) (*wire.RunEntryMessage, } m.Write(msg) - // don't return the original error, we will turn the entry back into an error later - // that way its not different replay vs non-replay - return msg, nil + return msg } else { - ty := uint32(wire.RunEntryMessageType) - msg := wire.ErrorMessage{ - ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(err)), - Message: err.Error(), - RelatedEntryType: &ty, - }, - } - if err := m.protocol.Write(&msg); err != nil { - return nil, err - } - - return nil, err + panic(m.newSideEffectFailure(err)) } } else { msg := &wire.RunEntryMessage{ @@ -386,6 +372,17 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) (*wire.RunEntryMessage, } m.Write(msg) - return msg, nil + return msg } } + +type sideEffectFailure struct { + entryIndex uint32 + err error +} + +func (m *Machine) newSideEffectFailure(err error) *sideEffectFailure { + s := &sideEffectFailure{m.entryIndex, err} + m.failure = s + return s +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 27f1c2e..62eb741 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -63,6 +63,11 @@ func (t Type) String() string { return fmt.Sprintf("0x%04X", uint16(t)) } +func (t Type) UInt32() *uint32 { + u := uint32(t) + return &u +} + // Flag section of the header this can have // a different meaning based on message type. type Flag uint16 @@ -93,6 +98,46 @@ type Message interface { proto.Message } +func MessageType(message Message) Type { + switch message.(type) { + case *StartMessage: + return StartMessageType + case *SuspensionMessage: + return SuspensionMessageType + case *InputEntryMessage: + return InputEntryMessageType + case *OutputEntryMessage: + return OutputEntryMessageType + case *ErrorMessage: + return ErrorMessageType + case *EndMessage: + return EndMessageType + case *GetStateEntryMessage: + return GetStateEntryMessageType + case *SetStateEntryMessage: + return SetStateEntryMessageType + case *ClearStateEntryMessage: + return ClearStateEntryMessageType + case *ClearAllStateEntryMessage: + return ClearAllStateEntryMessageType + case *GetStateKeysEntryMessage: + return GetStateKeysEntryMessageType + case *SleepEntryMessage: + return SleepEntryMessageType + case *CallEntryMessage: + return CallEntryMessageType + case *OneWayCallEntryMessage: + return OneWayCallEntryMessageType + case *AwakeableEntryMessage: + return AwakeableEntryMessageType + case *CompleteAwakeableEntryMessage: + return CompleteAwakeableEntryMessageType + case *RunEntryMessage: + return RunEntryMessageType + } + panic(fmt.Sprintf("unknown message type %T", message)) +} + type ReaderMessage struct { Message Message Err error @@ -172,46 +217,7 @@ func (s *Protocol) Write(message Message) error { flag |= FlagRequiresAck } - // all possible types sent by the sdk - var typ Type - switch message.(type) { - case *StartMessage: - typ = StartMessageType - case *SuspensionMessage: - typ = SuspensionMessageType - case *InputEntryMessage: - typ = InputEntryMessageType - case *OutputEntryMessage: - typ = OutputEntryMessageType - case *ErrorMessage: - typ = ErrorMessageType - case *EndMessage: - typ = EndMessageType - case *GetStateEntryMessage: - typ = GetStateEntryMessageType - case *SetStateEntryMessage: - typ = SetStateEntryMessageType - case *ClearStateEntryMessage: - typ = ClearStateEntryMessageType - case *ClearAllStateEntryMessage: - typ = ClearAllStateEntryMessageType - case *GetStateKeysEntryMessage: - typ = GetStateKeysEntryMessageType - case *SleepEntryMessage: - typ = SleepEntryMessageType - case *CallEntryMessage: - typ = CallEntryMessageType - case *OneWayCallEntryMessage: - typ = OneWayCallEntryMessageType - case *AwakeableEntryMessage: - typ = AwakeableEntryMessageType - case *CompleteAwakeableEntryMessage: - typ = CompleteAwakeableEntryMessageType - case *RunEntryMessage: - typ = RunEntryMessageType - default: - return fmt.Errorf("can not send message of unknown message type") - } + typ := MessageType(message) s.log.Trace().Stringer("type", typ).Interface("msg", message).Msg("sending message to runtime") From 1da28e5ab9fae907def99aed282a38886f26e26c Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 12:46:22 +0200 Subject: [PATCH 4/6] Make replayOrNew infallible except by panic --- example/ticket_service.go | 3 +- example/user_session.go | 3 +- internal/futures/select_test.go | 9 ++---- internal/state/awakeable.go | 22 ++++++--------- internal/state/call.go | 12 ++++---- internal/state/state.go | 28 +++++++++--------- internal/state/sys.go | 50 +++++++-------------------------- router.go | 27 ++++++++---------- 8 files changed, 56 insertions(+), 98 deletions(-) diff --git a/example/ticket_service.go b/example/ticket_service.go index 15db39a..98c6cf9 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -37,7 +37,8 @@ func unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, er } if status != TicketSold { - return void, ctx.Clear("status") + ctx.Clear("status") + return void, nil } return void, nil diff --git a/example/user_session.go b/example/user_session.go index e69e28b..02e2ae6 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -96,7 +96,8 @@ func checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { } } - return true, ctx.Clear("tickets") + ctx.Clear("tickets") + return true, nil } var ( diff --git a/internal/futures/select_test.go b/internal/futures/select_test.go index 74abba1..a9ca09b 100644 --- a/internal/futures/select_test.go +++ b/internal/futures/select_test.go @@ -13,8 +13,8 @@ type fakeContext struct { restate.Context } -func (f *fakeContext) Awakeable() (restate.Awakeable[[]byte], error) { - return futures.NewAwakeable(context.TODO(), nil, 0, nil), nil +func (f *fakeContext) Awakeable() restate.Awakeable[[]byte] { + return futures.NewAwakeable(context.TODO(), nil, 0, nil) } var _ restate.Context = (*fakeContext)(nil) @@ -22,10 +22,7 @@ var _ restate.Context = (*fakeContext)(nil) func TestSelect(t *testing.T) { after := futures.NewAfter(context.TODO(), nil) awakeableOne := futures.NewAwakeable(context.TODO(), nil, 0, nil) - awakeableTwo, err := restate.AwakeableAs[string](&fakeContext{}) - if err != nil { - t.Fatal(err) - } + awakeableTwo := restate.AwakeableAs[string](&fakeContext{}) responseFut := futures.NewResponseFuture(context.TODO(), nil) // one-off (race) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 98454c9..710da25 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -14,19 +14,16 @@ type indexedEntry struct { entryIndex uint32 } -func (c *Machine) awakeable() (restate.Awakeable[[]byte], error) { - indexedEntry, err := replayOrNew( +func (c *Machine) awakeable() restate.Awakeable[[]byte] { + indexedEntry := replayOrNew( c, func(entry *wire.AwakeableEntryMessage) indexedEntry { return indexedEntry{entry, c.entryIndex} }, c._awakeable, ) - if err != nil { - return nil, err - } - return futures.NewAwakeable(c.ctx, c.id, indexedEntry.entryIndex, indexedEntry.entry), nil + return futures.NewAwakeable(c.ctx, c.id, indexedEntry.entryIndex, indexedEntry.entry) } func (c *Machine) _awakeable() indexedEntry { @@ -35,8 +32,8 @@ func (c *Machine) _awakeable() indexedEntry { return indexedEntry{msg, c.entryIndex} } -func (m *Machine) resolveAwakeable(id string, value []byte) error { - _, err := replayOrNew( +func (m *Machine) resolveAwakeable(id string, value []byte) { + _ = replayOrNew( m, func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageValue, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Value) @@ -55,7 +52,6 @@ func (m *Machine) resolveAwakeable(id string, value []byte) error { return restate.Void{} }, ) - return err } func (c *Machine) _resolveAwakeable(id string, value []byte) { @@ -67,8 +63,8 @@ func (c *Machine) _resolveAwakeable(id string, value []byte) { }) } -func (m *Machine) rejectAwakeable(id string, reason error) error { - _, err := replayOrNew( +func (m *Machine) rejectAwakeable(id string, reason error) { + _ = replayOrNew( m, func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageFailure, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Failure) @@ -90,10 +86,9 @@ func (m *Machine) rejectAwakeable(id string, reason error) error { return restate.Void{} }, ) - return err } -func (c *Machine) _rejectAwakeable(id string, reason error) error { +func (c *Machine) _rejectAwakeable(id string, reason error) { c.Write(&wire.CompleteAwakeableEntryMessage{ CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, @@ -103,5 +98,4 @@ func (c *Machine) _rejectAwakeable(id string, reason error) error { }}, }, }) - return nil } diff --git a/internal/state/call.go b/internal/state/call.go index fd25a12..6f33f10 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -85,10 +85,10 @@ func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallE return nil, err } - return m.doCall(service, key, method, params) + return m.doCall(service, key, method, params), nil } -func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, error) { +func (m *Machine) doCall(service, key, method string, params []byte) *wire.CallEntryMessage { m.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing sync call") return replayOrNew( @@ -136,7 +136,7 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur return err } - _, err = replayOrNew( + _ = replayOrNew( m, func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || @@ -161,10 +161,10 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur }, ) - return err + return nil } -func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) error { +func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) { var invokeTime uint64 if delay != 0 { invokeTime = uint64(time.Now().Add(delay).UnixMilli()) @@ -179,6 +179,4 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti InvokeTime: invokeTime, }, }) - - return nil } diff --git a/internal/state/state.go b/internal/state/state.go index 86995e9..b290ec4 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -42,18 +42,18 @@ func (c *Context) Ctx() context.Context { return c.ctx } -func (c *Context) Set(key string, value []byte) error { - return c.machine.set(key, value) +func (c *Context) Set(key string, value []byte) { + c.machine.set(key, value) } -func (c *Context) Clear(key string) error { - return c.machine.clear(key) +func (c *Context) Clear(key string) { + c.machine.clear(key) } // ClearAll drops all associated keys -func (c *Context) ClearAll() error { - return c.machine.clearAll() +func (c *Context) ClearAll() { + c.machine.clearAll() } @@ -113,16 +113,16 @@ func (c *Context) SideEffect(fn func() ([]byte, error)) ([]byte, error) { return c.machine.sideEffect(fn) } -func (c *Context) Awakeable() (restate.Awakeable[[]byte], error) { +func (c *Context) Awakeable() restate.Awakeable[[]byte] { return c.machine.awakeable() } -func (c *Context) ResolveAwakeable(id string, value []byte) error { - return c.machine.resolveAwakeable(id, value) +func (c *Context) ResolveAwakeable(id string, value []byte) { + c.machine.resolveAwakeable(id, value) } -func (c *Context) RejectAwakeable(id string, reason error) error { - return c.machine.rejectAwakeable(id, reason) +func (c *Context) RejectAwakeable(id string, reason error) { + c.machine.rejectAwakeable(id, reason) } func (c *Context) Key() string { @@ -420,7 +420,7 @@ func replayOrNew[M wire.Message, O any]( m *Machine, replay func(msg M) O, new func() O, -) (output O, err error) { +) (output O) { // lock around preparing the entry, but we would never await an ack or completion with this held. m.entryMutex.Lock() defer m.entryMutex.Unlock() @@ -443,10 +443,10 @@ func replayOrNew[M wire.Message, O any]( var expectedEntry M panic(m.newEntryMismatch(expectedEntry, entry)) } else { - return replay(entry), nil + return replay(entry) } } // other wise call the new function - return new(), nil + return new() } diff --git a/internal/state/sys.go b/internal/state/sys.go index 4d18828..a2d9ca2 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -27,8 +27,8 @@ func (m *Machine) newEntryMismatch(expectedEntry wire.Message, actualEntry wire. return e } -func (m *Machine) set(key string, value []byte) error { - _, err := replayOrNew( +func (m *Machine) set(key string, value []byte) { + _ = replayOrNew( m, func(entry *wire.SetStateEntryMessage) (void restate.Void) { if string(entry.Key) != key || !bytes.Equal(entry.Value, value) { @@ -44,13 +44,8 @@ func (m *Machine) set(key string, value []byte) error { m._set(key, value) return void }) - if err != nil { - return err - } m.current[key] = value - - return nil } func (m *Machine) _set(key string, value []byte) { @@ -63,8 +58,8 @@ func (m *Machine) _set(key string, value []byte) { }) } -func (m *Machine) clear(key string) error { - _, err := replayOrNew( +func (m *Machine) clear(key string) { + _ = replayOrNew( m, func(entry *wire.ClearStateEntryMessage) (void restate.Void) { if string(entry.Key) != key { @@ -82,13 +77,7 @@ func (m *Machine) clear(key string) error { }, ) - if err != nil { - return err - } - delete(m.current, key) - - return err } func (m *Machine) _clear(key string) { @@ -101,8 +90,8 @@ func (m *Machine) _clear(key string) { ) } -func (m *Machine) clearAll() error { - _, err := replayOrNew( +func (m *Machine) clearAll() { + _ = replayOrNew( m, func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { return @@ -111,14 +100,8 @@ func (m *Machine) clearAll() error { return restate.Void{} }, ) - if err != nil { - return err - } - m.current = map[string][]byte{} m.partial = false - - return nil } // clearAll drops all associated keys @@ -129,7 +112,7 @@ func (m *Machine) _clearAll() { } func (m *Machine) get(key string) ([]byte, error) { - entry, err := replayOrNew( + entry := replayOrNew( m, func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { if string(entry.Key) != key { @@ -143,9 +126,6 @@ func (m *Machine) get(key string) ([]byte, error) { }, func() *wire.GetStateEntryMessage { return m._get(key) }) - if err != nil { - return nil, err - } if err := entry.Await(m.ctx); err != nil { return nil, err @@ -205,16 +185,13 @@ func (m *Machine) _get(key string) *wire.GetStateEntryMessage { } func (m *Machine) keys() ([]string, error) { - entry, err := replayOrNew( + entry := replayOrNew( m, func(entry *wire.GetStateKeysEntryMessage) *wire.GetStateKeysEntryMessage { return entry }, m._keys, ) - if err != nil { - return nil, err - } if err := entry.Await(m.ctx); err != nil { return nil, err @@ -271,7 +248,7 @@ func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { } func (m *Machine) after(d time.Duration) (restate.After, error) { - entry, err := replayOrNew( + entry := replayOrNew( m, func(entry *wire.SleepEntryMessage) *wire.SleepEntryMessage { // we shouldn't verify the time because this would be different every time @@ -280,9 +257,6 @@ func (m *Machine) after(d time.Duration) (restate.After, error) { return m._sleep(d) }, ) - if err != nil { - return nil, err - } return futures.NewAfter(m.ctx, entry), nil } @@ -309,7 +283,7 @@ func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { } func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { - entry, err := replayOrNew( + entry := replayOrNew( m, func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { return entry @@ -318,10 +292,6 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { return m._sideEffect(fn) }, ) - if err != nil { - // either a transient error from the fn or from our sending of the result - return nil, err - } // side effect must be acknowledged before proceeding if err := entry.Await(m.ctx); err != nil { diff --git a/router.go b/router.go index 9744dd7..450813e 100644 --- a/router.go +++ b/router.go @@ -76,9 +76,9 @@ type Context interface { // Note: use the SideEffectAs helper function SideEffect(fn func() ([]byte, error)) ([]byte, error) - Awakeable() (Awakeable[[]byte], error) - ResolveAwakeable(id string, value []byte) error - RejectAwakeable(id string, reason error) error + Awakeable() Awakeable[[]byte] + ResolveAwakeable(id string, value []byte) + RejectAwakeable(id string, reason error) } // Router interface @@ -123,7 +123,7 @@ type KeyValueStore interface { // Set sets key value to bytes array. You can // Note: Use SetAs helper function to seamlessly store // a value of specific type. - Set(key string, value []byte) error + Set(key string, value []byte) // Get gets value (bytes array) associated with key // If key does not exist, this function return a nil bytes array // and a nil error @@ -131,9 +131,9 @@ type KeyValueStore interface { // as specific type. Get(key string) ([]byte, error) // Clear deletes a key - Clear(key string) error + Clear(key string) // ClearAll drops all stored state associated with key - ClearAll() error + ClearAll() // Keys returns a list of all associated key Keys() ([]string, error) } @@ -232,7 +232,8 @@ func SetAs[T any](ctx ObjectContext, key string, value T) error { return err } - return ctx.Set(key, bytes) + ctx.Set(key, bytes) + return nil } // SideEffectAs helper function runs a side effect function with specific concrete type as a result @@ -279,13 +280,8 @@ func (d decodingAwakeable[T]) Result() (out T, err error) { return } -func AwakeableAs[T any](ctx Context) (Awakeable[T], error) { - inner, err := ctx.Awakeable() - if err != nil { - return nil, err - } - - return decodingAwakeable[T]{Awakeable: inner}, nil +func AwakeableAs[T any](ctx Context) Awakeable[T] { + return decodingAwakeable[T]{Awakeable: ctx.Awakeable()} } func ResolveAwakeableAs[T any](ctx Context, id string, value T) error { @@ -293,7 +289,8 @@ func ResolveAwakeableAs[T any](ctx Context, id string, value T) error { if err != nil { return TerminalError(err) } - return ctx.ResolveAwakeable(id, bytes) + ctx.ResolveAwakeable(id, bytes) + return nil } type After interface { From 7e49dcf364f2f2cb8f08f6be3f1cf8a6381d81d0 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 15:16:59 +0200 Subject: [PATCH 5/6] Implement suspension --- internal/futures/futures.go | 65 ++++++++++++++-------------- internal/futures/select_test.go | 8 ++-- internal/state/awakeable.go | 21 ++++----- internal/state/call.go | 33 +++++++------- internal/state/completion.go | 1 + internal/state/state.go | 76 +++++++++++++++++++++------------ internal/state/sys.go | 39 ++++++----------- internal/wire/wire.go | 27 +++++++----- router.go | 9 ++-- 9 files changed, 145 insertions(+), 134 deletions(-) diff --git a/internal/futures/futures.go b/internal/futures/futures.go index 76d6da8..07675ef 100644 --- a/internal/futures/futures.go +++ b/internal/futures/futures.go @@ -19,16 +19,17 @@ var ( ) type After struct { - ctx context.Context - entry *wire.SleepEntryMessage + suspensionCtx context.Context + entry *wire.SleepEntryMessage + entryIndex uint32 } -func NewAfter(ctx context.Context, entry *wire.SleepEntryMessage) *After { - return &After{ctx, entry} +func NewAfter(suspensionCtx context.Context, entry *wire.SleepEntryMessage, entryIndex uint32) *After { + return &After{suspensionCtx, entry, entryIndex} } -func (a *After) Done() error { - return a.entry.Await(a.ctx) +func (a *After) Done() { + a.entry.Await(a.suspensionCtx, a.entryIndex) } func (a *After) getEntry() (wire.CompleteableMessage, error) { @@ -38,29 +39,27 @@ func (a *After) getEntry() (wire.CompleteableMessage, error) { const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1" type Awakeable struct { - ctx context.Context - invocationID []byte - entryIndex uint32 - entry *wire.AwakeableEntryMessage + suspensionCtx context.Context + invocationID []byte + entry *wire.AwakeableEntryMessage + entryIndex uint32 } -func NewAwakeable(ctx context.Context, invocationID []byte, entryIndex uint32, entry *wire.AwakeableEntryMessage) *Awakeable { - return &Awakeable{ctx, invocationID, entryIndex, entry} +func NewAwakeable(suspensionCtx context.Context, invocationID []byte, entry *wire.AwakeableEntryMessage, entryIndex uint32) *Awakeable { + return &Awakeable{suspensionCtx, invocationID, entry, entryIndex} } func (c *Awakeable) Id() string { return awakeableID(c.invocationID, c.entryIndex) } func (c *Awakeable) Result() ([]byte, error) { - if err := c.entry.Await(c.ctx); err != nil { - return nil, err - } else { - switch result := c.entry.Result.(type) { - case *protocol.AwakeableEntryMessage_Value: - return result.Value, nil - case *protocol.AwakeableEntryMessage_Failure: - return nil, errors.ErrorFromFailure(result.Failure) - default: - return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) - } + c.entry.Await(c.suspensionCtx, c.entryIndex) + + switch result := c.entry.Result.(type) { + case *protocol.AwakeableEntryMessage_Value: + return result.Value, nil + case *protocol.AwakeableEntryMessage_Failure: + return nil, errors.ErrorFromFailure(result.Failure) + default: + return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) } } func (c *Awakeable) getEntry() (wire.CompleteableMessage, error) { return c.entry, nil } @@ -73,17 +72,18 @@ func awakeableID(invocationID []byte, entryIndex uint32) string { } type ResponseFuture struct { - ctx context.Context - err error - entry *wire.CallEntryMessage + suspensionCtx context.Context + err error + entry *wire.CallEntryMessage + entryIndex uint32 } -func NewResponseFuture(ctx context.Context, entry *wire.CallEntryMessage) *ResponseFuture { - return &ResponseFuture{ctx, nil, entry} +func NewResponseFuture(suspensionCtx context.Context, entry *wire.CallEntryMessage, entryIndex uint32) *ResponseFuture { + return &ResponseFuture{suspensionCtx, nil, entry, entryIndex} } -func NewFailedResponseFuture(ctx context.Context, err error) *ResponseFuture { - return &ResponseFuture{ctx, err, nil} +func NewFailedResponseFuture(err error) *ResponseFuture { + return &ResponseFuture{nil, err, nil, 0} } func (r *ResponseFuture) Err() error { @@ -95,10 +95,7 @@ func (r *ResponseFuture) Response(output any) error { return r.err } - if err := r.entry.Await(r.ctx); err != nil { - r.err = err - return r.err - } + r.entry.Await(r.suspensionCtx, r.entryIndex) var bytes []byte switch result := r.entry.Result.(type) { diff --git a/internal/futures/select_test.go b/internal/futures/select_test.go index a9ca09b..882b8fc 100644 --- a/internal/futures/select_test.go +++ b/internal/futures/select_test.go @@ -14,16 +14,16 @@ type fakeContext struct { } func (f *fakeContext) Awakeable() restate.Awakeable[[]byte] { - return futures.NewAwakeable(context.TODO(), nil, 0, nil) + return futures.NewAwakeable(context.TODO(), nil, nil, 0) } var _ restate.Context = (*fakeContext)(nil) func TestSelect(t *testing.T) { - after := futures.NewAfter(context.TODO(), nil) - awakeableOne := futures.NewAwakeable(context.TODO(), nil, 0, nil) + after := futures.NewAfter(context.TODO(), nil, 0) + awakeableOne := futures.NewAwakeable(context.TODO(), nil, nil, 0) awakeableTwo := restate.AwakeableAs[string](&fakeContext{}) - responseFut := futures.NewResponseFuture(context.TODO(), nil) + responseFut := futures.NewResponseFuture(context.TODO(), nil, 0) // one-off (race) selector := futures.Select(context.TODO(), after, awakeableOne, awakeableTwo, responseFut) diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 710da25..05b575d 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -9,31 +9,26 @@ import ( "github.com/restatedev/sdk-go/internal/wire" ) -type indexedEntry struct { - entry *wire.AwakeableEntryMessage - entryIndex uint32 -} - func (c *Machine) awakeable() restate.Awakeable[[]byte] { - indexedEntry := replayOrNew( + entry, entryIndex := replayOrNew( c, - func(entry *wire.AwakeableEntryMessage) indexedEntry { - return indexedEntry{entry, c.entryIndex} + func(entry *wire.AwakeableEntryMessage) *wire.AwakeableEntryMessage { + return entry }, c._awakeable, ) - return futures.NewAwakeable(c.ctx, c.id, indexedEntry.entryIndex, indexedEntry.entry) + return futures.NewAwakeable(c.suspensionCtx, c.id, entry, entryIndex) } -func (c *Machine) _awakeable() indexedEntry { +func (c *Machine) _awakeable() *wire.AwakeableEntryMessage { msg := &wire.AwakeableEntryMessage{} c.Write(msg) - return indexedEntry{msg, c.entryIndex} + return msg } func (m *Machine) resolveAwakeable(id string, value []byte) { - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageValue, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Value) @@ -64,7 +59,7 @@ func (c *Machine) _resolveAwakeable(id string, value []byte) { } func (m *Machine) rejectAwakeable(id string, reason error) { - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { messageFailure, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Failure) diff --git a/internal/state/call.go b/internal/state/call.go index 6f33f10..4c987b6 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -19,14 +19,14 @@ var ( ) type serviceProxy struct { - *Context + machine *Machine service string key string } func (c *serviceProxy) Method(fn string) restate.CallClient { return &serviceCall{ - Context: c.Context, + machine: c.machine, service: c.service, key: c.key, method: fn, @@ -34,7 +34,7 @@ func (c *serviceProxy) Method(fn string) restate.CallClient { } type serviceSendProxy struct { - *Context + machine *Machine service string key string delay time.Duration @@ -42,15 +42,16 @@ type serviceSendProxy struct { func (c *serviceSendProxy) Method(fn string) restate.SendClient { return &serviceSend{ - Context: c.Context, + machine: c.machine, service: c.service, key: c.key, method: fn, + delay: c.delay, } } type serviceCall struct { - *Context + machine *Machine service string key string method string @@ -58,15 +59,15 @@ type serviceCall struct { // Do makes a call and wait for the response func (c *serviceCall) Request(input any) restate.ResponseFuture { - if msg, err := c.machine.doDynCall(c.service, c.key, c.method, input); err != nil { - return futures.NewFailedResponseFuture(c.ctx, err) + if entry, entryIndex, err := c.machine.doDynCall(c.service, c.key, c.method, input); err != nil { + return futures.NewFailedResponseFuture(err) } else { - return futures.NewResponseFuture(c.ctx, msg) + return futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex) } } type serviceSend struct { - *Context + machine *Machine service string key string method string @@ -79,19 +80,20 @@ func (c *serviceSend) Request(input any) error { return c.machine.sendCall(c.service, c.key, c.method, input, c.delay) } -func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallEntryMessage, error) { +func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallEntryMessage, uint32, error) { params, err := json.Marshal(input) if err != nil { - return nil, err + return nil, 0, err } - return m.doCall(service, key, method, params), nil + entry, entryIndex := m.doCall(service, key, method, params) + return entry, entryIndex, nil } -func (m *Machine) doCall(service, key, method string, params []byte) *wire.CallEntryMessage { +func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) { m.log.Debug().Str("service", service).Str("method", method).Str("key", key).Msg("executing sync call") - return replayOrNew( + entry, entryIndex := replayOrNew( m, func(entry *wire.CallEntryMessage) *wire.CallEntryMessage { if entry.ServiceName != service || @@ -112,6 +114,7 @@ func (m *Machine) doCall(service, key, method string, params []byte) *wire.CallE }, func() *wire.CallEntryMessage { return m._doCall(service, key, method, params) }) + return entry, entryIndex } func (m *Machine) _doCall(service, key, method string, params []byte) *wire.CallEntryMessage { @@ -136,7 +139,7 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur return err } - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || diff --git a/internal/state/completion.go b/internal/state/completion.go index 138d0d2..050cb1c 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -58,6 +58,7 @@ func (m *Machine) handleCompletionsAcks() { for { msg, err := m.protocol.Read() if err != nil { + m.suspend(err) return } switch msg := msg.(type) { diff --git a/internal/state/state.go b/internal/state/state.go index b290ec4..0479877 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -3,6 +3,7 @@ package state import ( "context" "encoding/json" + stderrors "errors" "fmt" "io" "runtime/debug" @@ -31,17 +32,13 @@ var ( ) type Context struct { - ctx context.Context + context.Context machine *Machine } var _ restate.ObjectContext = &Context{} var _ restate.Context = &Context{} -func (c *Context) Ctx() context.Context { - return c.ctx -} - func (c *Context) Set(key string, value []byte) { c.machine.set(key, value) } @@ -65,28 +62,24 @@ func (c *Context) Keys() ([]string, error) { return c.machine.keys() } -func (c *Context) Sleep(d time.Duration) error { - after, err := c.machine.after(d) - if err != nil { - return err - } - return after.Done() +func (c *Context) Sleep(d time.Duration) { + c.machine.sleep(d) } -func (c *Context) After(d time.Duration) (restate.After, error) { +func (c *Context) After(d time.Duration) restate.After { return c.machine.after(d) } func (c *Context) Service(service string) restate.ServiceClient { return &serviceProxy{ - Context: c, + machine: c.machine, service: service, } } func (c *Context) ServiceSend(service string, delay time.Duration) restate.ServiceSendClient { return &serviceSendProxy{ - Context: c, + machine: c.machine, service: service, delay: delay, } @@ -94,7 +87,7 @@ func (c *Context) ServiceSend(service string, delay time.Duration) restate.Servi func (c *Context) Object(service, key string) restate.ServiceClient { return &serviceProxy{ - Context: c, + machine: c.machine, service: service, key: key, } @@ -102,7 +95,7 @@ func (c *Context) Object(service, key string) restate.ServiceClient { func (c *Context) ObjectSend(service, key string, delay time.Duration) restate.ServiceSendClient { return &serviceSendProxy{ - Context: c, + machine: c.machine, service: service, key: key, delay: delay, @@ -130,14 +123,11 @@ func (c *Context) Key() string { } func newContext(inner context.Context, machine *Machine) *Context { - - // state := make(map[string][]byte) - // for _, entry := range start.Payload.StateMap { - // state[string(entry.Key)] = entry.Value - // } - + // will be cancelled when the http2 stream is cancelled + // but NOT when we just suspend - just because we can't get completions doesn't mean we can't make + // progress towards producing an output message ctx := &Context{ - ctx: inner, + Context: inner, machine: machine, } @@ -145,7 +135,9 @@ func newContext(inner context.Context, machine *Machine) *Context { } type Machine struct { - ctx context.Context + ctx context.Context + suspensionCtx context.Context + suspend func(error) handler restate.Handler protocol *wire.Protocol @@ -197,6 +189,7 @@ func (m *Machine) Start(inner context.Context, trace string) error { } m.ctx = inner + m.suspensionCtx, m.suspend = context.WithCancelCause(m.ctx) m.id = start.Id m.key = start.Key @@ -283,6 +276,35 @@ The journal entry at position %d was: m.log.Error().Err(err).Msg("error sending failure message") } + return + case *wire.SuspensionPanic: + if m.ctx.Err() != nil { + // the http2 request has been cancelled; just return because we can't send a response + return + } + if stderrors.Is(typ.Err, io.EOF) { + m.log.Info().Uints32("entryIndexes", typ.EntryIndexes).Msg("Suspending") + + if err := m.protocol.Write(&wire.SuspensionMessage{ + SuspensionMessage: protocol.SuspensionMessage{ + EntryIndexes: typ.EntryIndexes, + }, + }); err != nil { + m.log.Error().Err(err).Msg("error sending suspension message") + } + } else { + m.log.Error().Err(typ.Err).Uints32("entryIndexes", typ.EntryIndexes).Msg("Unexpected error reading completions; shutting down state machine") + + // don't check for error here, most likely we will fail to send if we are in such a bad state + _ = m.protocol.Write(&wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(restate.ErrorCode(typ.Err)), + Message: fmt.Sprintf("problem reading completions: %v", typ.Err), + Description: string(debug.Stack()), + }, + }) + } + return default: // unknown panic! @@ -420,7 +442,7 @@ func replayOrNew[M wire.Message, O any]( m *Machine, replay func(msg M) O, new func() O, -) (output O) { +) (output O, entryIndex uint32) { // lock around preparing the entry, but we would never await an ack or completion with this held. m.entryMutex.Lock() defer m.entryMutex.Unlock() @@ -443,10 +465,10 @@ func replayOrNew[M wire.Message, O any]( var expectedEntry M panic(m.newEntryMismatch(expectedEntry, entry)) } else { - return replay(entry) + return replay(entry), m.entryIndex } } // other wise call the new function - return new() + return new(), m.entryIndex } diff --git a/internal/state/sys.go b/internal/state/sys.go index a2d9ca2..6564eaf 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -28,7 +28,7 @@ func (m *Machine) newEntryMismatch(expectedEntry wire.Message, actualEntry wire. } func (m *Machine) set(key string, value []byte) { - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.SetStateEntryMessage) (void restate.Void) { if string(entry.Key) != key || !bytes.Equal(entry.Value, value) { @@ -59,7 +59,7 @@ func (m *Machine) _set(key string, value []byte) { } func (m *Machine) clear(key string) { - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.ClearStateEntryMessage) (void restate.Void) { if string(entry.Key) != key { @@ -91,7 +91,7 @@ func (m *Machine) _clear(key string) { } func (m *Machine) clearAll() { - _ = replayOrNew( + _, _ = replayOrNew( m, func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { return @@ -112,7 +112,7 @@ func (m *Machine) _clearAll() { } func (m *Machine) get(key string) ([]byte, error) { - entry := replayOrNew( + entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { if string(entry.Key) != key { @@ -127,9 +127,7 @@ func (m *Machine) get(key string) ([]byte, error) { return m._get(key) }) - if err := entry.Await(m.ctx); err != nil { - return nil, err - } + entry.Await(m.suspensionCtx, entryIndex) switch value := entry.Result.(type) { case *protocol.GetStateEntryMessage_Empty: @@ -185,7 +183,7 @@ func (m *Machine) _get(key string) *wire.GetStateEntryMessage { } func (m *Machine) keys() ([]string, error) { - entry := replayOrNew( + entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateKeysEntryMessage) *wire.GetStateKeysEntryMessage { return entry @@ -193,9 +191,7 @@ func (m *Machine) keys() ([]string, error) { m._keys, ) - if err := entry.Await(m.ctx); err != nil { - return nil, err - } + entry.Await(m.suspensionCtx, entryIndex) switch value := entry.Result.(type) { case *protocol.GetStateKeysEntryMessage_Failure: @@ -247,8 +243,8 @@ func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { return msg } -func (m *Machine) after(d time.Duration) (restate.After, error) { - entry := replayOrNew( +func (m *Machine) after(d time.Duration) restate.After { + entry, entryIndex := replayOrNew( m, func(entry *wire.SleepEntryMessage) *wire.SleepEntryMessage { // we shouldn't verify the time because this would be different every time @@ -258,16 +254,11 @@ func (m *Machine) after(d time.Duration) (restate.After, error) { }, ) - return futures.NewAfter(m.ctx, entry), nil + return futures.NewAfter(m.suspensionCtx, entry, entryIndex) } -func (m *Machine) sleep(d time.Duration) error { - after, err := m.after(d) - if err != nil { - return err - } - - return after.Done() +func (m *Machine) sleep(d time.Duration) { + m.after(d).Done() } // _sleep creating a new sleep entry. @@ -283,7 +274,7 @@ func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { } func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { - entry := replayOrNew( + entry, entryIndex := replayOrNew( m, func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { return entry @@ -294,9 +285,7 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { ) // side effect must be acknowledged before proceeding - if err := entry.Await(m.ctx); err != nil { - return nil, err - } + entry.Await(m.suspensionCtx, entryIndex) switch result := entry.Result.(type) { case *protocol.RunEntryMessage_Failure: diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 62eb741..b958ee6 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -548,7 +548,7 @@ type EntryAckMessage struct { type CompleteableMessage interface { Message Completed() bool - Await(ctx context.Context) error + Await(suspensionCtx context.Context, entryIndex uint32) Complete(*protocol.CompletionMessage) } @@ -570,17 +570,17 @@ func (c *completable) Completed() bool { return c.completed.Load() } -func (c *completable) Await(ctx context.Context) error { +func (c *completable) Await(suspensionCtx context.Context, entryIndex uint32) { c.init() if c.completed.Load() { // fast path - return nil + return } select { - case <-ctx.Done(): - return ctx.Err() + case <-suspensionCtx.Done(): + panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: suspensionCtx.Err()}) case <-c.done: - return nil + return } } @@ -619,17 +619,17 @@ func (c *ackable) Acked() bool { return c.acked.Load() } -func (c *ackable) Await(ctx context.Context) error { +func (c *ackable) Await(suspensionCtx context.Context, entryIndex uint32) { c.init() if c.acked.Load() { // fast path - return nil + return } select { - case <-ctx.Done(): - return ctx.Err() + case <-suspensionCtx.Done(): + panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: suspensionCtx.Err()}) case <-c.done: - return nil + return } } @@ -642,3 +642,8 @@ func (c *ackable) Ack() { // already completed } } + +type SuspensionPanic struct { + EntryIndexes []uint32 + Err error +} diff --git a/router.go b/router.go index 450813e..b1253ed 100644 --- a/router.go +++ b/router.go @@ -44,13 +44,12 @@ type ServiceSendClient interface { } type Context interface { - // Context of request. - Ctx() context.Context + context.Context // Sleep for the duration d - Sleep(d time.Duration) error + Sleep(d time.Duration) // Return a handle on a sleep duration which can be combined - After(d time.Duration) (After, error) + After(d time.Duration) After // Service gets a Service accessor by name where service // must be another service known by restate runtime @@ -294,6 +293,6 @@ func ResolveAwakeableAs[T any](ctx Context, id string, value T) error { } type After interface { - Done() error + Done() futures.Selectable } From fe268eea939d785306eac14b7676932e434cc9dc Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 11 Jul 2024 17:51:00 +0200 Subject: [PATCH 6/6] Implement selector --- generated/proto/go/go.pb.go | 180 +++++++++++++++++ generated/proto/javascript/javascript.pb.go | 186 ------------------ internal/futures/futures.go | 16 +- internal/futures/select.go | 106 ++++++---- internal/futures/select_test.go | 78 -------- internal/state/completion.go | 6 + internal/state/select.go | 74 +++++++ internal/state/state.go | 5 + internal/wire/wire.go | 55 +++++- .../javascript.proto => go/go.proto} | 11 +- router.go | 9 +- 11 files changed, 402 insertions(+), 324 deletions(-) create mode 100644 generated/proto/go/go.pb.go delete mode 100644 generated/proto/javascript/javascript.pb.go delete mode 100644 internal/futures/select_test.go create mode 100644 internal/state/select.go rename proto/{javascript/javascript.proto => go/go.proto} (72%) diff --git a/generated/proto/go/go.pb.go b/generated/proto/go/go.pb.go new file mode 100644 index 0000000..1e21a8b --- /dev/null +++ b/generated/proto/go/go.pb.go @@ -0,0 +1,180 @@ +// +// Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate SDK for Node.js/TypeScript, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc (unknown) +// source: proto/go/go.proto + +package _go + +import ( + _ "github.com/restatedev/sdk-go/generated/proto/protocol" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Type: 0xFC00 + 3 +type SelectorEntryMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + JournalEntries []uint32 `protobuf:"varint,1,rep,packed,name=journal_entries,json=journalEntries,proto3" json:"journal_entries,omitempty"` + WinningEntryIndex uint32 `protobuf:"varint,2,opt,name=winning_entry_index,json=winningEntryIndex,proto3" json:"winning_entry_index,omitempty"` +} + +func (x *SelectorEntryMessage) Reset() { + *x = SelectorEntryMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_go_go_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SelectorEntryMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SelectorEntryMessage) ProtoMessage() {} + +func (x *SelectorEntryMessage) ProtoReflect() protoreflect.Message { + mi := &file_proto_go_go_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SelectorEntryMessage.ProtoReflect.Descriptor instead. +func (*SelectorEntryMessage) Descriptor() ([]byte, []int) { + return file_proto_go_go_proto_rawDescGZIP(), []int{0} +} + +func (x *SelectorEntryMessage) GetJournalEntries() []uint32 { + if x != nil { + return x.JournalEntries + } + return nil +} + +func (x *SelectorEntryMessage) GetWinningEntryIndex() uint32 { + if x != nil { + return x.WinningEntryIndex + } + return 0 +} + +var File_proto_go_go_proto protoreflect.FileDescriptor + +var file_proto_go_go_proto_rawDesc = []byte{ + 0x0a, 0x11, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x67, 0x6f, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x64, 0x65, 0x76, 0x2e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, + 0x2e, 0x73, 0x64, 0x6b, 0x2e, 0x67, 0x6f, 0x1a, 0x1d, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6f, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x6f, 0x72, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x27, + 0x0a, 0x0f, 0x6a, 0x6f, 0x75, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x65, 0x6e, 0x74, 0x72, 0x69, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x0e, 0x6a, 0x6f, 0x75, 0x72, 0x6e, 0x61, 0x6c, + 0x45, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x13, 0x77, 0x69, 0x6e, 0x6e, 0x69, + 0x6e, 0x67, 0x5f, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x11, 0x77, 0x69, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x45, 0x6e, 0x74, + 0x72, 0x79, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x42, 0xbe, 0x01, 0x0a, 0x16, 0x63, 0x6f, 0x6d, 0x2e, + 0x64, 0x65, 0x76, 0x2e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x73, 0x64, 0x6b, 0x2e, + 0x67, 0x6f, 0x42, 0x07, 0x47, 0x6f, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2f, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x64, 0x65, 0x76, 0x2f, 0x73, 0x64, 0x6b, 0x2d, 0x67, 0x6f, 0x2f, 0x67, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x74, 0x65, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0xa2, 0x02, + 0x04, 0x44, 0x52, 0x53, 0x47, 0xaa, 0x02, 0x12, 0x44, 0x65, 0x76, 0x2e, 0x52, 0x65, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x2e, 0x53, 0x64, 0x6b, 0x2e, 0x47, 0x6f, 0xca, 0x02, 0x12, 0x44, 0x65, 0x76, + 0x5c, 0x52, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5c, 0x53, 0x64, 0x6b, 0x5c, 0x47, 0x6f, 0xe2, + 0x02, 0x1e, 0x44, 0x65, 0x76, 0x5c, 0x52, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5c, 0x53, 0x64, + 0x6b, 0x5c, 0x47, 0x6f, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0xea, 0x02, 0x15, 0x44, 0x65, 0x76, 0x3a, 0x3a, 0x52, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x3a, + 0x3a, 0x53, 0x64, 0x6b, 0x3a, 0x3a, 0x47, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_proto_go_go_proto_rawDescOnce sync.Once + file_proto_go_go_proto_rawDescData = file_proto_go_go_proto_rawDesc +) + +func file_proto_go_go_proto_rawDescGZIP() []byte { + file_proto_go_go_proto_rawDescOnce.Do(func() { + file_proto_go_go_proto_rawDescData = protoimpl.X.CompressGZIP(file_proto_go_go_proto_rawDescData) + }) + return file_proto_go_go_proto_rawDescData +} + +var file_proto_go_go_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proto_go_go_proto_goTypes = []interface{}{ + (*SelectorEntryMessage)(nil), // 0: dev.restate.sdk.go.SelectorEntryMessage +} +var file_proto_go_go_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_proto_go_go_proto_init() } +func file_proto_go_go_proto_init() { + if File_proto_go_go_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proto_go_go_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SelectorEntryMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proto_go_go_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proto_go_go_proto_goTypes, + DependencyIndexes: file_proto_go_go_proto_depIdxs, + MessageInfos: file_proto_go_go_proto_msgTypes, + }.Build() + File_proto_go_go_proto = out.File + file_proto_go_go_proto_rawDesc = nil + file_proto_go_go_proto_goTypes = nil + file_proto_go_go_proto_depIdxs = nil +} diff --git a/generated/proto/javascript/javascript.pb.go b/generated/proto/javascript/javascript.pb.go deleted file mode 100644 index 114ee38..0000000 --- a/generated/proto/javascript/javascript.pb.go +++ /dev/null @@ -1,186 +0,0 @@ -// -// Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate SDK for Node.js/TypeScript, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-typescript/blob/main/LICENSE - -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.26.0 -// protoc (unknown) -// source: proto/javascript/javascript.proto - -package javascript - -import ( - _ "github.com/restatedev/sdk-go/generated/proto/protocol" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -// Type: 0xFC00 + 2 -type CombinatorEntryMessage struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - CombinatorId int32 `protobuf:"varint,1,opt,name=combinator_id,json=combinatorId,proto3" json:"combinator_id,omitempty"` - JournalEntriesOrder []int32 `protobuf:"varint,2,rep,packed,name=journal_entries_order,json=journalEntriesOrder,proto3" json:"journal_entries_order,omitempty"` -} - -func (x *CombinatorEntryMessage) Reset() { - *x = CombinatorEntryMessage{} - if protoimpl.UnsafeEnabled { - mi := &file_proto_javascript_javascript_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *CombinatorEntryMessage) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*CombinatorEntryMessage) ProtoMessage() {} - -func (x *CombinatorEntryMessage) ProtoReflect() protoreflect.Message { - mi := &file_proto_javascript_javascript_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use CombinatorEntryMessage.ProtoReflect.Descriptor instead. -func (*CombinatorEntryMessage) Descriptor() ([]byte, []int) { - return file_proto_javascript_javascript_proto_rawDescGZIP(), []int{0} -} - -func (x *CombinatorEntryMessage) GetCombinatorId() int32 { - if x != nil { - return x.CombinatorId - } - return 0 -} - -func (x *CombinatorEntryMessage) GetJournalEntriesOrder() []int32 { - if x != nil { - return x.JournalEntriesOrder - } - return nil -} - -var File_proto_javascript_javascript_proto protoreflect.FileDescriptor - -var file_proto_javascript_javascript_proto_rawDesc = []byte{ - 0x0a, 0x21, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, - 0x70, 0x74, 0x2f, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x12, 0x1a, 0x64, 0x65, 0x76, 0x2e, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, - 0x2e, 0x73, 0x64, 0x6b, 0x2e, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x1a, - 0x1d, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x71, - 0x0a, 0x16, 0x43, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x45, 0x6e, 0x74, 0x72, - 0x79, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x6f, 0x6d, 0x62, - 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, - 0x0c, 0x63, 0x6f, 0x6d, 0x62, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x49, 0x64, 0x12, 0x32, 0x0a, - 0x15, 0x6a, 0x6f, 0x75, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x65, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, - 0x5f, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x18, 0x02, 0x20, 0x03, 0x28, 0x05, 0x52, 0x13, 0x6a, 0x6f, - 0x75, 0x72, 0x6e, 0x61, 0x6c, 0x45, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x4f, 0x72, 0x64, 0x65, - 0x72, 0x42, 0xf6, 0x01, 0x0a, 0x1e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x65, 0x76, 0x2e, 0x72, 0x65, - 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x73, 0x64, 0x6b, 0x2e, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0x42, 0x0f, 0x4a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x37, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x64, 0x65, 0x76, 0x2f, 0x73, - 0x64, 0x6b, 0x2d, 0x67, 0x6f, 0x2f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, - 0xa2, 0x02, 0x04, 0x44, 0x52, 0x53, 0x4a, 0xaa, 0x02, 0x1a, 0x44, 0x65, 0x76, 0x2e, 0x52, 0x65, - 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x53, 0x64, 0x6b, 0x2e, 0x4a, 0x61, 0x76, 0x61, 0x73, 0x63, - 0x72, 0x69, 0x70, 0x74, 0xca, 0x02, 0x1a, 0x44, 0x65, 0x76, 0x5c, 0x52, 0x65, 0x73, 0x74, 0x61, - 0x74, 0x65, 0x5c, 0x53, 0x64, 0x6b, 0x5c, 0x4a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, - 0x74, 0xe2, 0x02, 0x26, 0x44, 0x65, 0x76, 0x5c, 0x52, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5c, - 0x53, 0x64, 0x6b, 0x5c, 0x4a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x5c, 0x47, - 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1d, 0x44, 0x65, 0x76, - 0x3a, 0x3a, 0x52, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x3a, 0x3a, 0x53, 0x64, 0x6b, 0x3a, 0x3a, - 0x4a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, -} - -var ( - file_proto_javascript_javascript_proto_rawDescOnce sync.Once - file_proto_javascript_javascript_proto_rawDescData = file_proto_javascript_javascript_proto_rawDesc -) - -func file_proto_javascript_javascript_proto_rawDescGZIP() []byte { - file_proto_javascript_javascript_proto_rawDescOnce.Do(func() { - file_proto_javascript_javascript_proto_rawDescData = protoimpl.X.CompressGZIP(file_proto_javascript_javascript_proto_rawDescData) - }) - return file_proto_javascript_javascript_proto_rawDescData -} - -var file_proto_javascript_javascript_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_proto_javascript_javascript_proto_goTypes = []interface{}{ - (*CombinatorEntryMessage)(nil), // 0: dev.restate.sdk.javascript.CombinatorEntryMessage -} -var file_proto_javascript_javascript_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_proto_javascript_javascript_proto_init() } -func file_proto_javascript_javascript_proto_init() { - if File_proto_javascript_javascript_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_proto_javascript_javascript_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CombinatorEntryMessage); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_proto_javascript_javascript_proto_rawDesc, - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_proto_javascript_javascript_proto_goTypes, - DependencyIndexes: file_proto_javascript_javascript_proto_depIdxs, - MessageInfos: file_proto_javascript_javascript_proto_msgTypes, - }.Build() - File_proto_javascript_javascript_proto = out.File - file_proto_javascript_javascript_proto_rawDesc = nil - file_proto_javascript_javascript_proto_goTypes = nil - file_proto_javascript_javascript_proto_depIdxs = nil -} diff --git a/internal/futures/futures.go b/internal/futures/futures.go index 07675ef..148f078 100644 --- a/internal/futures/futures.go +++ b/internal/futures/futures.go @@ -32,8 +32,8 @@ func (a *After) Done() { a.entry.Await(a.suspensionCtx, a.entryIndex) } -func (a *After) getEntry() (wire.CompleteableMessage, error) { - return a.entry, nil +func (a *After) getEntry() (wire.CompleteableMessage, uint32, error) { + return a.entry, a.entryIndex, nil } const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1" @@ -62,7 +62,9 @@ func (c *Awakeable) Result() ([]byte, error) { return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) } } -func (c *Awakeable) getEntry() (wire.CompleteableMessage, error) { return c.entry, nil } +func (c *Awakeable) getEntry() (wire.CompleteableMessage, uint32, error) { + return c.entry, c.entryIndex, nil +} func awakeableID(invocationID []byte, entryIndex uint32) string { bytes := make([]byte, 0, len(invocationID)+4) @@ -86,10 +88,6 @@ func NewFailedResponseFuture(err error) *ResponseFuture { return &ResponseFuture{nil, err, nil, 0} } -func (r *ResponseFuture) Err() error { - return r.err -} - func (r *ResponseFuture) Response(output any) error { if r.err != nil { return r.err @@ -116,6 +114,6 @@ func (r *ResponseFuture) Response(output any) error { return nil } -func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, error) { - return r.entry, r.err +func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32, error) { + return r.entry, r.entryIndex, r.err } diff --git a/internal/futures/select.go b/internal/futures/select.go index 87d12db..fc3674d 100644 --- a/internal/futures/select.go +++ b/internal/futures/select.go @@ -2,66 +2,98 @@ package futures import ( "context" + "reflect" + "slices" "github.com/restatedev/sdk-go/internal/wire" ) type Selectable interface { - getEntry() (wire.CompleteableMessage, error) + getEntry() (wire.CompleteableMessage, uint32, error) } -type Selector interface { - Select() bool - Err() error - Result() Selectable +type Selector struct { + suspensionCtx context.Context + indexedFuts map[uint32]Selectable + indexedChans map[uint32]<-chan struct{} + chosen Selectable + err error } -type selector struct { - ctx context.Context - futs []Selectable - chosen Selectable - err error -} - -func (s *selector) Select() bool { +func (s *Selector) Select() (uint32, bool) { if s.err != nil { - return false + return 0, false } - if len(s.futs) == 0 { - return false + if len(s.indexedFuts) == 0 { + return 0, false } - // todo pick what element to taike - i := len(s.futs) - 1 + indexes := s.Indexes() + cases := make([]reflect.SelectCase, len(indexes)+1) + for i, entryIndex := range indexes { + cases[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(s.indexedChans[entryIndex]), + } - // pick future - s.chosen = s.futs[i] - // swap it to the end - s.futs[i], s.futs[len(s.futs)-1] = s.futs[len(s.futs)-1], s.futs[i] - // trim - s.futs = s.futs[:len(s.futs)-1] + } + cases[len(indexes)] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(s.suspensionCtx.Done()), + } + chosen, _, _ := reflect.Select(cases) + switch chosen { + case len(indexes): + // suspensionCtx won + panic(&wire.SuspensionPanic{EntryIndexes: indexes, Err: context.Cause(s.suspensionCtx)}) + default: + return indexes[chosen], true + } +} - return true +func (s *Selector) Take(winningEntryIndex uint32) Selectable { + selectable := s.indexedFuts[winningEntryIndex] + if selectable == nil { + return nil + } + entry, _, err := selectable.getEntry() + if err != nil { + return nil + } + if !entry.Completed() { + return nil + } + delete(s.indexedFuts, winningEntryIndex) + delete(s.indexedChans, winningEntryIndex) + return selectable } -func (s *selector) Result() Selectable { - // TODO - return s.chosen +func (s *Selector) Remaining() bool { + return len(s.indexedFuts) > 0 } -func (s *selector) Err() error { - // TODO - return s.err +func (s *Selector) Indexes() []uint32 { + indexes := make([]uint32, 0, len(s.indexedFuts)) + for i := range s.indexedFuts { + indexes = append(indexes, i) + } + slices.Sort(indexes) + return indexes } -func Select(ctx context.Context, futs ...Selectable) Selector { - s := &selector{ctx: ctx, futs: futs} +func Select(suspensionCtx context.Context, futs ...Selectable) (*Selector, error) { + s := &Selector{ + suspensionCtx: suspensionCtx, + indexedFuts: make(map[uint32]Selectable, len(futs)), + indexedChans: make(map[uint32]<-chan struct{}, len(futs)), + } for i := range futs { - _, err := futs[i].getEntry() + entry, entryIndex, err := futs[i].getEntry() if err != nil { - s.err = err - break + return nil, err } + s.indexedFuts[entryIndex] = futs[i] + s.indexedChans[entryIndex] = entry.Done() } - return s + return s, nil } diff --git a/internal/futures/select_test.go b/internal/futures/select_test.go deleted file mode 100644 index 882b8fc..0000000 --- a/internal/futures/select_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package futures_test - -import ( - "context" - "fmt" - "testing" - - restate "github.com/restatedev/sdk-go" - "github.com/restatedev/sdk-go/internal/futures" -) - -type fakeContext struct { - restate.Context -} - -func (f *fakeContext) Awakeable() restate.Awakeable[[]byte] { - return futures.NewAwakeable(context.TODO(), nil, nil, 0) -} - -var _ restate.Context = (*fakeContext)(nil) - -func TestSelect(t *testing.T) { - after := futures.NewAfter(context.TODO(), nil, 0) - awakeableOne := futures.NewAwakeable(context.TODO(), nil, nil, 0) - awakeableTwo := restate.AwakeableAs[string](&fakeContext{}) - responseFut := futures.NewResponseFuture(context.TODO(), nil, 0) - - // one-off (race) - selector := futures.Select(context.TODO(), after, awakeableOne, awakeableTwo, responseFut) - if !selector.Select() { - t.Fatal(selector.Err()) - } - switch selector.Result() { - case after: - t.Log("after won") - case awakeableOne: - t.Log("awakeable one won") - case awakeableTwo: - t.Log("awakeable two won") - case responseFut: - t.Log("response won") - } - - // or as a loop (all or any) - selector = futures.Select(context.TODO(), after, awakeableOne, awakeableTwo, responseFut) - for selector.Select() { - switch selector.Result() { - case after: - t.Log("after") - case awakeableOne: - t.Log("awakeable one") - case awakeableTwo: - t.Log("awakeable two") - case responseFut: - t.Log("response") - } - } - - if selector.Err() != nil { - t.Fatal(selector.Err()) - } - -} - -func TestFailedSelect(t *testing.T) { - err := fmt.Errorf("oops") - failedResponseFut := futures.NewFailedResponseFuture(context.TODO(), err) - selector := futures.Select(context.TODO(), failedResponseFut) - if selector.Select() { - t.Fatal("Select() should return false immediately") - } - if selector.Err() == nil { - t.Fatal("Err() should return an error") - } - if selector.Err() != err { - t.Fatalf("Err() returned an unexpected err: %v", err) - } -} diff --git a/internal/state/completion.go b/internal/state/completion.go index 050cb1c..111e597 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -1,6 +1,9 @@ package state import ( + "errors" + "io" + "github.com/restatedev/sdk-go/internal/wire" ) @@ -58,6 +61,9 @@ func (m *Machine) handleCompletionsAcks() { for { msg, err := m.protocol.Read() if err != nil { + if errors.Is(err, io.EOF) { + m.log.Trace().Err(err).Msg("request body closed; next blocking operation will suspend") + } m.suspend(err) return } diff --git a/internal/state/select.go b/internal/state/select.go new file mode 100644 index 0000000..1ea21e1 --- /dev/null +++ b/internal/state/select.go @@ -0,0 +1,74 @@ +package state + +import ( + "slices" + + _go "github.com/restatedev/sdk-go/generated/proto/go" + "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/wire" +) + +type selector struct { + machine *Machine + inner *futures.Selector +} + +func (m *Machine) selector(futs ...futures.Selectable) (*selector, error) { + inner, err := futures.Select(m.suspensionCtx, futs...) + if err != nil { + return nil, err + } + return &selector{m, inner}, nil +} + +func (s *selector) Select() futures.Selectable { + entry, entryIndex := replayOrNew( + s.machine, + func(entry *wire.SelectorEntryMessage) *wire.SelectorEntryMessage { + indexes := s.inner.Indexes() + if !slices.Equal(entry.JournalEntries, indexes) { + panic(s.machine.newEntryMismatch(&wire.SelectorEntryMessage{ + SelectorEntryMessage: _go.SelectorEntryMessage{ + JournalEntries: indexes, + }, + }, entry)) + } + return entry + }, + func() *wire.SelectorEntryMessage { + return s._select() + }, + ) + + if entry == nil { + // no futures left to select + return nil + } + + // selector entry must be acknowledged before proceeding + entry.Await(s.machine.suspensionCtx, entryIndex) + return s.inner.Take(entry.WinningEntryIndex) +} + +func (s *selector) Remaining() bool { + return s.inner.Remaining() +} + +func (s *selector) _select() *wire.SelectorEntryMessage { + indexes := s.inner.Indexes() + winningEntryIndex, ok := s.inner.Select() + if !ok { + // no more promises left, we don't need to write this to the journal + return nil + } + + entry := &wire.SelectorEntryMessage{ + SelectorEntryMessage: _go.SelectorEntryMessage{ + JournalEntries: indexes, + WinningEntryIndex: winningEntryIndex, + }, + } + s.machine.Write(entry) + + return entry +} diff --git a/internal/state/state.go b/internal/state/state.go index 0479877..929f829 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -13,6 +13,7 @@ import ( restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal/errors" + "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/wire" "github.com/rs/zerolog" @@ -118,6 +119,10 @@ func (c *Context) RejectAwakeable(id string, reason error) { c.machine.rejectAwakeable(id, reason) } +func (c *Context) Selector(futs ...futures.Selectable) (restate.Selector, error) { + return c.machine.selector(futs...) +} + func (c *Context) Key() string { return c.machine.key } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index b958ee6..0eb048e 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -11,6 +11,7 @@ import ( "sync" "sync/atomic" + _go "github.com/restatedev/sdk-go/generated/proto/go" protocol "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -55,6 +56,9 @@ const ( AwakeableEntryMessageType Type = 0x0C00 + 3 CompleteAwakeableEntryMessageType Type = 0x0C00 + 4 RunEntryMessageType Type = 0x0C00 + 5 + + // Custom + SelectorEntryMessageType Type = 0xFC03 ) type Type uint16 @@ -110,6 +114,8 @@ func MessageType(message Message) Type { return OutputEntryMessageType case *ErrorMessage: return ErrorMessageType + case *EntryAckMessage: + return EntryAckMessageType case *EndMessage: return EndMessageType case *GetStateEntryMessage: @@ -134,6 +140,8 @@ func MessageType(message Message) Type { return CompleteAwakeableEntryMessageType case *RunEntryMessage: return RunEntryMessageType + case *SelectorEntryMessage: + return SelectorEntryMessageType } panic(fmt.Sprintf("unknown message type %T", message)) } @@ -369,6 +377,14 @@ var ( // replayed side effects are inherently acked msg.Ack() + return msg, proto.Unmarshal(bytes, msg) + }, + SelectorEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &SelectorEntryMessage{} + + // replayed selectors are inherently acked + msg.Ack() + return msg, proto.Unmarshal(bytes, msg) }, } @@ -404,11 +420,18 @@ type EndMessage struct { protocol.EndMessage } +type EntryAckMessage struct { + Header + protocol.EntryAckMessage +} + type GetStateEntryMessage struct { completable protocol.GetStateEntryMessage } +var _ CompleteableMessage = (*GetStateEntryMessage)(nil) + func (a *GetStateEntryMessage) Complete(c *protocol.CompletionMessage) { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: @@ -442,6 +465,8 @@ type GetStateKeysEntryMessage struct { protocol.GetStateKeysEntryMessage } +var _ CompleteableMessage = (*GetStateKeysEntryMessage)(nil) + func (a *GetStateKeysEntryMessage) Complete(c *protocol.CompletionMessage) { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: @@ -473,6 +498,8 @@ type SleepEntryMessage struct { protocol.SleepEntryMessage } +var _ CompleteableMessage = (*SleepEntryMessage)(nil) + func (a *SleepEntryMessage) Complete(c *protocol.CompletionMessage) { switch result := c.Result.(type) { case *protocol.CompletionMessage_Empty: @@ -492,6 +519,8 @@ type CallEntryMessage struct { protocol.CallEntryMessage } +var _ CompleteableMessage = (*CallEntryMessage)(nil) + func (a *CallEntryMessage) Complete(c *protocol.CompletionMessage) { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: @@ -516,6 +545,8 @@ type AwakeableEntryMessage struct { protocol.AwakeableEntryMessage } +var _ CompleteableMessage = (*AwakeableEntryMessage)(nil) + func (a *AwakeableEntryMessage) Complete(c *protocol.CompletionMessage) { switch result := c.Result.(type) { case *protocol.CompletionMessage_Value: @@ -540,13 +571,19 @@ type RunEntryMessage struct { protocol.RunEntryMessage } -type EntryAckMessage struct { - Header - protocol.EntryAckMessage +var _ AckableMessage = (*RunEntryMessage)(nil) + +type SelectorEntryMessage struct { + ackable + _go.SelectorEntryMessage } +var _ AckableMessage = (*SelectorEntryMessage)(nil) + type CompleteableMessage interface { Message + // only for use in selector + Done() <-chan struct{} Completed() bool Await(suspensionCtx context.Context, entryIndex uint32) Complete(*protocol.CompletionMessage) @@ -570,6 +607,12 @@ func (c *completable) Completed() bool { return c.completed.Load() } +func (c *completable) Done() <-chan struct{} { + c.init() + + return c.done +} + func (c *completable) Await(suspensionCtx context.Context, entryIndex uint32) { c.init() if c.completed.Load() { @@ -578,7 +621,7 @@ func (c *completable) Await(suspensionCtx context.Context, entryIndex uint32) { } select { case <-suspensionCtx.Done(): - panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: suspensionCtx.Err()}) + panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: context.Cause(suspensionCtx)}) case <-c.done: return } @@ -597,7 +640,7 @@ func (c *completable) complete() { type AckableMessage interface { Message Acked() bool - Await(ctx context.Context) error + Await(ctx context.Context, entryIndex uint32) Ack() } @@ -627,7 +670,7 @@ func (c *ackable) Await(suspensionCtx context.Context, entryIndex uint32) { } select { case <-suspensionCtx.Done(): - panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: suspensionCtx.Err()}) + panic(&SuspensionPanic{EntryIndexes: []uint32{entryIndex}, Err: context.Cause(suspensionCtx)}) case <-c.done: return } diff --git a/proto/javascript/javascript.proto b/proto/go/go.proto similarity index 72% rename from proto/javascript/javascript.proto rename to proto/go/go.proto index f774986..be0c0ec 100644 --- a/proto/javascript/javascript.proto +++ b/proto/go/go.proto @@ -11,13 +11,12 @@ syntax = "proto3"; -package dev.restate.sdk.javascript; +package dev.restate.sdk.go; import "proto/protocol/protocol.proto"; -// Type: 0xFC00 + 2 -message CombinatorEntryMessage { - int32 combinator_id = 1; - - repeated int32 journal_entries_order = 2; +// Type: 0xFC00 + 3 +message SelectorEntryMessage { + repeated uint32 journal_entries = 1; + uint32 winning_entry_index = 2; } diff --git a/router.go b/router.go index b1253ed..8a97c4f 100644 --- a/router.go +++ b/router.go @@ -26,8 +26,6 @@ type SendClient interface { } type ResponseFuture interface { - // Err returns errors that occurred when sending off the request, without having to wait for the response - Err() error // Response waits for the response to the call and unmarshals it into output Response(output any) error futures.Selectable @@ -43,6 +41,11 @@ type ServiceSendClient interface { Method(method string) SendClient } +type Selector interface { + Remaining() bool + Select() futures.Selectable +} + type Context interface { context.Context @@ -78,6 +81,8 @@ type Context interface { Awakeable() Awakeable[[]byte] ResolveAwakeable(id string, value []byte) RejectAwakeable(id string, reason error) + + Selector(futs ...futures.Selectable) (Selector, error) } // Router interface