From 6761288d39e2af51d73a5d8edb328dafc2054b1c Mon Sep 17 00:00:00 2001 From: Rodrigo Locatti Date: Wed, 10 Apr 2024 11:40:10 -0300 Subject: [PATCH] Validator: Support SPV_NV_raw_access_chains (#5568) --- source/opcode.cpp | 3 + source/val/validate_annotation.cpp | 3 +- source/val/validate_decorations.cpp | 8 +- source/val/validate_memory.cpp | 123 +++++ test/val/CMakeLists.txt | 1 + ...val_extension_spv_nv_raw_access_chains.cpp | 510 ++++++++++++++++++ 6 files changed, 644 insertions(+), 4 deletions(-) create mode 100644 test/val/val_extension_spv_nv_raw_access_chains.cpp diff --git a/source/opcode.cpp b/source/opcode.cpp index 38d1a1be6e..787dbb3403 100644 --- a/source/opcode.cpp +++ b/source/opcode.cpp @@ -295,6 +295,7 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) { case spv::Op::OpPtrAccessChain: case spv::Op::OpLoad: case spv::Op::OpConstantNull: + case spv::Op::OpRawAccessChainNV: return true; default: return false; @@ -309,6 +310,7 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) { case spv::Op::OpFunctionParameter: case spv::Op::OpImageTexelPointer: case spv::Op::OpCopyObject: + case spv::Op::OpRawAccessChainNV: return true; default: return false; @@ -754,6 +756,7 @@ bool spvOpcodeIsAccessChain(spv::Op opcode) { case spv::Op::OpInBoundsAccessChain: case spv::Op::OpPtrAccessChain: case spv::Op::OpInBoundsPtrAccessChain: + case spv::Op::OpRawAccessChainNV: return true; default: return false; diff --git a/source/val/validate_annotation.cpp b/source/val/validate_annotation.cpp index 106004d066..dac3585788 100644 --- a/source/val/validate_annotation.cpp +++ b/source/val/validate_annotation.cpp @@ -161,7 +161,8 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec, case spv::Decoration::RestrictPointer: case spv::Decoration::AliasedPointer: if (target->opcode() != spv::Op::OpVariable && - target->opcode() != spv::Op::OpFunctionParameter) { + target->opcode() != spv::Op::OpFunctionParameter && + target->opcode() != spv::Op::OpRawAccessChainNV) { return fail(0) << "must be a memory object declaration"; } if (_.GetIdOpcode(target->type_id()) != spv::Op::OpTypePointer) { diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp index caa4a6f184..bb1fea5502 100644 --- a/source/val/validate_decorations.cpp +++ b/source/val/validate_decorations.cpp @@ -1556,7 +1556,8 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate, const auto opcode = inst.opcode(); const auto type_id = inst.type_id(); if (opcode != spv::Op::OpVariable && - opcode != spv::Op::OpFunctionParameter) { + opcode != spv::Op::OpFunctionParameter && + opcode != spv::Op::OpRawAccessChainNV) { return vstate.diag(SPV_ERROR_INVALID_ID, &inst) << "Target of NonWritable decoration must be a memory object " "declaration (a variable or a function parameter)"; @@ -1569,10 +1570,11 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate, vstate.features().nonwritable_var_in_function_or_private) { // New permitted feature in SPIR-V 1.4. } else if ( - // It may point to a UBO, SSBO, or storage image. + // It may point to a UBO, SSBO, storage image, or raw access chain. vstate.IsPointerToUniformBlock(type_id) || vstate.IsPointerToStorageBuffer(type_id) || - vstate.IsPointerToStorageImage(type_id)) { + vstate.IsPointerToStorageImage(type_id) || + opcode == spv::Op::OpRawAccessChainNV) { } else { return vstate.diag(SPV_ERROR_INVALID_ID, &inst) << "Target of NonWritable decoration is invalid: must point to a " diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index c9ecf51a1e..2d6715f423 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -1427,6 +1427,126 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, return SPV_SUCCESS; } +spv_result_t ValidateRawAccessChain(ValidationState_t& _, + const Instruction* inst) { + std::string instr_name = "Op" + std::string(spvOpcodeString(inst->opcode())); + + // The result type must be OpTypePointer. + const auto result_type = _.FindDef(inst->type_id()); + if (spv::Op::OpTypePointer != result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Result Type of " << instr_name << " " + << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op" + << spvOpcodeString(result_type->opcode()) << '.'; + } + + // The pointed storage class must be valid. + const auto storage_class = result_type->GetOperandAs(1); + if (storage_class != spv::StorageClass::StorageBuffer && + storage_class != spv::StorageClass::PhysicalStorageBuffer && + storage_class != spv::StorageClass::Uniform) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Result Type of " << instr_name << " " + << _.getIdName(inst->id()) + << " must point to a storage class of " + "StorageBuffer, PhysicalStorageBuffer, or Uniform."; + } + + // The pointed type must not be one in the list below. + const auto result_type_pointee = + _.FindDef(result_type->GetOperandAs(2)); + if (result_type_pointee->opcode() == spv::Op::OpTypeArray || + result_type_pointee->opcode() == spv::Op::OpTypeMatrix || + result_type_pointee->opcode() == spv::Op::OpTypeStruct) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Result Type of " << instr_name << " " + << _.getIdName(inst->id()) + << " must not point to " + "OpTypeArray, OpTypeMatrix, or OpTypeStruct."; + } + + // Validate Stride is a OpConstant. + const auto stride = _.FindDef(inst->GetOperandAs(3)); + if (stride->opcode() != spv::Op::OpConstant) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Stride of " << instr_name << " " + << _.getIdName(inst->id()) << " must be OpConstant. Found Op" + << spvOpcodeString(stride->opcode()) << '.'; + } + // Stride type must be OpTypeInt + const auto stride_type = _.FindDef(stride->type_id()); + if (stride_type->opcode() != spv::Op::OpTypeInt) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The type of Stride of " << instr_name << " " + << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op" + << spvOpcodeString(stride_type->opcode()) << '.'; + } + + // Index and Offset type must be OpTypeInt with a width of 32 + const auto ValidateType = [&](const char* name, + int operandIndex) -> spv_result_t { + const auto value = _.FindDef(inst->GetOperandAs(operandIndex)); + const auto value_type = _.FindDef(value->type_id()); + if (value_type->opcode() != spv::Op::OpTypeInt) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The type of " << name << " of " << instr_name << " " + << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op" + << spvOpcodeString(value_type->opcode()) << '.'; + } + const auto width = value_type->GetOperandAs(1); + if (width != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The integer width of " << name << " of " << instr_name + << " " << _.getIdName(inst->id()) << " must be 32. Found " + << width << '.'; + } + return SPV_SUCCESS; + }; + spv_result_t result; + result = ValidateType("Index", 4); + if (result != SPV_SUCCESS) { + return result; + } + result = ValidateType("Offset", 5); + if (result != SPV_SUCCESS) { + return result; + } + + uint32_t access_operands = 0; + if (inst->operands().size() >= 7) { + access_operands = inst->GetOperandAs(6); + } + if (access_operands & + uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) { + uint64_t stride_value = 0; + if (_.EvalConstantValUint64(stride->id(), &stride_value) && + stride_value == 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Stride must not be zero when per-element robustness is used."; + } + } + if (access_operands & + uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) || + access_operands & + uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) { + if (storage_class == spv::StorageClass::PhysicalStorageBuffer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Storage class cannot be PhysicalStorageBuffer when " + "raw access chain robustness is used."; + } + } + if (access_operands & + uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) && + access_operands & + uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Per-component robustness and per-element robustness are " + "mutually exclusive."; + } + + return SPV_SUCCESS; +} + spv_result_t ValidatePtrAccessChain(ValidationState_t& _, const Instruction* inst) { if (_.addressing_model() == spv::AddressingModel::Logical) { @@ -1866,6 +1986,9 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { case spv::Op::OpInBoundsPtrAccessChain: if (auto error = ValidateAccessChain(_, inst)) return error; break; + case spv::Op::OpRawAccessChainNV: + if (auto error = ValidateRawAccessChain(_, inst)) return error; + break; case spv::Op::OpArrayLength: if (auto error = ValidateArrayLength(_, inst)) return error; break; diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index 62d93bddc1..9d6f6ea6a9 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt @@ -46,6 +46,7 @@ add_spvtools_unittest(TARGET val_abcde val_extension_spv_khr_bit_instructions_test.cpp val_extension_spv_khr_terminate_invocation_test.cpp val_extension_spv_khr_subgroup_rotate_test.cpp + val_extension_spv_nv_raw_access_chains.cpp val_ext_inst_test.cpp val_ext_inst_debug_test.cpp ${VAL_TEST_COMMON_SRCS} diff --git a/test/val/val_extension_spv_nv_raw_access_chains.cpp b/test/val/val_extension_spv_nv_raw_access_chains.cpp new file mode 100644 index 0000000000..f06d7cd4b5 --- /dev/null +++ b/test/val/val_extension_spv_nv_raw_access_chains.cpp @@ -0,0 +1,510 @@ +// Copyright (c) 2024 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/spirv_target_env.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidateSpvNVRawAccessChains = spvtest::ValidateBase; + +TEST_F(ValidateSpvNVRawAccessChains, Valid) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSpvNVRawAccessChains, NoCapability) { + const std::string str = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("requires one of these capabilities: RawAccessChainsNV")); +} + +TEST_F(ValidateSpvNVRawAccessChains, NoExtension) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("requires one of these extensions: SPV_NV_raw_access_chains")); +} + +TEST_F(ValidateSpvNVRawAccessChains, ReturnTypeNotPointer) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %int %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be OpTypePointer. Found OpTypeInt")); +} + +TEST_F(ValidateSpvNVRawAccessChains, Workgroup) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer Workgroup %intStruct + %ssbo = OpVariable %intStructPtr Workgroup + %intPtr = OpTypePointer Workgroup %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must point to a storage class of")); +} + +TEST_F(ValidateSpvNVRawAccessChains, ReturnTypeArray) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %int_1 = OpConstant %int 1 + %intArray = OpTypeArray %int %int_1 + %intArrayPtr = OpTypePointer StorageBuffer %intArray + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intArrayPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerComponentNV + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must not point to")); +} + +TEST_F(ValidateSpvNVRawAccessChains, VariableStride) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %stride = OpIAdd %int %int_0 %int_0 + %rawChain = OpRawAccessChainNV %intPtr %ssbo %stride %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must be OpConstant")); +} + +TEST_F(ValidateSpvNVRawAccessChains, RobustnessPerElementZeroStride) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_0 %int_0 %int_0 RobustnessPerElementNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Stride must not be zero when per-element robustness is used")); +} + +TEST_F(ValidateSpvNVRawAccessChains, BothRobustness) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %int_0 RobustnessPerElementNV|RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Per-component robustness and per-element robustness " + "are mutually exclusive")); +} + +TEST_F(ValidateSpvNVRawAccessChains, StrideFloat) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %float = OpTypeFloat 32 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %float_16 = OpConstant %float 16 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %float_16 %int_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("must be OpTypeInt")); +} + +TEST_F(ValidateSpvNVRawAccessChains, IndexType) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpCapability Int64 + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %long = OpTypeInt 64 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + %long_0 = OpConstant %long 0 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %long_0 %int_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("The integer width of Index")); +} + +TEST_F(ValidateSpvNVRawAccessChains, OffsetType) { + const std::string str = R"( + OpCapability Shader + OpCapability RawAccessChainsNV + OpCapability Int64 + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_NV_raw_access_chains" + OpMemoryModel Logical GLSL450 + + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + + OpDecorate %intStruct Block + OpMemberDecorate %intStruct 0 Offset 0 + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + + %int = OpTypeInt 32 1 + %long = OpTypeInt 64 1 + %void = OpTypeVoid + %mainFunctionType = OpTypeFunction %void + %intStruct = OpTypeStruct %int + %intStructPtr = OpTypePointer StorageBuffer %intStruct + %ssbo = OpVariable %intStructPtr StorageBuffer + %intPtr = OpTypePointer StorageBuffer %int + + %int_0 = OpConstant %int 0 + %int_16 = OpConstant %int 16 + %long_0 = OpConstant %long 0 + + %main = OpFunction %void None %mainFunctionType + %label = OpLabel + %rawChain = OpRawAccessChainNV %intPtr %ssbo %int_16 %int_0 %long_0 RobustnessPerComponentNV + %unused = OpLoad %int %rawChain + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("The integer width of Offset")); +} + +} // namespace +} // namespace val +} // namespace spvtools