Skip to content

Commit

Permalink
Validator: Support SPV_NV_raw_access_chains (#5568)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlocatti-nv authored Apr 10, 2024
1 parent 3983d15 commit 6761288
Show file tree
Hide file tree
Showing 6 changed files with 644 additions and 4 deletions.
3 changes: 3 additions & 0 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) {
case spv::Op::OpPtrAccessChain:
case spv::Op::OpLoad:
case spv::Op::OpConstantNull:
case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
Expand All @@ -309,6 +310,7 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) {
case spv::Op::OpFunctionParameter:
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
Expand Down Expand Up @@ -754,6 +756,7 @@ bool spvOpcodeIsAccessChain(spv::Op opcode) {
case spv::Op::OpInBoundsAccessChain:
case spv::Op::OpPtrAccessChain:
case spv::Op::OpInBoundsPtrAccessChain:
case spv::Op::OpRawAccessChainNV:
return true;
default:
return false;
Expand Down
3 changes: 2 additions & 1 deletion source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, spv::Decoration dec,
case spv::Decoration::RestrictPointer:
case spv::Decoration::AliasedPointer:
if (target->opcode() != spv::Op::OpVariable &&
target->opcode() != spv::Op::OpFunctionParameter) {
target->opcode() != spv::Op::OpFunctionParameter &&
target->opcode() != spv::Op::OpRawAccessChainNV) {
return fail(0) << "must be a memory object declaration";
}
if (_.GetIdOpcode(target->type_id()) != spv::Op::OpTypePointer) {
Expand Down
8 changes: 5 additions & 3 deletions source/val/validate_decorations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,8 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
const auto opcode = inst.opcode();
const auto type_id = inst.type_id();
if (opcode != spv::Op::OpVariable &&
opcode != spv::Op::OpFunctionParameter) {
opcode != spv::Op::OpFunctionParameter &&
opcode != spv::Op::OpRawAccessChainNV) {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration must be a memory object "
"declaration (a variable or a function parameter)";
Expand All @@ -1569,10 +1570,11 @@ spv_result_t CheckNonWritableDecoration(ValidationState_t& vstate,
vstate.features().nonwritable_var_in_function_or_private) {
// New permitted feature in SPIR-V 1.4.
} else if (
// It may point to a UBO, SSBO, or storage image.
// It may point to a UBO, SSBO, storage image, or raw access chain.
vstate.IsPointerToUniformBlock(type_id) ||
vstate.IsPointerToStorageBuffer(type_id) ||
vstate.IsPointerToStorageImage(type_id)) {
vstate.IsPointerToStorageImage(type_id) ||
opcode == spv::Op::OpRawAccessChainNV) {
} else {
return vstate.diag(SPV_ERROR_INVALID_ID, &inst)
<< "Target of NonWritable decoration is invalid: must point to a "
Expand Down
123 changes: 123 additions & 0 deletions source/val/validate_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,126 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
return SPV_SUCCESS;
}

spv_result_t ValidateRawAccessChain(ValidationState_t& _,
const Instruction* inst) {
std::string instr_name = "Op" + std::string(spvOpcodeString(inst->opcode()));

// The result type must be OpTypePointer.
const auto result_type = _.FindDef(inst->type_id());
if (spv::Op::OpTypePointer != result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The Result Type of " << instr_name << " <id> "
<< _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
<< spvOpcodeString(result_type->opcode()) << '.';
}

// The pointed storage class must be valid.
const auto storage_class = result_type->GetOperandAs<spv::StorageClass>(1);
if (storage_class != spv::StorageClass::StorageBuffer &&
storage_class != spv::StorageClass::PhysicalStorageBuffer &&
storage_class != spv::StorageClass::Uniform) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The Result Type of " << instr_name << " <id> "
<< _.getIdName(inst->id())
<< " must point to a storage class of "
"StorageBuffer, PhysicalStorageBuffer, or Uniform.";
}

// The pointed type must not be one in the list below.
const auto result_type_pointee =
_.FindDef(result_type->GetOperandAs<uint32_t>(2));
if (result_type_pointee->opcode() == spv::Op::OpTypeArray ||
result_type_pointee->opcode() == spv::Op::OpTypeMatrix ||
result_type_pointee->opcode() == spv::Op::OpTypeStruct) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The Result Type of " << instr_name << " <id> "
<< _.getIdName(inst->id())
<< " must not point to "
"OpTypeArray, OpTypeMatrix, or OpTypeStruct.";
}

// Validate Stride is a OpConstant.
const auto stride = _.FindDef(inst->GetOperandAs<uint32_t>(3));
if (stride->opcode() != spv::Op::OpConstant) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The Stride of " << instr_name << " <id> "
<< _.getIdName(inst->id()) << " must be OpConstant. Found Op"
<< spvOpcodeString(stride->opcode()) << '.';
}
// Stride type must be OpTypeInt
const auto stride_type = _.FindDef(stride->type_id());
if (stride_type->opcode() != spv::Op::OpTypeInt) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The type of Stride of " << instr_name << " <id> "
<< _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
<< spvOpcodeString(stride_type->opcode()) << '.';
}

// Index and Offset type must be OpTypeInt with a width of 32
const auto ValidateType = [&](const char* name,
int operandIndex) -> spv_result_t {
const auto value = _.FindDef(inst->GetOperandAs<uint32_t>(operandIndex));
const auto value_type = _.FindDef(value->type_id());
if (value_type->opcode() != spv::Op::OpTypeInt) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The type of " << name << " of " << instr_name << " <id> "
<< _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
<< spvOpcodeString(value_type->opcode()) << '.';
}
const auto width = value_type->GetOperandAs<uint32_t>(1);
if (width != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The integer width of " << name << " of " << instr_name
<< " <id> " << _.getIdName(inst->id()) << " must be 32. Found "
<< width << '.';
}
return SPV_SUCCESS;
};
spv_result_t result;
result = ValidateType("Index", 4);
if (result != SPV_SUCCESS) {
return result;
}
result = ValidateType("Offset", 5);
if (result != SPV_SUCCESS) {
return result;
}

