Skip to content

Commit

Permalink
[naga wgsl-in] Allow abstract literals to be used as return values
Browse files Browse the repository at this point in the history
When lowering a return statement, call expression_for_abstract()
rather than expression() to avoid concretizing the return value. Then,
if the function has a return type, call try_automatic_conversions() to
attempt to convert our return value to the correct type.

This has the unfortunate side effect that some errors that would have
been caught by the validator are instead encountered as conversion
errors by the parser. This may result in a slightly less descriptive
error message in some cases. (See the change to the invalid_functions()
test, for example.)
  • Loading branch information
jamienicol authored and jimblandy committed Feb 7, 2025
1 parent 005bde9 commit c07fab2
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
#### Naga

- Fix some instances of functions which have a return type but don't return a value being incorrectly validated. By @jamienicol in [#7013](https://github.com/gfx-rs/wgpu/pull/7013).
- Allow abstract expressions to be used in WGSL function return statements. By @jamienicol in [#7035](https://github.com/gfx-rs/wgpu/pull/7035).

#### General

Expand Down
23 changes: 19 additions & 4 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1672,13 +1672,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::StatementKind::Break => crate::Statement::Break,
ast::StatementKind::Continue => crate::Statement::Continue,
ast::StatementKind::Return { value } => {
ast::StatementKind::Return { value: ast_value } => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let value = value
.map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter)))
.transpose()?;
let value;
if let Some(ast_expr) = ast_value {
let result_ty = ctx.function.result.as_ref().map(|r| r.ty);
let mut ectx = ctx.as_expression(block, &mut emitter);
let expr = self.expression_for_abstract(ast_expr, &mut ectx)?;

if let Some(result_ty) = result_ty {
let mut ectx = ctx.as_expression(block, &mut emitter);
let resolution = crate::proc::TypeResolution::Handle(result_ty);
let converted =
ectx.try_automatic_conversions(expr, &resolution, Span::default())?;
value = Some(converted);
} else {
value = Some(expr);
}
} else {
value = None;
}
block.extend(emitter.finish(&ctx.function.expressions));

crate::Statement::Return { value }
Expand Down
26 changes: 26 additions & 0 deletions naga/tests/in/abstract-types-return.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@compute @workgroup_size(1)
fn main() {}

fn return_i32_ai() -> i32 {
return 1;
}

fn return_u32_ai() -> u32 {
return 1;
}

fn return_f32_ai() -> f32 {
return 1;
}

fn return_f32_af() -> f32 {
return 1.0;
}

fn return_vec2f32_ai() -> vec2<f32> {
return vec2(1);
}

fn return_arrf32_ai() -> array<f32, 4> {
return array(1, 1, 1, 1);
}
36 changes: 36 additions & 0 deletions naga/tests/out/glsl/abstract-types-return.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#version 310 es

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;


int return_i32_ai() {
return 1;
}

uint return_u32_ai() {
return 1u;
}

float return_f32_ai() {
return 1.0;
}

float return_f32_af() {
return 1.0;
}

vec2 return_vec2f32_ai() {
return vec2(1.0);
}

float[4] return_arrf32_ai() {
return float[4](1.0, 1.0, 1.0, 1.0);
}

void main() {
return;
}

42 changes: 42 additions & 0 deletions naga/tests/out/hlsl/abstract-types-return.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
int return_i32_ai()
{
return 1;
}

uint return_u32_ai()
{
return 1u;
}

float return_f32_ai()
{
return 1.0;
}

float return_f32_af()
{
return 1.0;
}

float2 return_vec2f32_ai()
{
return (1.0).xx;
}

