diff --git a/java/fury-core/src/main/java/org/apache/fury/builder/BaseObjectCodecBuilder.java b/java/fury-core/src/main/java/org/apache/fury/builder/BaseObjectCodecBuilder.java index 170fdc2df8..ed0d89c4a7 100644 --- a/java/fury-core/src/main/java/org/apache/fury/builder/BaseObjectCodecBuilder.java +++ b/java/fury-core/src/main/java/org/apache/fury/builder/BaseObjectCodecBuilder.java @@ -21,10 +21,13 @@ import static org.apache.fury.codegen.CodeGenerator.getPackage; import static org.apache.fury.codegen.Expression.Invoke.inlineInvoke; +import static org.apache.fury.codegen.Expression.Literal.ofInt; import static org.apache.fury.codegen.Expression.Reference.fieldRef; import static org.apache.fury.codegen.ExpressionOptimizer.invokeGenerated; import static org.apache.fury.codegen.ExpressionUtils.add; import static org.apache.fury.codegen.ExpressionUtils.and; +import static org.apache.fury.codegen.ExpressionUtils.bitand; +import static org.apache.fury.codegen.ExpressionUtils.bitor; import static org.apache.fury.codegen.ExpressionUtils.cast; import static org.apache.fury.codegen.ExpressionUtils.eq; import static org.apache.fury.codegen.ExpressionUtils.eqNull; @@ -34,10 +37,9 @@ import static org.apache.fury.codegen.ExpressionUtils.neq; import static org.apache.fury.codegen.ExpressionUtils.neqNull; import static org.apache.fury.codegen.ExpressionUtils.not; -import static org.apache.fury.codegen.ExpressionUtils.notNull; import static org.apache.fury.codegen.ExpressionUtils.nullValue; -import static org.apache.fury.codegen.ExpressionUtils.ofInt; import static org.apache.fury.codegen.ExpressionUtils.or; +import static org.apache.fury.codegen.ExpressionUtils.shift; import static org.apache.fury.codegen.ExpressionUtils.subtract; import static org.apache.fury.codegen.ExpressionUtils.uninline; import static org.apache.fury.collection.Collections.ofHashSet; @@ -57,6 +59,7 @@ import static org.apache.fury.type.TypeUtils.PRIMITIVE_BOOLEAN_TYPE; import static org.apache.fury.type.TypeUtils.PRIMITIVE_BYTE_TYPE; import static org.apache.fury.type.TypeUtils.PRIMITIVE_INT_TYPE; +import static org.apache.fury.type.TypeUtils.PRIMITIVE_LONG_TYPE; import static org.apache.fury.type.TypeUtils.PRIMITIVE_VOID_TYPE; import static org.apache.fury.type.TypeUtils.SET_TYPE; import static org.apache.fury.type.TypeUtils.getElementType; @@ -93,6 +96,7 @@ import org.apache.fury.codegen.Expression.Literal; import org.apache.fury.codegen.Expression.Reference; import org.apache.fury.codegen.Expression.Return; +import org.apache.fury.codegen.Expression.While; import org.apache.fury.codegen.ExpressionUtils; import org.apache.fury.codegen.ExpressionVisitor.ExprHolder; import org.apache.fury.collection.Tuple2; @@ -261,23 +265,23 @@ public String genCode() { return ctx.genCode(); } - protected static class CutPoint { + protected static class InvokeHint { public boolean genNewMethod; public Set cutPoints = new HashSet<>(); - public CutPoint(boolean genNewMethod, Expression... cutPoints) { + public InvokeHint(boolean genNewMethod, Expression... cutPoints) { this.genNewMethod = genNewMethod; Collections.addAll(this.cutPoints, cutPoints); } - public CutPoint add(Expression cutPoint) { + public InvokeHint add(Expression cutPoint) { cutPoints.add(cutPoint); return this; } @Override public String toString() { - return "CutPoint{" + "genNewMethod=" + genNewMethod + ", cutPoints=" + cutPoints + '}'; + return "InvokeHint{" + "genNewMethod=" + genNewMethod + ", cutPoints=" + cutPoints + '}'; } } @@ -776,18 +780,18 @@ protected Expression writeCollectionData( builder.add( writeContainerElements(elementType, true, null, null, buffer, collection, size)); } else { - Literal hasNullFlag = Literal.ofInt(CollectionFlags.HAS_NULL); + Literal hasNullFlag = ofInt(CollectionFlags.HAS_NULL); Expression hasNull = eq(new BitAnd(flags, hasNullFlag), hasNullFlag, "hasNull"); builder.add( hasNull, writeContainerElements(elementType, false, null, hasNull, buffer, collection, size)); } } else { - Literal flag = Literal.ofInt(CollectionFlags.NOT_SAME_TYPE); + Literal flag = ofInt(CollectionFlags.NOT_SAME_TYPE); Expression sameElementClass = neq(new BitAnd(flags, flag), flag, "sameElementClass"); builder.add(sameElementClass); // if ((flags & Flags.NOT_DECL_ELEMENT_TYPE) == Flags.NOT_DECL_ELEMENT_TYPE) - Literal notDeclTypeFlag = Literal.ofInt(CollectionFlags.NOT_DECL_ELEMENT_TYPE); + Literal notDeclTypeFlag = ofInt(CollectionFlags.NOT_DECL_ELEMENT_TYPE); Expression isDeclType = neq(new BitAnd(flags, notDeclTypeFlag), notDeclTypeFlag); Expression elemSerializer; // make it in scope of `if(sameElementClass)` boolean maybeDecl = visitFury(f -> f.getClassResolver().isSerializable(elemClass)); @@ -820,7 +824,7 @@ protected Expression writeCollectionData( invokeGenerated(ctx, cutPoint, writeBuilder, "sameElementClassWrite", false), writeContainerElements(elementType, true, null, null, buffer, collection, size)); } else { - Literal hasNullFlag = Literal.ofInt(CollectionFlags.HAS_NULL); + Literal hasNullFlag = ofInt(CollectionFlags.HAS_NULL); Expression hasNull = eq(new BitAnd(flags, hasNullFlag), hasNullFlag, "hasNull"); builder.add(hasNull); ListExpression writeBuilder = new ListExpression(elemSerializer); @@ -841,7 +845,7 @@ protected Expression writeCollectionData( builder.add(action); } walkPath.removeLast(); - return new ListExpression(onCollectionWrite, new If(gt(size, Literal.ofInt(0)), builder)); + return new ListExpression(onCollectionWrite, new If(gt(size, ofInt(0)), builder)); } /** @@ -861,8 +865,8 @@ private Tuple2 writeElementsHeader( if (trackingRef) { bitmap = new ListExpression( - new Invoke(buffer, "writeByte", Literal.ofInt(CollectionFlags.TRACKING_REF)), - Literal.ofInt(CollectionFlags.TRACKING_REF)); + new Invoke(buffer, "writeByte", ofInt(CollectionFlags.TRACKING_REF)), + ofInt(CollectionFlags.TRACKING_REF)); } else { bitmap = new Invoke( @@ -1063,22 +1067,12 @@ private Expression jitWriteMap( visitFury(fury -> fury.getClassResolver().needToWriteRef(keyTypeRawType)); boolean trackingValueRef = visitFury(fury -> fury.getClassResolver().needToWriteRef(valueTypeRawType)); - Expression keySerializer, valueSerializer; - if (keyMonomorphic && valueMonomorphic) { - keySerializer = getOrCreateSerializer(keyType.getRawType()); - valueSerializer = getOrCreateSerializer(valueType.getRawType()); - } else if (keyMonomorphic) { - keySerializer = getOrCreateSerializer(keyType.getRawType()); - valueSerializer = nullValue(SERIALIZER_TYPE); - } else if (valueMonomorphic) { - keySerializer = nullValue(SERIALIZER_TYPE); - valueSerializer = getOrCreateSerializer(valueType.getRawType()); - } else { - keySerializer = nullValue(SERIALIZER_TYPE); - valueSerializer = nullValue(SERIALIZER_TYPE); - } - Expression.While whileAction = - new Expression.While( + Tuple2 mapKVSerializer = + getMapKVSerializer(keyTypeRawType, valueTypeRawType); + Expression keySerializer = mapKVSerializer.f0; + Expression valueSerializer = mapKVSerializer.f1; + While whileAction = + new While( neqNull(entry), () -> { String method = "writeJavaNullChunk"; @@ -1105,6 +1099,26 @@ private Expression jitWriteMap( return new If(not(inlineInvoke(map, "isEmpty", PRIMITIVE_BOOLEAN_TYPE)), whileAction); } + private Tuple2 getMapKVSerializer(Class keyType, Class valueType) { + Expression keySerializer, valueSerializer; + boolean keyMonomorphic = isMonomorphic(keyType); + boolean valueMonomorphic = isMonomorphic(valueType); + if (keyMonomorphic && valueMonomorphic) { + keySerializer = getOrCreateSerializer(keyType); + valueSerializer = getOrCreateSerializer(valueType); + } else if (keyMonomorphic) { + keySerializer = getOrCreateSerializer(keyType); + valueSerializer = nullValue(SERIALIZER_TYPE); + } else if (valueMonomorphic) { + keySerializer = nullValue(SERIALIZER_TYPE); + valueSerializer = getOrCreateSerializer(valueType); + } else { + keySerializer = nullValue(SERIALIZER_TYPE); + valueSerializer = nullValue(SERIALIZER_TYPE); + } + return Tuple2.of(keySerializer, valueSerializer); + } + protected Expression writeChunk( Expression buffer, Expression entry, @@ -1130,9 +1144,7 @@ protected Expression writeChunk( Expression writePlaceHolder = new Invoke(buffer, "writeInt16", Literal.ofShort((short) -1)); Expression chunkSizeOffset = subtract( - inlineInvoke(buffer, "writerIndex", PRIMITIVE_INT_TYPE), - Literal.ofInt(1), - "chunkSizeOffset"); + inlineInvoke(buffer, "writerIndex", PRIMITIVE_INT_TYPE), ofInt(1), "chunkSizeOffset"); Expression chunkHeader; Expression keySerializer, valueSerializer; @@ -1154,7 +1166,7 @@ protected Expression writeChunk( if (trackingValueRef) { header |= TRACKING_VALUE_REF; } - chunkHeader = Literal.ofInt(header); + chunkHeader = ofInt(header); } else if (keyMonomorphic) { int header = KEY_DECL_TYPE; if (trackingKeyRef) { @@ -1162,7 +1174,7 @@ protected Expression writeChunk( } keySerializer = getOrCreateSerializer(keyTypeRawType); valueSerializer = writeClassInfo(buffer, valueTypeExpr, valueTypeRawType, true); - chunkHeader = Literal.ofInt(header); + chunkHeader = ofInt(header); if (trackingValueRef) { // value type may be subclass and not track ref. valueWriteRef = @@ -1176,7 +1188,7 @@ protected Expression writeChunk( if (trackingValueRef) { header |= TRACKING_VALUE_REF; } - chunkHeader = Literal.ofInt(header); + chunkHeader = ofInt(header); if (trackingKeyRef) { // key type may be subclass and not track ref. keyWriteRef = @@ -1186,7 +1198,7 @@ protected Expression writeChunk( } else { keySerializer = writeClassInfo(buffer, keyTypeExpr, keyTypeRawType, true); valueSerializer = writeClassInfo(buffer, valueTypeExpr, valueTypeRawType, true); - chunkHeader = Literal.ofInt(0); + chunkHeader = ofInt(0); if (trackingKeyRef) { // key type may be subclass and not track ref. valueWriteRef = @@ -1200,7 +1212,7 @@ protected Expression writeChunk( chunkHeader = and(chunkHeader, keyWriteRef, "chunkHeader"); } } - Expression chunkSize = ofInt("chunkSize", 0); + Expression chunkSize = ExpressionUtils.ofInt("chunkSize", 0); expressions.add( key, value, @@ -1213,12 +1225,12 @@ protected Expression writeChunk( valueSerializer, keyWriteRef, valueWriteRef, - new Invoke(buffer, "putByte", subtract(chunkSizeOffset, Literal.ofInt(1)), chunkHeader), + new Invoke(buffer, "putByte", subtract(chunkSizeOffset, ofInt(1)), chunkHeader), chunkSize); Expression keyWriteRefExpr = keyWriteRef; Expression valueWriteRefExpr = valueWriteRef; - Expression.While writeLoop = - new Expression.While( + While writeLoop = + new While( Literal.ofBoolean(true), () -> { Expression breakCondition; @@ -1278,8 +1290,8 @@ protected Expression writeChunk( new If(breakCondition, new Break()), writeKey, writeValue, - new Assign(chunkSize, add(chunkSize, Literal.ofInt(1))), - new If(eq(chunkSize, Literal.ofInt(MAX_CHUNK_SIZE)), new Break()), + new Assign(chunkSize, add(chunkSize, ofInt(1))), + new If(eq(chunkSize, ofInt(MAX_CHUNK_SIZE)), new Break()), new If( inlineInvoke(iterator, "hasNext", PRIMITIVE_BOOLEAN_TYPE), new ListExpression( @@ -1336,18 +1348,18 @@ protected Expression deserializeFor( Expression buffer, TypeRef typeRef, Function callback, - CutPoint cutPoint) { + InvokeHint invokeHint) { Class rawType = getRawType(typeRef); if (visitFury(f -> f.getClassResolver().needToWriteRef(rawType))) { - return readRef(buffer, callback, () -> deserializeForNotNull(buffer, typeRef, cutPoint)); + return readRef(buffer, callback, () -> deserializeForNotNull(buffer, typeRef, invokeHint)); } else { if (typeRef.isPrimitive()) { - Expression value = deserializeForNotNull(buffer, typeRef, cutPoint); + Expression value = deserializeForNotNull(buffer, typeRef, invokeHint); // Should put value expr ahead to avoid generated code in wrong scope. return new ListExpression(value, callback.apply(value)); } return readNullable( - buffer, typeRef, callback, () -> deserializeForNotNull(buffer, typeRef, cutPoint)); + buffer, typeRef, callback, () -> deserializeForNotNull(buffer, typeRef, invokeHint)); } } @@ -1387,18 +1399,18 @@ private Expression readNullable( } protected Expression deserializeForNotNull( - Expression buffer, TypeRef typeRef, CutPoint cutPoint) { - return deserializeForNotNull(buffer, typeRef, null, cutPoint); + Expression buffer, TypeRef typeRef, InvokeHint invokeHint) { + return deserializeForNotNull(buffer, typeRef, null, invokeHint); } /** * Return an expression that deserialize an not null inputObject from buffer * . * - * @param cutPoint for generate new method to cut off dependencies. + * @param invokeHint for generate new method to cut off dependencies. */ protected Expression deserializeForNotNull( - Expression buffer, TypeRef typeRef, Expression serializer, CutPoint cutPoint) { + Expression buffer, TypeRef typeRef, Expression serializer, InvokeHint invokeHint) { Class cls = getRawType(typeRef); if (isPrimitive(cls) || isBoxed(cls)) { // for primitive, inline call here to avoid java boxing, rather call corresponding serializer. @@ -1427,9 +1439,9 @@ protected Expression deserializeForNotNull( } Expression obj; if (useCollectionSerialization(typeRef)) { - obj = deserializeForCollection(buffer, typeRef, serializer, cutPoint); + obj = deserializeForCollection(buffer, typeRef, serializer, invokeHint); } else if (useMapSerialization(typeRef)) { - obj = deserializeForMap(buffer, typeRef, serializer, cutPoint); + obj = deserializeForMap(buffer, typeRef, serializer, invokeHint); } else { if (isMonomorphic(cls)) { Preconditions.checkState(serializer == null); @@ -1459,7 +1471,7 @@ protected Expression readForNotNullNonFinal( * with {@link BaseObjectCodecBuilder#serializeForCollection} */ protected Expression deserializeForCollection( - Expression buffer, TypeRef typeRef, Expression serializer, CutPoint cutPoint) { + Expression buffer, TypeRef typeRef, Expression serializer, InvokeHint invokeHint) { TypeRef elementType = getElementType(typeRef); if (serializer == null) { Class cls = getRawType(typeRef); @@ -1491,11 +1503,11 @@ protected Expression deserializeForCollection( new ListExpression(collection, hookRead), new Invoke(serializer, "read", OBJECT_TYPE, buffer), false); - if (cutPoint != null && cutPoint.genNewMethod) { - cutPoint.add(buffer); + if (invokeHint != null && invokeHint.genNewMethod) { + invokeHint.add(buffer); return invokeGenerated( ctx, - cutPoint.cutPoints, + invokeHint.cutPoints, new ListExpression(action, new Return(action)), "readCollection", false); @@ -1516,18 +1528,18 @@ protected Expression readCollectionCodegen( if (trackingRef) { builder.add(readContainerElements(elementType, true, null, null, buffer, collection, size)); } else { - Literal hasNullFlag = Literal.ofInt(CollectionFlags.HAS_NULL); + Literal hasNullFlag = ofInt(CollectionFlags.HAS_NULL); Expression hasNull = eq(new BitAnd(flags.inline(), hasNullFlag), hasNullFlag, "hasNull"); builder.add( hasNull, readContainerElements(elementType, false, null, hasNull, buffer, collection, size)); } } else { - Literal notSameTypeFlag = Literal.ofInt(CollectionFlags.NOT_SAME_TYPE); + Literal notSameTypeFlag = ofInt(CollectionFlags.NOT_SAME_TYPE); Expression sameElementClass = neq(new BitAnd(flags, notSameTypeFlag), notSameTypeFlag, "sameElementClass"); // if ((flags & Flags.NOT_DECL_ELEMENT_TYPE) == Flags.NOT_DECL_ELEMENT_TYPE) - Literal notDeclTypeFlag = Literal.ofInt(CollectionFlags.NOT_DECL_ELEMENT_TYPE); + Literal notDeclTypeFlag = ofInt(CollectionFlags.NOT_DECL_ELEMENT_TYPE); Expression isDeclType = neq(new BitAnd(flags, notDeclTypeFlag), notDeclTypeFlag); Invoke serializer = inlineInvoke(readClassInfo(elemClass, buffer), "getSerializer", SERIALIZER_TYPE); @@ -1565,7 +1577,7 @@ protected Expression readCollectionCodegen( false); action = new If(sameElementClass, readBuilder, differentElemTypeRead); } else { - Literal hasNullFlag = Literal.ofInt(CollectionFlags.HAS_NULL); + Literal hasNullFlag = ofInt(CollectionFlags.HAS_NULL); Expression hasNull = eq(new BitAnd(flags, hasNullFlag), hasNullFlag, "hasNull"); builder.add(hasNull); // Same element class read start @@ -1588,8 +1600,7 @@ protected Expression readCollectionCodegen( } walkPath.removeLast(); // place newCollection as last as expr value - return new ListExpression( - size, collection, new If(gt(size, Literal.ofInt(0)), builder), collection); + return new ListExpression(size, collection, new If(gt(size, ofInt(0)), builder), collection); } private Expression readContainerElements( @@ -1635,32 +1646,32 @@ private Expression readContainerElement( Function callback) { boolean genNewMethod = useCollectionSerialization(elementType) || useMapSerialization(elementType); - CutPoint cutPoint = new CutPoint(genNewMethod, buffer); + InvokeHint invokeHint = new InvokeHint(genNewMethod, buffer); Class rawType = getRawType(elementType); boolean finalType = isMonomorphic(rawType); Expression read; if (finalType) { if (trackingRef) { - read = deserializeFor(buffer, elementType, callback, cutPoint); + read = deserializeFor(buffer, elementType, callback, invokeHint); } else { - cutPoint.add(hasNull); + invokeHint.add(hasNull); read = new If( hasNull, - deserializeFor(buffer, elementType, callback, cutPoint), - callback.apply(deserializeForNotNull(buffer, elementType, cutPoint))); + deserializeFor(buffer, elementType, callback, invokeHint), + callback.apply(deserializeForNotNull(buffer, elementType, invokeHint))); } } else { - cutPoint.add(elemSerializer); + invokeHint.add(elemSerializer); if (trackingRef) { // eager callback, no need to use ExprHolder. read = readRef( buffer, callback, - () -> deserializeForNotNull(buffer, elementType, elemSerializer, cutPoint)); + () -> deserializeForNotNull(buffer, elementType, elemSerializer, invokeHint)); } else { - cutPoint.add(hasNull); + invokeHint.add(hasNull); read = new If( hasNull, @@ -1668,9 +1679,9 @@ private Expression readContainerElement( buffer, elementType, callback, - () -> deserializeForNotNull(buffer, elementType, elemSerializer, cutPoint)), + () -> deserializeForNotNull(buffer, elementType, elemSerializer, invokeHint)), callback.apply( - deserializeForNotNull(buffer, elementType, elemSerializer, cutPoint))); + deserializeForNotNull(buffer, elementType, elemSerializer, invokeHint))); } } return read; @@ -1681,7 +1692,7 @@ private Expression readContainerElement( * {@link BaseObjectCodecBuilder#serializeForMap} */ protected Expression deserializeForMap( - Expression buffer, TypeRef typeRef, Expression serializer, CutPoint cutPoint) { + Expression buffer, TypeRef typeRef, Expression serializer, InvokeHint invokeHint) { Tuple2, TypeRef> keyValueType = TypeUtils.getMapKeyValueType(typeRef); TypeRef keyType = keyValueType.f0; TypeRef valueType = keyValueType.f1; @@ -1700,49 +1711,212 @@ protected Expression deserializeForMap( "Expected AbstractMapSerializer but got %s", serializer.type()); } + Expression mapSerializer = serializer; Invoke supportHook = inlineInvoke(serializer, "supportCodegenHook", PRIMITIVE_BOOLEAN_TYPE); + ListExpression expressions = new ListExpression(); Expression newMap = new Invoke(serializer, "newMap", MAP_TYPE, buffer); Expression size = new Invoke(serializer, "getAndClearNumElements", "size", PRIMITIVE_INT_TYPE); - Expression start = new Literal(0, PRIMITIVE_INT_TYPE); - Expression step = new Literal(1, PRIMITIVE_INT_TYPE); - ExprHolder exprHolder = ExprHolder.of("map", newMap, "buffer", buffer); + Expression chunkHeader = + new Invoke(buffer, "readUnsignedByte", "chunkHeader", PRIMITIVE_INT_TYPE); + expressions.add(newMap, size, new If(eq(size, ofInt(0)), new Return()), chunkHeader); + boolean keyMonomorphic = isMonomorphic(keyType); + boolean valueMonomorphic = isMonomorphic(valueType); + boolean inline = keyMonomorphic && valueMonomorphic; + Tuple2 mapKVSerializer = + getMapKVSerializer(keyType.getRawType(), valueType.getRawType()); + Expression keySerializer = mapKVSerializer.f0; + Expression valueSerializer = mapKVSerializer.f1; + While chunksLoop = + new While( + gt(size, ofInt(0)), + () -> { + ListExpression exprs = new ListExpression(); + Expression sizeAndHeader = + new Invoke( + mapSerializer, + "readJavaNullChunk", + "sizeAndHeader", + PRIMITIVE_LONG_TYPE, + false, + buffer, + newMap, + chunkHeader, + size, + keySerializer, + valueSerializer); + exprs.add( + new Assign( + chunkHeader, cast(bitand(sizeAndHeader, ofInt(0xff)), PRIMITIVE_INT_TYPE)), + new Assign(size, cast(shift(">>>", sizeAndHeader, 8), PRIMITIVE_INT_TYPE))); + Expression sizeAndHeader2 = + readChunk(buffer, newMap, size, keyType, valueType, chunkHeader); + if (inline) { + exprs.add(sizeAndHeader2); + } else { + exprs.add( + new Assign( + chunkHeader, cast(bitand(sizeAndHeader2, ofInt(0xff)), PRIMITIVE_INT_TYPE)), + new Assign(size, cast(shift(">>>", sizeAndHeader2, 8), PRIMITIVE_INT_TYPE))); + } + return exprs; + }); + expressions.add(chunksLoop, newMap); + // first newMap to create map, last newMap as expr value + Expression map = inlineInvoke(serializer, "onMapRead", OBJECT_TYPE, expressions); + Expression action = + new If(supportHook, map, new Invoke(serializer, "read", OBJECT_TYPE, buffer), false); + if (invokeHint != null && invokeHint.genNewMethod) { + invokeHint.add(buffer); + return invokeGenerated( + ctx, + invokeHint.cutPoints, + new ListExpression(action, new Return(action)), + "readMap", + false); + } + return action; + } + + private Expression readChunk( + Expression buffer, + Expression map, + Expression size, + TypeRef keyType, + TypeRef valueType, + Expression chunkHeader) { + boolean keyMonomorphic = isMonomorphic(keyType); + boolean valueMonomorphic = isMonomorphic(valueType); + Class keyTypeRawType = keyType.getRawType(); + Class valueTypeRawType = valueType.getRawType(); + boolean inline = keyMonomorphic && valueMonomorphic; + boolean trackingKeyRef = + visitFury(fury -> fury.getClassResolver().needToWriteRef(keyTypeRawType)); + boolean trackingValueRef = + visitFury(fury -> fury.getClassResolver().needToWriteRef(valueTypeRawType)); + ListExpression expressions = new ListExpression(); + Expression trackKeyRef = neq(and(chunkHeader, ofInt(TRACKING_KEY_REF)), ofInt(0)); + Expression trackValueRef = neq(and(chunkHeader, ofInt(TRACKING_VALUE_REF)), ofInt(0)); + Expression keyIsDeclaredType = neq(and(chunkHeader, ofInt(KEY_DECL_TYPE)), ofInt(0)); + Expression valueIsDeclaredType = neq(and(chunkHeader, ofInt(VALUE_DECL_TYPE)), ofInt(0)); + Expression chunkSize = new Invoke(buffer, "readUnsignedByte", "chunkSize", PRIMITIVE_INT_TYPE); + expressions.add(chunkSize); + + Expression keySerializer = getOrCreateSerializer(keyType.getRawType()); + Expression valueSerializer = getOrCreateSerializer(valueType.getRawType()); + if (!keyMonomorphic && !valueMonomorphic) { + keySerializer = + new If( + keyIsDeclaredType, + keySerializer, + new Invoke( + readClassInfo(keyTypeRawType, buffer), + "getSerializer", + "keySerializer", + SERIALIZER_TYPE)); + valueSerializer = + new If( + valueIsDeclaredType, + valueSerializer, + new Invoke( + readClassInfo(valueTypeRawType, buffer), + "getSerializer", + "valueSerializer", + SERIALIZER_TYPE)); + } else if (!keyMonomorphic) { + keySerializer = + new If( + keyIsDeclaredType, + keySerializer, + new Invoke( + readClassInfo(keyTypeRawType, buffer), + "getSerializer", + "keySerializer", + SERIALIZER_TYPE)); + } else if (!valueMonomorphic) { + valueSerializer = + new If( + valueIsDeclaredType, + valueSerializer, + new Invoke( + readClassInfo(valueTypeRawType, buffer), + "getSerializer", + "valueSerializer", + SERIALIZER_TYPE)); + } + expressions.add(keySerializer, valueSerializer); + ExprHolder exprHolder = + ExprHolder.of("keySerializer", keySerializer, "valueSerializer", valueSerializer); + ForLoop readKeyValues = new ForLoop( - start, - size, - step, + ofInt(0), + chunkSize, + ofInt(1), i -> { boolean genKeyMethod = useCollectionSerialization(keyType) || useMapSerialization(keyType); boolean genValueMethod = useCollectionSerialization(valueType) || useMapSerialization(valueType); walkPath.add("key:" + keyType); - Expression keyAction = - deserializeFor( - exprHolder.get("buffer"), keyType, e -> e, new CutPoint(genKeyMethod)); + Expression keyAction, valueAction; + InvokeHint keyHint = new InvokeHint(genKeyMethod); + InvokeHint valueHint = new InvokeHint(genValueMethod); + Expression keySerializerExpr = exprHolder.get("keySerializer"); + if (trackingKeyRef) { + keyAction = + new If( + trackKeyRef, + readRef( + buffer, + expr -> expr, + () -> + deserializeForNotNull(buffer, keyType, keySerializerExpr, keyHint)), + deserializeForNotNull(buffer, keyType, keySerializerExpr, keyHint)); + } else { + keyAction = deserializeForNotNull(buffer, keyType, keySerializerExpr, keyHint); + } walkPath.removeLast(); walkPath.add("value:" + valueType); - Expression valueAction = - deserializeFor( - exprHolder.get("buffer"), valueType, e -> e, new CutPoint(genValueMethod)); + Expression valueSerializerExpr = exprHolder.get("valueSerializer"); + if (trackingValueRef) { + valueAction = + new If( + trackValueRef, + readRef( + buffer, + expr -> expr, + () -> + deserializeForNotNull( + buffer, valueType, valueSerializerExpr, valueHint)), + deserializeForNotNull(buffer, valueType, valueSerializerExpr, valueHint)); + } else { + valueAction = + deserializeForNotNull(buffer, valueType, valueSerializerExpr, valueHint); + } walkPath.removeLast(); - return new Invoke(exprHolder.get("map"), "put", keyAction, valueAction); + return list( + new Invoke(map, "put", keyAction, valueAction), + new Assign(size, subtract(size, ofInt(1)))); }); - // first newMap to create map, last newMap as expr value - Expression hookRead = new ListExpression(newMap, size, readKeyValues, newMap); - hookRead = new Invoke(serializer, "onMapRead", OBJECT_TYPE, hookRead); - Expression action = - new If(supportHook, hookRead, new Invoke(serializer, "read", OBJECT_TYPE, buffer), false); - if (cutPoint != null && cutPoint.genNewMethod) { - cutPoint.add(buffer); - return invokeGenerated( - ctx, - cutPoint.cutPoints, - new ListExpression(action, new Return(action)), - "readMap", - false); + expressions.add(readKeyValues); + + if (inline) { + expressions.add( + new If( + gt(size, ofInt(0)), + new Assign( + chunkHeader, inlineInvoke(buffer, "readUnsignedByte", PRIMITIVE_INT_TYPE)))); + } else { + expressions.add( + new Return( + new If( + gt(size, ofInt(0)), + (bitor( + shift("<<", size, 8), + inlineInvoke(buffer, "readUnsignedByte", PRIMITIVE_INT_TYPE))), + ofInt(0)))); } - return action; + return expressions; } @Override diff --git a/java/fury-core/src/main/java/org/apache/fury/serializer/collection/AbstractMapSerializer.java b/java/fury-core/src/main/java/org/apache/fury/serializer/collection/AbstractMapSerializer.java index 57dc794baf..060b99b45d 100644 --- a/java/fury-core/src/main/java/org/apache/fury/serializer/collection/AbstractMapSerializer.java +++ b/java/fury-core/src/main/java/org/apache/fury/serializer/collection/AbstractMapSerializer.java @@ -844,7 +844,7 @@ protected final void chunkReadElements(MemoryBuffer buffer, int size, Map map) { while (size > 0) { long sizeAndHeader = readJavaNullChunk(buffer, map, chunkHeader, size, keySerializer, valueSerializer); - chunkHeader = (int) (sizeAndHeader & 0b11111111); + chunkHeader = (int) (sizeAndHeader & 0xff); size = (int) (sizeAndHeader >>> 8); if (keySerializer != null || valueSerializer != null) { sizeAndHeader = diff --git a/java/fury-core/src/test/java/org/apache/fury/codegen/ExpressionVisitorTest.java b/java/fury-core/src/test/java/org/apache/fury/codegen/ExpressionVisitorTest.java index 3d6e9ef4f6..d9951ce9fd 100644 --- a/java/fury-core/src/test/java/org/apache/fury/codegen/ExpressionVisitorTest.java +++ b/java/fury-core/src/test/java/org/apache/fury/codegen/ExpressionVisitorTest.java @@ -19,8 +19,10 @@ package org.apache.fury.codegen; +import static org.apache.fury.type.TypeUtils.LIST_TYPE; import static org.testng.Assert.assertEquals; +import com.google.common.collect.ImmutableList; import java.lang.invoke.SerializedLambda; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -42,18 +44,15 @@ public void testTraverseExpression() throws InvocationTargetException, IllegalAc Expression.Reference ref = new Expression.Reference("a", TypeRef.of(ExpressionVisitorTest.class)); Expression e1 = new Expression.Invoke(ref, "testTraverseExpression"); - Literal start = Literal.ofInt(0); - Literal end = Literal.ofInt(10); - Literal step = Literal.ofInt(1); + Literal literal1 = Literal.ofInt(1); + Expression list = new Expression.StaticInvoke(ImmutableList.class, "of", LIST_TYPE, literal1); ExpressionVisitor.ExprHolder holder = ExpressionVisitor.ExprHolder.of("e1", e1, "e2", new Expression.ListExpression()); // FIXME ListExpression#add in lambda don't get executed, so ListExpression is the last expr. - Expression.ForLoop forLoop = - new Expression.ForLoop( - start, - end, - step, - expr -> ((Expression.ListExpression) (holder.get("e2"))).add(holder.get("e1"))); + Expression.ForEach forLoop = + new Expression.ForEach( + list, + (i, expr) -> ((Expression.ListExpression) (holder.get("e2"))).add(holder.get("e1"))); List expressions = new ArrayList<>(); new ExpressionVisitor() .traverseExpression(forLoop, exprSite -> expressions.add(exprSite.current)); @@ -69,7 +68,7 @@ public void testTraverseExpression() throws InvocationTargetException, IllegalAc // Traversal relies on getDeclaredFields(), nondeterministic order. Set expressionsSet = new HashSet<>(expressions); Set expressionsSet2 = - new HashSet<>(Arrays.asList(forLoop, e1, ref, exprHolder.get("e2"), end, start, step)); + new HashSet<>(Arrays.asList(forLoop, e1, ref, exprHolder.get("e2"), list, literal1)); assertEquals(expressionsSet, expressionsSet2); } }