uint32_t access_operands = 0;
if (inst->operands().size() >= 7) {
access_operands = inst->GetOperandAs<uint32_t>(6);
}
if (access_operands &
uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
uint64_t stride_value = 0;
if (_.EvalConstantValUint64(stride->id(), &stride_value) &&
stride_value == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Stride must not be zero when per-element robustness is used.";
}
}
if (access_operands &
uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) ||
access_operands &
uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Storage class cannot be PhysicalStorageBuffer when "
"raw access chain robustness is used.";
}
}
if (access_operands &
uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) &&
access_operands &
uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Per-component robustness and per-element robustness are "
"mutually exclusive.";
}

return SPV_SUCCESS;
}

spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
const Instruction* inst) {
if (_.addressing_model() == spv::AddressingModel::Logical) {
Expand Down Expand Up @@ -1866,6 +1986,9 @@ spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpInBoundsPtrAccessChain:
if (auto error = ValidateAccessChain(_, inst)) return error;
break;
case spv::Op::OpRawAccessChainNV:
if (auto error = ValidateRawAccessChain(_, inst)) return error;
break;
case spv::Op::OpArrayLength:
if (auto error = ValidateArrayLength(_, inst)) return error;
break;
Expand Down
1 change: 1 addition & 0 deletions test/val/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_spvtools_unittest(TARGET val_abcde
val_extension_spv_khr_bit_instructions_test.cpp
val_extension_spv_khr_terminate_invocation_test.cpp
val_extension_spv_khr_subgroup_rotate_test.cpp
val_extension_spv_nv_raw_access_chains.cpp
val_ext_inst_test.cpp
val_ext_inst_debug_test.cpp
${VAL_TEST_COMMON_SRCS}
Expand Down
Loading

0 comments on commit 6761288

Please sign in to comment.