diff --git a/src/main/java/org/truffleruby/language/arguments/ProfileArgumentNode.java b/src/main/java/org/truffleruby/language/arguments/ProfileArgumentNode.java index ad2c74b94a66..bcc21648f596 100644 --- a/src/main/java/org/truffleruby/language/arguments/ProfileArgumentNode.java +++ b/src/main/java/org/truffleruby/language/arguments/ProfileArgumentNode.java @@ -9,6 +9,7 @@ */ package org.truffleruby.language.arguments; +import com.oracle.truffle.api.dsl.Idempotent; import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.NeverDefault; import org.truffleruby.language.NoImplicitCastsToLong; @@ -18,6 +19,7 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Cached.Shared; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; @@ -34,62 +36,84 @@ public abstract class ProfileArgumentNode extends RubyContextSourceNode { protected abstract RubyNode getChildNode(); - @Specialization(guards = "value == cachedValue", limit = "1") - boolean cacheBoolean(boolean value, - @Cached("value") boolean cachedValue) { - return cachedValue; + @Specialization(guards = "guardBoolean(value, cachedValue)") + boolean doBoolean(boolean value, + @Cached("createCachedValue(value)") @Shared Object cachedValue) { + return (boolean) cachedValue; } - @Specialization(guards = "value == cachedValue", limit = "1") - int cacheInt(int value, - @Cached("value") int cachedValue) { - return cachedValue; + @Specialization(guards = "guardInt(value, cachedValue)") + int doInt(int value, + @Cached("createCachedValue(value)") @Shared Object cachedValue) { + return (int) cachedValue; } - @Specialization(guards = "value == cachedValue", limit = "1") - long cacheLong(long value, - @Cached("value") long cachedValue) { - return cachedValue; + @Specialization(guards = "guardLong(value, cachedValue)") + long doLong(long value, + @Cached("createCachedValue(value)") @Shared Object cachedValue) { + return (long) cachedValue; } - @Specialization(guards = "exactCompare(value, cachedValue)", limit = "1") - double cacheDouble(double value, - @Cached("value") double cachedValue) { - return cachedValue; + @Specialization(guards = "guardDouble(value, cachedValue)") + double doDouble(double value, + @Cached("createCachedValue(value)") @Shared Object cachedValue) { + return (double) cachedValue; } - @Specialization( - guards = { "isExact(object, cachedClass)", "!isPrimitiveClass(cachedClass)" }, - limit = "1") - Object cacheClass(Object object, - @Cached("getClassOrObject(object)") Class cachedClass) { + @Specialization(guards = { "guardClass(value, cachedValue)", "!isPrimitiveClass(cachedValue)" }) + Object doClass(Object value, + @Cached("createCachedValue(value)") @Shared Object cachedValue) { + assert RubyGuards.assertIsValidRubyValue(value); // The cast is only useful for the compiler. if (CompilerDirectives.inInterpreter()) { - return object; + return value; } else { - return CompilerDirectives.castExact(object, cachedClass); + return CompilerDirectives.castExact(value, (Class) cachedValue); } } - @Specialization(replaces = { "cacheBoolean", "cacheInt", "cacheLong", "cacheDouble", "cacheClass" }) - Object unprofiled(Object object) { - assert RubyGuards.assertIsValidRubyValue(object); - return object; - } - - protected static boolean exactCompare(double a, double b) { - // -0.0 == 0.0, but you can tell the difference through other means, so we need to differentiate. - return Double.doubleToRawLongBits(a) == Double.doubleToRawLongBits(b); + @Specialization(replaces = { "doBoolean", "doInt", "doLong", "doDouble", "doClass" }) + Object doGeneric(Object value) { + assert RubyGuards.assertIsValidRubyValue(value); + return value; } + /** The reason this method is used for all cached arguments is that Truffle DSL forces us to use same cache + * initializer for all @Shared arguments. Otherwise, it throws an error. */ @NeverDefault - protected static Class getClassOrObject(Object value) { + static Object createCachedValue(Object value) { + if (RubyGuards.isPrimitive(value)) { + return value; + } + return value == null ? Objects.class : value.getClass(); } - @Override - public String toString() { - return "Profiled(" + getChildNode() + ")"; + static boolean guardBoolean(boolean value, Object cachedValue) { + return cachedValue instanceof Boolean cachedBoolean && value == cachedBoolean; + } + + static boolean guardInt(int value, Object cachedValue) { + return cachedValue instanceof Integer cachedInt && value == cachedInt; + } + + static boolean guardLong(long value, Object cachedValue) { + return cachedValue instanceof Long cachedLong && value == cachedLong; + } + + static boolean guardDouble(double value, Object cachedValue) { + return cachedValue instanceof Double cachedDouble && + // -0.0 == 0.0, but you can tell the difference through other means, so we need to differentiate. + Double.doubleToRawLongBits(value) == Double.doubleToRawLongBits(cachedDouble); + } + + static boolean guardClass(Object value, Object cachedValue) { + return cachedValue instanceof Class cachedClass && CompilerDirectives.isExact(value, cachedClass); + } + + @Idempotent + static boolean isPrimitiveClass(Object cachedValue) { + return RubyGuards.isPrimitiveClass((Class) cachedValue); } @Override @@ -97,5 +121,4 @@ public RubyNode cloneUninitialized() { var copy = ProfileArgumentNodeGen.create(getChildNode().cloneUninitialized()); return copy.copyFlags(this); } - }