Skip to content

Commit

Permalink
Improve matrix layout validation (#5662)
Browse files Browse the repository at this point in the history
* Check for matrix decorations on arrays of matrices
  * MatrixStide, RowMajor and ColMajor can be applied to matrix or
    arrays of matrix members
  * Check that matrix stride satisfies alignment in arrays
  • Loading branch information
alan-baker authored May 14, 2024
1 parent 199038f commit ccf3e3c
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 1 deletion.
19 changes: 18 additions & 1 deletion source/val/validate_decorations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,14 @@ spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str,

seen[next_offset % 16] = true;
}
} else if (spv::Op::OpTypeMatrix == element_inst->opcode()) {
// Matrix stride would be on the array element in the struct.
const auto stride = constraint.matrix_stride;
if (!IsAlignedTo(stride, alignment)) {
return fail(memberIdx)
<< "is a matrix with stride " << stride
<< " not satisfying alignment to " << alignment;
}
}

// Proceed to the element in case it is an array.
Expand Down Expand Up @@ -675,7 +683,16 @@ bool checkForRequiredDecoration(uint32_t struct_id,
spv::Op type, ValidationState_t& vstate) {
const auto& members = getStructMembers(struct_id, vstate);
for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) {
const auto id = members[memberIdx];
auto id = members[memberIdx];
if (type == spv::Op::OpTypeMatrix) {
// Matrix decorations also apply to arrays of matrices.
auto memberInst = vstate.FindDef(id);
while (memberInst->opcode() == spv::Op::OpTypeArray ||
memberInst->opcode() == spv::Op::OpTypeRuntimeArray) {
memberInst = vstate.FindDef(memberInst->GetOperandAs<uint32_t>(1u));
}
id = memberInst->id();
}
if (type != vstate.FindDef(id)->opcode()) continue;
bool found = false;
for (auto& dec : vstate.id_decorations(id)) {
Expand Down
226 changes: 226 additions & 0 deletions test/val/val_decoration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9444,6 +9444,232 @@ OpFunctionEnd
"contains an array with stride 0, but with an element size of 4"));
}

TEST_F(ValidateDecorations, MatrixArrayMissingMajorness) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 MatrixStride 16
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"must be explicitly laid out with RowMajor or ColMajor decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayMissingStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("must be explicitly laid out with MatrixStride decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayBadStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpMemberDecorate %block 0 MatrixStride 8
OpDecorate %array ArrayStride 32
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%block = OpTypeStruct %array
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayMissingMajorness) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 MatrixStride 16
OpDecorate %array ArrayStride 32
OpDecorate %rta ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%rta = OpTypeRuntimeArray %array
%block = OpTypeStruct %rta
%ptr = OpTypePointer StorageBuffer %block
%var = OpVariable %ptr StorageBuffer
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"must be explicitly laid out with RowMajor or ColMajor decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayMissingStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpDecorate %array ArrayStride 32
OpDecorate %rta ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%rta = OpTypeRuntimeArray %array
%block = OpTypeStruct %rta
%ptr = OpTypePointer StorageBuffer %block
%var = OpVariable %ptr StorageBuffer
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("must be explicitly laid out with MatrixStride decorations"));
}

TEST_F(ValidateDecorations, MatrixArrayArrayBadStride) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 0
OpDecorate %block Block
OpMemberDecorate %block 0 Offset 0
OpMemberDecorate %block 0 ColMajor
OpMemberDecorate %block 0 MatrixStride 8
OpDecorate %array ArrayStride 32
OpDecorate %a ArrayStride 64
%void = OpTypeVoid
%float = OpTypeFloat 32
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%vec = OpTypeVector %float 2
%mat = OpTypeMatrix %vec 2
%array = OpTypeArray %mat %int_2
%a = OpTypeArray %array %int_2
%block = OpTypeStruct %a
%ptr = OpTypePointer Uniform %block
%var = OpVariable %ptr Uniform
%void_fn = OpTypeFunction %void
%main = OpFunction %void None %void_fn
%entry = OpLabel
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("is a matrix with stride 8 not satisfying alignment to 16"));
}

} // namespace
} // namespace val
} // namespace spvtools

0 comments on commit ccf3e3c

Please sign in to comment.