Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor inferred Future code for async functions. #28486

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
28 changes: 15 additions & 13 deletions compiler/passes/src/code_generation/visit_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ impl<'a> CodeGenerator<'a> {
// Note that in the function inlining pass, we reorder the functions such that they are in post-order.
// In other words, a callee function precedes its caller function in the program scope.
for (_symbol, function) in program_scope.functions.iter() {
// program_string.push_str(&program_scope.functions.iter().map(|(_, function)| {
if function.variant != Variant::AsyncFunction {
let mut function_string = self.visit_function(function);

Expand All @@ -98,11 +97,11 @@ impl<'a> CodeGenerator<'a> {
.unwrap()
.clone()
.finalize
.unwrap()
.name;
.unwrap();
// Write the finalize string.
function_string.push_str(&self.visit_function(
&program_scope.functions.iter().find(|(name, _f)| name == finalize).unwrap().1,
function_string.push_str(&self.visit_function_with(
&program_scope.functions.iter().find(|(name, _f)| name == &finalize.location.name).unwrap().1,
&finalize.future_inputs,
));
}

Expand Down Expand Up @@ -167,7 +166,7 @@ impl<'a> CodeGenerator<'a> {
output_string
}

fn visit_function(&mut self, function: &'a Function) -> String {
fn visit_function_with(&mut self, function: &'a Function, futures: &[Location]) -> String {
// Initialize the state of `self` with the appropriate values before visiting `function`.
self.next_register = 0;
self.variable_mapping = IndexMap::new();
Expand All @@ -189,13 +188,9 @@ impl<'a> CodeGenerator<'a> {
Variant::Inline => return String::new(),
};

let mut futures = futures.iter();

// Construct and append the input declarations of the function.
let mut futures = self
.symbol_table
.lookup_fn_symbol(Location::new(Some(self.program_id.unwrap().name.name), function.identifier.name))
.unwrap()
.future_inputs
.clone();
for input in function.input.iter() {
let register_string = format!("r{}", self.next_register);
self.next_register += 1;
Expand All @@ -210,7 +205,10 @@ impl<'a> CodeGenerator<'a> {
};
// Futures are displayed differently in the input section. `input r0 as foo.aleo/bar.future;`
if matches!(input.type_, Type::Future(_)) {
let location = futures.remove(0);
let location = futures
.next()
.expect("Type checking guarantees we have future locations for each future input")
.clone();
format!("{}.aleo/{}.future", location.program.unwrap(), location.name)
} else {
self.visit_type_with_visibility(&input.type_, visibility)
Expand All @@ -235,6 +233,10 @@ impl<'a> CodeGenerator<'a> {
function_string
}

fn visit_function(&mut self, function: &'a Function) -> String {
self.visit_function_with(function, &[])
}

fn visit_mapping(&mut self, mapping: &'a Mapping) -> String {
// Create the prefix of the mapping string, e.g. `mapping foo:`.
let mut mapping_string = format!("\nmapping {}:\n", mapping.identifier);
Expand Down
19 changes: 14 additions & 5 deletions compiler/passes/src/common/symbol_table/function_symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,20 @@ pub struct FunctionSymbol {
pub(crate) _span: Span,
/// The inputs to the function.
pub(crate) input: Vec<Input>,
/// Future inputs.
pub(crate) future_inputs: Vec<Location>,
/// The finalize block associated with the function.
pub(crate) finalize: Option<Location>,
/// The finalizer associated with this async transition.
pub(crate) finalize: Option<Finalizer>,
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Finalizer {
/// The name of the async function this async transition calls.
pub location: Location,

/// The locations of the futures passed to the async function called by this async transition.
pub future_inputs: Vec<Location>,

/// The types passed to the async function called by this async transition.
pub inferred_inputs: Vec<Type>,
}

impl SymbolTable {
Expand All @@ -48,7 +58,6 @@ impl SymbolTable {
variant: func.variant,
_span: func.span,
input: func.input.clone(),
future_inputs: Vec::new(),
finalize: None,
}
}
Expand Down
26 changes: 10 additions & 16 deletions compiler/passes/src/common/symbol_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub use variable_symbol::*;

use std::cell::RefCell;

use leo_ast::{Composite, Function, Location, normalize_json_value, remove_key_from_json};
use leo_ast::{Composite, Function, Location, Type, normalize_json_value, remove_key_from_json};
use leo_errors::{AstError, Result};
use leo_span::{Span, Symbol};

Expand Down Expand Up @@ -148,12 +148,18 @@ impl SymbolTable {
}

/// Attach a finalize to a function.
pub fn attach_finalize(&mut self, caller: Location, callee: Location) -> Result<()> {
pub fn attach_finalize(
&mut self,
caller: Location,
callee: Location,
future_inputs: Vec<Location>,
inferred_inputs: Vec<Type>,
) -> Result<()> {
if let Some(func) = self.functions.get_mut(&caller) {
func.finalize = Some(callee);
func.finalize = Some(Finalizer { location: callee, future_inputs, inferred_inputs });
Ok(())
} else if let Some(parent) = self.parent.as_mut() {
parent.attach_finalize(caller, callee)
parent.attach_finalize(caller, callee, future_inputs, inferred_inputs)
} else {
Err(AstError::function_not_found(caller.name).into())
}
Expand All @@ -171,18 +177,6 @@ impl SymbolTable {
Ok(())
}

/// Inserts futures into the function definition.
pub fn insert_futures(&mut self, program: Symbol, function: Symbol, futures: Vec<Location>) -> Result<()> {
if let Some(func) = self.functions.get_mut(&Location::new(Some(program), function)) {
func.future_inputs = futures;
Ok(())
} else if let Some(parent) = self.parent.as_mut() {
parent.insert_futures(program, function, futures)
} else {
Err(AstError::function_not_found(function).into())
}
}

/// Removes a variable from the symbol table.
pub fn remove_variable_from_current_scope(&mut self, location: Location) {
self.variables.shift_remove(&location);
Expand Down
4 changes: 2 additions & 2 deletions compiler/passes/src/static_analysis/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<'a, N: Network> StaticAnalyzer<'a, N> {
}
// Otherwise, get the location of the finalize block.
let location = match &function.finalize {
Some(location) => location.clone(),
Some(finalizer) => finalizer.location.clone(),
None => {
unreachable!("Typechecking guarantees that all async transitions have an associated `finalize` field.");
}
Expand All @@ -136,7 +136,7 @@ impl<'a, N: Network> StaticAnalyzer<'a, N> {
}
};
// If the async function takes a future as an argument, emit an error.
if !async_function.future_inputs.is_empty() {
if async_function.input.iter().any(|input| matches!(input.type_(), Type::Future(..))) {
self.emit_err(StaticAnalyzerError::async_transition_call_with_future_argument(function_name, span));
}
}
Expand Down
27 changes: 10 additions & 17 deletions compiler/passes/src/symbol_table_creation/creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,27 +121,20 @@ impl<'a> ProgramVisitor<'a> for SymbolTableCreator<'a> {
if let Err(err) = self.symbol_table.insert_fn(location.clone(), &Function::from(input.clone())) {
self.handler.emit_err(err);
}

// If the `FunctionStub` is an async transition, attach the finalize logic to the function.
// NOTE - for an external function like this, we really only need to attach the finalizer
// for the use of `assert_simple_async_transition_call` in the static analyzer.
// In principle that could be handled differently.
if matches!(input.variant, Variant::AsyncTransition) {
// This matches the logic in the disassembler.
let name = Symbol::intern(&format!("finalize/{}", input.name()));
if let Err(err) = self.symbol_table.attach_finalize(location, Location::new(self.program_name, name)) {
self.handler.emit_err(err);
}
}
// Otherwise is the `FunctionStub` is an async function, attach the future inputs.
else if matches!(input.variant, Variant::AsyncFunction) {
d0cd marked this conversation as resolved.
Show resolved Hide resolved
let future_inputs = input
.input
.iter()
.filter_map(|input| match &input.type_ {
Type::Future(future_type) => future_type.location.clone(),
_ => None,
})
.collect();
// Note that this unwrap is safe, because `self.program_name` is set before traversing the AST.
if let Err(err) = self.symbol_table.insert_futures(self.program_name.unwrap(), input.name(), future_inputs)
{
if let Err(err) = self.symbol_table.attach_finalize(
location,
Location::new(self.program_name, name),
Vec::new(),
Vec::new(),
) {
self.handler.emit_err(err);
}
}
Expand Down
25 changes: 18 additions & 7 deletions compiler/passes/src/type_checking/check_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
return Type::Err;
};

// If all inferred types weren't the same, the member will be of type `Type::Err`.
if let Type::Err = actual {
self.emit_err(TypeCheckerError::future_error_member(access.index.value(), access.span()));
return Type::Err;
}

self.maybe_assert_type(actual, expected, access.span());

actual.clone()
Expand Down Expand Up @@ -563,7 +569,7 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {

let future_type =
Type::Future(FutureType::new(inputs.clone(), Some(Location::new(input.program, ident.name)), true));
let fully_inferred_type = match func.output_type {
let fully_inferred_type = match &func.output_type {
Type::Tuple(tup) => Type::Tuple(TupleType::new(
tup.elements()
.iter()
Expand Down Expand Up @@ -687,15 +693,20 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
}
// Add future locations to symbol table. Unwrap safe since insert function into symbol table during previous pass.
let mut st = self.symbol_table.borrow_mut();
// Insert futures into symbol table.
st.insert_futures(input.program.unwrap(), ident.name, input_futures).unwrap();
// Link async transition to the async function that finalizes it.
st.attach_finalize(self.scope_state.location(), Location::new(self.scope_state.program_name, ident.name))
.unwrap();
st.attach_finalize(
self.scope_state.location(),
Location::new(self.scope_state.program_name, ident.name),
input_futures,
inferred_finalize_inputs.clone(),
)
.expect("Failed to attach finalize");
drop(st);
// Create expectation for finalize inputs that will be checked when checking corresponding finalize function signature.
self.async_function_input_types
.insert(Location::new(self.scope_state.program_name, ident.name), inferred_finalize_inputs.clone());
self.async_function_callers
.entry(Location::new(self.scope_state.program_name, ident.name))
.or_default()
.insert(self.scope_state.location());

// Set scope state flag.
self.scope_state.has_called_finalize = true;
Expand Down
2 changes: 1 addition & 1 deletion compiler/passes/src/type_checking/check_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ impl<'a, N: Network> ProgramVisitor<'a> for TypeChecker<'a, N> {
Type::Future(f) => {
// Since we traverse stubs in post-order, we can assume that the corresponding finalize stub has already been traversed.
Type::Future(FutureType::new(
finalize_input_map.get(&f.location.clone().unwrap()).unwrap().clone(),
finalize_input_map.get(f.location.as_ref().unwrap()).unwrap().clone(),
f.location.clone(),
true,
))
Expand Down
18 changes: 6 additions & 12 deletions compiler/passes/src/type_checking/check_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,12 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {

// Fully type the expected return value.
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.has_called_finalize {
let inferred_future_type =
match self.async_function_input_types.get(&func.unwrap().finalize.clone().unwrap()) {
Some(types) => Future(FutureType::new(
types.clone(),
Some(Location::new(self.scope_state.program_name, parent)),
true,
)),
None => {
return self
.emit_err(TypeCheckerError::async_transition_missing_future_to_return(input.span()));
}
};
let inferred_future_type = Future(FutureType::new(
func.unwrap().finalize.as_ref().unwrap().inferred_inputs.clone(),
Some(Location::new(self.scope_state.program_name, parent)),
true,
));

// Need to modify return type since the function signature is just default future, but the actual return type is the fully inferred future of the finalize input type.
let inferred = match return_type.clone() {
Some(Future(_)) => Some(inferred_future_type),
Expand Down
Loading
Loading