diff --git a/pkg/client/gomock_reflect_3526400147/prog.go b/pkg/client/gomock_reflect_3526400147/prog.go new file mode 100644 index 000000000..6003ba81a --- /dev/null +++ b/pkg/client/gomock_reflect_3526400147/prog.go @@ -0,0 +1,66 @@ +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "os" + "path" + "reflect" + + "github.com/golang/mock/mockgen/model" + + pkg_ "github.com/pokt-network/poktroll/pkg/client" +) + +var output = flag.String("output", "", "The output file name, or empty to use stdout.") + +func main() { + flag.Parse() + + its := []struct { + sym string + typ reflect.Type + }{ + + {"TxContext", reflect.TypeOf((*pkg_.TxContext)(nil)).Elem()}, + + {"TxClient", reflect.TypeOf((*pkg_.TxClient)(nil)).Elem()}, + } + pkg := &model.Package{ + // NOTE: This behaves contrary to documented behaviour if the + // package name is not the final component of the import path. + // The reflect package doesn't expose the package name, though. + Name: path.Base("github.com/pokt-network/poktroll/pkg/client"), + } + + for _, it := range its { + intf, err := model.InterfaceFromInterfaceType(it.typ) + if err != nil { + fmt.Fprintf(os.Stderr, "Reflection: %v\n", err) + os.Exit(1) + } + intf.Name = it.sym + pkg.Interfaces = append(pkg.Interfaces, intf) + } + + outfile := os.Stdout + if len(*output) != 0 { + var err error + outfile, err = os.Create(*output) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open output file %q", *output) + } + defer func() { + if err := outfile.Close(); err != nil { + fmt.Fprintf(os.Stderr, "failed to close output file %q", *output) + os.Exit(1) + } + }() + } + + if err := gob.NewEncoder(outfile).Encode(pkg); err != nil { + fmt.Fprintf(os.Stderr, "gob encode: %v\n", err) + os.Exit(1) + } +} diff --git a/x/supplier/types/errors.go b/x/supplier/types/errors.go index 9403950e1..bc5a32be9 100644 --- a/x/supplier/types/errors.go +++ b/x/supplier/types/errors.go @@ -15,4 +15,6 @@ var ( ErrSupplierInvalidServiceConfig = sdkerrors.Register(ModuleName, 5, "invalid service config") ErrSupplierInvalidSessionStartHeight = sdkerrors.Register(ModuleName, 6, "invalid session start height") ErrSupplierInvalidSessionId = sdkerrors.Register(ModuleName, 7, "invalid session ID") + ErrSupplierInvalidService = sdkerrors.Register(ModuleName, 8, "invalid service in supplier") + ErrSupplierInvalidClaimRootHash = sdkerrors.Register(ModuleName, 9, "invalid root hash") ) diff --git a/x/supplier/types/message_create_claim.go b/x/supplier/types/message_create_claim.go index 7d68c6a94..0e702c956 100644 --- a/x/supplier/types/message_create_claim.go +++ b/x/supplier/types/message_create_claim.go @@ -5,6 +5,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedhelpers "github.com/pokt-network/poktroll/x/shared/helpers" ) const TypeMsgCreateClaim = "create_claim" @@ -41,9 +42,29 @@ func (msg *MsgCreateClaim) GetSignBytes() []byte { } func (msg *MsgCreateClaim) ValidateBasic() error { + // Validate the supplier address _, err := sdk.AccAddressFromBech32(msg.SupplierAddress) if err != nil { return sdkerrors.Wrapf(ErrSupplierInvalidAddress, "invalid supplierAddress address (%s)", err) } + + // Validate the session header + sessionHeader := msg.SessionHeader + if sessionHeader.SessionStartBlockHeight < 1 { + return sdkerrors.Wrapf(ErrSupplierInvalidSessionStartHeight, "invalid session start block height (%d)", sessionHeader.SessionStartBlockHeight) + } + if len(sessionHeader.SessionId) == 0 { + return sdkerrors.Wrapf(ErrSupplierInvalidSessionId, "invalid session ID (%v)", sessionHeader.SessionId) + } + if !sharedhelpers.IsValidService(sessionHeader.Service) { + return sdkerrors.Wrapf(ErrSupplierInvalidService, "invalid service (%v)", sessionHeader.Service) + } + + // Validate the root hash + // TODO_IMPROVE: Only checking to make sure a non-nil hash was provided for now, but we can validate the length as well. + if len(msg.RootHash) == 0 { + return sdkerrors.Wrapf(ErrSupplierInvalidClaimRootHash, "invalid root hash (%v)", msg.RootHash) + } + return nil } diff --git a/x/supplier/types/message_create_claim_test.go b/x/supplier/types/message_create_claim_test.go index 8401c0d3d..c46647ebe 100644 --- a/x/supplier/types/message_create_claim_test.go +++ b/x/supplier/types/message_create_claim_test.go @@ -6,37 +6,104 @@ import ( "github.com/stretchr/testify/require" "github.com/pokt-network/poktroll/testutil/sample" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) -// TODO(@bryanchriswhite): Add unit tests for message validation when adding the business logic. - func TestMsgCreateClaim_ValidateBasic(t *testing.T) { tests := []struct { - name string - msg MsgCreateClaim - err error + desc string + + msg MsgCreateClaim + err error }{ { - name: "invalid address", + desc: "invalid address", + msg: MsgCreateClaim{ SupplierAddress: "invalid_address", }, err: ErrSupplierInvalidAddress, - }, { - name: "valid address", + }, + { + desc: "valid address but invalid session start height", + + msg: MsgCreateClaim{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + SessionStartBlockHeight: 0, // Invalid start height + }, + }, + err: ErrSupplierInvalidSessionStartHeight, + }, + { + desc: "valid address and session start height but invalid session ID", + + msg: MsgCreateClaim{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + SessionStartBlockHeight: 100, + SessionId: "", // Invalid session ID + }, + }, + err: ErrSupplierInvalidSessionId, + }, + { + desc: "valid address, session start height, session ID but invalid service", + + msg: MsgCreateClaim{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + SessionStartBlockHeight: 100, + SessionId: "valid_session_id", + Service: &sharedtypes.Service{ + Id: "invalid_service_id", // Assuming this ID is invalid + }, // Should trigger error + }, + }, + err: ErrSupplierInvalidService, + }, + { + desc: "valid address, session start height, session ID, service but invalid root hash", + + msg: MsgCreateClaim{ + SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + SessionStartBlockHeight: 100, + SessionId: "valid_session_id", + Service: &sharedtypes.Service{ + Id: "svcId", // Assuming this ID is valid + }, + }, + RootHash: []byte(""), // Invalid root hash + }, + err: ErrSupplierInvalidClaimRootHash, + }, + { + desc: "all valid inputs", + msg: MsgCreateClaim{ SupplierAddress: sample.AccAddress(), + SessionHeader: &sessiontypes.SessionHeader{ + SessionStartBlockHeight: 100, + SessionId: "valid_session_id", + Service: &sharedtypes.Service{ + Id: "svcId", // Assuming this ID is valid + }, + }, + RootHash: []byte("valid_root_hash"), // Assuming this is valid }, + err: nil, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.desc, func(t *testing.T) { err := tt.msg.ValidateBasic() if tt.err != nil { require.ErrorIs(t, err, tt.err) - return + } else { + require.NoError(t, err) } - require.NoError(t, err) }) } }