typedef float ret_Constructarray4_float_[4];
ret_Constructarray4_float_ Constructarray4_float_(float arg0, float arg1, float arg2, float arg3) {
float ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

typedef float ret_return_arrf32_ai[4];
ret_return_arrf32_ai return_arrf32_ai()
{
return Constructarray4_float_(1.0, 1.0, 1.0, 1.0);
}

[numthreads(1, 1, 1)]
void main()
{
return;
}
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/abstract-types-return.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
44 changes: 44 additions & 0 deletions naga/tests/out/msl/abstract-types-return.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_4 {
float inner[4];
};

int return_i32_ai(
) {
return 1;
}

uint return_u32_ai(
) {
return 1u;
}

float return_f32_ai(
) {
return 1.0;
}

float return_f32_af(
) {
return 1.0;
}

metal::float2 return_vec2f32_ai(
) {
return metal::float2(1.0);
}

type_4 return_arrf32_ai(
) {
return type_4 {1.0, 1.0, 1.0, 1.0};
}

kernel void main_(
) {
return;
}
70 changes: 70 additions & 0 deletions naga/tests/out/spv/abstract-types-return.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 41
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %38 "main"
OpExecutionMode %38 LocalSize 1 1 1
OpDecorate %7 ArrayStride 4
%2 = OpTypeVoid
%3 = OpTypeInt 32 1
%4 = OpTypeInt 32 0
%5 = OpTypeFloat 32
%6 = OpTypeVector %5 2
%8 = OpConstant %4 4
%7 = OpTypeArray %5 %8
%11 = OpTypeFunction %3
%12 = OpConstant %3 1
%16 = OpTypeFunction %4
%17 = OpConstant %4 1
%21 = OpTypeFunction %5
%22 = OpConstant %5 1.0
%29 = OpTypeFunction %6
%30 = OpConstantComposite %6 %22 %22
%34 = OpTypeFunction %7
%35 = OpConstantComposite %7 %22 %22 %22 %22
%39 = OpTypeFunction %2
%10 = OpFunction %3 None %11
%9 = OpLabel
OpBranch %13
%13 = OpLabel
OpReturnValue %12
OpFunctionEnd
%15 = OpFunction %4 None %16
%14 = OpLabel
OpBranch %18
%18 = OpLabel
OpReturnValue %17
OpFunctionEnd
%20 = OpFunction %5 None %21
%19 = OpLabel
OpBranch %23
%23 = OpLabel
OpReturnValue %22
OpFunctionEnd
%25 = OpFunction %5 None %21
%24 = OpLabel
OpBranch %26
%26 = OpLabel
OpReturnValue %22
OpFunctionEnd
%28 = OpFunction %6 None %29
%27 = OpLabel
OpBranch %31
%31 = OpLabel
OpReturnValue %30
OpFunctionEnd
%33 = OpFunction %7 None %34
%32 = OpLabel
OpBranch %36
%36 = OpLabel
OpReturnValue %35
OpFunctionEnd
%38 = OpFunction %2 None %39
%37 = OpLabel
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd
28 changes: 28 additions & 0 deletions naga/tests/out/wgsl/abstract-types-return.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
fn return_i32_ai() -> i32 {
return 1i;
}

fn return_u32_ai() -> u32 {
return 1u;
}

fn return_f32_ai() -> f32 {
return 1f;
}

fn return_f32_af() -> f32 {
return 1f;
}

fn return_vec2f32_ai() -> vec2<f32> {
return vec2(1f);
}

fn return_arrf32_ai() -> array<f32, 4> {
return array<f32, 4>(1f, 1f, 1f, 1f);
}

@compute @workgroup_size(1, 1, 1)
fn main() {
return;
}
4 changes: 4 additions & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,10 @@ fn convert_wgsl() {
"abstract-types-operators",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
),
(
"abstract-types-return",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"int64",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
Expand Down
19 changes: 10 additions & 9 deletions naga/tests/wgsl_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1174,22 +1174,23 @@ fn invalid_functions() {
if function_name == "return_pointer"
}

check_validation! {
check(
"
@group(0) @binding(0)
var<storage> atom: atomic<u32>;
fn return_atomic() -> atomic<u32> {
return atom;
}
":
Err(naga::valid::ValidationError::Function {
name: function_name,
source: naga::valid::FunctionError::NonConstructibleReturnType,
..
})
if function_name == "return_atomic"
}
",
"error: automatic conversions cannot convert `u32` to `atomic<u32>`
┌─ wgsl:6:19
6 │ return atom;
│ ^^^^ this expression has type u32
",
);
}

#[test]
Expand Down

0 comments on commit c07fab2

Please sign in to comment.