Skip to content

Commit

Permalink
[GR-17457] We no longer need to look into a Refinement's ancestors fo…
Browse files Browse the repository at this point in the history
…r method lookup

PullRequest: truffleruby/4257
  • Loading branch information
eregon authored and andrykonchin committed Apr 29, 2024
2 parents 98ba377 + dc9ac5a commit 47b99a3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 138 deletions.
178 changes: 55 additions & 123 deletions src/main/java/org/truffleruby/core/module/ModuleOperations.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ public static ConstantLookupResult lookupConstant(RubyContext context, RubyModul
private static ConstantLookupResult lookupConstant(RubyContext context, RubyModule module, String name,
ArrayList<Assumption> assumptions) {
// Look in the current module
ModuleFields fields = module.fields;
ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
ConstantEntry constantEntry = module.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -162,8 +161,7 @@ private static ConstantLookupResult lookupConstant(RubyContext context, RubyModu
if (ancestor == module) {
continue;
}
fields = ancestor.fields;
constantEntry = fields.getOrComputeConstantEntry(name);
constantEntry = ancestor.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -179,16 +177,14 @@ public static ConstantLookupResult lookupConstantInObject(RubyContext context, S
ArrayList<Assumption> assumptions) {
final RubyClass objectClass = context.getCoreLibrary().objectClass;

ModuleFields fields = objectClass.fields;
ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
ConstantEntry constantEntry = objectClass.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
}

for (RubyModule ancestor : objectClass.fields.prependedAndIncludedModules()) {
fields = ancestor.fields;
constantEntry = fields.getOrComputeConstantEntry(name);
constantEntry = ancestor.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -202,15 +198,13 @@ public static ConstantLookupResult lookupConstantInObject(RubyContext context, S
public static RubyConstant lookupConstantInObjectUncached(RubyContext context, String name) {
final RubyClass objectClass = context.getCoreLibrary().objectClass;

ModuleFields fields = objectClass.fields;
RubyConstant constant = fields.getConstant(name);
RubyConstant constant = objectClass.fields.getConstant(name);
if (constantExists(constant, null)) {
return constant;
}

for (RubyModule ancestor : objectClass.fields.prependedAndIncludedModules()) {
fields = ancestor.fields;
constant = fields.getConstant(name);
constant = ancestor.fields.getConstant(name);
if (constantExists(constant, null)) {
return constant;
}
Expand Down Expand Up @@ -243,8 +237,7 @@ public static ConstantLookupResult lookupConstantWithLexicalScope(RubyContext co

// Look in lexical scope
while (lexicalScope != context.getRootLexicalScope()) {
final ModuleFields fields = lexicalScope.getLiveModule().fields;
final ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
final ConstantEntry constantEntry = lexicalScope.getLiveModule().fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand Down Expand Up @@ -328,8 +321,7 @@ public static ConstantLookupResult lookupConstantWithInherit(RubyContext context
return ModuleOperations.lookupConstant(context, module, name, assumptions);
}
} else {
final ModuleFields fields = module.fields;
final ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
final ConstantEntry constantEntry = module.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand Down Expand Up @@ -426,49 +418,25 @@ public static Map<String, InternalMethod> withoutUndefinedMethods(Map<String, In
return definedMethods;
}

public static MethodLookupResult lookupMethodCached(RubyModule module, String name,
DeclarationContext declarationContext) {
return lookupMethodCached(module, null, name, declarationContext);
}

@TruffleBoundary
private static MethodLookupResult lookupMethodCached(RubyModule module, RubyModule lookupTo, String name,
public static MethodLookupResult lookupMethodCached(RubyModule module, String name,
DeclarationContext declarationContext) {
final ArrayList<Assumption> assumptions = new ArrayList<>();
var assumptions = new ArrayList<Assumption>();

// Look in ancestors
for (RubyModule ancestor : module.fields.ancestors()) {
if (ancestor == lookupTo) {
return new MethodLookupResult(null, toArray(assumptions));
}
final RubyModule[] refinements = getRefinementsFor(declarationContext, ancestor);

var refinements = getRefinementsFor(declarationContext, ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
// If we have more then one active refinement for C (where C is refined module):
// R1.ancestors = [R1, A, C, ...]
// R2.ancestors = [R2, B, C, ...]
// R3.ancestors = [R3, D, C, ...]
// we are only looking up to C
// R3 -> D -> R2 -> B -> R1 -> A
final MethodLookupResult refinedMethod = lookupMethodCached(
refinement,
ancestor,
name,
null);
for (Assumption assumption : refinedMethod.getAssumptions()) {
assumptions.add(assumption);
}
if (refinedMethod.isDefined()) {
InternalMethod method = rememberUsedRefinements(refinedMethod.getMethod(), declarationContext);
var refinedMethod = refinement.fields.getMethodAndAssumption(name, assumptions);
if (refinedMethod != null) {
InternalMethod method = rememberUsedRefinements(refinedMethod, declarationContext);
return new MethodLookupResult(method, toArray(assumptions));
}
}
}

final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethodAndAssumption(name, assumptions);

var method = ancestor.fields.getMethodAndAssumption(name, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
Expand All @@ -478,37 +446,22 @@ private static MethodLookupResult lookupMethodCached(RubyModule module, RubyModu
return new MethodLookupResult(null, toArray(assumptions));
}

public static InternalMethod lookupMethodUncached(RubyModule module, String name,
DeclarationContext declarationContext) {
return lookupMethodUncached(module, null, name, declarationContext);
}

@TruffleBoundary
private static InternalMethod lookupMethodUncached(RubyModule module, RubyModule lookupTo, String name,
public static InternalMethod lookupMethodUncached(RubyModule module, String name,
DeclarationContext declarationContext) {

// Look in ancestors
for (RubyModule ancestor : module.fields.ancestors()) {
if (ancestor == lookupTo) {
return null;
}
final RubyModule[] refinements = getRefinementsFor(declarationContext, ancestor);

var refinements = getRefinementsFor(declarationContext, ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
final InternalMethod refinedMethod = lookupMethodUncached(
refinement,
ancestor,
name,
null);
var refinedMethod = refinement.fields.getMethod(name);
if (refinedMethod != null) {
return rememberUsedRefinements(refinedMethod, declarationContext);
}
}
}

final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethod(name);

var method = ancestor.fields.getMethod(name);
if (method != null) {
return method;
}
Expand All @@ -518,82 +471,62 @@ private static InternalMethod lookupMethodUncached(RubyModule module, RubyModule
return null;
}

@TruffleBoundary
public static MethodLookupResult lookupSuperMethod(InternalMethod currentMethod, RubyModule objectMetaClass) {
final String name = currentMethod.getSharedMethodInfo().getMethodNameForNotBlock(); // use the original name
var name = currentMethod.getSharedMethodInfo().getMethodNameForNotBlock(); // use the original name

Memo<Boolean> foundDeclaringModule = new Memo<>(false);
return lookupSuperMethod(
currentMethod.getDeclaringModule(),
null,
name,
objectMetaClass,
foundDeclaringModule,
currentMethod.getDeclarationContext(),
currentMethod.getActiveRefinements());
}


@TruffleBoundary
private static MethodLookupResult lookupSuperMethod(RubyModule declaringModule, RubyModule lookupTo,
String name, RubyModule objectMetaClass, Memo<Boolean> foundDeclaringModule,
DeclarationContext declarationContext, DeclarationContext callerDeclaringContext) {
final ArrayList<Assumption> assumptions = new ArrayList<>();
final boolean isRefinedMethod = declaringModule.fields.isRefinement();
var foundDeclaringModule = new Memo<>(false);
var declaringModule = currentMethod.getDeclaringModule();
var declarationContext = currentMethod.getDeclarationContext();
var assumptions = new ArrayList<Assumption>();

// First we need to skip all ancestors until we find declaringModule,
// and then we return the first ancestor after declaringModule which has the method defined.
for (RubyModule ancestor : objectMetaClass.fields.ancestors()) {
if (ancestor == lookupTo) {
return new MethodLookupResult(null, toArray(assumptions));
}

final RubyModule[] refinements = getRefinementsFor(declarationContext, callerDeclaringContext, ancestor);

var refinements = getRefinementsFor(declarationContext, currentMethod.getActiveRefinements(), ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
final MethodLookupResult superMethodInRefinement = lookupSuperMethod(
declaringModule,
ancestor,
name,
refinement,
foundDeclaringModule,
null,
null);
for (Assumption assumption : superMethodInRefinement.getAssumptions()) {
assumptions.add(assumption);
}
if (superMethodInRefinement.isDefined()) {
InternalMethod method = superMethodInRefinement.getMethod();
var refinedMethod = lookupSuperMethodInModule(declaringModule, name, foundDeclaringModule,
refinement, assumptions);
if (refinedMethod != null) {
return new MethodLookupResult(
rememberUsedRefinements(method, declarationContext, refinements, ancestor),
rememberUsedRefinements(refinedMethod, declarationContext, refinements, ancestor),
toArray(assumptions));
}
if (foundDeclaringModule.get() && isRefinedMethod) {
if (foundDeclaringModule.get() && declaringModule.fields.isRefinement()) {
// if method is defined in refinement module (R)
// we should lookup only in this active refinement and skip other
// we should lookup only in this active refinement and skip others
break;
}
}
}

if (!foundDeclaringModule.get()) {
if (ancestor == declaringModule) {
// The declaring module's assumption needs to appended for cases where a newly included module
// should invalidate previous super lookups.
ancestor.fields.getMethodAndAssumption(name, assumptions);
foundDeclaringModule.set(true);
}
} else {
final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethodAndAssumption(name, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
var method = lookupSuperMethodInModule(declaringModule, name, foundDeclaringModule, ancestor, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
}

// Nothing found
return new MethodLookupResult(null, toArray(assumptions));
}


private static InternalMethod lookupSuperMethodInModule(RubyModule declaringModule, String name,
Memo<Boolean> foundDeclaringModule, RubyModule module, ArrayList<Assumption> assumptions) {
if (!foundDeclaringModule.get()) {
if (module == declaringModule) {
// The declaring module's assumption needs to appended for cases where a newly included module
// should invalidate previous super lookups.
module.fields.getMethodAndAssumption(name, assumptions);
foundDeclaringModule.set(true);
}
return null;
} else {
return module.fields.getMethodAndAssumption(name, assumptions);
}
}

private static InternalMethod rememberUsedRefinements(InternalMethod method,
DeclarationContext declarationContext) {
return method.withActiveRefinements(declarationContext);
Expand All @@ -603,8 +536,7 @@ private static InternalMethod rememberUsedRefinements(InternalMethod method,
DeclarationContext declarationContext, RubyModule[] refinements, RubyModule ancestor) {
assert refinements != null;

final Map<RubyModule, RubyModule[]> currentRefinements = new HashMap<>(
declarationContext.getRefinements());
final Map<RubyModule, RubyModule[]> currentRefinements = new HashMap<>(declarationContext.getRefinements());
currentRefinements.put(ancestor, refinements);

return method.withActiveRefinements(declarationContext.withRefinements(currentRefinements));
Expand Down
25 changes: 10 additions & 15 deletions test/mri/excludes/TestRefinement.rb
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
exclude :test_eval_with_binding_scoping, "needs investigation"
exclude :test_include_refinement, "needs investigation"
exclude :test_prepend_after_refine, "needs investigation"
exclude :test_refine_prepended_class, "needs investigation"
exclude :test_refine_with_proc, "needs investigation"
exclude :test_undef_original_method, "needs investigation"
exclude :test_warn_setconst_in_refinmenet, "needs investigation"
exclude :test_refine_in_using, "needs investigation"
exclude :test_used_modules, "needs investigation"
exclude :test_unbound_refine_method, "needs investigation"
exclude :test_ancestors, "[ruby-core:86949] [Bug #14744]."
exclude :test_import_methods, "NoMethodError: undefined method `bar' for #<TestRefinement::TestImport::A:0x17bee78>"
exclude :test_eval_with_binding_scoping, "pid 123017 exit 0."
exclude :test_import_methods, "ArgumentError expected but nothing was raised."
exclude :test_prepend_after_refine, "<\"refined\"> expected but was"
exclude :test_refine_prepended_class, "<[:c, :m1, :m2]> expected but was"
exclude :test_refine_with_proc, "ArgumentError expected but nothing was raised."
exclude :test_unbound_refine_method, "TypeError expected but nothing was raised."
exclude :test_used_modules, "<[TestRefinement::VisibleRefinements::RefB,"
exclude :test_refinements, "TruffleRuby does not guarantee refinement list ordering"
exclude :test_refined_class, "TruffleRuby does not guarantee refinement list ordering"
exclude :test_prepend_into_refinement, "TypeError expected but nothing was raised."
exclude :test_include_into_refinement, "TypeError expected but nothing was raised."
exclude :test_refined_protected_methods, "assert_separately failed with error message"
exclude :test_warn_setconst_in_refinmenet, "[ruby-core:64143] [Bug #10103]"
exclude :test_refine_in_using, "NoMethodError: undefined method `foo' for #<TestRefinement::RefineInUsing:0xd88>"
exclude :test_refined_protected_methods, "NoMethodError: protected method `foo' called for #<C:0x2c8>"

0 comments on commit 47b99a3

Please sign in to comment.