Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve matrix layout validation #5662

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion source/val/validate_decorations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,14 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,

seen[next_offset % 16] = true;
}
} else if (spv::Op::OpTypeMatrix == element_inst->opcode()) {
// Matrix stride would be on the array element in the struct.
const auto stride = constraint.matrix_stride;
if (!IsAlignedTo(stride, alignment)) {
return fail(memberIdx)
<< "is a matrix with stride " << stride
<< " not satisfying alignment to " << alignment;
}
}

// Proceed to the element in case it is an array.
Expand Down Expand Up @@ -675,7 +683,16 @@ bool checkForRequiredDecoration(uint32_t struct_id,
spv::Op type, ValidationState_t& vstate) {
const auto& members = getStructMembers(struct_id, vstate);
for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) {
const auto id = members[memberIdx];
auto id = members[memberIdx];
if (type == spv::Op::OpTypeMatrix) {
// Matrix decorations also apply to arrays of matrices.
auto memberInst = vstate.FindDef(id);
while (memberInst->opcode() == spv::Op::OpTypeArray ||
memberInst->opcode() == spv::Op::OpTypeRuntimeArray) {
memberInst = vstate.FindDef(memberInst->GetOperandAs<uint32_t>(1u));
}
id = memberInst->id();
}
if (type != vstate.FindDef(id)->opcode()) continue;
bool found = false;
for (auto& dec : vstate.id_decorations(id)) {
Expand Down
226 changes: 226 additions & 0 deletions test/val/val_decoration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9444,6 +9444,232 @@ OpFunctionEnd
"contains an array with stride 0, but with an element size of 4"));
}

TEST_F(ValidateDecorations, MatrixArrayMissingMajorness) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 MatrixStride 16
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"must be explicitly laid out with RowMajor or ColMajor decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayMissingStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("must be explicitly laid out with MatrixStride decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayBadStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpMemberDecorate %block 0 MatrixStride 8
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayMissingMajorness) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 MatrixStride 16
OpDecorate %array ArrayStride 32
OpDecorate %rta ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%rta = OpTypeRuntimeArray %array
%block = OpTypeStruct %rta
%ptr = OpTypePointer StorageBuffer %block
%var = OpVariable %ptr StorageBuffer
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"must be explicitly laid out with RowMajor or ColMajor decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayMissingStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpDecorate %array ArrayStride 32
OpDecorate %rta ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%rta = OpTypeRuntimeArray %array
%block = OpTypeStruct %rta
%ptr = OpTypePointer StorageBuffer %block
%var = OpVariable %ptr StorageBuffer
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("must be explicitly laid out with MatrixStride decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayBadStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpMemberDecorate %block 0 MatrixStride 8
OpDecorate %array ArrayStride 32
OpDecorate %a ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%a = OpTypeArray %array %int_2
%block = OpTypeStruct %a
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
}

} // namespace
} // namespace val
} // namespace spvtools