diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 9f85e9d04d187..8cbfacf9c2d1f 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1462,7 +1462,7 @@ func (t *dropPartitionTask) PreExecute(ctx context.Context) error { return err } if collLoaded { - loaded, err := isPartitionLoaded(ctx, t.queryCoord, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, t.queryCoord, collID, partID) if err != nil { return err } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 76b477400fea4..bf7ae45b09ee8 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -49,7 +49,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -1476,11 +1475,11 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoordClient, collID i return false, nil } -func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64, partIDs []int64) (bool, error) { +func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64, partID int64) (bool, error) { // get all loading collections resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ CollectionID: collID, - PartitionIDs: partIDs, + PartitionIDs: []int64{partID}, }) if err := merr.CheckRPCCall(resp, err); err != nil { // qc returns error if partition not loaded @@ -1490,7 +1489,7 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID in return false, err } - return funcutil.SliceSetEqual(partIDs, resp.GetPartitionIDs()), nil + return true, nil } func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg, inInsert bool) error { diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 93bef5a1b0af0..a4824e7ef11f4 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -1063,7 +1063,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { Status: merr.Success(), PartitionIDs: []int64{partID}, }, nil) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.NoError(t, err) assert.True(t, loaded) }) @@ -1088,7 +1088,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { Status: merr.Success(), PartitionIDs: []int64{partID}, }, errors.New("error")) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.Error(t, err) assert.False(t, loaded) }) @@ -1116,7 +1116,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { }, PartitionIDs: []int64{partID}, }, nil) - loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, qc, collID, partID) assert.Error(t, err) assert.False(t, loaded) }) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 72ed6101e744d..756712cfdfe92 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -159,15 +159,16 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions if percentage < 0 { err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID()) if err != nil { - status := merr.Status(err) - log.Warn("show partition failed", zap.Error(err)) + partitionErr := merr.WrapErrPartitionNotLoaded(partitionID, err.Error()) + status := merr.Status(partitionErr) + log.Warn("show partition failed", zap.Error(partitionErr)) return &querypb.ShowPartitionsResponse{ Status: status, }, nil } err = merr.WrapErrPartitionNotLoaded(partitionID) - log.Warn("show partitions failed", zap.Error(err)) + log.Warn("show partition failed", zap.Error(err)) return &querypb.ShowPartitionsResponse{ Status: merr.Status(err), }, nil diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 77195dce5cd08..badf645debf76 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -310,7 +311,8 @@ func (suite *ServiceSuite) TestShowPartitions() { meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + err := merr.CheckRPCCall(resp, err) + assert.True(suite.T(), errors.Is(err, merr.ErrPartitionNotLoaded)) meta.GlobalFailedLoadCache.Remove(collection) err = suite.meta.CollectionManager.PutCollection(ctx, colBak) suite.NoError(err) @@ -322,7 +324,8 @@ func (suite *ServiceSuite) TestShowPartitions() { meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + err := merr.CheckRPCCall(resp, err) + assert.True(suite.T(), errors.Is(err, merr.ErrPartitionNotLoaded)) meta.GlobalFailedLoadCache.Remove(collection) err = suite.meta.CollectionManager.PutPartition(ctx, parBak) suite.NoError(err) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 1515dbd56ce7b..f86711e7297f9 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -40,14 +40,14 @@ func Code(err error) int32 { } cause := errors.Cause(err) - switch cause := cause.(type) { + switch specificErr := cause.(type) { case milvusError: - return cause.code() + return specificErr.code() default: - if errors.Is(cause, context.Canceled) { + if errors.Is(specificErr, context.Canceled) { return CanceledCode - } else if errors.Is(cause, context.DeadlineExceeded) { + } else if errors.Is(specificErr, context.DeadlineExceeded) { return TimeoutCode } else { return errUnexpected.code()